Unverified Commit db92ee13 authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #64 from ROCmSoftwarePlatform/IFU-master-2021-12-08

IFU-master-2021-12-08
parents d150afdc 68364b49
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace mask_softmax_dropout { namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const uint8_t *pad_mask,
torch::Tensor const& input, float dropout_prob) {
const uint8_t* pad_mask, const int attn_batches = input.size(0);
float dropout_prob const int sequences = attn_batches / heads;
) const int q_seq_len = input.size(1);
{ const int k_seq_len = q_seq_len;
const int attn_batches = input.size(0); const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1); // There is no reason to use more than one stream as every kernel is
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = input.options().requires_grad(false); // by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr()); void *input_ptr = static_cast<void *>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
pad_mask, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
if (is_training) { if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version // use at:: function so that C++ version generates the same random mask as
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob); // python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple); dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple);
} }
// Matmul2 // Matmul2
return { return {dropout_results, dropout_mask, softmax_results};
dropout_results,
dropout_mask,
softmax_results
};
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask,
torch::Tensor const& softmax_results, const uint8_t *padding_mask, float dropout_prob) {
torch::Tensor const& dropout_mask, const int attn_batches = output_grads.size(0);
const uint8_t *padding_mask, const int q_seq_len = output_grads.size(1);
float dropout_prob const int k_seq_len = q_seq_len;
) const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
{
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads); // torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
if (padding_mask == nullptr) { if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len, stream);
k_seq_len, } else {
attn_batches*q_seq_len, stream); dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,
} else{ false>(
dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>( static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(padding_mask), 1.0 / (1.0 - dropout_prob),
static_cast<uint8_t const*>(padding_mask), k_seq_len, k_seq_len, attn_batches * q_seq_len, heads, stream);
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
heads, stream);
} }
//backward pass is completely in-place // backward pass is completely in-place
return output_grads; return output_grads;
} }
} } // namespace mask_softmax_dropout
} } // namespace fused_softmax
} } // namespace multihead_attn
#pragma once #pragma once
//Philox CUDA. // Philox CUDA.
class Philox { class Philox {
public: public:
...@@ -15,28 +15,30 @@ public: ...@@ -15,28 +15,30 @@ public:
incr_n(offset / 4); incr_n(offset / 4);
} }
__device__ inline uint4 operator()() { __device__ inline uint4 operator()() {
if(STATE == 0) { if (STATE == 0) {
uint4 counter_ = counter; uint4 counter_ = counter;
uint2 key_ = key; uint2 key_ = key;
//7-round philox // 7-round philox
for(int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_); counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B); key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
} }
output = single_round(counter_, key_); output = single_round(counter_, key_);
incr(); incr();
} }
//return a float4 directly // return a float4 directly
//unsigned long ret; // unsigned long ret;
//switch(STATE) { // switch(STATE) {
// case 0: ret = output.x; break; // case 0: ret = output.x; break;
// case 1: ret = output.y; break; // case 1: ret = output.y; break;
// case 2: ret = output.z; break; // case 2: ret = output.z; break;
// case 3: ret = output.w; break; // case 3: ret = output.w; break;
//} //}
//STATE = (STATE + 1) % 4; // STATE = (STATE + 1) % 4;
return output; return output;
} }
private: private:
uint4 counter; uint4 counter;
uint4 output; uint4 output;
...@@ -67,7 +69,7 @@ private: ...@@ -67,7 +69,7 @@ private:
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b, __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) { unsigned int *result_high) {
*result_high = __umulhi(a, b); *result_high = __umulhi(a, b);
return a*b; return a * b;
} }
__device__ inline uint4 single_round(uint4 ctr, uint2 key) { __device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0; unsigned int hi0;
...@@ -84,7 +86,7 @@ private: ...@@ -84,7 +86,7 @@ private:
}; };
// Inverse of 2^32. // Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f #define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float4 uniform4(uint4 x) { __device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32); return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
x.w * M_RAN_INVM32);
} }
#include <cuda_fp16.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <vector> #include <vector>
#include <cuda_fp16.h>
namespace multihead_attn { namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights,
torch::Tensor const& inputs, torch::Tensor const &input_biases,
torch::Tensor const& input_weights, torch::Tensor const &output_biases,
torch::Tensor const& output_weights, const half *pad_mask, float dropout_prob);
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const half* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, // torch::Tensor const& softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
// torch::Tensor const& softmax_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& bmm1_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& pad_mask, // torch::Tensor const& input_biases,
torch::Tensor const& input_lin_results, // torch::Tensor const& output_biases,
torch::Tensor const& inputs, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &input_weights,
bool is_training, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
int heads, torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
torch::Tensor const& inputs, torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights, AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_biases, torch::Tensor const& output_biases, AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(use_mask , "no mask is not supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(use_mask, "no mask is not supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only Half is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only Half is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
use_time_mask, output_weights, input_biases, output_biases,
is_training, use_mask ? static_cast<const half *>(pad_mask.data_ptr())
heads, : nullptr,
inputs, dropout_prob);
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& bmm1_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& pad_mask, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& input_lin_results, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& inputs, AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights, AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, bmm1_results, pad_mask, input_lin_results, inputs,
output_grads, input_weights, output_weights, dropout_mask, dropout_prob);
matmul2_results,
dropout_results,
bmm1_results,
pad_mask,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
...@@ -140,4 +112,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -140,4 +112,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
...@@ -55,28 +52,36 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -55,28 +52,36 @@ std::vector<torch::Tensor> fwd_cuda(
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = inputs.options().requires_grad(false); // by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor input_lin_results =
torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor bmm1_results =
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor dropout_results =
torch::Tensor outputs = torch::empty_like(inputs, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim); void *k_lin_results_ptr = static_cast<void *>(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* bmm1_results_ptr = static_cast<void*>(bmm1_results.data_ptr()); void *bmm1_results_ptr = static_cast<void *>(bmm1_results.data_ptr());
void* dropout_results_ptr = static_cast<void*>(dropout_results.data_ptr()); void *dropout_results_ptr = static_cast<void *>(dropout_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
...@@ -111,8 +116,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -111,8 +116,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -136,32 +140,28 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -136,32 +140,28 @@ std::vector<torch::Tensor> fwd_cuda(
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (is_training) { if (is_training) {
softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>( softmax_success =
reinterpret_cast<half*>(dropout_results_ptr), dispatch_additive_masked_softmax_dropout<half, half, float>(
(is_training) ? reinterpret_cast<uint8_t*>(dropout_mask.data_ptr<uint8_t>()) : nullptr, reinterpret_cast<half *>(dropout_results_ptr),
reinterpret_cast<const half*>(bmm1_results_ptr), (is_training)
pad_mask, ? reinterpret_cast<uint8_t *>(dropout_mask.data_ptr<uint8_t>())
attn_batches*q_seq_len*q_seq_len, : nullptr,
k_seq_len, reinterpret_cast<const half *>(bmm1_results_ptr), pad_mask,
k_seq_len, attn_batches * q_seq_len * q_seq_len, k_seq_len, k_seq_len,
attn_batches*q_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences,
attn_batches*q_seq_len/sequences, 1.0f - dropout_prob, stream);
1.0f-dropout_prob,
stream);
} else { } else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>( softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(dropout_results_ptr),//this is actually softmax results, but making it consistent for the next function reinterpret_cast<half *>(
reinterpret_cast<const half*>(bmm1_results_ptr), dropout_results_ptr), // this is actually softmax results, but
pad_mask, // making it consistent for the next function
k_seq_len, reinterpret_cast<const half *>(bmm1_results_ptr), pad_mask, k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches*q_seq_len, attn_batches * q_seq_len / sequences);
attn_batches*q_seq_len/sequences);
} }
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -211,73 +211,63 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -211,73 +211,63 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_lin_results, bmm1_results, dropout_results,
input_lin_results, dropout_mask, matmul2_results, outputs};
bmm1_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& bmm1_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& pad_mask, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& input_lin_results, const int embed_dim = inputs.size(2);
torch::Tensor const& inputs, const int sequences = inputs.size(1);
torch::Tensor const& input_weights, const int q_seq_len = inputs.size(0);
torch::Tensor const& output_weights, const int k_seq_len = q_seq_len;
torch::Tensor const& dropout_mask, const int batches = sequences * q_seq_len;
float dropout_prob const int head_dim = embed_dim / heads;
) const int output_lin_dim = 3 * embed_dim;
{ const int attn_batches = heads * sequences;
const int embed_dim = inputs.size(2); const int lead_dim = attn_batches * 3 * head_dim;
const int sequences = inputs.size(1); const int batch_stride = 3 * head_dim;
const int q_seq_len = inputs.size(0); const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int k_seq_len = q_seq_len; const float alpha = 1.0;
const int batches = sequences * q_seq_len; const float beta = 0.0;
const int head_dim = embed_dim / heads; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs); torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim; auto k_lin_results_ptr =
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()); auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim; auto k_lin_grads_ptr =
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
...@@ -335,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -335,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda(
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -358,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -358,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -396,8 +384,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -396,8 +384,7 @@ std::vector<torch::Tensor> bwd_cuda(
stream); stream);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -419,8 +406,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -419,8 +406,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -496,13 +482,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -496,13 +482,8 @@ std::vector<torch::Tensor> bwd_cuda(
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_grads, input_weight_grads, output_weight_grads,
input_grads, input_bias_grads, output_bias_grads};
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads
};
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
......
...@@ -5,127 +5,102 @@ namespace multihead_attn { ...@@ -5,127 +5,102 @@ namespace multihead_attn {
namespace self_bias { namespace self_bias {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor>
bool use_time_mask, fwd_cuda(bool use_time_mask, bool is_training, int heads,
bool is_training, torch::Tensor const &inputs, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const& inputs, torch::Tensor const &output_biases, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob);
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, // torch::Tensor const& input_biases,
torch::Tensor const& inputs, // torch::Tensor const& output_biases,
torch::Tensor const& input_weights, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &input_weights,
bool is_training, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
int heads, torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
torch::Tensor const& inputs, torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights, AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_biases, torch::Tensor const& output_biases, AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
use_time_mask, output_weights, input_biases, output_biases,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
heads, : nullptr,
inputs, dropout_prob);
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights, AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, softmax_results, input_lin_results, inputs, input_weights,
output_grads, output_weights, dropout_mask, dropout_prob);
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
...@@ -136,4 +111,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -136,4 +111,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); m.def("backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_bias { namespace self_bias {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor>
bool use_time_mask, fwd_cuda(bool use_time_mask, bool is_training, int heads,
bool is_training, torch::Tensor const &inputs, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const& inputs, torch::Tensor const &output_biases, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights, const int embed_dim = inputs.size(2);
torch::Tensor const& input_biases, const int sequences = inputs.size(1);
torch::Tensor const& output_biases, const int q_seq_len = inputs.size(0);
const uint8_t* pad_mask, const int k_seq_len = q_seq_len;
float dropout_prob const int batches = sequences * q_seq_len;
) const int head_dim = embed_dim / heads;
{ const int output_lin_dim = 3 * embed_dim;
const int embed_dim = inputs.size(2); const int attn_batches = heads * sequences;
const int sequences = inputs.size(1); const int lead_dim = attn_batches * 3 * head_dim;
const int q_seq_len = inputs.size(0); const int batch_stride = 3 * head_dim;
const int k_seq_len = q_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int batches = sequences * q_seq_len; const float alpha = 1.0;
const int head_dim = embed_dim / heads; const float beta_zero = 0.0;
const int output_lin_dim = 3 * embed_dim; const float beta_one = 1.0;
const int attn_batches = heads * sequences; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim; // There is no reason to use more than one stream as every kernel is
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = inputs.options().requires_grad(false); // by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor input_lin_results =
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor dropout_results =
torch::Tensor outputs = torch::empty_like(inputs, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim); void *k_lin_results_ptr = static_cast<void *>(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
...@@ -110,8 +108,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -110,8 +108,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -136,44 +133,35 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -136,44 +133,35 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
if (is_training) { if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version // use at:: function so that C++ version generates the same random mask as
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob); // python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple); dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple);
} }
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -223,72 +211,63 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -223,72 +211,63 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_lin_results, softmax_results, dropout_results,
input_lin_results, dropout_mask, matmul2_results, outputs};
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs, const int embed_dim = inputs.size(2);
torch::Tensor const& input_weights, const int sequences = inputs.size(1);
torch::Tensor const& output_weights, const int q_seq_len = inputs.size(0);
torch::Tensor const& dropout_mask, const int k_seq_len = q_seq_len;
float dropout_prob const int batches = sequences * q_seq_len;
) const int head_dim = embed_dim / heads;
{ const int output_lin_dim = 3 * embed_dim;
const int embed_dim = inputs.size(2); const int attn_batches = heads * sequences;
const int sequences = inputs.size(1); const int lead_dim = attn_batches * 3 * head_dim;
const int q_seq_len = inputs.size(0); const int batch_stride = 3 * head_dim;
const int k_seq_len = q_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int batches = sequences * q_seq_len; const float alpha = 1.0;
const int head_dim = embed_dim / heads; const float beta = 0.0;
const int output_lin_dim = 3 * embed_dim; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs); torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim; auto k_lin_results_ptr =
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()); auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim; auto k_lin_grads_ptr =
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
...@@ -346,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -346,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda(
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -369,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -369,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -393,19 +370,16 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -393,19 +370,16 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len, stream);
k_seq_len,
attn_batches*q_seq_len, stream);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -427,8 +401,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -427,8 +401,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -503,15 +476,11 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -503,15 +476,11 @@ std::vector<torch::Tensor> bwd_cuda(
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_grads, input_weight_grads, output_weight_grads,
input_grads, input_bias_grads, output_bias_grads};
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads
};
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -5,120 +5,98 @@ namespace multihead_attn { ...@@ -5,120 +5,98 @@ namespace multihead_attn {
namespace self { namespace self {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights,
torch::Tensor const& inputs, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &input_weights,
bool is_training, torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
int heads, float dropout_prob) {
torch::Tensor const& inputs, torch::Tensor const& input_weights, AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(
use_time_mask, use_time_mask, is_training, heads, inputs, input_weights, output_weights,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
heads, dropout_prob);
inputs,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights, AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
return bwd_cuda( AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
heads, "Only HALF is supported");
output_grads, AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
matmul2_results, "Only HALF is supported");
dropout_results, AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
softmax_results, "Only HALF is supported");
input_lin_results, AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
inputs, "Only BYTE is supported");
input_weights,
output_weights, return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
dropout_mask, softmax_results, input_lin_results, inputs, input_weights,
dropout_prob output_weights, dropout_mask, dropout_prob);
);
} }
} // end namespace rocblas_gemm_ex } // end namespace rocblas_gemm_ex
...@@ -129,4 +107,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -129,4 +107,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self::rocblas_gemmex::fwd, "Self Multihead Attention Forward."); m.def("forward", &multihead_attn::self::rocblas_gemmex::fwd, "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::rocblas_gemmex::bwd, "Self Multihead Attention Backward."); m.def("backward", &multihead_attn::self::rocblas_gemmex::bwd, "Self Multihead Attention Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self { namespace self {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights,
torch::Tensor const& inputs, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights, const int embed_dim = inputs.size(2);
const uint8_t* pad_mask, const int sequences = inputs.size(1);
float dropout_prob const int q_seq_len = inputs.size(0);
) const int k_seq_len = q_seq_len;
{ const int batches = sequences * q_seq_len;
const int embed_dim = inputs.size(2); const int head_dim = embed_dim / heads;
const int sequences = inputs.size(1); const int output_lin_dim = 3 * embed_dim;
const int q_seq_len = inputs.size(0); const int attn_batches = heads * sequences;
const int k_seq_len = q_seq_len; const int lead_dim = attn_batches * 3 * head_dim;
const int batches = sequences * q_seq_len; const int batch_stride = 3 * head_dim;
const int head_dim = embed_dim / heads; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int output_lin_dim = 3 * embed_dim; const float alpha = 1.0;
const int attn_batches = heads * sequences; const float beta = 0.0;
const int lead_dim = attn_batches * 3 * head_dim; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; // There is no reason to use more than one stream as every kernel is
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = inputs.options().requires_grad(false); // by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor input_lin_results =
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor dropout_results =
torch::Tensor outputs = torch::empty_like(inputs, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim); void *k_lin_results_ptr = static_cast<void *>(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
...@@ -106,8 +106,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -106,8 +106,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -132,46 +131,35 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -132,46 +131,35 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>( apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()), static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
dropout_elems, (1.0f - dropout_prob));
(1.0f - dropout_prob));
} }
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -219,67 +207,58 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -219,67 +207,58 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_lin_results, softmax_results, dropout_results,
input_lin_results, dropout_mask, matmul2_results, outputs};
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs, const int embed_dim = inputs.size(2);
torch::Tensor const& input_weights, const int sequences = inputs.size(1);
torch::Tensor const& output_weights, const int q_seq_len = inputs.size(0);
torch::Tensor const& dropout_mask, const int k_seq_len = q_seq_len;
float dropout_prob const int batches = sequences * q_seq_len;
) const int head_dim = embed_dim / heads;
{ const int output_lin_dim = 3 * embed_dim;
const int embed_dim = inputs.size(2); const int attn_batches = heads * sequences;
const int sequences = inputs.size(1); const int lead_dim = attn_batches * 3 * head_dim;
const int q_seq_len = inputs.size(0); const int batch_stride = 3 * head_dim;
const int k_seq_len = q_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int batches = sequences * q_seq_len; const float alpha = 1.0;
const int head_dim = embed_dim / heads; const float beta = 0.0;
const int output_lin_dim = 3 * embed_dim; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs); torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim; auto k_lin_results_ptr =
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()); static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim; auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
...@@ -341,8 +320,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -341,8 +320,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -364,8 +342,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -364,8 +342,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -397,17 +374,14 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -397,17 +374,14 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad // Softmax Grad
bool softmax_success = false; bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>( softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -429,8 +403,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -429,8 +403,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -514,3 +487,4 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -514,3 +487,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -5,169 +5,145 @@ namespace multihead_attn { ...@@ -5,169 +5,145 @@ namespace multihead_attn {
namespace self_norm_add { namespace self_norm_add {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &lyr_nrm_gamma_weights,
int heads, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& inputs, torch::Tensor const &input_weights,
torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const &output_weights,
torch::Tensor const& lyr_nrm_beta_weights, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results,
torch::Tensor const& softmax_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& input_lin_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& inputs, torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
torch::Tensor const& lyr_nrm_gamma_weights, float dropout_prob);
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights,
bool is_training, torch::Tensor const &lyr_nrm_beta_weights,
int heads, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& inputs, torch::Tensor const &pad_mask, float dropout_prob) {
torch::Tensor const& lyr_nrm_gamma_weights, AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
torch::Tensor const& lyr_nrm_beta_weights, AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
torch::Tensor const& input_weights, AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
torch::Tensor const& pad_mask, AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(
use_time_mask, use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights,
is_training, lyr_nrm_beta_weights, input_weights, output_weights,
heads, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
inputs, dropout_prob);
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor>
std::vector<torch::Tensor> bwd( bwd(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& output_grads, torch::Tensor const &softmax_results,
torch::Tensor const& matmul2_results, torch::Tensor const &input_lin_results,
torch::Tensor const& dropout_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& softmax_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const& input_lin_results, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
torch::Tensor const& inputs, float dropout_prob) {
torch::Tensor const& lyr_nrm_gamma_weights, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& lyr_nrm_beta_weights, AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights, AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_add_mask, AT_ASSERTM(lyr_nrm_results.dim() == 3, "expected 3D tensor");
float dropout_prob AT_ASSERTM(lyr_nrm_mean.dim() == 1, "expected 1D tensor");
) AT_ASSERTM(lyr_nrm_invvar.dim() == 1, "expected 1D tensor");
{ AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_mean.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_invvar.dim() == 1, "expected 1D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported"); "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float,
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only FLOAT is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
return bwd_cuda(heads, AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
output_grads, "Only HALF is supported");
matmul2_results, AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
dropout_results, "Only HALF is supported");
softmax_results, AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
input_lin_results, "Only HALF is supported");
lyr_nrm_results, AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
lyr_nrm_mean, "Only HALF is supported");
lyr_nrm_invvar, AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
inputs, "Only HALF is supported");
lyr_nrm_gamma_weights, AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
lyr_nrm_beta_weights, "Only BYTE is supported");
input_weights, AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte,
output_weights, "Only BYTE is supported");
dropout_mask,
dropout_add_mask, return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
dropout_prob softmax_results, input_lin_results, lyr_nrm_results,
); lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights,
lyr_nrm_beta_weights, input_weights, output_weights,
dropout_mask, dropout_add_mask, dropout_prob);
} }
} // end namespace cublas_gemmex } // end namespace cublas_gemmex
} // end namespace self_norm_add } // end namespace self_norm_add
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_norm_add::rocblas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward."); m.def("forward", &multihead_attn::self_norm_add::rocblas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::self_norm_add::rocblas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward."); m.def("backward", &multihead_attn::self_norm_add::rocblas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_norm_add { namespace self_norm_add {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &lyr_nrm_gamma_weights,
int heads, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& inputs, torch::Tensor const &input_weights,
torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const &output_weights,
torch::Tensor const& lyr_nrm_beta_weights, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights, const int embed_dim = inputs.size(2);
const uint8_t* pad_mask, const int sequences = inputs.size(1);
float dropout_prob const int q_seq_len = inputs.size(0);
) const int k_seq_len = q_seq_len;
{ const int batches = sequences * q_seq_len;
const int embed_dim = inputs.size(2); const int total_tokens = batches * embed_dim;
const int sequences = inputs.size(1); const int head_dim = embed_dim / heads;
const int q_seq_len = inputs.size(0); const int output_lin_dim = 3 * embed_dim;
const int k_seq_len = q_seq_len; const int attn_batches = heads * sequences;
const int batches = sequences * q_seq_len; const int lead_dim = attn_batches * 3 * head_dim;
const int total_tokens = batches * embed_dim; const int batch_stride = 3 * head_dim;
const int head_dim = embed_dim / heads; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int output_lin_dim = 3 * embed_dim; const float alpha = 1.0;
const int attn_batches = heads * sequences; const float beta = 0.0;
const int lead_dim = attn_batches * 3 * head_dim; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; // There is no reason to use more than one stream as every kernel is
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = inputs.options().requires_grad(false); // by ATen library code)
auto lyr_nrm_options = act_options.dtype(torch::kFloat32); auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto lyr_nrm_options = act_options.dtype(torch::kFloat32);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor lyr_nrm_mean = torch::empty({batches}, lyr_nrm_options);
torch::Tensor lyr_nrm_invvar = torch::empty({batches}, lyr_nrm_options); torch::Tensor lyr_nrm_mean = torch::empty({batches}, lyr_nrm_options);
torch::Tensor lyr_nrm_results = torch::empty_like(inputs, act_options); torch::Tensor lyr_nrm_invvar = torch::empty({batches}, lyr_nrm_options);
torch::Tensor lyr_nrm_results = torch::empty_like(inputs, act_options);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor input_lin_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor softmax_results =
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor output_lin_results= torch::empty_like(inputs, act_options); torch::Tensor dropout_results =
torch::Tensor dropout_add_mask = torch::empty_like(inputs, mask_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options); torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor output_lin_results = torch::empty_like(inputs, act_options);
torch::Tensor dropout_add_mask = torch::empty_like(inputs, mask_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim); void *k_lin_results_ptr = static_cast<void *>(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
HostApplyLayerNorm<at::Half,float>( HostApplyLayerNorm<at::Half, float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()), static_cast<at::Half *>(lyr_nrm_results.data_ptr()),
static_cast<float*>(lyr_nrm_mean.data_ptr()), static_cast<float *>(lyr_nrm_mean.data_ptr()),
static_cast<float*>(lyr_nrm_invvar.data_ptr()), static_cast<float *>(lyr_nrm_invvar.data_ptr()),
static_cast<const at::Half*>(inputs.data_ptr()), static_cast<const at::Half *>(inputs.data_ptr()),
static_cast<int>(batches), // n1 static_cast<int>(batches), // n1
static_cast<int>(embed_dim), // n2 static_cast<int>(embed_dim), // n2
1.0e-5, 1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half*>(lyr_nrm_gamma_weights.data_ptr()), static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd // Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -129,8 +128,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -129,8 +128,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -155,46 +153,35 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -155,46 +153,35 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>( apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()), static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
dropout_elems, (1.0f - dropout_prob));
(1.0f - dropout_prob));
} }
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -245,99 +232,84 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -245,99 +232,84 @@ std::vector<torch::Tensor> fwd_cuda(
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
apex_dropout_add_cuda<at::Half,float,uint32_t>( apex_dropout_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()), static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs.data_ptr()), static_cast<at::Half const *>(inputs.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()), static_cast<at::Half *>(outputs.data_ptr()),
static_cast<uint8_t*>(dropout_add_mask.data_ptr()), static_cast<uint8_t *>(dropout_add_mask.data_ptr()), total_tokens,
total_tokens, (1.0f - dropout_prob));
(1.0f - dropout_prob));
} else { } else {
apex_add_cuda<at::Half,float,uint32_t>( apex_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()), static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs.data_ptr()), static_cast<at::Half const *>(inputs.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()), static_cast<at::Half *>(outputs.data_ptr()), total_tokens);
total_tokens);
} }
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, input_lin_results,
lyr_nrm_results, softmax_results, dropout_results, dropout_mask, matmul2_results,
lyr_nrm_mean, dropout_add_mask, outputs};
lyr_nrm_invvar,
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
dropout_add_mask,
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results,
torch::Tensor const& softmax_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& input_lin_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& inputs, torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
torch::Tensor const& lyr_nrm_gamma_weights, float dropout_prob) {
torch::Tensor const& lyr_nrm_beta_weights, const int embed_dim = inputs.size(2);
torch::Tensor const& input_weights, const int sequences = inputs.size(1);
torch::Tensor const& output_weights, const int q_seq_len = inputs.size(0);
torch::Tensor const& dropout_mask, const int k_seq_len = q_seq_len;
torch::Tensor const& dropout_add_mask, const int batches = sequences * q_seq_len;
float dropout_prob const int total_tokens = batches * embed_dim;
) const int head_dim = embed_dim / heads;
{ const int output_lin_dim = 3 * embed_dim;
const int embed_dim = inputs.size(2); const int attn_batches = heads * sequences;
const int sequences = inputs.size(1); const int lead_dim = attn_batches * 3 * head_dim;
const int q_seq_len = inputs.size(0); const int batch_stride = 3 * head_dim;
const int k_seq_len = q_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int batches = sequences * q_seq_len; const float alpha = 1.0;
const int total_tokens = batches * embed_dim; const float beta = 0.0;
const int head_dim = embed_dim / heads; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs); torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights); torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights);
torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights); torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights);
torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
torch::Tensor dropout_add_grads = torch::empty_like(output_grads); torch::Tensor dropout_add_grads = torch::empty_like(output_grads);
torch::Tensor output_lin_grads = torch::empty_like(matmul2_results); torch::Tensor output_lin_grads = torch::empty_like(matmul2_results);
torch::Tensor matmul2_grads = torch::empty_like(dropout_results); torch::Tensor matmul2_grads = torch::empty_like(dropout_results);
torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
torch::Tensor input_lin_grads = torch::empty_like(inputs); torch::Tensor input_lin_grads = torch::empty_like(inputs);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim; auto k_lin_results_ptr =
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()); static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim; auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
...@@ -346,14 +318,13 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -346,14 +318,13 @@ std::vector<torch::Tensor> bwd_cuda(
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()), static_cast<at::Half const *>(output_grads.data_ptr()),
static_cast<at::Half*>(dropout_add_grads.data_ptr()), static_cast<at::Half *>(dropout_add_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_add_mask.data_ptr()), static_cast<uint8_t const *>(dropout_add_mask.data_ptr()), total_tokens,
total_tokens, (1.0 / (1.0 - dropout_prob)));
(1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
...@@ -407,8 +378,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -407,8 +378,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -430,8 +400,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -430,8 +400,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -463,17 +432,14 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -463,17 +432,14 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad // Softmax Grad
bool softmax_success = false; bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>( softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -495,8 +461,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -495,8 +461,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -572,33 +537,26 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -572,33 +537,26 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>( HostLayerNormGradient<half, float>(
static_cast<const half*>(input_lin_grads.data_ptr()), static_cast<const half *>(input_lin_grads.data_ptr()),
static_cast<half const*>(output_grads.data_ptr()), static_cast<half const *>(output_grads.data_ptr()),
static_cast<const float*>(lyr_nrm_mean.data_ptr()), static_cast<const float *>(lyr_nrm_mean.data_ptr()),
static_cast<const float*>(lyr_nrm_invvar.data_ptr()), static_cast<const float *>(lyr_nrm_invvar.data_ptr()), inputs,
inputs, static_cast<int>(batches), // n1
static_cast<int>(batches), // n1 static_cast<int>(embed_dim), // n2
static_cast<int>(embed_dim), // n2 static_cast<const half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const half*>(lyr_nrm_gamma_weights.data_ptr()), static_cast<const half *>(lyr_nrm_beta_weights.data_ptr()), 1.0e-5,
static_cast<const half*>(lyr_nrm_beta_weights.data_ptr()), static_cast<half *>(input_grads.data_ptr()),
1.0e-5, static_cast<half *>(lyr_nrm_gamma_grads.data_ptr()),
static_cast<half*>(input_grads.data_ptr()), static_cast<half *>(lyr_nrm_beta_grads.data_ptr()));
static_cast<half*>(lyr_nrm_gamma_grads.data_ptr()),
static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_grads, lyr_nrm_gamma_grads, lyr_nrm_beta_grads,
input_grads, input_weight_grads, output_weight_grads};
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_grads,
output_weight_grads
};
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace self_norm_add } // end namespace self_norm_add
} // end namespace multihead_attn } // end namespace multihead_attn
This source diff could not be displayed because it is too large. You can view the blob instead.
#include <vector>
#include <iostream> #include <iostream>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
//#include <ATen/ATen.h> //#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
//#include "cutlass/cutlass.h"
//#include "cutlass/gemm/gemm.h"
//#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs // symbol to be automatically resolved by PyTorch libs
extern THCState *state;
rocblas_datatype a_type = rocblas_datatype_f16_r; rocblas_datatype a_type = rocblas_datatype_f16_r;
rocblas_datatype b_type = rocblas_datatype_f16_r; rocblas_datatype b_type = rocblas_datatype_f16_r;
...@@ -25,16 +28,19 @@ rocblas_int flags = 0; ...@@ -25,16 +28,19 @@ rocblas_int flags = 0;
cublasOperation_t convertTransToCublasOperation(char trans) { cublasOperation_t convertTransToCublasOperation(char trans) {
if (trans == 't') return CUBLAS_OP_T; if (trans == 't')
else if (trans == 'n') return CUBLAS_OP_N; return CUBLAS_OP_T;
else if (trans == 'c') return CUBLAS_OP_C; else if (trans == 'n')
return CUBLAS_OP_N;
else if (trans == 'c')
return CUBLAS_OP_C;
else { else {
AT_ERROR("trans must be one of: t, n, c"); AT_ERROR("trans must be one of: t, n, c");
return CUBLAS_OP_T; return CUBLAS_OP_T;
} }
} }
void RocblasStridedBatchedGemm(THCState *state, char transa, char transb, long m, long n, long k, void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo) { float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo) {
cublasOperation_t opa = convertTransToCublasOperation(transa); cublasOperation_t opa = convertTransToCublasOperation(transa);
...@@ -55,151 +61,73 @@ void RocblasStridedBatchedGemm(THCState *state, char transa, char transb, long m ...@@ -55,151 +61,73 @@ void RocblasStridedBatchedGemm(THCState *state, char transa, char transb, long m
(int)batchCount, compute_type, algo, solution_index, flags)); (int)batchCount, compute_type, algo, solution_index, flags));
} }
void gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k, void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount) { float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount) {
auto stream = c10::cuda::getCurrentCUDAStream(); auto stream = c10::cuda::getCurrentCUDAStream();
if ( (transa == 't') && (transb == 'n') ) { if ( (transa == 't') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
else { RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); } else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
} else if ( (transa == 'n') && (transb == 'n') ) { } else if ( (transa == 'n') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
else { RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); } else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
} else if ( (transa == 'n') && (transb == 't') ) { } else if ( (transa == 'n') && (transb == 't') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
else { RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); } else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
} else { } else {
AT_ASSERTM(false, "TransA and TransB are invalid"); AT_ASSERTM(false, "TransA and TransB are invalid");
} }
} }
void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t *lda, int64_t *ldb, int64_t *ldc) void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k,
{ int64_t *lda, int64_t *ldb, int64_t *ldc) {
int transa_ = ((transa == 't') || (transa == 'T')); int transa_ = ((transa == 't') || (transa == 'T'));
int transb_ = ((transb == 't') || (transb == 'T')); int transb_ = ((transb == 't') || (transb == 'T'));
// Note: leading dimensions generally are checked that they are > 0 and at least as big the result // Note: leading dimensions generally are checked that they are > 0 and at
// requires (even if the value won't be used). // least as big the result requires (even if the value won't be used).
if(n <= 1) if (n <= 1)
*ldc = std::max<int64_t>(m, 1); *ldc = std::max<int64_t>(m, 1);
if(transa_) if (transa_) {
{ if (m <= 1)
if(m <= 1)
*lda = std::max<int64_t>(k, 1); *lda = std::max<int64_t>(k, 1);
} } else {
else if (k <= 1)
{
if(k <= 1)
*lda = std::max<int64_t>(m, 1); *lda = std::max<int64_t>(m, 1);
} }
if(transb_) if (transb_) {
{ if (k <= 1)
if(k <= 1)
*ldb = std::max<int64_t>(n, 1); *ldb = std::max<int64_t>(n, 1);
} } else {
else if (n <= 1)
{
if(n <= 1)
*ldb = std::max<int64_t>(k, 1); *ldb = std::max<int64_t>(k, 1);
} }
} }
void HgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k, void HgemmStridedBatched(char transa, char transb, long m,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, long n, long k, float alpha, const half *a, long lda,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount) long strideA, const half *b, long ldb, long strideB,
{ float beta, half *c, long ldc, long strideC,
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) half *d, long ldd, long strideD, long batchCount) {
if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) ||
(ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX))
{ {
AT_ERROR("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" AT_ERROR("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, "
"with the bound [val] <= %d", INT_MAX); "batchCount"
"with the bound [val] <= %d",
INT_MAX);
} }
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
//gemm_switch(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); // gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
gemm_switch_fp32accum(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount); // b, ldb, strideB, beta, c, ldc, strideC, batchCount);
gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount);
} }
/******
at::Tensor strided_batched_gemm_cuda(
float beta,
at::Tensor in_result,
float alpha,
at::Tensor batch1,
at::Tensor batch2) {
bool transpose_result;
char transpose_batch1, transpose_batch2;
int64_t lda, ldb, ldc;
at::Tensor result, input1, input2;
if (in_result.stride(1) == 1)
{
transpose_result = false;
result = in_result;
ldc = result.stride(2);
}
else if (in_result.stride(2) == 1)
{
transpose_result = true;
at::Tensor swap = batch2;
batch2 = batch1;
batch1 = swap;
result = in_result;
ldc = result.stride(1);
} else {
AT_ASSERTM(false, "result should be contiguous");
}
if (batch1.stride(transpose_result ? 2 : 1) == 1 &&
batch1.stride(transpose_result ? 1 : 2) != 0) {
transpose_batch1 = 'n';
input1 = batch1;
lda = input1.stride(transpose_result ? 1 : 2);
} else if (batch1.stride(transpose_result ? 1 : 2) == 1 &&
batch1.stride(transpose_result ? 2 : 1) != 0) {
transpose_batch1 = 't';
input1 = batch1;
lda = input1.stride(transpose_result ? 2 : 1);
} else {
AT_ASSERTM(false, "input1 should be contiguous");
}
if (batch2.stride(transpose_result ? 2 : 1) == 1 &&
batch2.stride(transpose_result ? 1 : 2) != 0) {
transpose_batch2 = 'n';
input2 = batch2;
ldb = input2.stride(transpose_result ? 1 : 2);
} else if (batch2.stride(transpose_result ? 1 : 2) == 1 &&
batch2.stride(transpose_result ? 2 : 1) != 0) {
transpose_batch2 = 't';
input2 = batch2;
ldb = input2.stride(transpose_result ? 2 : 1);
} else {
AT_ASSERTM(false, "input2 should be contiguous");
}
int64_t num_batches = result.size(0);
HgemmStridedBatched(
state,
transpose_batch1,
transpose_batch2,
result.size(transpose_result ? 2 : 1),
result.size(transpose_result ? 1 : 2),
input1.size(transpose_result ? 1 : 2),
alpha,
static_cast<const half*>(input1.data_ptr()), lda, input1.stride(0),
static_cast<const half*>(input2.data_ptr()), ldb, input2.stride(0),
beta,
static_cast<half*>(result.data_ptr()), ldc, result.stride(0),
num_batches);
return in_result;
}
***/
import torch import torch
from torch.nn import init from torch.nn import init
from apex._autocast_utils import _cast_if_autocast_enabled
import fast_layer_norm import fast_layer_norm
class FastLayerNormFN(torch.autograd.Function): class FastLayerNormFN(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, gamma, beta, epsilon): def forward(ctx, x, gamma, beta, epsilon):
...@@ -14,23 +16,30 @@ class FastLayerNormFN(torch.autograd.Function): ...@@ -14,23 +16,30 @@ class FastLayerNormFN(torch.autograd.Function):
ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon) ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon)
ctx.save_for_backward(x, gamma, mu, rsigma) ctx.save_for_backward(x, gamma, mu, rsigma)
return ymat.view(x.shape) return ymat.view(x.shape)
@staticmethod @staticmethod
def backward(ctx, dy): def backward(ctx, dy):
#assert dy.is_contiguous() # assert dy.is_contiguous()
dy = dy.contiguous() # this happens! dy = dy.contiguous() # this happens!
x, gamma, mu, rsigma = ctx.saved_tensors x, gamma, mu, rsigma = ctx.saved_tensors
hidden_size = gamma.numel() hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size)) xmat = x.view((-1, hidden_size))
dymat = dy.view(xmat.shape) dymat = dy.view(xmat.shape)
dxmat, dgamma, dbeta = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma) dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma)
dx = dxmat.view(x.shape) dx = dxmat.view(x.shape)
return dx, dgamma, dbeta, None return dx, dgamma, dbeta, None
def _fast_layer_norm(x, weight, bias, epsilon):
args = _cast_if_autocast_enabled(x, weight, bias, epsilon)
with torch.cuda.amp.autocast(enabled=False):
return FastLayerNormFN.apply(*args)
class FastLayerNorm(torch.nn.Module): class FastLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5): def __init__(self, hidden_size, eps=1e-5):
super(FastLayerNorm, self).__init__() super().__init__()
self.epsilon = eps self.epsilon = eps
self.weight = torch.nn.Parameter(torch.Tensor(hidden_size)) self.weight = torch.nn.Parameter(torch.Tensor(hidden_size))
self.bias = torch.nn.Parameter(torch.Tensor(hidden_size)) self.bias = torch.nn.Parameter(torch.Tensor(hidden_size))
...@@ -41,4 +50,4 @@ class FastLayerNorm(torch.nn.Module): ...@@ -41,4 +50,4 @@ class FastLayerNorm(torch.nn.Module):
init.zeros_(self.bias) init.zeros_(self.bias)
def forward(self, x): def forward(self, x):
return FastLayerNormFN.apply(x, self.weight, self.bias, self.epsilon) return _fast_layer_norm(x, self.weight, self.bias, self.epsilon)
import torch
import unittest import unittest
import numpy as np import sys
import os
import torch.nn.functional as F import numpy as np
import torch
from apex.contrib.layer_norm import FastLayerNorm
import fast_layer_norm as fln import fast_layer_norm as fln
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
class GPUTimer: class GPUTimer:
...@@ -14,146 +14,262 @@ class GPUTimer: ...@@ -14,146 +14,262 @@ class GPUTimer:
self.start_ = torch.cuda.Event(enable_timing=True) self.start_ = torch.cuda.Event(enable_timing=True)
self.stop_ = torch.cuda.Event(enable_timing=True) self.stop_ = torch.cuda.Event(enable_timing=True)
self.stream_ = stream self.stream_ = stream
def start(self): def start(self):
self.stream_.record_event(self.start_) self.stream_.record_event(self.start_)
def stop(self): def stop(self):
self.stream_.record_event(self.stop_) self.stream_.record_event(self.stop_)
def sync(self): def sync(self):
self.stream_.synchronize() self.stream_.synchronize()
def millis(self): def millis(self):
return self.start_.elapsed_time(self.stop_) return self.start_.elapsed_time(self.stop_)
def size_in_bytes(t): def size_in_bytes(t):
return torch.numel(t) * t.element_size() return torch.numel(t) * t.element_size()
def abs_err(x, y):
xf = x.float()
yf = y.float()
return ((xf-yf).abs().sum() / yf.abs().sum()).item()
def metrics(y_ref, y, epsilon=1e-6):
y_ref = y_ref.float()
y = y.float()
relerr, mse = (
(y_ref - y).abs().sum() / (y_ref.abs().sum() + epsilon),
(y_ref - y).square().mean(),
)
return relerr.item(), mse.item()
device = torch.device("cuda")
fp32 = torch.float32
fp16 = torch.float16
bf16 = torch.bfloat16
def backward_(dz, x, mu, rs, gamma):
wtype = gamma.dtype
itype = x.dtype
otype = dz.dtype
ctype = mu.dtype
mu = mu.unsqueeze(1)
rs = rs.unsqueeze(1)
hidden_size = gamma.numel()
y = rs * (x.to(ctype) - mu)
dbeta = dz.view(-1, hidden_size).sum(0, dtype=ctype)
dgamma = (dz * y).view(-1, hidden_size).sum(0, dtype=ctype)
dy = dz.view(-1, hidden_size).to(ctype) * gamma.unsqueeze(0).to(ctype)
mdy = dy.mean(1, keepdim=True, dtype=ctype)
mdyy = (dy * y).mean(1, keepdim=True, dtype=ctype)
dx = rs * (dy - mdyy * y - mdy)
return dx.to(itype), dgamma.to(wtype), dbeta.to(wtype)
def benchmark_(S, B, hidden_size, itype, wtype, runs=100):
epsilon = 1e-5
x = torch.randn((S * B, hidden_size), dtype=itype, device=device)
beta = torch.randn(hidden_size, dtype=wtype, device=device)
gamma = torch.randn(hidden_size, dtype=wtype, device=device)
dz = torch.randn(x.shape, dtype=wtype, device=device)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
timer = GPUTimer(stream)
# warmup
for r in range(runs):
z, mu, rsigma = fln.ln_fwd(x, gamma, beta, epsilon)
timer.start()
for r in range(runs):
z, mu, rsigma = fln.ln_fwd(x, gamma, beta, epsilon)
timer.stop()
timer.sync()
total_bytes_fwd = sum([size_in_bytes(t) for t in [x, z, gamma, beta, mu, rsigma]])
ms_fwd = timer.millis() / runs
print(
"[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec".format(
ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd
)
)
timer.start()
for r in range(runs):
dx, dgamma, dbeta, dbp, dgp = fln.ln_bwd(dz, x, mu, rsigma, gamma)
timer.stop()
timer.sync()
total_bytes_bwd = sum(
[
size_in_bytes(t)
for t in [dz, x, mu, rsigma, gamma, dx, dgamma, dbeta, dbp, dbp, dgp, dgp]
]
)
ms_bwd = timer.millis() / runs
print(
"[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec".format(
ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd
)
)
def test_(S, B, hidden_size, itype, wtype, ctype=fp32):
seed = 1243
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
otype = wtype
print("========================================================")
print(f"S={S} B={B} Hidden={hidden_size} {itype} {wtype}")
print("--------------------------------------------------------")
x = torch.randn(S * B, hidden_size, dtype=itype, device=device)
gamma = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2
beta = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2
epsilon = 1e-5
x.requires_grad = True
gamma.requires_grad = True
beta.requires_grad = True
mu_ref = x.mean(1, dtype=ctype, keepdim=True)
v = torch.square(x - mu_ref).mean(1, dtype=ctype, keepdim=True)
rs_ref = torch.rsqrt(v + epsilon)
y_ref = rs_ref * (x.to(ctype) - mu_ref)
z_ref = (gamma.unsqueeze(0) * (y_ref).to(otype) + beta.unsqueeze(0)).to(otype)
mu_ref = mu_ref.flatten()
rs_ref = rs_ref.flatten()
dz = torch.randn_like(z_ref)
# z_ref.backward(dz)
# dx_ref = x.grad
# dgamma_ref = gamma.grad
# dbeta_ref = beta.grad
dx_ref, dg_ref, db_ref = backward_(dz, x, mu_ref, rs_ref, gamma)
z, mu, rs = fln.ln_fwd(x, gamma, beta, epsilon)
dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, x, mu, rs, gamma)
re_z, mse_z = metrics(z_ref, z)
re_mu, mse_mu = metrics(mu_ref, mu)
re_rs, mse_rs = metrics(rs_ref, rs)
re_dx, mse_dx = metrics(dx_ref, dx)
re_dg, mse_dg = metrics(dg_ref, dg)
re_db, mse_db = metrics(db_ref, db)
print(f" z: relerr={re_z :.4e} mse={mse_z :.4e}")
print(f"mu: relerr={re_mu:.4e} mse={mse_mu:.4e}")
print(f"rs: relerr={re_mu:.4e} mse={mse_mu:.4e}")
print(f"dx: relerr={re_dx:.4e} mse={mse_dx:.4e}")
print(f"dg: relerr={re_dg:.4e} mse={mse_dg:.4e}")
print(f"db: relerr={re_db:.4e} mse={mse_db:.4e}")
def check_err(x, relerr):
tol = 1e-3 if x.dtype == torch.float16 else 5e-6
return relerr < tol
return [
check_err(x, re)
for x, re in zip([z, mu, rs, dx, dg, db], [re_z, re_mu, re_rs, re_dx, re_dg, re_db])
]
class TestFastLayerNorm(unittest.TestCase): class TestFastLayerNorm(unittest.TestCase):
def assertAll(self, l):
def setUp(self, seed=1234): if not all(l):
seed = 1234 print(l)
torch.manual_seed(seed) for x in l:
torch.cuda.manual_seed_all(seed) self.assertTrue(x)
def test_ln_fp32(self): def test_all_configs(self):
self.run_test_layer_norm(torch.float32, atol=1e-5)
def test_ln_fp16(self): hidden_sizes = [
self.run_test_layer_norm(torch.float16, atol=1e-2, rtol=1e-3) 768,
1024,
def run_test_layer_norm(self, dtype, atol, rtol=1e-5): 1536,
device = torch.device('cuda') 2048,
s = 512 2304,
b = 32 3072,
hidden_size = 1024 3840,
epsilon = 1e-5 4096,
5120,
x = torch.randn((s,b,hidden_size), dtype=dtype, device=device) 6144,
beta = torch.randn(hidden_size, dtype=dtype, device=device) 8192,
gamma = torch.randn(hidden_size, dtype=dtype, device=device) 10240,
x.requires_grad = True 12288,
beta.requires_grad = True 12800,
gamma.requires_grad = True 15360,
16384,
x2 = x.clone().detach() 18432,
beta2 = beta.clone().detach() 20480,
gamma2 = gamma.clone().detach() 24576,
x2.requires_grad = True 25600,
beta2.requires_grad = True 30720,
gamma2.requires_grad = True 32768,
40960,
dummy_label = torch.randn_like(x) 49152,
65536,
y = F.layer_norm(x, [hidden_size], gamma, beta, epsilon) ]
diff = y-dummy_label for h in hidden_sizes:
l = (diff * diff).sum() / b with self.subTest(f"hidden_size={h}"):
l.backward() self.assertAll(test_(256, 2, h, fp32, fp32))
self.assertAll(test_(256, 2, h, fp16, fp16))
fln = FastLayerNorm(hidden_size).cuda() self.assertAll(test_(256, 2, h, fp32, fp16))
fln.load_state_dict({'bias': beta2, 'weight':gamma2}) self.assertAll(test_(256, 2, h, bf16, bf16))
if dtype == torch.float16: self.assertAll(test_(256, 2, h, fp32, bf16))
fln = fln.half()
def test_run_benchmark(self):
y2 = fln(x2) for (S, B, hidden_size, runs) in (
diff2 = (y2 - dummy_label) (512, 32, 768, 1000),
l2 = (diff2 * diff2).sum() / b (512, 32, 1024, 1000),
(512, 8, 4096, 1000),
l2.backward() (512, 8, 5120, 1000),
(512, 8, 6144, 1000),
self.assertTrue(torch.allclose(y2, y, atol=atol, rtol=rtol)) (256, 2, 20480, 500),
self.assertTrue(torch.allclose(x2.grad, x.grad, atol=atol,rtol=rtol)) (256, 2, 25600, 500),
self.assertTrue(torch.allclose(fln.bias.grad, beta.grad, atol=atol, rtol=rtol)) (256, 2, 40960, 250),
self.assertTrue(torch.allclose(fln.weight.grad, gamma.grad, atol=atol, rtol=rtol)) (256, 2, 65536, 250),
):
with self.subTest(f"(S, B, hidden_size)=({S}, {B}, {hidden_size})"):
benchmark_(S, B, hidden_size, fp16, fp16, runs)
def test_performance(self):
print() def test_compat_with_autocast(self):
runs = 1000 autocast_dtypes = (
device = torch.device('cuda') (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
dtype =torch.float16 )
s = 512 input_shape = (512, 32, 768)
b = 32 layer_norm = FastLayerNorm(input_shape[-1]).cuda()
hidden_size = 1024 input = torch.randn(input_shape).cuda()
epsilon = 1e-5
for dtype in autocast_dtypes:
x = torch.randn((s*b,hidden_size), dtype=dtype, device=device) layer_norm.zero_grad(set_to_none=True)
beta = torch.randn(hidden_size, dtype=dtype, device=device) with self.subTest(f"autocast_dtype={dtype}"):
gamma = torch.randn(hidden_size, dtype=dtype, device=device) with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
dy = torch.randn_like(x) out = layer_norm(input)
self.assertEqual(dtype, out.dtype)
grad = torch.randn_like(out)
stream = torch.cuda.Stream() out.backward(grad)
with torch.cuda.stream(stream): self.assertEqual(torch.float32, layer_norm.weight.grad.dtype)
timer = GPUTimer(stream)
if __name__ == "__main__":
#warmup
for r in range(runs):
y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5)
timer.start()
for r in range(runs):
y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5)
timer.stop()
timer.sync()
total_bytes_fwd = (size_in_bytes(x)
+ size_in_bytes(y)
+ size_in_bytes(gamma)
+ size_in_bytes(beta)
+ size_in_bytes(mu)
+ size_in_bytes(rsigma)
)
ms_fwd = timer.millis() / runs
print('[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format(ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd ))
timer.start()
for r in range(runs):
dx, dgamma, dbeta = fln.ln_bwd(dy, x, mu, rsigma, gamma)
timer.stop()
timer.sync()
total_bytes_bwd = (size_in_bytes(x)
+ size_in_bytes(dx)
+ size_in_bytes(dy)
+ size_in_bytes(gamma)
+ size_in_bytes(dgamma)
+ size_in_bytes(dbeta)
+ size_in_bytes(mu)
+ size_in_bytes(rsigma)
)
ms_bwd = timer.millis() / runs
print('[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format(ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd ))
if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -2,4 +2,5 @@ from .fused_sgd import FusedSGD ...@@ -2,4 +2,5 @@ from .fused_sgd import FusedSGD
from .fused_adam import FusedAdam from .fused_adam import FusedAdam
from .fused_novograd import FusedNovoGrad from .fused_novograd import FusedNovoGrad
from .fused_lamb import FusedLAMB from .fused_lamb import FusedLAMB
from .fused_adagrad import FusedAdagrad from .fused_adagrad import FusedAdagrad
\ No newline at end of file from .fused_mixed_precision_lamb import FusedMixedPrecisionLamb
import torch
from copy import deepcopy
from itertools import chain
from collections import defaultdict, abc as container_abcs
from apex.multi_tensor_apply import multi_tensor_applier
class FusedMixedPrecisionLamb(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, step=0, bias_correction=True,
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
amsgrad=False, adam_w_mode=True,
grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False,
reduced_precision_dtype=None):
if amsgrad:
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
# The learning rate (lr) and optimizer step (step) should be located on device
# in order to faciliated device sync free execution
defaults = dict(lr=torch.tensor(lr, dtype=torch.float32),
step=torch.tensor([step], dtype=torch.int),
bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
tensor_state = ['lr', 'step']
super(FusedMixedPrecisionLamb, self).__init__(params, defaults)
device = self.param_groups[0]['params'][0].device
for idx,group in enumerate(self.param_groups):
for item in tensor_state:
self.param_groups[idx][item] = group[item].to(device=device)
if multi_tensor_applier.available:
import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm_mp
# Skip buffer
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=device)
self.multi_tensor_lamb = amp_C.multi_tensor_lamb_mp
else:
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
# Mixed Precision support
self.reduced_precision_dtype = reduced_precision_dtype
self.param_groups_full_precision = []
self._step_supports_amp_scaling = True
self.adam_w_mode = 1 if adam_w_mode else 0
self.use_nvlamb = use_nvlamb
# This method is overridden from the parent class because there is not a way to override
# the nested function cast() that copies a saved piece of state to the device without
# redundantly doing the copy.
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']
if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of "
"parameter groups")
param_lens = (len(g['params']) for g in groups)
saved_lens = (len(g['params']) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")
# Update the state
id_map = {old_id: p for old_id, p in
zip(chain.from_iterable((g['params'] for g in saved_groups)),
chain.from_iterable((g['params'] for g in groups)))}
def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
# The original version casted the saved value to the params dtype
# This doesn't work for mixed precision Lamb where the momentum and
# velocity are expected to be in full precision while the params are
# in reduced precision
value = value.to(value.device)
return value
elif isinstance(value, dict):
return {k: cast(param, v) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict['state'].items():
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
else:
state[k] = v
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})
def _setup_full_precision_params(self):
for i, pg in enumerate(self.param_groups):
param_list = pg['params']
self.param_groups_full_precision.append({
'params': [
p.clone().detach().to(dtype=torch.float32)
if (self.reduced_precision_dtype is not None) and (p.dtype == self.reduced_precision_dtype)
else None
for p in param_list
],
})
# add_param_groups() is overridden because default items can be tensors. The
# parent version does not clone the default item, so two param groups can
# accidentally point to the same default item value where they can differ
# given they are in separate groups.
def add_param_group(self, param_group):
super().add_param_group(param_group)
for name, default in self.defaults.items():
if isinstance(default, torch.Tensor):
self.param_groups[len(self.param_groups) - 1][name] = default.clone()
@torch.no_grad()
def step(self, closure=None, grad_scaler=None):
loss = None
if closure is not None:
loss = closure()
# The full precision params are set up in the first step of the optimizer
# instead of in the constructor because the full precision params will get out
# out of sync with the model params if DDP syncs the model params across devices
# after the optimizer is constructed.
if len(self.param_groups_full_precision) == 0 :
self._setup_full_precision_params()
# create separate grad lists for params
grad_list = []
for gid,group in enumerate(self.param_groups):
for pid,p in enumerate(group['params']):
assert group['params'][0].dtype == p.dtype, \
"Error: Parameters are not of the identical type: {} != {}".format(
group['params'][0].dtype, p.dtype)
if p.grad is None:
continue
grad_list.append(p.grad)
# Overflow check of gradients
device = self.param_groups[0]["params"][0].device
found_inf = (
grad_scaler._check_inf_per_device(self)[device]
if grad_scaler is not None else torch.zeros((1,), device=device)
)
self._dummy_overflow_buf.copy_(found_inf)
# Get unscale scale factor
scale, inv_scale = None, None
if grad_scaler:
scale = grad_scaler._get_scale_async()
inv_scale = scale.double().reciprocal().float()
else:
scale = torch.ones((1,), device=device)
inv_scale = torch.ones((1,), device=device)
# grad_norm is of scaled gradients.
# So, multiply `max_grad_norm` by scale.
max_grad_norm = self.defaults['max_grad_norm'] * scale
grad_norm = multi_tensor_applier(
self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[grad_list],
False,
)[0]
# Run LAMB optimization math
for gid, (group, group_full) in enumerate(zip(self.param_groups, self.param_groups_full_precision)):
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
grad_averaging = 1 if group['grad_averaging'] else 0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
group['step'] += (self._dummy_overflow_buf != 1).to(torch.int)
state_lists = [ [], # (0) grads
[], # (1) params
[], # (2) momentum state
[], # (3) velocity state
]
if self.reduced_precision_dtype is not None:
state_lists.append([]) # (4) params reduced_dtype
for p, p_full in zip(group['params'], group_full['params']):
if p.grad is None:
continue
assert not p.grad.is_sparse
state = self.state[p]
# State initialization
if len(state) == 0:
dtype = p.dtype
if self.reduced_precision_dtype is not None and p.dtype == self.reduced_precision_dtype :
dtype = torch.float32
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data, dtype=dtype)
# Exponential moving average of gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=dtype)
if self.reduced_precision_dtype is not None :
state_lists[0].append(p.grad.data)
state_lists[1].append(p_full.data)
state_lists[2].append(state['exp_avg'])
state_lists[3].append(state['exp_avg_sq'])
state_lists[4].append(p.data)
else :
state_lists[0].append(p.grad.data)
state_lists[1].append(p.data)
state_lists[2].append(state['exp_avg'])
state_lists[3].append(state['exp_avg_sq'])
multi_tensor_applier(
self.multi_tensor_lamb,
self._dummy_overflow_buf,
state_lists,
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
grad_norm,
max_grad_norm,
self.use_nvlamb,
found_inf,
inv_scale)
return loss
...@@ -2,4 +2,80 @@ ...@@ -2,4 +2,80 @@
`apex.transformer` is a module which enables efficient large Transformer models at scale. `apex.transformer` is a module which enables efficient large Transformer models at scale.
`apex.transformer.tensor_parallel` is based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s `megatron.mpu` module. `apex.transformer.tensor_parallel` and `apex.transformer.pipeline_parallel` are both based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s module.
The former is based on `megatron.mpu` and the latter is on `megatron.schedules` and `megatron.p2p_communication`.
## Tensor Model Parallel (TP)
APEX's tensor model parallel utilities provides some `torch.nn.Module`'s, custom fused kernels, and PRNG state handling.
See Appendix B.2 of [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) for the details of
PRNG state handling.
## Pipeline Model Parallel (PP)
APEX's pipeline model parallel functions require models to have `.set_input_tensor` because
the input tensor for `.forward` method can be `None`.
The following is a really casual sketch of training script with apex pp.
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import get_forward_backward_func
class Model(nn.Module):
...
def __init__(self, *args, **kwargs):
super().__init__()
pre_process = kwargs.pop("pre_process")
post_process = kwargs.pop("post_process")
def set_input_tensor(self, tensor):
self.input_tensor = tensor
def forward(self, x, ...):
if parallel_state.is_pipeline_first_stage():
input = x
else:
input = self.input_tensor
...
def model_provider_func(*args, **kwargs):
return Model(*args, **kwargs)
def loss_func(pred, label):
loss = ...
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'nice_loss': averaged_loss}
def forward_step_func(batch, model):
input, label = process_batch(batch)
out = model(input)
return out, partial(loss_func, label)
forward_backward_func = get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size,
pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size,
)
# The following line basically is equivalent to `build_model(Model, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs)`
model = build_model(model_provider_func, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs)
optimizer = ...
data_loader = ...
for epoch in range(num_epochs):
for batch in data_loader:
forward_backward_func(forward_step_func, batch, model, forward_only=False, tensor_shape)
optimizer.step()
```
from . import tensor_parallel from apex.transformer import amp
from . import functional from apex.transformer import functional
from .enums import LayerType from apex.transformer import parallel_state
from .enums import AttnType from apex.transformer import pipeline_parallel
from .enums import AttnMaskType from apex.transformer import tensor_parallel
from .parallel_state import ( from apex.transformer import utils
is_unitialized, from apex.transformer.enums import LayerType
destroy_model_parallel, from apex.transformer.enums import AttnType
get_data_parallel_group, from apex.transformer.enums import AttnMaskType
get_data_parallel_rank,
get_data_parallel_world_size,
get_embedding_group, __all__ = [
get_model_parallel_group, "amp",
get_tensor_model_parallel_group, "functional",
get_pipeline_model_parallel_group, "parallel_state",
get_tensor_model_parallel_rank, "pipeline_parallel",
set_tensor_model_parallel_rank, "tensor_parallel",
get_pipeline_model_parallel_rank, "utils",
set_pipeline_model_parallel_rank, # enums.py
is_pipeline_first_stage, "LayerType",
is_pipeline_last_stage, "AttnType",
get_tensor_model_parallel_src_rank, "AttnMaskType",
get_pipeline_model_parallel_first_rank, ]
get_pipeline_model_parallel_last_rank,
get_pipeline_model_parallel_next_rank,
get_pipeline_model_parallel_prev_rank,
get_tensor_model_parallel_world_size,
set_tensor_model_parallel_world_size,
get_pipeline_model_parallel_world_size,
set_pipeline_model_parallel_world_size,
get_virtual_pipeline_model_parallel_rank,
set_virtual_pipeline_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from apex.transformer._data._batchsampler import MegatronPretrainingRandomSampler
from apex.transformer._data._batchsampler import MegatronPretrainingSampler
__all__ = [
"MegatronPretrainingRandomSampler",
"MegatronPretrainingSampler",
]
"""BatchSampler implementations for POC of dynamic batch size or rampup_batch_size support.
Implementations are based on https://github.com/NVIDIA/Megatron-LM/blob/bcd605f8570ebeeb0436c115ebbfafc3c5a40ae5/megatron/data/data_samplers.py.
""" # NOQA
import abc
import torch
__all__ = [
"MegatronPretrainingSampler",
"MegatronPretrainingRandomSampler",
]
class _Base:
"""Base class for Megatron style BatchSampler."""
@abc.abstractmethod
def __len__(self) -> int:
...
@abc.abstractmethod
def __iter__(self):
...
@property
@abc.abstractmethod
def local_minibatch_size(self) -> int:
...
@local_minibatch_size.setter
@abc.abstractclassmethod
def local_minibatch_size(self) -> None:
...
class MegatronPretrainingSampler(_Base):
def __init__(
self,
total_samples: int,
consumed_samples: int,
local_minibatch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool = True,
):
# Sanity checks.
if total_samples <= 0:
raise RuntimeError('no sample to consume: {}'.format(self.total_samples))
if consumed_samples >= total_samples:
raise RuntimeError('no samples left to consume: {}, {}'.format(self.consumed_samples, self.total_samples))
if local_minibatch_size <= 0:
raise RuntimeError(f"local minibatch size must be greater than 0: {local_minibatch_size}")
if data_parallel_size <= 0:
raise RuntimeError(f"data parallel size must be greater than 0: {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise RuntimeError('data_parallel_rank should be smaller than data size: {}, {}'.format(self.data_parallel_rank, data_parallel_size))
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self._local_minibatch_size = local_minibatch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * data_parallel_size
self.drop_last = drop_last
def __len__(self):
return self.total_samples
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.local_minibatch_size
end_idx = start_idx + self.local_minibatch_size
return start_idx, end_idx
@property
def local_minibatch_size(self) -> int:
return self._local_minibatch_size
@local_minibatch_size.setter
def local_minibatch_size(self, new_local_minibatch_size) -> None:
self._local_minibatch_size = new_local_minibatch_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.local_minibatch_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class MegatronPretrainingRandomSampler(_Base):
"""Megatron style Random Batch Sampler.
Major difference is that `__iter__` yields a local minibatch, not a microbatch.
A local minibatch consists of `global_batch_size / data_parallel_size`
Args:
total_samples: The number of data samples, i.e. ``len(dataset)``.
consumed_samples: The number of samples already consumed in pretraining.
local_minibatch_size: The number of data in each batch returned from `__iter__`. Basically
`local_minibatch_size = global_batch_size / data_parallel_size`.
data_parallel_rank:
data_parallel_size:
"""
def __init__(
self,
total_samples: int,
consumed_samples: int,
local_minibatch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
) -> None:
if total_samples <= 0:
raise ValueError(f"no sample to consume: total_samples of {total_samples}")
if local_minibatch_size <= 0:
raise ValueError(f"Invalid local_minibatch_size: {local_minibatch_size}")
if data_parallel_size <= 0:
raise ValueError(f"Invalid data_parallel_size: {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise ValueError(
f"data_parallel_rank should be smaller than data parallel size: {data_parallel_rank} < {data_parallel_size}"
)
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self._local_minibatch_size = local_minibatch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size
self.last_batch_size = self.total_samples % self.local_minibatch_times_data_parallel_size
def __len__(self) -> int:
return self.total_samples
@property
def local_minibatch_size(self) -> int:
return self._local_minibatch_size
@local_minibatch_size.setter
def local_minibatch_size(self, new_local_minibatch_size) -> None:
self._local_minibatch_size = new_local_minibatch_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
# note(mkozuki): might be better to uncomment
# assert current_epoch_samples % (self.data_parallel_size * apex.transformer.pipeline_parallel.utils.get_micro_batch_size()) == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.local_minibatch_times_data_parallel_size) * self.local_minibatch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.local_minibatch_size:
self.consumed_samples += self.local_minibatch_times_data_parallel_size
yield batch
batch = []
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment