Unverified Commit 8da5eaaf authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

support softmax with bias and mask (#2)

* try

* fix bugs

* code clean

* support mask in softmax

* code clean

* check for shapes
parent 0da9683c
......@@ -5,64 +5,85 @@
std::vector<c10::optional<torch::Tensor>> fwd_cuda(
bool is_training,
torch::Tensor &input,
torch::Tensor &input,
const c10::optional<torch::Tensor> &attn_mask,
const c10::optional<torch::Tensor> &bias,
float dropout_prob,
c10::optional<at::Generator> gen_
);
c10::optional<at::Generator> gen_);
torch::Tensor bwd_cuda(
torch::Tensor &output_grads,
const torch::Tensor &softmax_results,
const c10::optional<torch::Tensor> &dropout_mask,
float dropout_prob
);
float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<c10::optional<torch::Tensor>> fwd(
bool is_training,
torch::Tensor &input,
const c10::optional<torch::Tensor> &attn_mask,
const c10::optional<torch::Tensor> &bias,
float dropout_prob,
c10::optional<at::Generator> gen_
) {
c10::optional<at::Generator> gen_)
{
CHECK_INPUT(input);
if (attn_mask)
{
CHECK_INPUT(attn_mask.value());
AT_ASSERTM(attn_mask->dim() == 3, "expected 3D tensor");
}
if (bias)
{
CHECK_INPUT(bias.value());
AT_ASSERTM(bias->dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.size(0) % bias->size(0) == 0, "wrong first dim of bias.");
AT_ASSERTM(bias->size(1) == input.size(1) && bias->size(2) == input.size(2), "the last two dims of bias and input should be the same.");
}
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16 ||
input.scalar_type() == at::ScalarType::Float, "Only HALF/BFloat16/Float is supported");
return fwd_cuda(is_training, input, dropout_prob, gen_);
input.scalar_type() == at::ScalarType::BFloat16 ||
input.scalar_type() == at::ScalarType::Float,
"Only HALF/BFloat16/Float is supported");
return fwd_cuda(is_training, input, attn_mask, bias, dropout_prob, gen_);
}
torch::Tensor bwd(
torch::Tensor &output_grads,
torch::Tensor &output_grads,
const torch::Tensor &softmax_results,
const c10::optional<torch::Tensor> &dropout_mask,
float dropout_prob
) {
float dropout_prob)
{
CHECK_INPUT(output_grads);
CHECK_INPUT(softmax_results);
if (dropout_mask) {
if (dropout_mask)
{
CHECK_INPUT(dropout_mask.value());
}
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(!dropout_mask || dropout_mask->dim() == 1, "expected 1D tensor");
AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half ||
output_grads.scalar_type() == at::ScalarType::BFloat16 ||
output_grads.scalar_type() == at::ScalarType::Float, "Only HALF/BFloat16/Float is supported");
output_grads.scalar_type() == at::ScalarType::BFloat16 ||
output_grads.scalar_type() == at::ScalarType::Float,
"Only HALF/BFloat16/Float is supported");
AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half ||
softmax_results.scalar_type() == at::ScalarType::BFloat16 ||
softmax_results.scalar_type() == at::ScalarType::Float, "Only HALF/BFloat16/Float is supported");
softmax_results.scalar_type() == at::ScalarType::BFloat16 ||
softmax_results.scalar_type() == at::ScalarType::Float,
"Only HALF/BFloat16/Float is supported");
AT_ASSERTM(output_grads.scalar_type() == softmax_results.scalar_type(), "the types mismatch");
return bwd_cuda(output_grads, softmax_results, dropout_mask, dropout_prob);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward", &fwd, "softmax dropout -- Forward.");
m.def("backward", &bwd, "softmax dropout -- Backward.");
}
......@@ -14,22 +14,37 @@
#include <torch/extension.h>
#include <math.h>
#include "type_shim.h"
#include "softmax_fast.h"
std::vector<c10::optional<torch::Tensor>> fwd_cuda(
bool is_training,
torch::Tensor &input,
torch::Tensor &input,
const c10::optional<torch::Tensor> &attn_mask,
const c10::optional<torch::Tensor> &bias,
float dropout_prob,
c10::optional<at::Generator> gen_
) {
const int attn_batches = input.size(0);
const int q_seq_len = input.size(1);
const int k_seq_len = input.size(2);
auto act_options = input.options().requires_grad(false);
c10::optional<at::Generator> gen_)
{
const int64_t attn_batches = input.size(0);
const int q_seq_len = input.size(1);
const int k_seq_len = input.size(2);
void *bias_ptr = nullptr;
int64_t bias_batches = 0;
if (bias)
{
bias_ptr = reinterpret_cast<void *>(bias->data_ptr());
bias_batches = bias->size(0);
}
void *attn_mask_prt = nullptr;
int64_t mask_inner_skip = 0;
if (attn_mask)
{
attn_mask_prt = reinterpret_cast<void *>(attn_mask->data_ptr());
mask_inner_skip = static_cast<int64_t>(attn_batches / attn_mask->size(0) * q_seq_len / attn_mask->size(1));
}
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(softmax_mask_dtype(k_seq_len));
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void *input_ptr = reinterpret_cast<void *>(input.data_ptr());
void *softmax_results_ptr = reinterpret_cast<void *>(input.data_ptr());
......@@ -37,11 +52,11 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda(
// Padded Softmax
bool softmax_success = false;
auto scalar_type = input.scalar_type();
if (is_training && dropout_prob > 0.0f) {
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty(
{softmax_mask_size(attn_batches * q_seq_len, k_seq_len)}, mask_options
);
if (is_training && dropout_prob > 0.0f)
{
torch::Tensor dropout_results = torch::empty({static_cast<int64_t>(attn_batches), q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty(
{softmax_mask_size(static_cast<int64_t>(attn_batches * q_seq_len), k_seq_len)}, mask_options);
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
std::pair<uint64_t, uint64_t> rng_engine_inputs;
......@@ -52,153 +67,213 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda(
}
uint64_t seed = std::get<0>(rng_engine_inputs);
uint64_t offset = std::get<1>(rng_engine_inputs);
if (scalar_type == at::ScalarType::BFloat16){
softmax_success = dispatch_softmax_forward<nv_bfloat16, nv_bfloat16, float, true>(
reinterpret_cast<nv_bfloat16 *>(dropout_results.data_ptr()),
reinterpret_cast<nv_bfloat16 *>(softmax_results_ptr),
reinterpret_cast<const nv_bfloat16 *>(input_ptr),
reinterpret_cast<void *>(dropout_mask.data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches*q_seq_len, seed, offset);
} else if (scalar_type == at::ScalarType::Half){
softmax_success = dispatch_softmax_forward<half, half, float, true>(
reinterpret_cast<half *>(dropout_results.data_ptr()),
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr),
reinterpret_cast<void *>(dropout_mask.data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches*q_seq_len, seed, offset);
} else if (scalar_type == at::ScalarType::Float){
softmax_success = dispatch_softmax_forward<float, float, float, true>(
reinterpret_cast<float *>(dropout_results.data_ptr()),
reinterpret_cast<float *>(softmax_results_ptr),
reinterpret_cast<const float *>(input_ptr),
reinterpret_cast<void *>(dropout_mask.data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches*q_seq_len, seed, offset);
} else {
softmax_success = false;
if (bias)
{
if (attn_mask)
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, true, true, true>(
reinterpret_cast<scalar_t_0 *>(dropout_results.data_ptr()),
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
reinterpret_cast<const scalar_t_0 *>(input_ptr),
reinterpret_cast<const scalar_t_0 *>(attn_mask_prt),
reinterpret_cast<const scalar_t_0 *>(bias_ptr),
reinterpret_cast<void *>(dropout_mask.data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
seed, offset);)
}
else
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, true, true, false>(
reinterpret_cast<scalar_t_0 *>(dropout_results.data_ptr()),
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
reinterpret_cast<const scalar_t_0 *>(input_ptr),
nullptr,
reinterpret_cast<const scalar_t_0 *>(bias_ptr),
reinterpret_cast<void *>(dropout_mask.data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
seed, offset);)
}
}
if (softmax_success) {
else
{
if (attn_mask)
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, true, false, true>(
reinterpret_cast<scalar_t_0 *>(dropout_results.data_ptr()),
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
reinterpret_cast<const scalar_t_0 *>(input_ptr),
reinterpret_cast<const scalar_t_0 *>(attn_mask_prt),
nullptr,
reinterpret_cast<void *>(dropout_mask.data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
seed, offset);)
}
else
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, true, false, false>(
reinterpret_cast<scalar_t_0 *>(dropout_results.data_ptr()),
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
reinterpret_cast<const scalar_t_0 *>(input_ptr),
nullptr,
nullptr,
reinterpret_cast<void *>(dropout_mask.data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
seed, offset);)
}
}
if (softmax_success)
{
return {dropout_results, dropout_mask, input};
} else {
}
else
{
return {c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>()};
}
} else {
if (scalar_type == at::ScalarType::BFloat16){
softmax_success = dispatch_softmax_forward<nv_bfloat16, nv_bfloat16, float, false>(
reinterpret_cast<nv_bfloat16 *>(softmax_results_ptr),
nullptr,
reinterpret_cast<const nv_bfloat16 *>(input_ptr),
nullptr,
1.0,
k_seq_len,
attn_batches*q_seq_len, 0, 0);
} else if (scalar_type == at::ScalarType::Half){
softmax_success = dispatch_softmax_forward<half, half, float, false>(
reinterpret_cast<half *>(softmax_results_ptr),
nullptr,
reinterpret_cast<const half *>(input_ptr),
nullptr,
1.0,
k_seq_len,
attn_batches*q_seq_len, 0, 0);
} else if (scalar_type == at::ScalarType::Float){
softmax_success = dispatch_softmax_forward<float, float, float, false>(
reinterpret_cast<float *>(softmax_results_ptr),
nullptr,
reinterpret_cast<const float *>(input_ptr),
nullptr,
1.0,
k_seq_len,
attn_batches*q_seq_len, 0, 0);
} else {
softmax_success = false;
}
else
{
if (bias)
{
if (attn_mask)
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, false, true, true>(
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
nullptr,
reinterpret_cast<const scalar_t_0 *>(input_ptr),
reinterpret_cast<const scalar_t_0 *>(attn_mask_prt),
reinterpret_cast<const scalar_t_0 *>(bias_ptr),
nullptr,
1.0,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
0, 0);)
}
else
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, false, true, false>(
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
nullptr,
reinterpret_cast<const scalar_t_0 *>(input_ptr),
nullptr,
reinterpret_cast<const scalar_t_0 *>(bias_ptr),
nullptr,
1.0,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
0, 0);)
}
}
else
{
if (attn_mask)
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, false, false, true>(
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
nullptr,
reinterpret_cast<const scalar_t_0 *>(input_ptr),
reinterpret_cast<const scalar_t_0 *>(attn_mask_prt),
nullptr,
nullptr,
1.0,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
0, 0);)
}
else
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, false, false, false>(
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
nullptr,
reinterpret_cast<const scalar_t_0 *>(input_ptr),
nullptr,
nullptr,
nullptr,
1.0,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
0, 0);)
}
}
if (softmax_success) {
if (softmax_success)
{
return {input, c10::optional<torch::Tensor>(), input};
} else {
}
else
{
return {c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>()};
}
}
}
torch::Tensor bwd_cuda(
torch::Tensor &output_grads,
const torch::Tensor &softmax_results,
torch::Tensor &output_grads,
const torch::Tensor &softmax_results,
const c10::optional<torch::Tensor> &dropout_mask,
float dropout_prob) {
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = output_grads.size(2);
float dropout_prob)
{
const int64_t attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = output_grads.size(2);
auto scalar_type = output_grads.scalar_type();
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
if (dropout_mask) {
if (scalar_type == at::ScalarType::BFloat16){
dispatch_softmax_backward<nv_bfloat16, nv_bfloat16, float, false, true>(
reinterpret_cast<nv_bfloat16 *>(output_grads.data_ptr()),
reinterpret_cast<const nv_bfloat16 *>(output_grads.data_ptr()),
reinterpret_cast<const nv_bfloat16 *>(softmax_results.data_ptr()),
reinterpret_cast<const void *>(dropout_mask->data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches*q_seq_len);
} else if (scalar_type == at::ScalarType::Half){
dispatch_softmax_backward<half, half, float, false, true>(
reinterpret_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<const half *>(output_grads.data_ptr()),
reinterpret_cast<const half *>(softmax_results.data_ptr()),
reinterpret_cast<const void *>(dropout_mask->data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches*q_seq_len);
} else if (scalar_type == at::ScalarType::Float){
dispatch_softmax_backward<float, float, float, false, true>(
reinterpret_cast<float *>(output_grads.data_ptr()),
reinterpret_cast<const float *>(output_grads.data_ptr()),
reinterpret_cast<const float *>(softmax_results.data_ptr()),
reinterpret_cast<const void *>(dropout_mask->data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches*q_seq_len);
}
} else {
if (scalar_type == at::ScalarType::BFloat16){
dispatch_softmax_backward<nv_bfloat16, nv_bfloat16, float, false, false>(
reinterpret_cast<nv_bfloat16 *>(output_grads.data_ptr()),
reinterpret_cast<nv_bfloat16 *>(output_grads.data_ptr()),
reinterpret_cast<const nv_bfloat16 *>(softmax_results.data_ptr()),
nullptr,
1.0f,
k_seq_len,
attn_batches*q_seq_len);
} else if (scalar_type == at::ScalarType::Half){
dispatch_softmax_backward<half, half, float, false, false>(
reinterpret_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<const half *>(softmax_results.data_ptr()),
nullptr,
1.0f,
k_seq_len,
attn_batches*q_seq_len);
} else if (scalar_type == at::ScalarType::Float){
dispatch_softmax_backward<float, float, float, false, false>(
reinterpret_cast<float *>(output_grads.data_ptr()),
reinterpret_cast<float *>(output_grads.data_ptr()),
reinterpret_cast<const float *>(softmax_results.data_ptr()),
nullptr,
1.0f,
k_seq_len,
attn_batches*q_seq_len);
}
if (dropout_mask)
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_backward",
dispatch_softmax_backward<scalar_t_0, scalar_t_0, float, false, true>(
reinterpret_cast<scalar_t_0 *>(output_grads.data_ptr()),
reinterpret_cast<const scalar_t_0 *>(output_grads.data_ptr()),
reinterpret_cast<const scalar_t_0 *>(softmax_results.data_ptr()),
reinterpret_cast<const void *>(dropout_mask->data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches * q_seq_len);)
}
//backward pass is completely in-place
else
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_backward",
dispatch_softmax_backward<scalar_t_0, scalar_t_0, float, false, false>(
reinterpret_cast<scalar_t_0 *>(output_grads.data_ptr()),
reinterpret_cast<scalar_t_0 *>(output_grads.data_ptr()),
reinterpret_cast<const scalar_t_0 *>(softmax_results.data_ptr()),
nullptr,
1.0f,
k_seq_len,
attn_batches * q_seq_len);)
}
// backward pass is completely in-place
return output_grads;
}
......@@ -9,16 +9,14 @@
#include "util.h"
template <int N>
using IntegerBits = typename std::conditional<N <= 8, uint8_t,
typename std::conditional<N <= 16, uint16_t,
typename std::conditional<N <= 32, uint32_t,
typename std::conditional<N <= 64, uint64_t, void>::type
>::type
>::type
>::type;
using IntegerBits = typename std::conditional<N <= 8, uint8_t,
typename std::conditional<N <= 16, uint16_t,
typename std::conditional<N <= 32, uint32_t,
typename std::conditional<N <= 64, uint64_t, void>::type>::type>::type>::type;
template <int LogElements>
struct SoftmaxParameters {
struct SoftmaxParameters
{
static_assert(LogElements <= 11, "");
static constexpr int Elements = 1 << LogElements;
static constexpr int WarpBatch = Elements <= 128 ? 2 : 1;
......@@ -28,31 +26,41 @@ struct SoftmaxParameters {
static constexpr int MaskStride = WarpSize;
};
inline int log2_ceil(int value) {
inline int log2_ceil(int value)
{
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
while ((1 << log2_value) < value)
++log2_value;
return log2_value;
}
inline at::ScalarType softmax_mask_dtype(int elements) {
if (elements > 1024) {
inline at::ScalarType softmax_mask_dtype(int elements)
{
if (elements > 1024)
{
return torch::kInt64;
} else if (elements > 512) {
}
else if (elements > 512)
{
return torch::kInt32;
} else if (elements > 256) {
}
else if (elements > 256)
{
return torch::kInt16;
}
return torch::kInt8;
}
inline int softmax_mask_size(int batch_size, int elements) {
inline int softmax_mask_size(int batch_size, int elements)
{
int log2_elements = log2_ceil(elements);
int e = 1 << log2_elements;
int warp_size = e < 32 ? e : 32;
return batch_size * warp_size;
}
inline int softmax_rng_delta_offset(int elements) {
inline int softmax_rng_delta_offset(int elements)
{
int log2_elements = log2_ceil(elements);
int e = 1 << log2_elements;
int warp_iterations = e <= 32 ? 1 : e / 32;
......@@ -62,156 +70,215 @@ inline int softmax_rng_delta_offset(int elements) {
template <
typename input_t, typename output_t, typename acc_t,
typename Parameters, bool NeedMask
>
__global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const output_t *src,
typename Parameters::MaskType *mask, acc_t p, int batch_size, int element_count, uint64_t seed, uint64_t rand_offset) {
typename Parameters, bool NeedMask, bool NeedBias, bool NeedAttnMask>
__global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const output_t *src, const input_t *attn_mask, const input_t *bias,
typename Parameters::MaskType *mask, acc_t p, int64_t batch_size, int64_t attn_inner_skip_batch, int64_t bias_batch_size, int element_count, uint64_t seed, uint64_t rand_offset)
{
using MaskType = typename Parameters::MaskType;
curandStatePhilox4_32_10_t state;
int64_t first_batch = (static_cast<int64_t>(blockDim.y) * static_cast<int64_t>(blockIdx.x) + threadIdx.y) * Parameters::WarpBatch;
// there might be multiple batches per warp. compute the index within the batch
int64_t local_idx = threadIdx.x;
const int64_t thread_offset = first_batch * element_count + local_idx;
if IF_CONSTEXPR (NeedMask) {
if IF_CONSTEXPR (NeedMask)
{
curand_init(seed, thread_offset, rand_offset, &state);
}
// batch_size might not be a multiple of Parameters::WarpBatch. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > Parameters::WarpBatch)
local_batches = Parameters::WarpBatch;
src += thread_offset;
dst += thread_offset;
if IF_CONSTEXPR (NeedMask) {
if IF_CONSTEXPR (NeedMask)
{
dst_orig += thread_offset;
mask += first_batch * Parameters::MaskStride;
}
int64_t bias_mod_size = bias_batch_size * element_count;
int64_t attn_mask_div_size = element_count;
if IF_CONSTEXPR (NeedAttnMask)
{
attn_mask_div_size = attn_inner_skip_batch * element_count;
}
// load data from global memory
input_t elements_input[Parameters::WarpBatch][Parameters::WarpIterations];
for (int i = 0; i < Parameters::WarpBatch; ++i) {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < Parameters::WarpIterations; ++it) {
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
int element_index = local_idx + it * Parameters::WarpSize;
elements_input[i][it] = -std::numeric_limits<float>::infinity();
if (element_index < batch_element_count) {
if (element_index < batch_element_count)
{
elements_input[i][it] = src[i * element_count + it * Parameters::WarpSize];
}
}
}
// convert input_t to acc_t
acc_t elements[Parameters::WarpBatch][Parameters::WarpIterations];
for (int i = 0; i < Parameters::WarpBatch; ++i) {
for (int it = 0; it < Parameters::WarpIterations; ++it) {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
elements[i][it] = elements_input[i][it];
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < batch_element_count)
{
int64_t global_idx = thread_offset + i * element_count + it * Parameters::WarpSize;
if IF_CONSTEXPR (NeedAttnMask)
{
auto attn_mask_idx = static_cast<int64_t>(global_idx / attn_mask_div_size) * element_count + (global_idx % element_count);
elements[i][it] += attn_mask[attn_mask_idx];
}
if IF_CONSTEXPR (NeedBias)
{
elements[i][it] += bias[global_idx % bias_mod_size];
}
}
}
}
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[Parameters::WarpBatch];
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1; it < Parameters::WarpIterations; ++it) {
for (int i = 0; i < Parameters::WarpBatch; ++i) {
#pragma unroll
for (int it = 1; it < Parameters::WarpIterations; ++it)
{
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
// reduction max_value
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2)
{
float val[Parameters::WarpBatch];
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
val[i] = SHFL_XOR(max_value[i], offset, Parameters::WarpSize);
}
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[Parameters::WarpBatch] { 0.0f };
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
for (int it = 0; it < Parameters::WarpIterations; ++it) {
acc_t sum[Parameters::WarpBatch]{0.0f};
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
// reduction sum
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2)
{
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
sum[i] += SHFL_XOR(sum[i], offset, Parameters::WarpSize);
}
}
// store result
if IF_CONSTEXPR (NeedMask) {
if IF_CONSTEXPR (NeedMask)
{
const acc_t pinv = 1.0 / p;
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
if (i >= local_batches)
break;
MaskType m = 0;
if IF_CONSTEXPR (Parameters::WarpIterations == 1) {
if IF_CONSTEXPR (Parameters::WarpIterations == 1)
{
float rand = curand_uniform(&state);
m = rand < p;
} else if IF_CONSTEXPR (Parameters::WarpIterations == 2) {
}
else if IF_CONSTEXPR (Parameters::WarpIterations == 2)
{
m = curand_uniform(&state) < p;
m |= (curand_uniform(&state) < p) << 1;
} else {
#pragma unroll
for (int j = 0; j < DIV_CELL(Parameters::WarpIterations, 4); ++j) {
}
else
{
#pragma unroll
for (int j = 0; j < DIV_CELL(Parameters::WarpIterations, 4); ++j)
{
float4 rand4 = curand_uniform4(&state);
m |= (((MaskType)(rand4.x < p)) << (j * 4))
| (((MaskType)(rand4.y < p)) << (j * 4 + 1))
| (((MaskType)(rand4.z < p)) << (j * 4 + 2))
| (((MaskType)(rand4.w < p)) << (j * 4 + 3));
m |= (((MaskType)(rand4.x < p)) << (j * 4)) | (((MaskType)(rand4.y < p)) << (j * 4 + 1)) | (((MaskType)(rand4.z < p)) << (j * 4 + 2)) | (((MaskType)(rand4.w < p)) << (j * 4 + 3));
}
}
mask[i * Parameters::MaskStride + local_idx] = m;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it) {
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < element_count) {
if (element_index < element_count)
{
const output_t d = elements[i][it] / sum[i];
dst[i * element_count + it * Parameters::WarpSize] = (acc_t)d * ((acc_t)((m >> it) & 1) * pinv);
dst_orig[i * element_count + it * Parameters::WarpSize] = d;
}
else {
else
{
break;
}
}
}
} else {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
}
else
{
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it) {
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < element_count) {
if (element_index < element_count)
{
dst[i * element_count + it * Parameters::WarpSize] = elements[i][it] / sum[i];
}
else {
else
{
break;
}
}
......@@ -219,22 +286,24 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
}
}
#define LAUNCH_FORWARD_KERNEL(l) \
softmax_warp_forward<input_t, output_t, acc_t, SoftmaxParameters<l>, NeedMask> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
dst, dst_orig, src, (typename SoftmaxParameters<l>::MaskType *)mask, p, \
batch_count, softmax_elements, seed, offset \
); \
return true;
template<typename input_t, typename output_t, typename acc_t, bool NeedMask>
bool dispatch_softmax_forward(output_t *dst, output_t *dst_orig, const input_t *src, void *mask, acc_t p,
int softmax_elements, int batch_count, uint64_t seed, uint64_t offset)
#define LAUNCH_FORWARD_KERNEL(l) \
softmax_warp_forward<input_t, output_t, acc_t, SoftmaxParameters<l>, NeedMask, NeedBias, NeedAttnMask> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
dst, dst_orig, src, attn_mask, bias, (typename SoftmaxParameters<l>::MaskType *)mask, p, \
batch_count, attn_inner_skip_batch, bias_batch_count, softmax_elements, seed, offset); \
return true;
template <typename input_t, typename output_t, typename acc_t, bool NeedMask, bool NeedBias, bool NeedAttnMask>
bool dispatch_softmax_forward(output_t *dst, output_t *dst_orig, const input_t *src, const input_t *attn_mask, const input_t *bias, void *mask, acc_t p,
int softmax_elements, int64_t batch_count, int64_t attn_inner_skip_batch, int64_t bias_batch_count, uint64_t seed, uint64_t offset)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return false;
} else {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
if (softmax_elements == 0)
{
return false;
}
else
{
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
......@@ -252,20 +321,34 @@ bool dispatch_softmax_forward(output_t *dst, output_t *dst_orig, const input_t *
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: LAUNCH_FORWARD_KERNEL(0)
case 1: LAUNCH_FORWARD_KERNEL(1)
case 2: LAUNCH_FORWARD_KERNEL(2)
case 3: LAUNCH_FORWARD_KERNEL(3)
case 4: LAUNCH_FORWARD_KERNEL(4)
case 5: LAUNCH_FORWARD_KERNEL(5)
case 6: LAUNCH_FORWARD_KERNEL(6)
case 7: LAUNCH_FORWARD_KERNEL(7)
case 8: LAUNCH_FORWARD_KERNEL(8)
case 9: LAUNCH_FORWARD_KERNEL(9)
case 10: LAUNCH_FORWARD_KERNEL(10)
case 11: LAUNCH_FORWARD_KERNEL(11)
default: return false;
switch (log2_elements)
{
case 0:
LAUNCH_FORWARD_KERNEL(0)
case 1:
LAUNCH_FORWARD_KERNEL(1)
case 2:
LAUNCH_FORWARD_KERNEL(2)
case 3:
LAUNCH_FORWARD_KERNEL(3)
case 4:
LAUNCH_FORWARD_KERNEL(4)
case 5:
LAUNCH_FORWARD_KERNEL(5)
case 6:
LAUNCH_FORWARD_KERNEL(6)
case 7:
LAUNCH_FORWARD_KERNEL(7)
case 8:
LAUNCH_FORWARD_KERNEL(8)
case 9:
LAUNCH_FORWARD_KERNEL(9)
case 10:
LAUNCH_FORWARD_KERNEL(10)
case 11:
LAUNCH_FORWARD_KERNEL(11)
default:
return false;
}
}
return false;
......@@ -273,10 +356,9 @@ bool dispatch_softmax_forward(output_t *dst, output_t *dst_orig, const input_t *
template <
typename input_t, typename output_t, typename acc_t, typename Parameters,
bool IsLogSoftmax, bool NeedMask
>
bool IsLogSoftmax, bool NeedMask>
__global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output,
const typename Parameters::MaskType *mask, acc_t p, int batch_size, int element_count)
const typename Parameters::MaskType *mask, acc_t p, int64_t batch_size, int element_count)
{
using MaskType = typename Parameters::MaskType;
int64_t first_batch = (static_cast<int64_t>(blockDim.y) * static_cast<int64_t>(blockIdx.x) + threadIdx.y) * Parameters::WarpBatch;
......@@ -295,7 +377,8 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
if IF_CONSTEXPR (NeedMask) {
if IF_CONSTEXPR (NeedMask)
{
mask += first_batch * Parameters::MaskStride;
}
......@@ -306,52 +389,64 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
// load data from global memory
acc_t grad_reg[Parameters::WarpBatch][Parameters::WarpIterations];
acc_t output_reg[Parameters::WarpBatch][Parameters::WarpIterations] ;
if IF_CONSTEXPR (NeedMask) {
acc_t output_reg[Parameters::WarpBatch][Parameters::WarpIterations];
if IF_CONSTEXPR (NeedMask)
{
MaskType mask_reg[Parameters::WarpBatch];
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
if (i >= local_batches)
break;
mask_reg[i] = mask[i * Parameters::MaskStride + local_idx];
}
const acc_t pinv = 1.0 / p;
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
int batch_element_count = (i >= local_batches) ? 0 : element_count;
MaskType m = mask_reg[i];
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it) {
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < batch_element_count) {
if (element_index < batch_element_count)
{
grad_reg[i][it] =
(input_t)(
(acc_t)((m >> it) & 1) *
(acc_t)grad[i * element_count + it * Parameters::WarpSize] *
pinv
) *
(input_t)((acc_t)((m >> it) & 1) *
(acc_t)grad[i * element_count + it * Parameters::WarpSize] *
pinv) *
output[i * element_count + it * Parameters::WarpSize];
output_reg[i][it] = output[i * element_count + it * Parameters::WarpSize];
} else {
}
else
{
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
}
}
}
} else {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
}
else
{
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it) {
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < batch_element_count) {
if (element_index < batch_element_count)
{
grad_reg[i][it] = grad[i * element_count + it * Parameters::WarpSize] *
output[i * element_count + it * Parameters::WarpSize];
output[i * element_count + it * Parameters::WarpSize];
output_reg[i][it] = output[i * element_count + it * Parameters::WarpSize];
} else {
}
else
{
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
}
......@@ -360,37 +455,47 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
}
acc_t sum[Parameters::WarpBatch];
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < Parameters::WarpIterations; ++it) {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < Parameters::WarpIterations; ++it)
{
sum[i] += grad_reg[i][it];
}
}
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2)
{
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
sum[i] += SHFL_XOR(sum[i], offset, Parameters::WarpSize);
}
}
// store result
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
// store result
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it) {
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < element_count) {
if (element_index < element_count)
{
// compute gradients
if IF_CONSTEXPR (IsLogSoftmax) {
if IF_CONSTEXPR (IsLogSoftmax)
{
gradInput[i * element_count + it * Parameters::WarpSize] =
(grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
} else {
}
else
{
gradInput[i * element_count + it * Parameters::WarpSize] =
(grad_reg[i][it] - output_reg[i][it] * sum[i]);
}
......@@ -399,22 +504,24 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
}
}
#define LAUNCH_BACKWARD_KERNEL(l) \
softmax_warp_backward<input_t, output_t, acc_t, SoftmaxParameters<l>, IsLogSoftmax, NeedMask> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
grad_input, grad, output, (const typename SoftmaxParameters<l>::MaskType *)mask, p, \
batch_count, softmax_elements \
); \
break;
#define LAUNCH_BACKWARD_KERNEL(l) \
softmax_warp_backward<input_t, output_t, acc_t, SoftmaxParameters<l>, IsLogSoftmax, NeedMask> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
grad_input, grad, output, (const typename SoftmaxParameters<l>::MaskType *)mask, p, \
batch_count, softmax_elements); \
break;
template<typename input_t, typename output_t, typename acc_t, bool IsLogSoftmax, bool NeedMask>
template <typename input_t, typename output_t, typename acc_t, bool IsLogSoftmax, bool NeedMask>
void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output,
const void *mask, acc_t p, int softmax_elements, int batch_count)
const void *mask, acc_t p, int softmax_elements, int64_t batch_count)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
if (softmax_elements == 0)
{
return;
}
else
{
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
......@@ -432,20 +539,34 @@ void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: LAUNCH_BACKWARD_KERNEL(0)
case 1: LAUNCH_BACKWARD_KERNEL(1)
case 2: LAUNCH_BACKWARD_KERNEL(2)
case 3: LAUNCH_BACKWARD_KERNEL(3)
case 4: LAUNCH_BACKWARD_KERNEL(4)
case 5: LAUNCH_BACKWARD_KERNEL(5)
case 6: LAUNCH_BACKWARD_KERNEL(6)
case 7: LAUNCH_BACKWARD_KERNEL(7)
case 8: LAUNCH_BACKWARD_KERNEL(8)
case 9: LAUNCH_BACKWARD_KERNEL(9)
case 10: LAUNCH_BACKWARD_KERNEL(10)
case 11: LAUNCH_BACKWARD_KERNEL(11)
default: break;
switch (log2_elements)
{
case 0:
LAUNCH_BACKWARD_KERNEL(0)
case 1:
LAUNCH_BACKWARD_KERNEL(1)
case 2:
LAUNCH_BACKWARD_KERNEL(2)
case 3:
LAUNCH_BACKWARD_KERNEL(3)
case 4:
LAUNCH_BACKWARD_KERNEL(4)
case 5:
LAUNCH_BACKWARD_KERNEL(5)
case 6:
LAUNCH_BACKWARD_KERNEL(6)
case 7:
LAUNCH_BACKWARD_KERNEL(7)
case 8:
LAUNCH_BACKWARD_KERNEL(8)
case 9:
LAUNCH_BACKWARD_KERNEL(9)
case 10:
LAUNCH_BACKWARD_KERNEL(10)
case 11:
LAUNCH_BACKWARD_KERNEL(11)
default:
break;
}
}
}
import torch
import torch.nn.functional as F
from unicore.modules import softmax_dropout
def gen_attn_mask(mask, neg_inf):
assert neg_inf < -1e4
attn_mask = torch.zeros_like(mask)
attn_mask[mask == 0] = neg_inf
return attn_mask
def normal_softmax(a, mask, bias):
return F.softmax(a + mask + bias, dim=-1)
def fused_softmax(a, mask, bias):
return softmax_dropout(a, 0, True, mask=mask, bias=bias)
def wrap_forward_backward(func, a1, mask, bias1):
a = a1.clone()
bias = bias1.clone()
a.requires_grad = True
bias.requires_grad = True
output = func(a, mask, bias)
o = output.float().sum()
o.backward()
return output, a.grad, bias.grad
def check_diff(a, b, name, eps=1e-3):
assert (a - b).abs().max() < eps, "name {}, diff {}".format(
name, (a - b).abs().max()
)
def test_softmax():
n_batch = 4
n_heads = 8
n_query = 128
test_dims = [64, 128, 256, 512, 1024]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
for last_dim in test_dims:
for dtype in test_dtype:
x = torch.rand(
n_batch,
n_heads,
n_query,
last_dim,
dtype=dtype,
device=test_device,
)
mask = gen_attn_mask(
(
torch.rand(
n_batch,
1,
1,
last_dim,
dtype=dtype,
device=test_device,
)
> 0.2
).type(x.dtype),
-3e4,
)
bias = torch.rand(
n_batch, n_heads, n_query, last_dim, dtype=dtype, device=test_device
)
out_a1, out_b1, out_c1 = wrap_forward_backward(
normal_softmax, x, mask, bias
)
out_a2, out_b2, out_c2 = wrap_forward_backward(fused_softmax, x, mask, bias)
check_diff(out_a1, out_a2, "output")
check_diff(out_b1, out_b2, "grad_input")
check_diff(out_c1, out_c2, "grad_bias")
def test_tri_softmax1():
n_batch = 2
n_groups = 32
n_heads = 8
n_query = 128
test_dims = [64, 128, 256, 512, 1024]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
for last_dim in test_dims:
for dtype in test_dtype:
x = torch.rand(
n_batch,
n_groups,
n_heads,
n_query,
last_dim,
dtype=dtype,
device=test_device,
)
mask = gen_attn_mask(
(
torch.rand(
n_batch,
n_groups,
1,
1,
last_dim,
dtype=dtype,
device=test_device,
)
> 0.2
).type(x.dtype),
-3e4,
)
bias = torch.rand(
1, 1, n_heads, n_query, last_dim, dtype=dtype, device=test_device
)
out_a1, out_b1, out_c1 = wrap_forward_backward(
normal_softmax, x, mask, bias
)
out_a2, out_b2, out_c2 = wrap_forward_backward(fused_softmax, x, mask, bias)
check_diff(out_a1, out_a2, "output")
check_diff(out_b1, out_b2, "grad_input")
check_diff(out_c1, out_c2, "grad_bias")
def test_tri_softmax2():
n_batch = 2
n_groups = 32
n_heads = 8
n_query = 128
test_dims = [64, 128, 256, 512, 1024]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
for last_dim in test_dims:
for dtype in test_dtype:
x = torch.rand(
n_batch,
n_groups,
n_heads,
n_query,
last_dim,
dtype=dtype,
device=test_device,
)
mask = gen_attn_mask(
(
torch.rand(
n_batch,
n_groups,
n_heads,
1,
last_dim,
dtype=dtype,
device=test_device,
)
> 0.2
).type(x.dtype),
-3e4,
)
bias = torch.rand(
1, n_groups, n_heads, n_query, last_dim, dtype=dtype, device=test_device
)
out_a1, out_b1, out_c1 = wrap_forward_backward(
normal_softmax, x, mask, bias
)
out_a2, out_b2, out_c2 = wrap_forward_backward(fused_softmax, x, mask, bias)
check_diff(out_a1, out_a2, "output")
check_diff(out_b1, out_b2, "grad_input")
check_diff(out_c1, out_c2, "grad_bias")
......@@ -8,6 +8,7 @@ import torch
from torch import Tensor, nn
from .softmax_dropout import softmax_dropout
class SelfMultiheadAttention(nn.Module):
def __init__(
self,
......@@ -37,7 +38,7 @@ class SelfMultiheadAttention(nn.Module):
query,
key_padding_mask: Optional[Tensor] = None,
attn_bias: Optional[Tensor] = None,
return_attn: bool=False,
return_attn: bool = False,
) -> Tensor:
bsz, tgt_len, embed_dim = query.size()
......@@ -46,18 +47,25 @@ class SelfMultiheadAttention(nn.Module):
q, k, v = self.in_proj(query).chunk(3, dim=-1)
q = (
q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim) * self.scaling
q.view(bsz, tgt_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz * self.num_heads, -1, self.head_dim)
* self.scaling
)
if k is not None:
k = (
k.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
k.view(bsz, -1, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz * self.num_heads, -1, self.head_dim)
)
if v is not None:
v = (
v.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
v.view(bsz, -1, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz * self.num_heads, -1, self.head_dim)
)
assert k is not None
......@@ -72,37 +80,38 @@ class SelfMultiheadAttention(nn.Module):
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights.masked_fill_(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf")
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attn_bias is not None:
attn_weights += attn_bias
attn_probs = softmax_dropout(attn_weights, self.dropout, self.training)
attn_probs = softmax_dropout(
attn_weights, self.dropout, self.training, bias=attn_bias
)
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.view(bsz, self.num_heads, tgt_len, self.head_dim).transpose(1, 2).contiguous().view(bsz, tgt_len, embed_dim)
attn = (
attn.view(bsz, self.num_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz, tgt_len, embed_dim)
)
attn = self.out_proj(attn)
if not return_attn:
return attn
else:
return attn, attn_weights, attn_probs
class CrossMultiheadAttention(nn.Module):
def __init__(
self,
......@@ -147,18 +156,25 @@ class CrossMultiheadAttention(nn.Module):
v = self.v_proj(value)
q = (
q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim) * self.scaling
q.view(bsz, tgt_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz * self.num_heads, -1, self.head_dim)
* self.scaling
)
if k is not None:
k = (
k.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
k.view(bsz, -1, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz * self.num_heads, -1, self.head_dim)
)
if v is not None:
v = (
v.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
v.view(bsz, -1, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz * self.num_heads, -1, self.head_dim)
)
assert k is not None
......@@ -173,30 +189,28 @@ class CrossMultiheadAttention(nn.Module):
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights.masked_fill_(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf")
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attn_bias is not None:
attn_weights += attn_bias
attn_probs = softmax_dropout(attn_weights, self.dropout, self.training)
attn_probs = softmax_dropout(attn_weights, self.dropout, self.training, bias=attn_bias)
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.view(bsz, self.num_heads, tgt_len, self.head_dim).transpose(1, 2).contiguous().view(bsz, tgt_len, embed_dim)
attn = (
attn.view(bsz, self.num_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz, tgt_len, embed_dim)
)
attn = self.out_proj(attn)
return attn
\ No newline at end of file
return attn
......@@ -6,16 +6,23 @@ import torch
import unicore_fused_softmax_dropout
import torch.nn.functional as F
class SoftmaxDropoutFast(torch.autograd.Function):
@staticmethod
def forward(ctx, is_training, inputs, dropout_prob):
# don't use ctx.save_for_backward to save dropout_prob
# allocating space for a tensor is time-consuming
dropout_results, dropout_mask, softmax_results = unicore_fused_softmax_dropout.forward(is_training,
inputs, dropout_prob, None)
def forward(ctx, is_training, inputs, mask, bias, dropout_prob):
(
dropout_results,
dropout_mask,
softmax_results,
) = unicore_fused_softmax_dropout.forward(
is_training, inputs, mask, bias, dropout_prob, None
)
if is_training:
ctx.dropout_prob = dropout_prob
ctx.save_for_backward(softmax_results, dropout_mask)
ctx.has_bias = bias is not None and bias.requires_grad
if ctx.has_bias:
ctx.bias_batch_dim = bias.shape[0]
return dropout_results
@staticmethod
......@@ -23,15 +30,87 @@ class SoftmaxDropoutFast(torch.autograd.Function):
softmax_results, dropout_mask = ctx.saved_tensors
dropout_prob = ctx.dropout_prob
grad_output = grad_output.contiguous()
grad_input = unicore_fused_softmax_dropout.backward(grad_output, softmax_results,
dropout_mask, dropout_prob)
return None, grad_input, None
grad_input = unicore_fused_softmax_dropout.backward(
grad_output, softmax_results, dropout_mask, dropout_prob
)
if ctx.has_bias:
grad_bias = grad_input.view(
-1, ctx.bias_batch_dim, grad_input.shape[-2], grad_input.shape[-1]
).sum(dim=0)
else:
grad_bias = None
return None, grad_input, None, grad_bias, None
def _check_mask(mask, input):
assert mask.dtype == input.dtype, "mask and input must have the same dtype"
assert len(mask.shape) == len(input.shape), "wrong length of mask.shape"
assert (
mask.shape[-3] == 1 or mask.shape[-3] == input.shape[-3]
), "mask.shape[-3] must be 1 or input.shape[-3]"
if mask.shape[-3] == 1:
assert mask.shape[-2] == 1, "when mask.shape[-3] == 1, mask.shape[-2] must be 1"
else:
assert (
mask.shape[-2] == 1 or mask.shape[-2] == input.shape[-2]
), "mask.shape[-2] must be 1 or input.shape[-2]"
def _check_bias(bias, input):
assert bias.dtype == input.dtype, "bias and input must have the same dtype"
assert len(bias.shape) == len(input.shape), "wrong length of bias.shape"
assert bias.shape[-1] == input.shape[-1], "bias.shape[-1] must be input.shape[-1]"
assert bias.shape[-2] == input.shape[-2], "bias.shape[-2] must be input.shape[-2]"
len_shape = len(input.shape)
if len_shape > 3:
# head dim should be the same
assert (
bias.shape[-3] == input.shape[-3]
), "bias.shape[-3] must be input.shape[-3]"
offset = 3
else:
offset = 2
prev_non_one = True
for i in range(len_shape - offset - 1, -1, -1):
if prev_non_one:
assert (
bias.shape[i] == input.shape[i] or bias.shape[i] == 1
), "bias.shape[{}] must be input.shape[{}] or 1".format(i, i)
else:
assert bias.shape[i] == 1, "bias.shape[{}] must be 1".format(i)
prev_non_one = bias.shape[i] != 1
def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None):
"""softmax dropout, and mask, bias are optional.
Args:
input (torch.Tensor): input tensor
dropout_prob (float): dropout probability
is_training (bool, optional): is in training or not. Defaults to True.
mask (torch.Tensor, optional): the mask tensor, use as input + mask . Defaults to None.
bias (torch.Tensor, optional): the bias tensor, use as input + bias . Defaults to None.
def softmax_dropout(input, dropout_prob, is_training=True):
Returns:
torch.Tensor: the result after softmax
"""
input = input.contiguous()
input_size = input.size()
if mask is not None:
_check_mask(mask, input)
mask = mask.contiguous().view(-1, mask.shape[-2], mask.shape[-1])
if bias is not None:
_check_bias(bias, input)
bias = bias.contiguous().view(-1, input_size[-2], input_size[-1])
input = input.view(-1, input_size[-2], input_size[-1])
if input.is_cuda and input.shape[-1] <= 2048:
return SoftmaxDropoutFast.apply(is_training, input, dropout_prob).view(*input_size)
return SoftmaxDropoutFast.apply(
is_training, input, mask, bias, dropout_prob
).view(*input_size)
else:
return F.dropout(F.softmax(input, dim=-1), p=dropout_prob, training=is_training).view(*input_size)
if mask is None:
input += mask
if bias is not None:
input += bias
return F.dropout(
F.softmax(input, dim=-1), p=dropout_prob, training=is_training
).view(*input_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment