"vscode:/vscode.git/clone" did not exist on "2cf007c9d120583ee9ad7ad39b276f2ff81eeb62"
Unverified Commit cc92a4b4 authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #55 from ROCmSoftwarePlatform/IFU-master-2021-10-15

IFU-2021-10-15 (+ remove redundant defines + C10_CUDA_CHECK)
parents 1e0f9bc6 fec3141c
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include "THC/THC.h" #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -89,9 +86,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -89,9 +86,9 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd // Input Linear Q Fwd
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_q_dim, output_lin_q_dim,
...@@ -117,7 +114,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -117,7 +114,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
// Input Linear KV Fwd // Input Linear KV Fwd
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -230,7 +227,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -230,7 +227,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -254,6 +251,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -254,6 +251,7 @@ std::vector<torch::Tensor> fwd_cuda(
algo, algo,
solution_index, solution_index,
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_lin_q_results, input_lin_q_results,
...@@ -333,8 +331,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -333,8 +331,10 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -360,7 +360,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -360,7 +360,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -497,7 +497,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -497,7 +497,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Input Linear Q Dgrad // Input Linear Q Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -523,7 +523,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -523,7 +523,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear Q Wgrad // Input Linear Q Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -549,7 +549,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -549,7 +549,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear KV Dgrad // Input Linear KV Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -575,7 +575,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -575,7 +575,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear KV Wgrad // Input Linear KV Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -599,7 +599,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -599,7 +599,7 @@ std::vector<torch::Tensor> bwd_cuda(
algo, algo,
solution_index, solution_index,
flags)); flags));
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_q_grads, input_q_grads,
input_kv_grads, input_kv_grads,
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include "THC/THC.h" //#include <cuda_profiler_api.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -29,12 +25,12 @@ namespace rocblas_gemmex { ...@@ -29,12 +25,12 @@ namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
bool is_training, bool is_training,
int heads, int heads,
torch::Tensor const& inputs_q, torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv, torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv, torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights, torch::Tensor const& output_weights,
...@@ -99,6 +95,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -99,6 +95,7 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
HostApplyLayerNorm<at::Half,float>( HostApplyLayerNorm<at::Half,float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()), static_cast<at::Half*>(lyr_nrm_results.data_ptr()),
...@@ -112,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -112,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Q Fwd // Input Linear Q Fwd
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_q_dim, output_lin_q_dim,
...@@ -139,7 +136,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -139,7 +136,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
// Input Linear KV Fwd // Input Linear KV Fwd
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -252,7 +249,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -252,7 +249,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -294,6 +291,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -294,6 +291,8 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens_q); total_tokens_q);
} }
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
lyr_nrm_results, lyr_nrm_results,
lyr_nrm_mean, lyr_nrm_mean,
...@@ -386,7 +385,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -386,7 +385,9 @@ std::vector<torch::Tensor> bwd_cuda(
char a_layout_t{'t'}; char a_layout_t{'t'};
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()), static_cast<at::Half const*>(output_grads.data_ptr()),
...@@ -396,7 +397,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -396,7 +397,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -422,7 +423,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -422,7 +423,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -559,7 +560,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -559,7 +560,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Input Linear Q Dgrad // Input Linear Q Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -586,7 +587,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -586,7 +587,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear Q Wgrad // Input Linear Q Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -612,7 +613,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -612,7 +613,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear KV Dgrad // Input Linear KV Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -638,7 +639,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -638,7 +639,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear KV Wgrad // Input Linear KV Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -680,6 +681,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -680,6 +681,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(lyr_nrm_beta_grads.data_ptr()) static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
); );
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_q_grads, input_q_grads,
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include "THC/THC.h" //#include <cuda_profiler_api.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "softmax.h" #include "softmax.h"
#include "dropout.h" #include "dropout.h"
......
#include <vector> #include <vector>
#include <math.h> #include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
...@@ -85,9 +82,10 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -85,9 +82,10 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -187,7 +185,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -187,7 +185,7 @@ std::vector<torch::Tensor> fwd_cuda(
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -211,6 +209,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -211,6 +209,7 @@ std::vector<torch::Tensor> fwd_cuda(
algo, algo,
solution_index, solution_index,
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_lin_results, input_lin_results,
...@@ -280,8 +279,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -280,8 +279,10 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -307,7 +308,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -307,7 +308,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -441,7 +442,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -441,7 +442,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -467,7 +468,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -467,7 +468,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -493,6 +494,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -493,6 +494,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -83,10 +80,11 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -83,10 +80,11 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -199,7 +197,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -199,7 +197,7 @@ std::vector<torch::Tensor> fwd_cuda(
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -223,6 +221,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -223,6 +221,7 @@ std::vector<torch::Tensor> fwd_cuda(
algo, algo,
solution_index, solution_index,
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_lin_results, input_lin_results,
...@@ -291,8 +290,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -291,8 +290,10 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -318,7 +319,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -318,7 +319,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -448,7 +449,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -448,7 +449,7 @@ std::vector<torch::Tensor> bwd_cuda(
batch_stride, batch_stride,
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -474,7 +475,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -474,7 +475,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -500,6 +501,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -500,6 +501,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -81,8 +78,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -81,8 +78,9 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -195,7 +193,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -195,7 +193,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -219,6 +217,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -219,6 +217,7 @@ std::vector<torch::Tensor> fwd_cuda(
algo, algo,
solution_index, solution_index,
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_lin_results, input_lin_results,
...@@ -286,9 +285,11 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -286,9 +285,11 @@ std::vector<torch::Tensor> bwd_cuda(
char a_layout_t{'t'}; char a_layout_t{'t'};
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -314,7 +315,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -314,7 +315,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -451,7 +452,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -451,7 +452,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -477,7 +478,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -477,7 +478,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -501,7 +502,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -501,7 +502,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo, algo,
solution_index, solution_index,
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
input_weight_grads, input_weight_grads,
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include "THC/THC.h" //#include <cuda_profiler_api.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -106,7 +102,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -106,7 +102,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd // Input Linear Fwd
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -221,7 +217,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -221,7 +217,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -264,6 +260,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -264,6 +260,8 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens); total_tokens);
} }
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
lyr_nrm_results, lyr_nrm_results,
lyr_nrm_mean, lyr_nrm_mean,
...@@ -346,6 +344,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -346,6 +344,8 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()), static_cast<at::Half const*>(output_grads.data_ptr()),
...@@ -355,7 +355,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -355,7 +355,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -381,7 +381,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -381,7 +381,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -518,7 +518,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -518,7 +518,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -545,7 +545,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -545,7 +545,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -588,6 +588,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -588,6 +588,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(lyr_nrm_beta_grads.data_ptr()) static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
); );
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// symbol to be automatically resolved by PyTorch libs // symbol to be automatically resolved by PyTorch libs
extern THCState *state; extern THCState *state;
...@@ -28,7 +29,7 @@ cublasOperation_t convertTransToCublasOperation(char trans) { ...@@ -28,7 +29,7 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
else if (trans == 'n') return CUBLAS_OP_N; else if (trans == 'n') return CUBLAS_OP_N;
else if (trans == 'c') return CUBLAS_OP_C; else if (trans == 'c') return CUBLAS_OP_C;
else { else {
THError("trans must be one of: t, n, c"); AT_ERROR("trans must be one of: t, n, c");
return CUBLAS_OP_T; return CUBLAS_OP_T;
} }
} }
...@@ -44,7 +45,8 @@ void RocblasStridedBatchedGemm(THCState *state, char transa, char transb, long m ...@@ -44,7 +45,8 @@ void RocblasStridedBatchedGemm(THCState *state, char transa, char transb, long m
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
float fAlpha = alpha; float fAlpha = alpha;
float fBeta = beta; float fBeta = beta;
THCublasCheck(rocblas_gemm_strided_batched_ex(handle, //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
opa, opb, (int)m, (int)n, (int)k, opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, a_type, (int)lda, strideA, (void*)&fAlpha, a, a_type, (int)lda, strideA,
b, b_type, (int)ldb, strideB, b, b_type, (int)ldb, strideB,
...@@ -112,7 +114,7 @@ void HgemmStridedBatched(THCState *state, char transa, char transb, long m, long ...@@ -112,7 +114,7 @@ void HgemmStridedBatched(THCState *state, char transa, char transb, long m, long
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{ {
THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" AT_ERROR("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX); "with the bound [val] <= %d", INT_MAX);
} }
......
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <stdio.h> #include <stdio.h>
#include <cmath> #include <cmath>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include "ATen/TensorUtils.h" #include "ATen/TensorUtils.h"
// #include "ATen/Type.h" // #include "ATen/Type.h"
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
...@@ -275,7 +276,7 @@ void fused_adam_cuda( ...@@ -275,7 +276,7 @@ void fused_adam_cuda(
decay); decay);
); );
} }
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
...@@ -382,7 +383,7 @@ void fused_adam_cuda_mt( ...@@ -382,7 +383,7 @@ void fused_adam_cuda_mt(
); );
} }
} }
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
template <typename FROM_T, typename TO_T> template <typename FROM_T, typename TO_T>
...@@ -807,7 +808,7 @@ void fused_strided_check_finite( ...@@ -807,7 +808,7 @@ void fused_strided_check_finite(
stride, stride,
clear_overflow_first); clear_overflow_first);
); );
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
void fused_reversible_adam_cuda( void fused_reversible_adam_cuda(
...@@ -908,7 +909,7 @@ void fused_reversible_adam_cuda( ...@@ -908,7 +909,7 @@ void fused_reversible_adam_cuda(
decay); decay);
); );
} }
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
void maybe_cast_cuda( void maybe_cast_cuda(
...@@ -932,7 +933,7 @@ void maybe_cast_cuda( ...@@ -932,7 +933,7 @@ void maybe_cast_cuda(
p_in.DATA_PTR<scalar_t_0>(), p_in.DATA_PTR<scalar_t_0>(),
p_out.DATA_PTR<scalar_t_1>(), p_out.DATA_PTR<scalar_t_1>(),
tsize); )) tsize); ))
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
void maybe_cast_cuda_mt( void maybe_cast_cuda_mt(
...@@ -954,7 +955,7 @@ void maybe_cast_cuda_mt( ...@@ -954,7 +955,7 @@ void maybe_cast_cuda_mt(
overflow_flag, overflow_flag,
tensor_lists, tensor_lists,
MaybeCastFunctor<2, scalar_t_0, scalar_t_1>()); )) MaybeCastFunctor<2, scalar_t_0, scalar_t_1>()); ))
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
void fused_maybe_adam_undo_cuda( void fused_maybe_adam_undo_cuda(
...@@ -1032,5 +1033,5 @@ void fused_maybe_adam_undo_cuda( ...@@ -1032,5 +1033,5 @@ void fused_maybe_adam_undo_cuda(
decay); decay);
); );
} }
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <THC/THCGeneral.h>
// Another possibility: // Another possibility:
// #include <torch/all.h> // #include <torch/all.h>
...@@ -225,5 +224,5 @@ void multi_tensor_fused_adam_cuda( ...@@ -225,5 +224,5 @@ void multi_tensor_fused_adam_cuda(
(adamMode_t) mode); (adamMode_t) mode);
); );
} }
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
#include <torch/extension.h>
#include <ATen/Functions.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> transducer_joint_cuda_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int64_t packedBatch,
int opt,
bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize);
std::vector<torch::Tensor> transducer_joint_cuda_backward(
std::vector<torch::Tensor> in,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput,
float scale);
std::vector<torch::Tensor> transducer_joint_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int64_t packedBatch,
int opt,
bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize) {
CHECK_INPUT(f);
CHECK_INPUT(g);
CHECK_INPUT(fLen);
CHECK_INPUT(gLen);
if (packOutput)
CHECK_INPUT(batchOffset);
return transducer_joint_cuda_forward(
f,
g,
fLen,
gLen,
batchOffset,
packedBatch,
opt,
packOutput,
relu,
dropout,
dropoutProb,
tileSize);
}
std::vector<torch::Tensor> transducer_joint_backward(
std::vector<torch::Tensor> in,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput,
float scale) {
for (auto t : in){
CHECK_INPUT(t);
}
CHECK_INPUT(fLen);
CHECK_INPUT(gLen);
if (packOutput)
CHECK_INPUT(batchOffset);
return transducer_joint_cuda_backward(
in,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
packOutput,
scale);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &transducer_joint_forward, "transducer joint forward (CUDA)");
m.def("backward", &transducer_joint_backward, "transducer joint backward (CUDA)");
}
\ No newline at end of file
#include <cuda.h>
#include <curand_kernel.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/AccumulateType.h>
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <c10/macros/Macros.h>
#include "philox.h"
// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.
// width should be a power of 2 and should be less than warpSize.
template <typename scalar_t>
__device__ __forceinline__ scalar_t warpReduce(scalar_t x, int width=C10_WARP_SIZE){
for (unsigned offset = width/2; offset > 0; offset /= 2){
x += __shfl_down_sync(0xffffffff, x, offset, width);
}
return x;
}
inline int largestPowerOfTwo(int x){
int y = 1;
while (y <= x)
y <<= 1;
return y >> 1;
}
/*
Figure out vectorization type for masks.
Similar to how PyTorch figures out acc_t here:
aten/src/ATen/AccumulateType.h
*/
template <int V>
struct MaskVecType { };
template <> struct MaskVecType<1> { using type = uint8_t; };
template <> struct MaskVecType<2> { using type = uint16_t; };
template <> struct MaskVecType<4> { using type = uint32_t; };
template<int V>
using mvec_type = typename MaskVecType<V>::type;
// Helper class to calculate pointer offset that can be shared by different flavors of kernels.
// For fwd, batch offset and stride are different for packing and non-packing mode.
struct OffsetCalFwd{
__device__ __forceinline__ OffsetCalFwd(
int64_t batch,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t gLen,
int64_t hiddenSize,
bool packOutput) :
batch(batch),
batchOffset(batchOffset),
maxFLen(maxFLen),
maxGLen(maxGLen),
gLen(gLen),
hiddenSize(hiddenSize),
packOutput(packOutput)
{}
int64_t batch;
const int64_t *batchOffset;
int64_t maxFLen;
int64_t maxGLen;
int64_t gLen;
int64_t hiddenSize;
bool packOutput;
__device__ __forceinline__ int64_t getBatchOffset(){
return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize
: batch*maxFLen*maxGLen*hiddenSize;
}
__device__ __forceinline__ int64_t getStrideF(){
return packOutput ? gLen*hiddenSize : maxGLen*hiddenSize;
}
};
// Helper class to calculate pointer offset that can be shared by different flavors of kernels
// For bwd, batch offset and stride are different for packing and non-packing mode.
// The reducion is done for two input tensors. Therefore, generating two sets of offsets
// according to bwdFasterDim can lead to a unified implementation in the actual kernel.
struct OffsetCalBwd{
__device__ __forceinline__ OffsetCalBwd(
int64_t batch,
const int64_t *batchOffset,
const int *fLen,
const int *gLen,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
bool bwdFasterDim) :
batch(batch),
batchOffset(batchOffset),
maxFLen(maxFLen),
maxGLen(maxGLen),
fLen(fLen),
gLen(gLen),
hiddenSize(hiddenSize),
packOutput(packOutput),
bwdFasterDim(bwdFasterDim)
{}
int64_t batch;
const int64_t *batchOffset;
const int *fLen;
const int *gLen;
int64_t maxFLen;
int64_t maxGLen;
int64_t hiddenSize;
bool packOutput;
bool bwdFasterDim; // whether doing bwd on the faster moving dimension
__device__ __forceinline__ int64_t getBatchOffset(){
return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize
: batch*maxFLen*maxGLen*hiddenSize;
}
__device__ __forceinline__ int64_t getMaxXLen(){
return bwdFasterDim ? maxGLen : maxFLen;
}
__device__ __forceinline__ auto getMyXLen() -> decltype(gLen[batch]){
return bwdFasterDim ? gLen[batch] : fLen[batch];
}
__device__ __forceinline__ auto getMyYLen() -> decltype(gLen[batch]){
return bwdFasterDim ? fLen[batch] : gLen[batch];
}
__device__ __forceinline__ int64_t getStrideX(){
return bwdFasterDim ? hiddenSize : ((packOutput ? gLen[batch] : maxGLen) * hiddenSize);
}
__device__ __forceinline__ int64_t getStrideY(){
return bwdFasterDim ? ((packOutput ? gLen[batch] : maxGLen) * hiddenSize) : hiddenSize;
}
};
// Vanila transducer joint forward kernel
// Detail of this joint function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// f is a tensor of shape [batch, T, H]
// g is a tensor of shape [batch, U, H]
// the transducer joint does
// sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)
// The resultant tensor is of shape [batch, T, U, H]
// Each thread block is working on one "batch" of data in the output tensor, [batch, t, u, :]
// This joint function can optionally pack the output where the output tensor with a shape of
// [B, T, U, H] is packed into [B_packed, H].
// Don't-care region (t > fLen) or (u > gLen) is removed.
// To enable packing, the starting offset for each batch need to be specified with batchOffset.
template <typename scalar_t, class OffsetCal>
__global__ void transducer_joint_forward(
const scalar_t *f,
const scalar_t *g,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
scalar_t *sum) {
const int batch = blockIdx.z;
const int t = blockIdx.y;
const int u = blockIdx.x;
const auto myFLen = fLen[batch];
const auto myGLen = gLen[batch];
OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput);
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideF = offsetCal.getStrideF();
scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize;
scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize;
scalar_t *mySum = sum + myBatchOffset + t*strideF + u * hiddenSize;
if (t < myFLen and u < myGLen){
#pragma unroll
for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){
if (h < hiddenSize){
mySum[h] = myF[h] + myG[h];
}
}
}
else if (packOutput == false and t < maxFLen and u < maxGLen){
// Need to write finite data to don't-care region because we instantiate the result tensor
// with torch::empty for performance reasons. Even though it is don't-care region, the
// contents need to be finite, otherwise could lead to NaN in WGRAD.
// In packing mode, this write is no longer necessary as we remove the don't-care region
// from the output.
// Picking -1 (over 0) here for ease of testing.
#pragma unroll
for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){
if (h < hiddenSize){
mySum[h] = -1;
}
}
}
}
/*
Tiled version of the joint forward kernel
Detail of this joint function can be found in:
[1] Sequence Transduction with Recurrent Neural Networks.
f is a tensor of shape [batch, T, H]
g is a tensor of shape [batch, U, H]
the transducer joint does
sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)
The resultant tensor is of shape [batch, T, U, H]
Each thread is working on a tile of the shape of tileF x tileG in the result tensor.
The input for the tile is first loaded in the register and is reused tileG and tileF times.
This joint function can optionally pack the output where the output tensor with a shape of
[B, T, U, H] is packed into [B_packed, H].
Don't-care region (t > fLen) or (u > gLen) is removed.
To enable packing, the starting offset for each batch need to be specified with batchOffset.
Optionally this joint function performs ReLU and/or dropout on the joint output, which is
controlled by arguments relu and dropout, respectively. philoxArgs is argument used for generating
pseudorandom number. When at least one of operations in ReLU and dropout is activated, the joint
function is a masked operation, which is controlled by the template argument masked. In this case,
masks are saved to backward.
*/
template <typename scalar_t, int tileF, int tileG, int U, class OffsetCal, bool masked>
__global__ void transducer_joint_tiled_forward(
const scalar_t *f,
const scalar_t *g,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
int64_t hiddenPerBlock,
bool packOutput,
bool relu,
bool dropout,
float p,
at::PhiloxCudaState philoxArgs,
scalar_t *sum,
uint8_t *mask) {
static_assert(U == 4, "U has to be 4, as random numbers are generated in batch of 4");
const int batch = blockIdx.z;
const int t = blockIdx.y * tileF;
const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock;
const int u = blockIdx.x / hiddenBlock * tileG;
const int hOffset = (blockIdx.x % hiddenBlock) * hiddenPerBlock;
const int h = threadIdx.x;
const auto myFLen = fLen[batch];
const auto myGLen = gLen[batch];
OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput);
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideF = offsetCal.getStrideF();
scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize + hOffset;
scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize + hOffset;
scalar_t *mySum = sum + myBatchOffset + t*strideF + u*hiddenSize + hOffset;
uint8_t *myMask = mask + myBatchOffset + t*strideF + u*hiddenSize + hOffset;
// The following code is only needed for dropout. We try to bypass them as much as possible.
auto seeds = masked ? at::cuda::philox::unpack(philoxArgs)
: std::make_tuple(static_cast<uint64_t>(0), static_cast<uint64_t>(0));
uint64_t tid = masked ? (static_cast<uint64_t>(blockIdx.z)*gridDim.y*gridDim.x +
blockIdx.y*gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x
: 0;
Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds));
scalar_t scale = masked ? ((p == 0) ? 0 : 1 / p) : 0;
bool dropoutMask[U];
if (t < myFLen and u < myGLen and hOffset+h < hiddenSize){
// register buffers for tiled input reuse
scalar_t fBuffer[tileF], gBuffer[tileG];
for (int i = 0; i < tileF; ++i){
if (t + i < myFLen)
fBuffer[i] = myF[i*hiddenSize + h];
}
for (int j = 0; j < tileG; ++j){
if (u + j < myGLen)
gBuffer[j] = myG[j*hiddenSize + h];
}
#pragma unroll
for (int i = 0; i < tileF; ++i){
if (t + i < myFLen){
#pragma unroll
for (int j = 0; j < tileG; ++j){
int idx = i*tileG + j;
if (masked and dropout and idx % U == 0){
// For performance, generate 4 random numbers in one shot
// auto rand4 = curand_uniform4(&state);
auto rand4 = uniform4(ph());
dropoutMask[0] = rand4.x < p;
dropoutMask[1] = rand4.y < p;
dropoutMask[2] = rand4.z < p;
dropoutMask[3] = rand4.w < p;
}
if (u + j < myGLen){
scalar_t out = fBuffer[i] + gBuffer[j];
if (masked){
// Apply ReLU here when relu is True
bool localMask = relu ? (out>0) : 1;
localMask = dropout ? localMask & dropoutMask[idx%U] : localMask;
out = dropout ? out*localMask*scale : out*localMask;
myMask[i*strideF + j*hiddenSize + h] = static_cast<uint8_t>(localMask);
}
mySum[i*strideF + j*hiddenSize + h] = out;
}
else if (packOutput == false and u + j < maxGLen)
mySum[i*strideF + j*hiddenSize + h] = -1;
}
}
else if (packOutput == false and t + i < maxFLen){
// Again need to write finite data to don't-care region
#pragma unroll
for (int j = 0; j < tileG; ++j){
if (u + j < maxGLen)
mySum[i*strideF + j*hiddenSize + h] = -1;
}
}
}
}
else if (packOutput == false and t < maxFLen and u < maxGLen and hOffset+h < hiddenSize){
// Only need to ensure the finity in normal mode
#pragma unroll
for (int i = 0; i < tileF; ++i){
if (t + i < maxFLen){
#pragma unroll
for (int j = 0; j < tileG; ++j){
if (u + j < maxGLen)
mySum[i*strideF + j*hiddenSize + h] = -1;
}
}
}
}
}
/*
Bwd operation (reduction) on one input tensor. Since the operation performed for the two input
tensors are exactly the same, only one kernel is needed, and the different indexing offsets
and strides are handled by OffsetCalBwd.
When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a
non-packed form.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template <typename scalar_t, typename acc_t, class OffsetCal, bool masked>
__device__ void transducer_joint_single_backward(
const scalar_t *grad,
const uint8_t *mask,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
bool bwdFasterDim, // whether bwd on the faster moving dimension (u)
float scale,
scalar_t *inGrad,
int yBlockOffset=0) {
const int batch = blockIdx.z;
// For the second input tensor, this offset need to be subtracted because the first yBlockOffset
// sets of thread blocks are for the first input tensor.
const int x = blockIdx.y-yBlockOffset;
const int hOffset = blockIdx.x*C10_WARP_SIZE;
const int wid = threadIdx.y;
const int lid = threadIdx.x;
const int numWarp = blockDim.y;
extern __shared__ char smem8[];
auto smem = reinterpret_cast<acc_t*>(smem8);
OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput,
bwdFasterDim);
const auto maxXLen = offsetCal.getMaxXLen();
const auto myXLen = offsetCal.getMyXLen();
const auto myYLen = offsetCal.getMyYLen();
scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset;
if (x < myXLen){
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideX = offsetCal.getStrideX();
const auto strideY = offsetCal.getStrideY();
const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset;
const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset : nullptr;
// Each warp reduces numYPerWarp "y" first
acc_t warpSum = 0;
auto numYPerWarp = (myYLen+numWarp-1)/numWarp;
#pragma unroll
for (int warpY = 0; warpY < numYPerWarp; ++warpY){
auto y = wid*numYPerWarp + warpY;
if (y < myYLen and (hOffset+lid) < hiddenSize)
if (masked)
warpSum += static_cast<acc_t>(myGrad[y*strideY + lid]) * myMask[y*strideY + lid] * scale;
else
warpSum += myGrad[y*strideY + lid];
}
// transpose partial sum in SMEM and reduce further using warpReduce
smem[lid*numWarp + wid] = warpSum;
__syncthreads();
auto sum = smem[wid*C10_WARP_SIZE + lid];
sum = warpReduce(sum, numWarp);
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// example of 4 warps (a, b, c, d) with 8 threads per warp
// Each warp need 8 / 4 = 2 threads to write the results.
if (hOffset+wid*C10_WARP_SIZE/numWarp+lid/numWarp < hiddenSize){
if (lid % numWarp == 0){
myInGrad[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = sum;
}
}
}
else if (wid == 0 and hOffset + lid < hiddenSize){
// Need to ensure the grad is zero for don't care region
myInGrad[lid] = 0;
}
}
/*
Actual bwd (reduction) kernel get launched.
Call transducer_joint_single_backward twice on two input tensors.
The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op
uses the rest.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template <typename scalar_t, typename acc_t, class OffsetCal, bool masked>
__global__ void transducer_joint_combined_backward(
const scalar_t *grad,
const uint8_t *mask,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
float scale,
scalar_t *fGrad,
scalar_t *gGrad) {
if (blockIdx.y < maxFLen){
transducer_joint_single_backward<scalar_t, acc_t, OffsetCal, masked>(
grad,
mask,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
false,
scale,
fGrad);
}
else{
transducer_joint_single_backward<scalar_t, acc_t, OffsetCal, masked>(
grad,
mask,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
true,
scale,
gGrad,
maxFLen);
}
}
/*
Vectorized version of transducer_joint_single_backward
Doing exact same operation as transducer_joint_single_backward except the load and store are
vectorized.
When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a
non-packed form.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal, bool masked>
__device__ void transducer_joint_single_vec_backward(
const scalar_t *grad,
const uint8_t *mask,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
bool bwdFasterDim,
float scale,
scalar_t *inGrad,
int yBlockOffset=0){
const int batch = blockIdx.z;
const int x = blockIdx.y - yBlockOffset;
const int hOffset = blockIdx.x*C10_WARP_SIZE*V;
const int wid = threadIdx.y;
const int lid = threadIdx.x;
const int numWarp = blockDim.y;
// Figure out the vectorization type for mask
using mvec_t = mvec_type<V>;
OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput,
bwdFasterDim);
const auto maxXLen = offsetCal.getMaxXLen();
const auto myXLen = offsetCal.getMyXLen();
const auto myYLen = offsetCal.getMyYLen();
scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset;
extern __shared__ char smem8[];
auto smem = reinterpret_cast<acc_t*>(smem8);
acc_t warpSum[V];
scalar_t inBuffer[V];
uint8_t maskBuffer[V];
scalar_t outBuffer[V];
auto myInGradVec = reinterpret_cast<vec_t*>(myInGrad);
auto outBufferVec = reinterpret_cast<vec_t*>(outBuffer);
if (x < myXLen){
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideX = offsetCal.getStrideX();
const auto strideY = offsetCal.getStrideY();
const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset;
const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset
:nullptr;
for (int i = 0; i < V; ++i)
warpSum[i] = 0;
// Each warp reduces numYPerWarp "y" first
auto numYPerWarp = (myYLen+numWarp-1)/numWarp;
for (int warpY = 0; warpY < numYPerWarp; ++warpY){
auto y = wid*numYPerWarp + warpY;
auto myGradVec = reinterpret_cast<vec_t const *>(myGrad + y*strideY);
auto myMaskVec = masked ? reinterpret_cast<mvec_t const *>(myMask + y*strideY)
: nullptr;
auto inBufferVec = reinterpret_cast<vec_t*>(inBuffer);
auto maskBufferVec = reinterpret_cast<mvec_t*>(maskBuffer);
if (hOffset + lid*V < hiddenSize and y < myYLen){
*inBufferVec = myGradVec[lid]; // vectorized load
if (masked){
*maskBufferVec = myMaskVec[lid];
#pragma unroll
for (int i = 0; i < V; ++i)
warpSum[i] += static_cast<acc_t>(inBuffer[i]) * maskBuffer[i] * scale;
}
else{
#pragma unroll
for (int i = 0; i < V; ++i)
warpSum[i] += inBuffer[i];
}
}
}
// transpose partial sum in SMEM and reduce further using warpReduce
for (int i = 0; i < V; ++i){
smem[lid*numWarp + wid] = warpSum[i];
__syncthreads();
auto sum = smem[wid*C10_WARP_SIZE + lid];
if (hOffset+(wid*C10_WARP_SIZE/numWarp)*V < hiddenSize){
sum = warpReduce(sum, numWarp);
if (lid % numWarp == 0){
outBuffer[i] = sum;
}
}
__syncthreads();
}
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// example of 4 warps (a, b, c, d) with 8 threads per warp
// Each warp need 8 / 4 = 2 threads to write the results.
if (lid % numWarp == 0 and hOffset+(wid*C10_WARP_SIZE/numWarp + lid/numWarp)*V < hiddenSize)
myInGradVec[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = *outBufferVec;
}
else if (wid == 0 and hOffset + lid*V < hiddenSize){
// Need to ensure the grad is zero for don't care region
myInGradVec[lid] = 0;
}
}
/*
Vecotrized version of transducer_joint_combined_backward
Call transducer_joint_single_vec_backward twice on two input tensors.
The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op
uses the rest.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal, bool masked>
__global__ void transducer_joint_combined_vec_backward(
const scalar_t *grad,
const uint8_t *mask,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
float scale,
scalar_t *fGrad,
scalar_t *gGrad) {
if (blockIdx.y < maxFLen){
transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal, masked>(
grad,
mask,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
false,
scale,
fGrad);
}
else{
transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal, masked>(
grad,
mask,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
true,
scale,
gGrad,
maxFLen);
}
}
std::vector<torch::Tensor> transducer_joint_cuda_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int64_t packedBatch,
int opt,
bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize){
auto tensorOpt = f.options();
auto dtype = f.scalar_type();
const auto batchSize = f.size(0);
const auto maxFLen = f.size(1);
const auto maxGLen = g.size(1);
const auto hiddenSize = f.size(2);
bool masked = dropout or relu;
int64_t *batchOffsetPtr = nullptr;
torch::Tensor sum, mask;
auto maskOpt = tensorOpt.dtype(torch::kUInt8);
if (!packOutput){
sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt);
batchOffsetPtr = nullptr;
if (masked)
mask = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt);
}
else{
sum = torch::empty({packedBatch, hiddenSize}, tensorOpt);
batchOffsetPtr = batchOffset.data_ptr<int64_t>();
if (masked)
mask = torch::empty({packedBatch, hiddenSize}, maskOpt);
}
uint8_t *maskPtr = masked ? mask.data_ptr<uint8_t>() : nullptr;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(opt == 0 or opt == 1, "Got an invalid optimization level ", opt);
// Simple heuristics
const int numThread = std::min(128, (static_cast<int>(hiddenSize)+C10_WARP_SIZE-1)
/ C10_WARP_SIZE * C10_WARP_SIZE);
if (opt == 0){
// vanilla kernel
const int threads = numThread;
const dim3 blocks(maxGLen, maxFLen, batchSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] {
transducer_joint_forward<scalar_t, OffsetCalFwd>
<<<blocks, threads, 0, stream>>>(
f.data_ptr<scalar_t>(),
g.data_ptr<scalar_t>(),
fLen.data_ptr<int>(),
gLen.data_ptr<int>(),
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
sum.data_ptr<scalar_t>());
}));
}
if (opt == 1){
// tiled version. For simplicity, assume tileF == tileG, even though the kernel can
// support more general cases.
const int threads = numThread;
const int hiddenPerBlock = numThread;
const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock;
const dim3 blocks( (maxGLen+tileSize-1)/tileSize * hiddenBlock,
(maxFLen+tileSize-1)/tileSize,
batchSize);
TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4,
"Expected tileSize to be in [1, 2, 4], but got ", tileSize);
at::PhiloxCudaState rng_engine_inputs;
if (masked){
// set up PRG when the input is masked. rng_engine_inputs will be used as a space filler
// for non-masked calls.
// Therefore no need to initialize.
c10::optional<at::Generator> gen_;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_,
at::cuda::detail::getDefaultCUDAGenerator());
// counterOffset records how many cuRAND calls each thread makes. For a tiled kernel,
// each thread processes tileF * tileG output elements.
int64_t counterOffset = tileSize * tileSize;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_cuda_state(counterOffset);
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] {
void(*kernel)(const scalar_t*, const scalar_t*, const int*, const int*, const int64_t*,
int64_t, int64_t, int64_t, int64_t, bool, bool, bool, float,
at::PhiloxCudaState, scalar_t*, uint8_t*);
if (masked){
switch (tileSize){
case 2:
kernel = &transducer_joint_tiled_forward<scalar_t, 2, 2, 4, OffsetCalFwd,
true>;
break;
case 4:
kernel = &transducer_joint_tiled_forward<scalar_t, 4, 4, 4, OffsetCalFwd,
true>;
break;
}
}
else{
switch (tileSize){
case 1:
kernel = &transducer_joint_tiled_forward<scalar_t, 1, 1, 4, OffsetCalFwd,
false>;
break;
case 2:
kernel = &transducer_joint_tiled_forward<scalar_t, 2, 2, 4, OffsetCalFwd,
false>;
break;
case 4:
kernel = &transducer_joint_tiled_forward<scalar_t, 4, 4, 4, OffsetCalFwd,
false>;
break;
}
}
kernel<<<blocks, threads, 0, stream>>>(
f.data_ptr<scalar_t>(),
g.data_ptr<scalar_t>(),
fLen.data_ptr<int>(),
gLen.data_ptr<int>(),
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
hiddenPerBlock,
packOutput,
relu,
dropout,
1.0f - dropoutProb,
rng_engine_inputs,
sum.data_ptr<scalar_t>(),
maskPtr);
}));
}
C10_CUDA_CHECK(cudaGetLastError());
if (masked)
return {sum, mask};
else
return {sum};
}
std::vector<torch::Tensor> transducer_joint_cuda_backward(
std::vector<torch::Tensor> in,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput,
float scale){
auto grad = in[0];
bool masked = (in.size() == 2);
uint8_t *maskPtr = masked ? in[1].data_ptr<uint8_t>() : nullptr;
auto tensorOpt = grad.options();
auto dtype = grad.scalar_type();
const int batchSize = fLen.size(0);
const int hiddenSize = grad.size(-1);
const auto deviceProperties = at::cuda::getCurrentDeviceProperties();
const int maxNumWarp = deviceProperties->maxThreadsPerBlock / C10_WARP_SIZE;
torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt);
torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt);
int64_t *batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr<int64_t>();
// The number "y" I would like each thread to work on
const int workPerThread = 32;
// Since the bwd for f and g have the same thread block size, we need to use the max of the two.
int numWarp = largestPowerOfTwo((std::max(maxFLen, maxGLen) + workPerThread-1) / workPerThread);
// Would like to have at least 2 warps
numWarp = std::max(2, numWarp);
// cap on the maximum number of warps allowed
numWarp = std::min(maxNumWarp, numWarp);
// Need smem for transposing the partial sum. The partial sum is in a matrix of the shape
// numWarp x warpSize
const int smemSize = numWarp * C10_WARP_SIZE;
const dim3 threads(C10_WARP_SIZE, numWarp, 1);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_cuda_backward_kernel", ([&] {
auto gradPtr = grad.data_ptr<scalar_t>();
auto fLenPtr = fLen.data_ptr<int>();
auto gLenPtr = gLen.data_ptr<int>();
auto fGradPtr = fGrad.data_ptr<scalar_t>();
auto gGradPtr = gGrad.data_ptr<scalar_t>();
// resolve the acc_t type
using acc_t = at::acc_type<scalar_t, true>;
using vec_t = uint64_t;
constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);
constexpr int vecAlignment = std::alignment_of<vec_t>::value;
// if all input and output tensors meet the alignment requirement
bool memAlign = (reinterpret_cast<uint64_t>(gradPtr) % vecAlignment == 0)
and (reinterpret_cast<uint64_t>(fGradPtr) % vecAlignment == 0)
and (reinterpret_cast<uint64_t>(gGradPtr) % vecAlignment == 0);
if (vectFactor > 1 and hiddenSize%vectFactor == 0 and memAlign){
// If vectorization helps and the alignment requirement is met, use the vectorized
// kernel. For simplicity, hiddenSize needs to be a multiple vecFactor.
const dim3 blocks( (hiddenSize+C10_WARP_SIZE*vectFactor-1)/(C10_WARP_SIZE*vectFactor),
maxFLen+maxGLen,
batchSize);
if (masked){
transducer_joint_combined_vec_backward
<scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd, true>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
maskPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
scale,
fGradPtr,
gGradPtr);
}
else{
transducer_joint_combined_vec_backward
<scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd, false>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
maskPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
scale,
fGradPtr,
gGradPtr);
}
}
else{
const dim3 blocks((hiddenSize+C10_WARP_SIZE-1)/C10_WARP_SIZE,
maxFLen + maxGLen, batchSize);
if (masked){
transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd, true>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
maskPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
scale,
fGradPtr,
gGradPtr);
}
else{
transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd, false>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
maskPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
scale,
fGradPtr,
gGradPtr);
}
}
}));
return {fGrad, gGrad};
}
#include <torch/extension.h>
#include <vector>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> transducer_loss_cuda_forward(
torch::Tensor x,
torch::Tensor label,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool packedInput);
torch::Tensor transducer_loss_cuda_backward(
torch::Tensor x,
torch::Tensor lossGrad,
torch::Tensor alpha,
torch::Tensor beta,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor label,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool fuseSoftmaxBackward,
bool packedInput);
std::vector<torch::Tensor> transducer_loss_forward(
torch::Tensor x,
torch::Tensor label,
torch::Tensor fLen,
torch::Tensor yLen,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool packedInput
) {
CHECK_INPUT(x);
CHECK_INPUT(label);
CHECK_INPUT(fLen);
CHECK_INPUT(yLen);
if (packedInput)
CHECK_INPUT(batchOffset);
return transducer_loss_cuda_forward(
x,
label,
fLen,
yLen,
batchOffset,
maxFLen,
blankIdx,
opt,
packedInput);
}
torch::Tensor transducer_loss_backward(
torch::Tensor x,
torch::Tensor lossGrad,
torch::Tensor alpha,
torch::Tensor beta,
torch::Tensor fLen,
torch::Tensor yLen,
torch::Tensor label,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool fuseSoftmaxBackward,
bool packedInput){
CHECK_INPUT(x);
CHECK_INPUT(label);
CHECK_INPUT(lossGrad);
CHECK_INPUT(alpha);
CHECK_INPUT(beta);
CHECK_INPUT(fLen);
CHECK_INPUT(yLen);
if (packedInput)
CHECK_INPUT(batchOffset);
return transducer_loss_cuda_backward(
x,
lossGrad,
alpha,
beta,
fLen,
yLen,
label,
batchOffset,
maxFLen,
blankIdx,
opt,
fuseSoftmaxBackward,
packedInput);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &transducer_loss_forward, "transducer loss forward (CUDA)");
m.def("backward", &transducer_loss_backward, "transducer loss backward (CUDA)");
}
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
template<typename scalar_t>
__device__ __forceinline__ scalar_t logSumExp(scalar_t a, scalar_t b) {
// standard log-sum-exp trick is used here to provide better numerical stability
return (a >= b) ? a + std::log1p(exp(b-a)) : b + std::log1p(exp(a-b));
}
// Vanilla transducer loss function (i.e. forward-backward algorithm)
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted
// into log scale by the preceding log_softmax layer
// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.
// alpha and beta are of acc_t type, as they are essentially accumulators.
// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into
// [B_packed, H].
// Don't-care region (t > audLen) or (u > txtLen) is removed.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t>
__global__ void transducer_loss_forward(
const scalar_t* x,
const int* label,
const int* audLen,
const int* txtLen,
const int64_t* batchOffset,
int64_t dictSize, // 64-bit indexing for data tensor
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
acc_t* alpha,
acc_t* beta,
scalar_t* loss) {
const int batch = blockIdx.y;
const int tid = threadIdx.x;
const auto myFLen = audLen[batch];
// Note that start of the sentence is added as 1 here
const auto myGLen = txtLen[batch] + 1;
const auto myLabel = label + batch * (maxGLen-1);
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
const scalar_t* myX = x + myBatchOffset * dictSize;
int u = tid;
if (blockIdx.x == 0){
// alpha path
acc_t* myAlpha = alpha + batch*maxFLen*maxGLen;
if (u == 0)
myAlpha[0] = 0;
__syncthreads();
for (int64_t step = 1; step < myFLen+myGLen-1; ++step){
// Move along the diagonal wavefront to leverage available parallelism
for (u = tid; u < myGLen; u += blockDim.x){
int64_t t = step - u;
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
// Eq(16) in [1]
if (u == 0){
// alpha(t, u) = alpha(t-1, u) * null(t-1, u)
myAlpha[t*maxGLen + u] = myAlpha[(t-1)*maxGLen]
+ myX[((t-1)*myStrideT) * dictSize + blankIdx];
}
else if (t == 0){
// alpha(t, u-1) = alpha(t, u-1) * y(t, u-1)
myAlpha[u] = myAlpha[u - 1] + myX[(u - 1) * dictSize + myLabel[u - 1]];
}
else{
// alpha(t, u) = alpha(t-1, u) * null(t-1, u) + alpha(t, u-1) * y(t, u-1)
acc_t current = myAlpha[(t-1)*maxGLen + u]
+ myX[((t-1)*myStrideT + u) * dictSize + blankIdx];
acc_t next = myAlpha[t*maxGLen + u - 1]
+ myX[(t*myStrideT + u - 1) * dictSize + myLabel[u - 1]];
myAlpha[t*maxGLen + u] = logSumExp(next, current);
}
}
}
__syncthreads();
}
}
else if (blockIdx.x == 1){
// beta path
acc_t* myBeta = beta + batch*maxFLen*maxGLen;
if (u == 0){
myBeta[(myFLen-1)*maxGLen + myGLen - 1] = myX[((myFLen-1)*myStrideT
+ myGLen - 1) * dictSize + blankIdx];
}
__syncthreads();
for (int64_t step = myFLen+myGLen - 3; step >= 0; --step){
for (u = tid; u < myGLen; u += blockDim.x){
int64_t t = step - u;
if (t >= 0 and t < myFLen and u >=0 and u < myGLen){
// Eq(18) in [1]
if (u == myGLen - 1){
// beta(t, u) = beta(t+1, u) * null(t, u)
myBeta[t*maxGLen + u] = myBeta[(t+1)*maxGLen + u]
+ myX[(t*myStrideT + u) * dictSize + blankIdx];
}
else if (t == myFLen - 1){
// beta(t, u) = beta(t, u+1) * y(t, u)
myBeta[t*maxGLen + u] = myBeta[t*maxGLen + u + 1]
+ myX[(t*myStrideT + u) * dictSize + myLabel[u]];
}
else{
// beta(t, u) = beta(t+1, u)*null(t, u) + beta(t, u+1)*y(t, u)
acc_t current = myBeta[(t+1)*maxGLen + u]
+ myX[(t*myStrideT + u) * dictSize + blankIdx];
acc_t next = myBeta[t*maxGLen + u + 1]
+ myX[(t*myStrideT + u) * dictSize + myLabel[u]];
myBeta[t*maxGLen + u] = logSumExp(next, current);
}
}
}
__syncthreads();
}
if (tid == 0)
loss[batch] = -myBeta[0];
}
}
// transudcer loss function (i.e. forward-backward algorithm) with batch loading optimization.
// Compared to the vanilla version, there are two optimizations:
// 1. load x in batch through loop unrolling to reduce the latency.
// 2. Use registers and shared memory to hold alpha and beta values passed from one step the next.
// For simplicity, this kernel currently only supports U <= maxThread, which should be the common
// case. For cases where U > maxThread, the vanilla kernel is used as a fallback option.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted
// into log scale by the preceding log_softmax layer
// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.
// alpha and beta are of acc_t type, as they are essentially accumulators.
// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into
// [B_packed, H].
// Don't-care region (t > audLen) or (u > txtLen) is removed.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t, int batchLdSize>
__global__ void transducer_loss_batch_load_forward(
const scalar_t* x,
const int* label,
const int* audLen,
const int* txtLen,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
acc_t* alpha,
acc_t* beta,
scalar_t* loss) {
const int batch = blockIdx.y;
int u = threadIdx.x;
const auto myFLen = audLen[batch];
const auto myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
const scalar_t* myX = x + myBatchOffset * dictSize;
scalar_t next[batchLdSize], current[batchLdSize];
extern __shared__ char smem8[];
auto smem = reinterpret_cast<acc_t*>(smem8);
if (blockIdx.x == 0){
// alpha path
acc_t* myAlpha = alpha + batch*maxFLen*maxGLen;
// two SMEM regions for double buffering read and write data to avoid data race
acc_t * const sharedAlpha[2] = {smem, smem+maxGLen};
sharedAlpha[0][u] = 0;
__syncthreads();
if (u == 0)
myAlpha[0] = 0;
auto myAlphaLabel = (u == 0) ? 0 : label[batch*(maxGLen-1) + u - 1];
// register used to pass value to the next step for the same thread
acc_t prvStepAlpha = 0;
for (int64_t step = 1; step < myFLen+myGLen-1+batchLdSize; step += batchLdSize){
// Move along the diagonal wavefront to leverage available parallelism
// Batch loading X through loop unrolling
#pragma unroll
for (int i = 0; i < batchLdSize; ++i){
if (step+i<myFLen+myGLen-1){
// index computing
int64_t t = step + i - u;
int64_t currentId = ((t-1)*myStrideT + u) * dictSize + blankIdx;
int64_t nextId = (t*myStrideT + u - 1) * dictSize + myAlphaLabel;
// main loading loop
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
if (u == 0){
current[i] = myX[currentId];
}
else if (t == 0){
next[i] = myX[nextId];
}
else{
current[i] = myX[currentId];
next[i] = myX[nextId];
}
}
}
}
// main computing loop
for (int i = 0; i < batchLdSize; ++i){
// swap the pointer for double buffering
auto sharedAlphaRd = sharedAlpha[(step+i-1)%2];
auto sharedAlphaWr = sharedAlpha[(step+i)%2];
if (step+i<myFLen+myGLen-1){
int64_t t = step + i - u;
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
// Eq(16) in [1]
if (u == 0)
prvStepAlpha = prvStepAlpha+current[i];
else if (t == 0)
prvStepAlpha = sharedAlphaRd[u-1]+next[i];
else
prvStepAlpha = logSumExp(prvStepAlpha+current[i], sharedAlphaRd[u-1]
+ next[i]);
sharedAlphaWr[u] = prvStepAlpha;
myAlpha[t*maxGLen + u] = prvStepAlpha;
}
}
__syncthreads();
}
}
}
else if (blockIdx.x == 1){
// beta path
acc_t* myBeta = beta + batch*maxFLen*maxGLen;
// two SMEM regions for double buffering read and write data to avoid data race
acc_t * const sharedBeta[2] = {smem, smem + maxGLen};
sharedBeta[0][u] = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx];
__syncthreads();
auto myBetaLabel = (u == maxGLen - 1) ? 0 : label[batch*(maxGLen-1) + u];
// register used to pass value to the next step for the same thread
acc_t prvStepBeta = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx];
if (u == 0)
myBeta[(myFLen-1)*maxGLen + myGLen - 1] = prvStepBeta;
for (int64_t step = 1; step < myFLen+myGLen-1; step += batchLdSize){
// Move along the diagonal wavefront to leverage available parallelism
// Batch loading X
#pragma unroll
for (int i = 0; i < batchLdSize; ++i){
if (step+i<myFLen+myGLen-1){
// index computing
int64_t t = myFLen+myGLen - (step + i) - 2 - u;
int64_t currentId = (t*myStrideT + u) * dictSize + blankIdx;
int64_t nextId = (t*myStrideT + u) * dictSize + myBetaLabel;
// main loading loop
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
if (u == myGLen - 1){
current[i] = myX[currentId];
}
else if (t == myFLen - 1){
next[i] = myX[nextId];
}
else{
current[i] = myX[currentId];
next[i] = myX[nextId];
}
}
}
}
// main computing loop
for (int i = 0; i < batchLdSize; ++i){
// swap the pointer for double buffering
auto sharedBetaRd = sharedBeta[(step+i-1)%2];
auto sharedBetaWr = sharedBeta[(step+i)%2];
if (step+i<myFLen+myGLen-1){
int64_t t = myFLen+myGLen - (step + i) - 2 - u;
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
// Eq(18) in [1]
if (u == myGLen - 1)
prvStepBeta = prvStepBeta+current[i];
else if (t == myFLen - 1)
prvStepBeta = sharedBetaRd[u+1]+next[i];
else
prvStepBeta = logSumExp(prvStepBeta+current[i], sharedBetaRd[u+1]
+ next[i]);
sharedBetaWr[u] = prvStepBeta;
myBeta[t*maxGLen + u] = prvStepBeta;
}
}
__syncthreads();
}
}
if (u == 0)
loss[batch] = -prvStepBeta;
}
}
// Vanilla transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// For this backward kernel, bwd op for the preceding softmax is assumed to be handled elsewhere,
// hence only Eq(20) in [1] is implemented in this kernel.
// Each thread block works on [batch, t, :, :] of data. Each thread works on a specific u at a time
// Since only gradients for the correct token and null token need to be updated, gradients at other
// locations are initialized to 0.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t>
__global__ void transducer_loss_backward(
const scalar_t* x,
const scalar_t* lossGrad,
const int* audLen,
const int* txtLen,
const int* label,
const acc_t* alpha,
const acc_t* beta,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
scalar_t* xGrad) {
const int tid = threadIdx.x;
const int t = blockIdx.x;
const int batch = blockIdx.y;
const int64_t myFLen = audLen[batch];
const int64_t myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
auto myX = x + (myBatchOffset + t*myStrideT)*dictSize;
auto myAlpha = alpha + batch*maxFLen*maxGLen;
auto myBeta = beta + batch*maxFLen*maxGLen;
auto myXGrad = xGrad + (myBatchOffset + t*myStrideT)*dictSize;
auto myLabel = label + batch*(maxGLen-1);
int64_t u = tid;
while (t < myFLen and u < myGLen){
// Do the update
// loss = -ln(Pr(y*|x))
acc_t grad = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];
if (u != myGLen - 1)
myXGrad[u*dictSize + myLabel[u]] = -std::exp(grad + myBeta[t*maxGLen + u + 1]
+ myX[u*dictSize + myLabel[u]]);
if (t == myFLen - 1 and u == myGLen - 1)
myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myX[u*dictSize + blankIdx]);
else if (t != myFLen - 1)
myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myBeta[(t+1)*maxGLen + u]
+ myX[u*dictSize + blankIdx]);
u += blockDim.x;
}
}
// Fused transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// The bwd op of the preceding softmax layer is fused in this kernel.
// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t>
__global__ void transducer_loss_fused_backward(
const scalar_t* x,
const scalar_t* lossGrad,
const int* audLen,
const int* txtLen,
const int* label,
const acc_t* alpha,
const acc_t* beta,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
scalar_t* xGrad) {
const int tid = threadIdx.x;
const int u = blockIdx.x;
const int t = blockIdx.y;
const int batch = blockIdx.z;
const int64_t myFLen = audLen[batch];
const int64_t myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
__shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;
auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize;
if (t < myFLen and u < myGLen){
auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize;
auto myAlpha = alpha + batch*maxFLen*maxGLen;
auto myBeta = beta + batch*maxFLen*maxGLen;
auto myLabel = label + batch*(maxGLen-1);
// load and store shared variables in SMEM
if (tid == 0){
commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];
myBetaTU = myBeta[t*maxGLen + u];
myBetaTUp1 = myBeta[t*maxGLen + u + 1];
myBetaTp1U = myBeta[(t+1)*maxGLen + u];
myLabelShared = myLabel[u];
}
__syncthreads();
for (int64_t h = tid; h < dictSize; h += blockDim.x){
// Do the update
acc_t grad = commonFactor + myX[h]; // loss = -ln(Pr(y*|x))
acc_t myGrad = std::exp(grad + myBetaTU);
if (u != myGLen - 1 and h == myLabelShared){
myGrad -= std::exp(grad + myBetaTUp1);
}
else if (h == blankIdx){
if (t == myFLen - 1 and u == myGLen - 1)
myGrad -= std::exp(grad);
else if (t != myFLen - 1)
myGrad -= std::exp(grad + myBetaTp1U);
}
myXGrad[h] = myGrad;
}
}
else if (!packedInput){
// In non-pack mode, need to make sure the gradients for don't-care regions are zero.
for (int64_t h = tid; h < dictSize; h += blockDim.x){
myXGrad[h] = 0;
}
}
}
// Vectorized version of fused transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// The bwd op of the preceding softmax layer is fused in this kernel.
// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t, typename vec_t, int V>
__global__ void transducer_loss_fused_vec_backward(
const scalar_t* x,
const scalar_t* lossGrad,
const int* audLen,
const int* txtLen,
const int* label,
const acc_t* alpha,
const acc_t* beta,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
scalar_t* xGrad) {
const int tid = threadIdx.x;
const int u = blockIdx.x;
const int t = blockIdx.y;
const int batch = blockIdx.z;
const int64_t myFLen = audLen[batch];
const int64_t myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
__shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;
auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize;
auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize;
auto myAlpha = alpha + batch*maxFLen*maxGLen;
auto myBeta = beta + batch*maxFLen*maxGLen;
auto myLabel = label + batch*(maxGLen-1);
// Variabels for vectorization
scalar_t myXBuffer[V], myXGradBuffer[V];
auto myXVec = reinterpret_cast<vec_t const *>(myX);
auto myXGradVec = reinterpret_cast<vec_t*>(myXGrad);
auto myXBufferVec = reinterpret_cast<vec_t*>(myXBuffer);
auto myXGradBufferVec = reinterpret_cast<vec_t*>(myXGradBuffer);
if (t < myFLen and u < myGLen){
// load and store shared variables in SMEM
if (tid == 0){
commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];
myBetaTU = myBeta[t*maxGLen + u];
if (t != myFLen - 1)
myBetaTp1U = myBeta[(t+1)*maxGLen + u];
if (u != myGLen - 1){
myBetaTUp1 = myBeta[t*maxGLen + u + 1];
myLabelShared = myLabel[u];
}
}
__syncthreads();
#pragma unroll
for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){
// Load myX in a vector form
*myXBufferVec = myXVec[h0/V];
// Do the update for a vector of input
#pragma unroll
for (int i = 0; i < V; ++i){
auto h = h0 + i;
acc_t grad = commonFactor + myXBuffer[i]; // loss = -ln(Pr(y*|x))
acc_t myGrad = std::exp(grad + myBetaTU);
if (u != myGLen - 1 and h == myLabelShared){
myGrad -= std::exp(grad + myBetaTUp1);
}
else if (h == blankIdx){
if (t == myFLen - 1 and u == myGLen - 1)
myGrad -= std::exp(grad);
else if (t != myFLen - 1)
myGrad -= std::exp(grad + myBetaTp1U);
}
myXGradBuffer[i] = myGrad;
}
// Store myXGrad in a vector form
myXGradVec[h0/V] = *myXGradBufferVec;
}
}
else if (!packedInput){
// In non-pack mode, need to make sure the gradients for don't-care regions are zero.
for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){
myXGradVec[h0/V] = 0;
}
}
}
std::vector<torch::Tensor> transducer_loss_cuda_forward(
torch::Tensor x,
torch::Tensor label,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool packedInput){
auto scalarType = x.scalar_type();
auto tensorOpt = x.options();
const int batchSize = label.size(0);
const int maxGLen = label.size(1) + 1;
const int dictSize = x.size(-1);
TORCH_CHECK(blankIdx >= 0 and blankIdx < dictSize,
"Expected blank index to be in the range of 0 to ",
dictSize-1,
", but got ",
blankIdx);
TORCH_CHECK(opt == -1 or opt == 0 or opt == 1,
"Got an invalid optimization level ",
opt);
// The data type of alpha and beta will be resolved at dispatch time,
// hence defined here and assigned later
torch::Tensor alpha;
torch::Tensor beta;
torch::Tensor loss = torch::empty({batchSize}, tensorOpt);
const auto deviceProperties = at::cuda::getCurrentDeviceProperties();
const auto maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;
const auto maxSmemPerBlock = deviceProperties->sharedMemPerBlock;
const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(scalarType, "transducer_loss_cuda_forward", ([&] {
// resolve accumulation type
using acc_t = at::acc_type<scalar_t, true>;
auto accType = c10::CppTypeToScalarType<acc_t>::value;
auto accTensorOpt = tensorOpt.dtype(accType);
alpha = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);
beta = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);
// decide what kernel to launch based on the problem size
// if the required SMEM size or number threads exceeds the limit, fall back to the vanilla
// kernel.
const auto smemSize = 2*maxGLen*sizeof(acc_t);
const auto optFallBack = (maxGLen > maxThreadPerBlock or smemSize > maxSmemPerBlock) ? 0
: (opt == -1) ? 1 : opt;
const int threads = std::min(maxThreadPerBlock, maxGLen);
const dim3 blocks(2, batchSize, 1);
if (optFallBack == 0)
transducer_loss_forward<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
label.data_ptr<int>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
loss.data_ptr<scalar_t>());
else if (optFallBack == 1)
transducer_loss_batch_load_forward<scalar_t, acc_t, 4>
<<<blocks, threads, smemSize, stream>>>(
x.data_ptr<scalar_t>(),
label.data_ptr<int>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
loss.data_ptr<scalar_t>());
}));
C10_CUDA_CHECK(cudaGetLastError());
return {alpha, beta, loss};
}
torch::Tensor transducer_loss_cuda_backward(
torch::Tensor x,
torch::Tensor lossGrad,
torch::Tensor alpha,
torch::Tensor beta,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor label,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool fuseSoftmaxBackward,
bool packedInput){
auto dtype = x.scalar_type();
torch::Tensor xGrad;
const int batchSize = label.size(0);
const int maxGLen = label.size(1) + 1;
const int dictSize = x.size(-1);
const auto deviceProperties = at::cuda::getCurrentDeviceProperties();
const int maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;
const int warpSize = deviceProperties->warpSize;
const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (fuseSoftmaxBackward){
// alloc empty tensors for performance, hence need to ensure zeros are writtern to
// don't-care region in the kernel.
xGrad = torch::empty_like(x);
// Would like each thread to work on 4 hidden units
const int workPerThread = 4;
// Don't want to have more than 128 threads per thread block
const int maxThreadPerElmt = std::min(128, maxThreadPerBlock);
const int threads = std::min(maxThreadPerElmt, std::max(warpSize,
(dictSize+workPerThread-1)/workPerThread));
const dim3 blocks(maxGLen, maxFLen, batchSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] {
using vec_t = uint64_t;
using acc_t = at::acc_type<scalar_t, true>;
constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);
constexpr int vecAlignment = std::alignment_of<vec_t>::value;
// if all input and output tensors meet the alignment requirement
bool memAlign = reinterpret_cast<uint64_t>(x.data_ptr<scalar_t>()) % vecAlignment == 0
and reinterpret_cast<uint64_t>(xGrad.data_ptr<scalar_t>())
% vecAlignment == 0;
if (vectFactor > 1 and dictSize%vectFactor == 0 and memAlign){
transducer_loss_fused_vec_backward<scalar_t, acc_t, vec_t, vectFactor>
<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
lossGrad.data_ptr<scalar_t>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
}
else{
transducer_loss_fused_backward<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
lossGrad.data_ptr<scalar_t>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
}
}));
}
else{
// for non-fused kernel, the gradients need to be writtern are very sparse, hence initialize
// the tensor with all zeros.
xGrad = torch::zeros_like(x);
// don't launch more threads than needed.
const int threads = std::min(maxThreadPerBlock, maxGLen);
const dim3 blocks(maxFLen, batchSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] {
using acc_t = at::acc_type<scalar_t, true>;
transducer_loss_backward<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
lossGrad.data_ptr<scalar_t>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
}));
}
C10_CUDA_CHECK(cudaGetLastError());
return xGrad;
}
...@@ -76,10 +76,6 @@ ...@@ -76,10 +76,6 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/NumericLimits.cuh> #include <ATen/cuda/NumericLimits.cuh>
#include <THC/THC.h>
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
#include "type_shim.h" #include "type_shim.h"
#include "compat.h" #include "compat.h"
...@@ -638,7 +634,7 @@ std::vector<Tensor> host_softmax_xentropy( ...@@ -638,7 +634,7 @@ std::vector<Tensor> host_softmax_xentropy(
} }
); );
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
std::vector<at::Tensor> ret = {losses, max_log_sum_exp}; std::vector<at::Tensor> ret = {losses, max_log_sum_exp};
return ret; return ret;
...@@ -708,7 +704,7 @@ Tensor host_softmax_xentropy_backward( ...@@ -708,7 +704,7 @@ Tensor host_softmax_xentropy_backward(
} }
); );
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
return gI; return gI;
} }
......
from .fmha import FMHAFun
###############################################################################
# Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
###############################################################################
import torch
import torch.nn.functional as F
import fmhalib as mha
class FMHAFun(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training):
batch_size = cu_seqlens.numel() - 1
if batch_size < 4:
context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, None)
else:
context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, None)
ctx.save_for_backward(qkv, S_dmask)
ctx.cu_seqlens = cu_seqlens
ctx.p_dropout = p_dropout
ctx.max_s = max_s
return context
@staticmethod
def backward(ctx, dout):
qkv, S_dmask = ctx.saved_tensors
batch_size = ctx.cu_seqlens.numel() - 1
if batch_size < 4:
dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)
else:
dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)
return dqkv, None, None, None, None, None, None
class FMHA(torch.nn.Module):
def __init__(self, config):
super(FMHA, self).__init__()
self.p_dropout = config.attention_probs_dropout_prob
self.h = config.num_attention_heads
self.hidden_size = config.hidden_size
self.d = self.hidden_size // self.h
assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads"
def forward(self, qkv, cu_seqlens, max_s, is_training=True):
ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training)
return ctx.view(-1, self.hidden_size)
import os
import math import math
import torch import torch
import importlib import importlib
...@@ -266,13 +267,48 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -266,13 +267,48 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._total_param_size = p_offset self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
def _lazy_init_stage1(self):
if self._lazy_init_stage1_done: return
p_i = 0
#self._model_params = []
#self._grad_accs = []
#self._group_properties = []
for group in self.param_groups:
for p in group['params']:
torch.distributed.broadcast(p, 0)
if not p.requires_grad:
continue
def wrapper(param, param_i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
if not self._set_flat_param_view:
if self._first_step:
# first time
self._param_order.add(param_i)
else:
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
else:
if not self._first_step:
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
wrapper(p, p_i)
p_i += 1
self._block_size = self._total_param_size // self._num_blocks self._block_size = self._total_param_size // self._num_blocks
self._chunk_size = self._block_size // self._num_chunks self._chunk_size = self._block_size // self._num_chunks
self._shard_size = self._chunk_size // self._group_size self._shard_size = self._chunk_size // self._group_size
#print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size)) #print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda') self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
# initialize master weights, moments buffers if not loaded from checkpoint # initialize master weights, moments buffers if not loaded from checkpoint
if self._fp32_p is None: if self._fp32_p is None:
...@@ -291,11 +327,18 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -291,11 +327,18 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)] return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
def __shardify(p): def __shardify(p):
return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)] return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
list_of_blocks = __blockify(self._flat_grads) list_of_blocks = __blockify(p)
list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks] list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks] list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]
return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads) def _flat_split_no_shards(p):
def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
list_of_blocks = __blockify(self._flat_grads)
list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
def _full_packed_split(p): def _full_packed_split(p):
def __shardify(p): def __shardify(p):
return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)] return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
...@@ -307,7 +350,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -307,7 +350,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards] list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks] list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
def _packed_split(p): def _packed_split(p):
def __packed_blockify(p): def __packed_blockify(p):
packed_block_size = self._num_chunks*self._shard_size packed_block_size = self._num_chunks*self._shard_size
...@@ -318,12 +360,86 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -318,12 +360,86 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
list_of_blocks = __packed_blockify(p) list_of_blocks = __packed_blockify(p)
list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks] list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks return list_of_blocks, list_of_list_of_chunks
def _split_assign(shards):
packed_block_size = self._num_chunks*self._shard_size
list_of_list_of_chunks=[]
for block_id in range(self._num_blocks):
list_of_chunks=[]
for chunk_id in range(self._num_chunks):
#self._fp16_g[block_id*packed_block_size+chunk_id*self._shard_size:block_id*packed_block_size+(chunk_id+1)*self._shard_size] = shards[block_id][chunk_id][self._rank_in_group]
list_of_chunks.append( shards[block_id][chunk_id][self._rank_in_group])
list_of_list_of_chunks.append(list_of_chunks)
return list_of_list_of_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
# this splitting scheme is needed when allgather needs to be split into multiple chunks in a contiguous way
self._new_params2_blocks, self._new_params2_chunks, self._new_params2_shards = _flat_split(self._new_params)
self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p) self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m) self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v) self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
self._fp32_u_blocks, self._fp32_u_chunks = _packed_split(self._fp32_u) self._fp32_u_blocks, self._fp32_u_chunks = _packed_split(self._fp32_u)
self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p) self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
if self._full_ar:
# for gradient all-reduce
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
# for weight update
self._fp16_g_chunks = _split_assign(self._flat_grads_shards)
else:
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
self._lazy_init_stage1_done = True
def _lazy_init_stage2(self):
if self._lazy_init_stage2_done: return
if not self._set_flat_param_view:
# reversing is needed for overlapping allreduce and backprop, but currently not supported for flat param view
self._param_order.order.reverse()
# re-order model_params, grad_accs, group_properties lists
self._model_params = [self._model_params[i] for i in self._param_order.order]
self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
self._group_properties = [self._group_properties[i] for i in self._param_order.order]
def _get_flat_view(param):
if param.is_contiguous(memory_format=torch.channels_last):
K, C, H, W = param.shape
pv = param.as_strided(size=(K,H,W,C), stride=(H*W*C, W*C, C, 1))
elif param.is_contiguous(memory_format=torch.channels_last_3d):
K, C, D, H, W = param.shape
pv = param.as_strided(size=(K,D,H,W,C), stride=(D*H*W*C, H*W*C, W*C, C, 1))
else:
pv = param
return pv.view(-1)
# re-collect grads info (size, offset) after ordering
prev = None
p_offset = 0
self._grads_info = []
self._individual_flat_grads = []
for i, p in enumerate(self._model_params):
p_grads_size = p.numel()
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
self._individual_flat_grads.append(self._flat_grads[p_offset:p_offset+p_grads_size].view_as(p))
# for the first iteration
self._do_overlapped_reduction(i, p)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
#print("self._low_param_i", self._low_param_i)
self._lazy_init_stage1_done = True self._lazy_init_stage1_done = True
...@@ -392,7 +508,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -392,7 +508,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
grad_offset = clipped_start - flat_grad_start grad_offset = clipped_start - flat_grad_start
grad_length = clipped_end - clipped_start grad_length = clipped_end - clipped_start
shard_offset = clipped_start - flat_shard_start shard_offset = clipped_start - flat_shard_start
model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length] pf = _get_flat_view(p)
model_param_fragment = pf[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length] new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]
if model_param_fragment.dtype == torch.float16: if model_param_fragment.dtype == torch.float16:
self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) ) self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) )
...@@ -633,19 +750,34 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -633,19 +750,34 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def _flatten_grad_mt(self, scale): def _flatten_grad_mt(self, scale):
if len(self._grads_fp16) > 0: if len(self._grads_fp16) > 0:
self._overflow_buf.zero_() self._overflow_buf.zero_()
multi_tensor_applier( if not self._fused_norm:
amp_C.multi_tensor_scale, multi_tensor_applier(
self._overflow_buf, amp_C.multi_tensor_scale,
list(zip(*self._grads_fp16)), self._overflow_buf,
scale) list(zip(*self._grads_fp16)),
scale)
else:
self._L2_grad_norm=multi_tensor_applier(
amp_C.multi_tensor_l2norm_scale,
self._overflow_buf,
list(zip(*self._grads_fp16)),
scale, False)[0].float()
self._grads_fp16 = [] self._grads_fp16 = []
if len(self._grads_fp32) > 0: if len(self._grads_fp32) > 0:
self._overflow_buf.zero_() self._overflow_buf.zero_()
multi_tensor_applier( if not self._fused_norm:
amp_C.multi_tensor_scale, multi_tensor_applier(
self._overflow_buf, amp_C.multi_tensor_scale,
list(zip(*self._grads_fp32)), self._overflow_buf,
scale) list(zip(*self._grads_fp32)),
scale)
else:
self._L2_grad_norm=multi_tensor_applier(
amp_C.multi_tensor_l2norm_scale,
self._overflow_buf,
list(zip(*self._grads_fp32)),
scale, False)[0].float()
self._grads_fp32 = [] self._grads_fp32 = []
def _do_overlapped_reduction(self, param_i, param): def _do_overlapped_reduction(self, param_i, param):
......
...@@ -111,7 +111,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -111,7 +111,7 @@ class FusedLAMB(torch.optim.Optimizer):
continue continue
if p.dtype == torch.float32: if p.dtype == torch.float32:
g_all_32.append(p.grad.data) g_all_32.append(p.grad.data)
elif p.dytpe == torch.float16: elif p.dtype == torch.float16:
g_all_16.append(p.grad.data) g_all_16.append(p.grad.data)
else: else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.') raise RuntimeError('FusedLAMB only support fp16 and fp32.')
......
...@@ -9,7 +9,7 @@ from apex.contrib.sparsity import ASP ...@@ -9,7 +9,7 @@ from apex.contrib.sparsity import ASP
## Initializing ASP ## Initializing ASP
Apart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/infercence: Apart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/inference:
``` ```
ASP.prune_trained_model(model, optimizer) ASP.prune_trained_model(model, optimizer)
``` ```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment