Unverified Commit 6c18b2ed authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

support softmax for large columns (#6)

* support softmax for large columns

* more tests
parent 31fe887e
......@@ -6,6 +6,7 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <curand_kernel.h>
#include <cub/cub.cuh>
#include "util.h"
template <int N>
......@@ -68,6 +69,141 @@ inline int softmax_rng_delta_offset(int elements)
return warp_iterations * warp_batch;
}
inline cudaError_t GetNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves,
int *num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) {
return err;
}
}
int sm_count;
{
cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) {
return err;
}
}
int tpm;
{
cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);
if (err != cudaSuccess) {
return err;
}
}
*num_blocks =
std::max<int>(1, std::min<int64_t>(max_blocks, sm_count * tpm / block_size * waves));
return cudaSuccess;
}
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const { return a + b; }
};
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const { return max(a, b); }
};
template <template <typename> class ReductionOp, typename T, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
if (threadIdx.x == 0) {
result_broadcast = result;
}
__syncthreads();
return result_broadcast;
}
// modified from https://github.com/Oneflow-Inc/oneflow/blob/5d74efa4d07adfd0acbc8e0074778687f1006b86/oneflow/core/cuda/softmax.cuh#L480-L529
// Copyright 2020 The OneFlow Authors. All rights reserved.
template <typename input_t, typename output_t, typename acc_t, int block_size, bool NeedBias, bool NeedAttnMask>
__global__ void softmax_block_forward(const input_t *input, output_t *output, const input_t *attn_mask, const input_t *bias,
int64_t rows, int cols, int64_t attn_inner_skip_batch, int64_t bias_batch_size) {
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto *buf = reinterpret_cast<acc_t *>(shared_buf);
const int tid = threadIdx.x;
auto element_count = cols;
int64_t bias_mod_size = bias_batch_size * cols;
int64_t attn_mask_div_size = element_count;
if IF_CONSTEXPR (NeedAttnMask)
{
attn_mask_div_size = attn_inner_skip_batch * element_count;
}
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
acc_t thread_max = -std::numeric_limits<acc_t>::infinity();
int64_t idx_offset = row * cols;
const input_t* input_ptr = input + idx_offset;
output_t* output_ptr = output + idx_offset;
const input_t* attn_mask_ptr = nullptr;
if IF_CONSTEXPR (NeedAttnMask){
attn_mask_ptr = attn_mask + static_cast<int64_t>(idx_offset / attn_mask_div_size) * element_count ;
}
const input_t* bias_ptr = nullptr;
if IF_CONSTEXPR (NeedBias) {
bias_ptr = bias + idx_offset % bias_mod_size;
}
// TODO: enable pack as oneflow
for (int col = tid; col < cols; col += block_size) {
buf[col] = static_cast<acc_t>(input_ptr[col]);
if IF_CONSTEXPR (NeedAttnMask)
{
buf[col] += attn_mask_ptr[col];
}
if IF_CONSTEXPR (NeedBias)
{
buf[col] += bias_ptr[col];
}
thread_max = max(thread_max, buf[col]);
}
const acc_t row_max = BlockAllReduce<MaxOp, acc_t, block_size>(thread_max);
acc_t thread_sum = 0;
for (int col = tid; col < cols; col += block_size) {
buf[col] = std::exp(buf[col] - row_max);
thread_sum += buf[col];
}
const acc_t row_sum = BlockAllReduce<SumOp, acc_t, block_size>(thread_sum);
for (int col = tid; col < cols; col += block_size) {
output_ptr[col] = static_cast<output_t>(buf[col] / row_sum);
}
}
}
template<typename input_t, typename output_t, typename acc_t, int block_size>
__global__ void softmax_block_backward(output_t* store, const input_t* dy, const input_t* y,
const int64_t rows, const int64_t cols) {
extern __shared__ __align__(sizeof(double)) unsigned char grad_shared_buf[];
auto* dy_buf = reinterpret_cast<acc_t*>(grad_shared_buf);
auto* y_buf = reinterpret_cast<input_t*>(dy_buf + cols);
const int tid = threadIdx.x;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
acc_t thread_sum = 0;
auto dy_ptr = dy + row * cols;
auto y_ptr = y + row * cols;
auto store_ptr = store + row * cols;
for (int col = tid; col < cols; col += block_size) {
y_buf[col] = y_ptr[col];
dy_buf[col] = dy_ptr[col] * (acc_t)y_ptr[col];
}
for (int col = tid; col < cols; col += block_size) {
thread_sum += dy_buf[col];
}
const acc_t row_sum = BlockAllReduce<SumOp, acc_t, block_size>(thread_sum);
for (int col = tid; col < cols; col += block_size) {
store_ptr[col] = static_cast<output_t>(dy_buf[col] - y_buf[col] * row_sum);
}
}
}
template <
typename input_t, typename output_t, typename acc_t,
typename Parameters, bool NeedMask, bool NeedBias, bool NeedAttnMask>
......@@ -113,6 +249,7 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
int batch_element_count = (i >= local_batches) ? 0 : element_count;
auto src_ptr = src + i * element_count;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
......@@ -121,7 +258,7 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
if (element_index < batch_element_count)
{
elements_input[i][it] = src[i * element_count + it * Parameters::WarpSize];
elements_input[i][it] = src_ptr[it * Parameters::WarpSize];
}
}
}
......@@ -132,6 +269,15 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int64_t idx_offset = (first_batch + i) * element_count;
const input_t* attn_mask_ptr = nullptr;
if IF_CONSTEXPR (NeedAttnMask){
attn_mask_ptr = attn_mask + static_cast<int64_t>(idx_offset / attn_mask_div_size) * element_count + local_idx;
}
const input_t* bias_ptr = nullptr;
if IF_CONSTEXPR (NeedBias){
bias_ptr = bias + idx_offset % bias_mod_size + local_idx;
}
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
......@@ -139,15 +285,13 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < batch_element_count)
{
int64_t global_idx = thread_offset + i * element_count + it * Parameters::WarpSize;
if IF_CONSTEXPR (NeedAttnMask)
{
auto attn_mask_idx = static_cast<int64_t>(global_idx / attn_mask_div_size) * element_count + (global_idx % element_count);
elements[i][it] += attn_mask[attn_mask_idx];
elements[i][it] += attn_mask_ptr[it * Parameters::WarpSize];
}
if IF_CONSTEXPR (NeedBias)
{
elements[i][it] += bias[global_idx % bias_mod_size];
elements[i][it] += bias_ptr[it * Parameters::WarpSize];
}
}
}
......@@ -245,6 +389,8 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
}
}
mask[i * Parameters::MaskStride + local_idx] = m;
auto dst_ptr = dst + i * element_count;
auto dst_orig_ptr = dst_orig + i * element_count;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
......@@ -252,8 +398,8 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
if (element_index < element_count)
{
const output_t d = elements[i][it] / sum[i];
dst[i * element_count + it * Parameters::WarpSize] = (acc_t)d * ((acc_t)((m >> it) & 1) * pinv);
dst_orig[i * element_count + it * Parameters::WarpSize] = d;
dst_ptr[it * Parameters::WarpSize] = (acc_t)d * ((acc_t)((m >> it) & 1) * pinv);
dst_orig_ptr[it * Parameters::WarpSize] = d;
}
else
{
......@@ -267,6 +413,7 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
auto dst_ptr = dst + i * element_count;
if (i >= local_batches)
break;
#pragma unroll
......@@ -275,7 +422,7 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < element_count)
{
dst[i * element_count + it * Parameters::WarpSize] = elements[i][it] / sum[i];
dst_ptr[it * Parameters::WarpSize] = elements[i][it] / sum[i];
}
else
{
......@@ -323,32 +470,42 @@ bool dispatch_softmax_forward(output_t *dst, output_t *dst_orig, const input_t *
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements)
{
case 0:
LAUNCH_FORWARD_KERNEL(0)
case 1:
LAUNCH_FORWARD_KERNEL(1)
case 2:
LAUNCH_FORWARD_KERNEL(2)
case 3:
LAUNCH_FORWARD_KERNEL(3)
case 4:
LAUNCH_FORWARD_KERNEL(4)
case 5:
LAUNCH_FORWARD_KERNEL(5)
case 6:
LAUNCH_FORWARD_KERNEL(6)
case 7:
LAUNCH_FORWARD_KERNEL(7)
case 8:
LAUNCH_FORWARD_KERNEL(8)
case 9:
LAUNCH_FORWARD_KERNEL(9)
case 10:
LAUNCH_FORWARD_KERNEL(10)
case 11:
LAUNCH_FORWARD_KERNEL(11)
default:
return false;
case 0:
LAUNCH_FORWARD_KERNEL(0)
case 1:
LAUNCH_FORWARD_KERNEL(1)
case 2:
LAUNCH_FORWARD_KERNEL(2)
case 3:
LAUNCH_FORWARD_KERNEL(3)
case 4:
LAUNCH_FORWARD_KERNEL(4)
case 5:
LAUNCH_FORWARD_KERNEL(5)
case 6:
LAUNCH_FORWARD_KERNEL(6)
case 7:
LAUNCH_FORWARD_KERNEL(7)
case 8:
LAUNCH_FORWARD_KERNEL(8)
case 9:
LAUNCH_FORWARD_KERNEL(9)
case 10:
LAUNCH_FORWARD_KERNEL(10)
default:
{
int grid_dim;
constexpr int block_size = 128;
constexpr int waves = 32;
auto cols = softmax_elements;
auto rows = batch_count;
GetNumBlocks(block_size, rows, waves, &grid_dim);
dim3 block(block_size);
const size_t smem = cols * sizeof(acc_t);
softmax_block_forward<input_t, output_t, acc_t, block_size, NeedAttnMask, NeedBias><<<grid_dim, block, smem>>>(
src, dst, attn_mask, bias, rows, cols, attn_inner_skip_batch, bias_batch_count);
return true;
}
}
}
return false;
......@@ -389,7 +546,7 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
// load data from global memory
acc_t grad_reg[Parameters::WarpBatch][Parameters::WarpIterations];
acc_t output_reg[Parameters::WarpBatch][Parameters::WarpIterations];
input_t output_reg[Parameters::WarpBatch][Parameters::WarpIterations];
if IF_CONSTEXPR (NeedMask)
{
MaskType mask_reg[Parameters::WarpBatch];
......@@ -408,6 +565,8 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
{
int batch_element_count = (i >= local_batches) ? 0 : element_count;
MaskType m = mask_reg[i];
auto output_ptr = output + i * element_count;
auto grad_ptr = grad + i * element_count;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
......@@ -415,16 +574,16 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
if (element_index < batch_element_count)
{
grad_reg[i][it] =
(input_t)((acc_t)((m >> it) & 1) *
(acc_t)grad[i * element_count + it * Parameters::WarpSize] *
pinv) *
output[i * element_count + it * Parameters::WarpSize];
output_reg[i][it] = output[i * element_count + it * Parameters::WarpSize];
(acc_t)((m >> it) & 1) *
(acc_t)grad_ptr[it * Parameters::WarpSize] *
pinv *
output_ptr[it * Parameters::WarpSize];
output_reg[i][it] = output_ptr[it * Parameters::WarpSize];
}
else
{
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
output_reg[i][it] = input_t(0);
}
}
}
......@@ -435,20 +594,22 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
int batch_element_count = (i >= local_batches) ? 0 : element_count;
auto output_ptr = output + i * element_count;
auto grad_ptr = grad + i * element_count;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < batch_element_count)
{
grad_reg[i][it] = grad[i * element_count + it * Parameters::WarpSize] *
output[i * element_count + it * Parameters::WarpSize];
output_reg[i][it] = output[i * element_count + it * Parameters::WarpSize];
output_reg[i][it] = output_ptr[it * Parameters::WarpSize];
grad_reg[i][it] = grad_ptr[it * Parameters::WarpSize] *
(acc_t)output_ptr[it * Parameters::WarpSize];
}
else
{
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
output_reg[i][it] = output_t(0);
}
}
}
......@@ -482,6 +643,7 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
{
if (i >= local_batches)
break;
auto gradInput_ptr = gradInput + i * element_count;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
......@@ -491,12 +653,12 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
// compute gradients
if IF_CONSTEXPR (IsLogSoftmax)
{
gradInput[i * element_count + it * Parameters::WarpSize] =
(grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
gradInput_ptr[it * Parameters::WarpSize] =
(grad_reg[i][it] - std::exp((acc_t)output_reg[i][it]) * sum[i]);
}
else
{
gradInput[i * element_count + it * Parameters::WarpSize] =
gradInput_ptr[it * Parameters::WarpSize] =
(grad_reg[i][it] - output_reg[i][it] * sum[i]);
}
}
......@@ -541,32 +703,41 @@ void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements)
{
case 0:
LAUNCH_BACKWARD_KERNEL(0)
case 1:
LAUNCH_BACKWARD_KERNEL(1)
case 2:
LAUNCH_BACKWARD_KERNEL(2)
case 3:
LAUNCH_BACKWARD_KERNEL(3)
case 4:
LAUNCH_BACKWARD_KERNEL(4)
case 5:
LAUNCH_BACKWARD_KERNEL(5)
case 6:
LAUNCH_BACKWARD_KERNEL(6)
case 7:
LAUNCH_BACKWARD_KERNEL(7)
case 8:
LAUNCH_BACKWARD_KERNEL(8)
case 9:
LAUNCH_BACKWARD_KERNEL(9)
case 10:
LAUNCH_BACKWARD_KERNEL(10)
case 11:
LAUNCH_BACKWARD_KERNEL(11)
default:
break;
case 0:
LAUNCH_BACKWARD_KERNEL(0)
case 1:
LAUNCH_BACKWARD_KERNEL(1)
case 2:
LAUNCH_BACKWARD_KERNEL(2)
case 3:
LAUNCH_BACKWARD_KERNEL(3)
case 4:
LAUNCH_BACKWARD_KERNEL(4)
case 5:
LAUNCH_BACKWARD_KERNEL(5)
case 6:
LAUNCH_BACKWARD_KERNEL(6)
case 7:
LAUNCH_BACKWARD_KERNEL(7)
case 8:
LAUNCH_BACKWARD_KERNEL(8)
case 9:
LAUNCH_BACKWARD_KERNEL(9)
case 10:
LAUNCH_BACKWARD_KERNEL(10)
default:
{
int grid_dim;
constexpr int block_size = 128;
constexpr int waves = 32;
auto cols = softmax_elements;
auto rows = batch_count;
GetNumBlocks(block_size, rows, waves, &grid_dim);
dim3 block(block_size);
const size_t smem = cols * sizeof(acc_t) + cols * sizeof(input_t) ;
softmax_block_backward<input_t, output_t, acc_t, block_size><<<grid_dim, block, smem>>>(
grad_input, grad, output, rows, cols);
}
}
}
}
......@@ -39,7 +39,7 @@ def test_softmax():
n_batch = 4
n_heads = 8
n_query = 128
test_dims = [64, 128, 256, 512, 1024]
test_dims = [64, 128, 256, 512, 1024, 1536, 2048]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
for last_dim in test_dims:
......@@ -83,7 +83,7 @@ def test_tri_softmax1():
n_groups = 32
n_heads = 8
n_query = 128
test_dims = [64, 128, 256, 512, 1024]
test_dims = [64, 128, 256, 512, 1024, 1536, 2048]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
for last_dim in test_dims:
......@@ -129,7 +129,7 @@ def test_tri_softmax2():
n_groups = 32
n_heads = 8
n_query = 128
test_dims = [64, 128, 256, 512, 1024]
test_dims = [64, 128, 256, 512, 1024, 1536, 2048]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
for last_dim in test_dims:
......
......@@ -94,7 +94,7 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None)
torch.Tensor: the result after softmax
"""
input = input.contiguous()
if input.is_cuda and input.shape[-1] <= 2048:
if input.is_cuda:
input_size = input.size()
if mask is not None:
_check_mask(mask, input)
......@@ -103,9 +103,14 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None)
_check_bias(bias, input)
bias = bias.contiguous().view(-1, input_size[-2], input_size[-1])
input = input.view(-1, input_size[-2], input_size[-1])
return SoftmaxDropoutFast.apply(
is_training, input, mask, bias, dropout_prob
).view(*input_size)
if dropout_prob <= 0.0 or input_size[-1] <= 1024:
return SoftmaxDropoutFast.apply(
is_training, input, mask, bias, dropout_prob
).view(*input_size)
else:
return F.dropout(SoftmaxDropoutFast.apply(
is_training, input, mask, bias, 0.0
).view(*input_size), p=dropout_prob, training=is_training)
else:
if mask is not None:
input += mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment