Unverified Commit 6c2babf9 authored by Burc Eryilmaz's avatar Burc Eryilmaz Committed by GitHub
Browse files

Fuses dropout and softmax in backward pass, add bias support to CPP MHA, add...


Fuses dropout and softmax in backward pass, add bias support to CPP MHA, add additive mask support, separate Q/K/V parameters (#854)
Co-authored-by: default avatarSukru Eryilmaz <seryilmaz@computelab-dgx1v-32.nvidia.com>
parent 36c9e904
#include <torch/extension.h>
#include <cuda_fp16.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool is_training,
int heads,
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only BYTE is supported");
}
return fwd_cuda(
is_training,
heads,
input,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
torch::Tensor bwd(
bool use_mask,
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
dropout_prob
);
}
} // end namespace mask_softmax_dropout
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward.");
}
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "softmax.h"
#include "dropout.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn {
namespace fused_softmax {
namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
)
{
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
} else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
return {
dropout_results,
dropout_mask,
softmax_results
};
}
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
//backward pass is completely in-place
return output_grads;
}
}
}
}
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const uint8_t* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool is_training,
int heads,
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
}
return fwd_cuda(
is_training,
heads,
input,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
torch::Tensor bwd(
bool use_mask,
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
torch::Tensor const& padding_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
use_mask ? static_cast<const uint8_t*>(padding_mask.data_ptr()) : nullptr,
dropout_prob
);
}
} // end namespace mask_softmax_dropout
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward.");
}
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "softmax.h"
#include "dropout.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn {
namespace fused_softmax {
namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const uint8_t* pad_mask,
float dropout_prob
)
{
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
return {
dropout_results,
dropout_mask,
softmax_results
};
}
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
)
{
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
} else{
dispatch_masked_scale_softmax_backward_masked_out<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
static_cast<uint8_t const*>(padding_mask),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
heads);
}
//backward pass is completely in-place
return output_grads;
}
}
}
}
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace self_bias {
namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases, torch::Tensor const& output_biases,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
}
} // end namespace cublas_gemmex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
#include <torch/extension.h>
#include <vector>
#include <cuda_fp16.h>
namespace multihead_attn {
namespace self_bias_additive_mask {
namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const half* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases, torch::Tensor const& output_biases,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only Half is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
}
} // end namespace cublas_gemmex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
This diff is collapsed.
from .self_multihead_attn import SelfMultiheadAttn from .self_multihead_attn import SelfMultiheadAttn
from .encdec_multihead_attn import EncdecMultiheadAttn from .encdec_multihead_attn import EncdecMultiheadAttn
from .mask_softmax_dropout_func import fast_mask_softmax_dropout_func
import torch import torch
import fast_self_multihead_attn import fast_self_multihead_attn
import fast_self_multihead_attn_bias
import fast_self_multihead_attn_bias_additive_mask
class FastSelfAttnFunc(torch.autograd.Function) : class FastSelfAttnFunc(torch.autograd.Function) :
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, pad_mask, dropout_prob): def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, input_biases, output_biases, pad_mask, mask_additive, dropout_prob):
use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads]) heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob]) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = (pad_mask is not None)
if use_biases_t[0]:
if not mask_additive:
input_lin_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_self_multihead_attn_bias.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
input_weights, \
output_weights, \
input_biases, \
output_biases, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
else:
input_lin_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_self_multihead_attn_bias_additive_mask.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
input_weights, \
output_weights, \
input_biases, \
output_biases, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
else:
input_lin_results, \ input_lin_results, \
softmax_results, \ softmax_results, \
dropout_results, \ dropout_results, \
...@@ -27,7 +71,8 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -27,7 +71,8 @@ class FastSelfAttnFunc(torch.autograd.Function) :
pad_mask if use_mask else null_tensor, \ pad_mask if use_mask else null_tensor, \
dropout_prob) dropout_prob)
ctx.save_for_backward(heads_t, \ ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
softmax_results, \ softmax_results, \
...@@ -38,10 +83,12 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -38,10 +83,12 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_mask, \ dropout_mask, \
dropout_prob_t) dropout_prob_t)
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
use_biases_t, \
heads_t, \ heads_t, \
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
...@@ -53,6 +100,28 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -53,6 +100,28 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_mask, \ dropout_mask, \
dropout_prob_t = ctx.saved_tensors dropout_prob_t = ctx.saved_tensors
if use_biases_t[0]:
input_grads, \
input_weight_grads, \
output_weight_grads, \
input_bias_grads, \
output_bias_grads = \
fast_self_multihead_attn_bias.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
else:
input_bias_grads = None
output_bias_grads = None
input_grads, \ input_grads, \
input_weight_grads, \ input_weight_grads, \
output_weight_grads = \ output_weight_grads = \
...@@ -68,7 +137,6 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -68,7 +137,6 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_weights, \ output_weights, \
dropout_mask, \ dropout_mask, \
dropout_prob_t[0]) dropout_prob_t[0])
return None, None, None, input_grads, input_weight_grads, output_weight_grads,input_bias_grads, output_bias_grads, None, None, None
return None, None, None, input_grads, input_weight_grads, output_weight_grads, None, None
fast_self_attn_func = FastSelfAttnFunc.apply fast_self_attn_func = FastSelfAttnFunc.apply
import torch
import fast_mask_softmax_dropout
import fast_additive_mask_softmax_dropout
class MaskSoftmaxDropout(torch.autograd.Function) :
@staticmethod
def forward(ctx, is_training, heads, inputs, pad_mask, mask_additive, dropout_prob):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
use_mask_t = torch.tensor([use_mask])
mask_additive_t = torch.tensor([mask_additive])
if mask_additive:
dropout_results, \
dropout_mask, \
softmax_results = \
fast_additive_mask_softmax_dropout.forward( \
use_mask, \
is_training, \
heads, \
inputs, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
else:
dropout_results, \
dropout_mask, \
softmax_results = \
fast_mask_softmax_dropout.forward( \
use_mask, \
is_training, \
heads, \
inputs, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(
use_mask_t, \
heads_t, \
softmax_results, \
dropout_mask, \
pad_mask if use_mask else null_tensor, \
mask_additive_t, \
dropout_prob_t)
return dropout_results.detach()
@staticmethod
def backward(ctx, output_grads):
use_mask_t, \
heads_t, \
softmax_results, \
dropout_mask, \
pad_mask, \
mask_additive_t, \
dropout_prob_t = ctx.saved_tensors
if mask_additive_t[0]:
input_grads = \
fast_additive_mask_softmax_dropout.backward( \
use_mask_t[0], \
heads_t[0], \
output_grads, \
softmax_results, \
dropout_mask, \
dropout_prob_t[0])
else:
input_grads = \
fast_mask_softmax_dropout.backward( \
use_mask_t[0], \
heads_t[0], \
output_grads, \
softmax_results, \
dropout_mask, \
pad_mask, \
dropout_prob_t[0])
return None, None, input_grads, None, None, None
fast_mask_softmax_dropout_func = MaskSoftmaxDropout.apply
...@@ -21,7 +21,7 @@ class SelfMultiheadAttn(nn.Module): ...@@ -21,7 +21,7 @@ class SelfMultiheadAttn(nn.Module):
See "Attention Is All You Need" for more details. See "Attention Is All You Need" for more details.
""" """
def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast'): def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast', separate_qkv_params=False, mask_additive=False):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
...@@ -32,17 +32,38 @@ class SelfMultiheadAttn(nn.Module): ...@@ -32,17 +32,38 @@ class SelfMultiheadAttn(nn.Module):
self.include_norm_add = include_norm_add self.include_norm_add = include_norm_add
self.impl = impl self.impl = impl
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.separate_qkv_params = separate_qkv_params
self.mask_additive = mask_additive
if mask_additive:
assert self.include_norm_add == False, "additive mask not supported with layer norm"
assert impl == 'default' or (impl == 'fast' and bias), "additive mask not supported for fast mode without bias"
if separate_qkv_params:
self.q_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.v_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
else:
self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim)) self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if self.bias: if self.bias:
assert impl != 'fast', "ERROR! The Fast implementation does not support biases!" if separate_qkv_params:
self.q_bias = Parameter(torch.Tensor(embed_dim))
self.k_bias = Parameter(torch.Tensor(embed_dim))
self.v_bias = Parameter(torch.Tensor(embed_dim))
else:
self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim)) self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
self.out_proj_bias = Parameter(torch.Tensor(embed_dim)) self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
else:
if separate_qkv_params:
self.register_parameter('q_bias', None)
self.register_parameter('k_bias', None)
self.register_parameter('v_bias', None)
self.q_bias = None
self.k_bias = None
self.v_bias = None
else: else:
self.register_parameter('in_proj_bias', None) self.register_parameter('in_proj_bias', None)
self.register_parameter('out_proj_bias', None)
self.in_proj_bias = None self.in_proj_bias = None
self.register_parameter('out_proj_bias', None)
self.out_proj_bias = None self.out_proj_bias = None
if self.include_norm_add: if self.include_norm_add:
if impl == 'fast': if impl == 'fast':
...@@ -67,9 +88,19 @@ class SelfMultiheadAttn(nn.Module): ...@@ -67,9 +88,19 @@ class SelfMultiheadAttn(nn.Module):
else : assert False, "Unsupported impl: {} !".format(impl) else : assert False, "Unsupported impl: {} !".format(impl)
def reset_parameters(self): def reset_parameters(self):
if self.separate_qkv_params:
nn.init.xavier_uniform_(self.q_weight)
nn.init.xavier_uniform_(self.k_weight)
nn.init.xavier_uniform_(self.v_weight)
else:
nn.init.xavier_uniform_(self.in_proj_weight) nn.init.xavier_uniform_(self.in_proj_weight)
nn.init.xavier_uniform_(self.out_proj_weight) nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias: if self.bias:
if self.separate_qkv_params:
nn.init.constant_(self.q_bias, 0.)
nn.init.constant_(self.k_bias, 0.)
nn.init.constant_(self.v_bias, 0.)
else:
nn.init.constant_(self.in_proj_bias, 0.) nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj_bias, 0.) nn.init.constant_(self.out_proj_bias, 0.)
if self.include_norm_add: if self.include_norm_add:
...@@ -88,10 +119,22 @@ class SelfMultiheadAttn(nn.Module): ...@@ -88,10 +119,22 @@ class SelfMultiheadAttn(nn.Module):
the key by passing a binary ByteTensor (`key_padding_mask`) with shape: the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s. batch x src_len, where padding elements are indicated by 1s.
""" """
if self.separate_qkv_params:
input_weights = torch.cat([self.q_weight.view(self.num_heads,1,self.head_dim,self.embed_dim), self.k_weight.view(self.num_heads,1,self.head_dim,self.embed_dim), self.v_weight.view(self.num_heads,1,self.head_dim,self.embed_dim)], dim=1).reshape(3*self.embed_dim,self.embed_dim).contiguous()
else:
input_weights = self.in_proj_weight
if self.bias:
if self.separate_qkv_params:
input_bias = torch.cat([self.q_bias.view(self.num_heads,1,self.head_dim), self.k_bias.view(self.num_heads,1,self.head_dim), self.v_bias.view(self.num_heads,1,self.head_dim)],dim=1).reshape(3*self.embed_dim).contiguous()
else:
input_bias = self.in_proj_bias
else:
input_bias=None
if key_padding_mask is not None: if key_padding_mask is not None:
assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!" assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask mask = key_padding_mask
elif attn_mask is not None: elif attn_mask is not None:
assert self.mask_additive == False, "additive mask not supported for time mask"
mask = attn_mask mask = attn_mask
else: else:
mask = None mask = None
...@@ -100,12 +143,12 @@ class SelfMultiheadAttn(nn.Module): ...@@ -100,12 +143,12 @@ class SelfMultiheadAttn(nn.Module):
if self.impl == 'fast': if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights, self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout) input_weights, self.out_proj_weight, mask, self.dropout)
else: else:
lyr_nrm_results = self.lyr_nrm(query) lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results,
self.in_proj_weight, self.out_proj_weight, input_weights, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias, input_bias, self.out_proj_bias,
mask, self.dropout) mask, self.dropout)
if is_training: if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training) outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
...@@ -114,11 +157,11 @@ class SelfMultiheadAttn(nn.Module): ...@@ -114,11 +157,11 @@ class SelfMultiheadAttn(nn.Module):
else: else:
if self.impl == 'fast': if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout) input_weights, self.out_proj_weight, input_bias, self.out_proj_bias, mask, self.mask_additive, self.dropout)
else: else:
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query,
self.in_proj_weight, self.out_proj_weight, input_weights, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias, input_bias, self.out_proj_bias,
mask, self.dropout) mask, self.mask_additive, self.dropout)
return outputs,None return outputs,None
import torch
import unittest
import torch.nn.functional as F
from apex.contrib.multihead_attn import fast_mask_softmax_dropout_func
class FusedSoftmaxTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.mask = (torch.randn(self.sequences,self.seq_length)>0).cuda()
self.mask = self.mask.half()*-10000
self.ref_inputs = torch.randn(self.heads * self.sequences, self.seq_length, self.seq_length,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True)
def test_fused_softmax(self) :
grads = torch.randn_like(self.tst_inputs)
y_ref = self.ref_inputs.view(self.sequences, self.heads, self.seq_length, self.seq_length)
y_ref = y_ref + self.mask.unsqueeze(1).unsqueeze(2)
y_ref = y_ref.view(self.sequences*self.heads, self.seq_length, self.seq_length)
y_ref = F.softmax(y_ref, dim=-1)
y_ref = torch._fused_dropout(y_ref, 1.0)
y_tst = fast_mask_softmax_dropout_func(True, self.heads, self.tst_inputs, self.mask, True, 0.0)
y_ref[0].backward(grads)
y_tst.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(y_ref[0], y_tst, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
...@@ -243,6 +243,58 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -243,6 +243,58 @@ if "--fast_multihead_attn" in sys.argv:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"]) subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append(
CUDAExtension(name='fast_additive_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp',
'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag}))
ext_modules.append(
CUDAExtension(name='fast_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp',
'apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag}))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag}))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn', CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment