Unverified Commit ed719967 authored by yjk21's avatar yjk21 Committed by GitHub
Browse files

Adds small-batch kernels (#1126)

parent c1378e6f
...@@ -30,28 +30,6 @@ ...@@ -30,28 +30,6 @@
#include "fmha.h" #include "fmha.h"
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void set_params(Fused_multihead_attention_fprop_params &params, void set_params(Fused_multihead_attention_fprop_params &params,
// sizes // sizes
const size_t b, const size_t b,
...@@ -61,7 +39,6 @@ void set_params(Fused_multihead_attention_fprop_params &params, ...@@ -61,7 +39,6 @@ void set_params(Fused_multihead_attention_fprop_params &params,
// device pointers // device pointers
void *qkv_packed_d, void *qkv_packed_d,
void *cu_seqlens_d, void *cu_seqlens_d,
void *seqlens_d,
void *o_packed_d, void *o_packed_d,
void *s_d, void *s_d,
float p_dropout) { float p_dropout) {
...@@ -79,7 +56,6 @@ void set_params(Fused_multihead_attention_fprop_params &params, ...@@ -79,7 +56,6 @@ void set_params(Fused_multihead_attention_fprop_params &params,
params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type); params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type);
params.cu_seqlens = static_cast<int *>(cu_seqlens_d); params.cu_seqlens = static_cast<int *>(cu_seqlens_d);
params.seqlens = static_cast<int *>(seqlens_d);
// S = softmax(P) // S = softmax(P)
params.s_ptr = s_d; params.s_ptr = s_d;
...@@ -107,13 +83,9 @@ void set_params(Fused_multihead_attention_fprop_params &params, ...@@ -107,13 +83,9 @@ void set_params(Fused_multihead_attention_fprop_params &params,
set_alpha(params.scale_dropout, params.rp_dropout, data_type); set_alpha(params.scale_dropout, params.rp_dropout, data_type);
} }
constexpr uint32_t NUM_HEADS_DIM = 2;
constexpr uint32_t THREE_DIM = 1;
std::vector<at::Tensor> std::vector<at::Tensor>
mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1 const at::Tensor &cu_seqlens, // b+1
const at::Tensor &seqlens, // b
const float p_dropout, const float p_dropout,
const int max_seq_len, const int max_seq_len,
const bool is_training, const bool is_training,
...@@ -149,17 +121,14 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \ ...@@ -149,17 +121,14 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
TORCH_CHECK(qkv.dtype() == torch::kFloat16); TORCH_CHECK(qkv.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32); TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(seqlens.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda()) TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda()) TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous()) TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous()) TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1); TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4); TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes(); const auto sizes = qkv.sizes();
...@@ -167,10 +136,9 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \ ...@@ -167,10 +136,9 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
TORCH_CHECK(sizes[THREE_DIM] == 3); TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1; const int batch_size = cu_seqlens.numel() - 1;
TORCH_CHECK(seqlens.numel() == batch_size); const int total = sizes[TOTAL_DIM];
const int total = sizes[0]; const int num_heads = sizes[H_DIM];
const int num_heads = sizes[NUM_HEADS_DIM]; const int head_size = sizes[D_DIM];
const int head_size = sizes[3];
TORCH_CHECK(batch_size > 0); TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64); TORCH_CHECK(head_size == 64);
auto opts = qkv.options(); auto opts = qkv.options();
...@@ -191,7 +159,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \ ...@@ -191,7 +159,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
head_size, head_size,
qkv.data_ptr(), qkv.data_ptr(),
cu_seqlens.data_ptr(), cu_seqlens.data_ptr(),
seqlens.data_ptr(),
ctx.data_ptr(), ctx.data_ptr(),
s.data_ptr(), s.data_ptr(),
p_dropout); p_dropout);
...@@ -217,7 +184,6 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size ...@@ -217,7 +184,6 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1 const at::Tensor &cu_seqlens, // b+1
const at::Tensor &seqlens, // b
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel const int max_seq_len // max sequence length to choose the kernel
) { ) {
...@@ -247,17 +213,14 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size ...@@ -247,17 +213,14 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
TORCH_CHECK(dout.dtype() == torch::kFloat16); TORCH_CHECK(dout.dtype() == torch::kFloat16);
TORCH_CHECK(softmax.dtype() == torch::kFloat16); TORCH_CHECK(softmax.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32); TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(seqlens.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda()); TORCH_CHECK(qkv.is_cuda());
TORCH_CHECK(cu_seqlens.is_cuda()); TORCH_CHECK(cu_seqlens.is_cuda());
TORCH_CHECK(qkv.is_contiguous()); TORCH_CHECK(qkv.is_contiguous());
TORCH_CHECK(cu_seqlens.is_contiguous()); TORCH_CHECK(cu_seqlens.is_contiguous());
TORCH_CHECK(seqlens.is_contiguous());
TORCH_CHECK(cu_seqlens.dim() == 1); TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4); TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes(); const auto sizes = qkv.sizes();
...@@ -265,9 +228,8 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size ...@@ -265,9 +228,8 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
TORCH_CHECK(sizes[THREE_DIM] == 3); TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1; const int batch_size = cu_seqlens.numel() - 1;
TORCH_CHECK(seqlens.numel() == batch_size); const int num_heads = sizes[H_DIM];
const int num_heads = sizes[NUM_HEADS_DIM]; const int head_size = sizes[D_DIM];
const int head_size = sizes[3];
TORCH_CHECK(batch_size > 0); TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64); TORCH_CHECK(head_size == 64);
...@@ -282,12 +244,11 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size ...@@ -282,12 +244,11 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
head_size, head_size,
qkv.data_ptr(), qkv.data_ptr(),
cu_seqlens.data_ptr(), cu_seqlens.data_ptr(),
seqlens.data_ptr(),
dout.data_ptr(), // we set o_ptr to dout dout.data_ptr(), // we set o_ptr to dout
softmax.data_ptr(), // softmax gets overwritten by dP! softmax.data_ptr(), // softmax gets overwritten by dP!
p_dropout); p_dropout);
// we're re-using these scales scales // we're re-using these scales
Data_type acc_type = DATA_TYPE_FP32; Data_type acc_type = DATA_TYPE_FP32;
set_alpha(params.scale_bmm1, 1.f, acc_type); set_alpha(params.scale_bmm1, 1.f, acc_type);
set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type); set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
...@@ -298,8 +259,174 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size ...@@ -298,8 +259,174 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
return { dqkv, softmax }; return { dqkv, softmax };
} }
std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const float p_dropout,
const int max_seq_len,
const bool is_training,
c10::optional<at::Generator> gen_) {
int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80_nl;
TORCH_CHECK(max_seq_len == seq_len);
constexpr int warps_m = 1;
constexpr int warps_n = 4; // this leads to an upper bound
const int mmas_m = seq_len / 16 / warps_m;
const int mmas_n = seq_len / 16 / warps_n;
// static_assert( mmas_m == 32 );
// static_assert( mmas_n == 4 );
const int elts_per_thread = 8 * mmas_m * mmas_n;
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
auto opts = qkv.options();
auto ctx = torch::empty({ total, num_heads, head_size }, opts);
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
Fused_multihead_attention_fprop_params params;
set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
ctx.data_ptr(),
s.data_ptr(),
p_dropout);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t counter_offset = elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;
if( is_training ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}
int num_chunks = 3;
if(batch_size == 3) {
num_chunks = 2;
}
launch(params, is_training, num_chunks, stream);
return { ctx, s };
}
std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel
) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
int seq_len = 512;
auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl;
auto opts = qkv.options();
auto dqkv = torch::empty_like(qkv);
int num_chunks = 2;
if( batch_size == 1 ) {
num_chunks = 4;
}else if( batch_size == 2 ) {
num_chunks = 3;
}
auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts);
Fused_multihead_attention_fprop_params params;
set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
dout.data_ptr(), // o_ptr = dout
softmax.data_ptr(), // softmax gets overwritten by dP!
p_dropout);
params.dkv_ptr = dkv.data_ptr();
Data_type acc_type = DATA_TYPE_FP32;
set_alpha(params.scale_bmm1, 1.f, acc_type);
set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);
params.dqkv_ptr = dqkv.data_ptr();
launch(params, num_chunks, stream);
//SPLIT-K reduction of num_chunks dK, dV parts
// The equivalent of the following Pytorch code:
// using namespace torch::indexing;
// at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)});
// torch::sum_out(view_out, dkv, 1);
const int hidden_size = num_heads * head_size;
fmha_run_noloop_reduce(
dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total, num_chunks, stream);
return { dqkv, softmax, dkv };
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Fused Multi-head Self-attention for BERT"; m.doc() = "Fused Multi-head Self-attention for BERT";
m.def("fwd", &mha_fwd, "Forward pass"); m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass"); m.def("bwd", &mha_bwd, "Backward pass");
m.def("fwd_nl", &mha_fwd_nl, "Forward pass (small-batch)");
m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)");
} }
...@@ -35,6 +35,12 @@ ...@@ -35,6 +35,12 @@
#include <fmha_utils.h> #include <fmha_utils.h>
constexpr int TOTAL_DIM = 0;
constexpr int THREE_DIM = 1;
constexpr int H_DIM = 2;
constexpr int D_DIM = 3;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params { struct Qkv_params {
...@@ -43,6 +49,9 @@ struct Qkv_params { ...@@ -43,6 +49,9 @@ struct Qkv_params {
// The stride between rows of the Q, K and V matrices. // The stride between rows of the Q, K and V matrices.
size_t qkv_stride_in_bytes; size_t qkv_stride_in_bytes;
// The number of heads.
int h;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -52,6 +61,9 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params { ...@@ -52,6 +61,9 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
// The dQKV matrices. // The dQKV matrices.
void *dqkv_ptr; void *dqkv_ptr;
// Temporary for dKV.
void *dkv_ptr;
// The O matrix (output). // The O matrix (output).
void *o_ptr; void *o_ptr;
...@@ -64,7 +76,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params { ...@@ -64,7 +76,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
int64_t s_stride_in_bytes; int64_t s_stride_in_bytes;
// The dimensions. // The dimensions.
int b, h, s, d; int b, s, d;
// The scaling factors for the kernel. // The scaling factors for the kernel.
uint32_t scale_bmm1, scale_softmax, scale_bmm2; uint32_t scale_bmm1, scale_softmax, scale_bmm2;
...@@ -72,9 +84,6 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params { ...@@ -72,9 +84,6 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
// array of length b+1 holding starting offset of each sequence. // array of length b+1 holding starting offset of each sequence.
int *cu_seqlens; int *cu_seqlens;
// array of length b holding the actual sequence lenghts.
int *seqlens;
// The dropout probability (probability of keeping an activation). // The dropout probability (probability of keeping an activation).
float p_dropout; float p_dropout;
...@@ -90,3 +99,27 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params { ...@@ -90,3 +99,27 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const bool is_training, const int num_chunks, cudaStream_t stream);
void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const int num_chunks, cudaStream_t stream);
void fmha_run_noloop_reduce(void *out,
const void *in,
const int *cu_seqlens,
const int hidden_size,
const int batch_size,
const int total,
const int num_chunks,
cudaStream_t stream);
...@@ -39,7 +39,9 @@ template< ...@@ -39,7 +39,9 @@ template<
// The number of rows of Q, K or V loaded by this tile. // The number of rows of Q, K or V loaded by this tile.
int ROWS, int ROWS,
// The number of columns. // The number of columns.
int COLS int COLS,
// The number of matrics.
int NUM_MATS = 3
> >
struct Gmem_tile_qkv { struct Gmem_tile_qkv {
...@@ -74,7 +76,7 @@ struct Gmem_tile_qkv { ...@@ -74,7 +76,7 @@ struct Gmem_tile_qkv {
// The row offset in the batched GEMM. For each seq element, we store QKV in that order. // The row offset in the batched GEMM. For each seq element, we store QKV in that order.
int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes; int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
// Add the block index. // Add the block index.
row_offset += (int64_t)((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
// Assemble the final pointer. // Assemble the final pointer.
qkv_ptr_ += row_offset + col * BYTES_PER_LDG; qkv_ptr_ += row_offset + col * BYTES_PER_LDG;
......
...@@ -1217,7 +1217,8 @@ struct Smem_tile_mma_epilogue : public Base { ...@@ -1217,7 +1217,8 @@ struct Smem_tile_mma_epilogue : public Base {
enum { WARPS_M = Base::WARPS_M }; enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N }; enum { WARPS_N = Base::WARPS_N };
static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
using Fragment = typename Base::Fragment;
using Acc = fmha::Fragment_accumulator;
inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) { inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) {
const int read_row = tidx / THREADS_PER_ROW; const int read_row = tidx / THREADS_PER_ROW;
...@@ -1233,6 +1234,40 @@ struct Smem_tile_mma_epilogue : public Base { ...@@ -1233,6 +1234,40 @@ struct Smem_tile_mma_epilogue : public Base {
} }
} }
template<int M, int N>
inline __device__ void store(const Acc (&acc)[M][N]){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// 1st row - 4 elements per row.
float tmp00 = acc[mi][ni].elt(0);
float tmp01 = acc[mi][ni].elt(1);
float tmp02 = acc[mi][ni].elt(4);
float tmp03 = acc[mi][ni].elt(5);
// 2nd row - 4 elements per row.
float tmp10 = acc[mi][ni].elt(2);
float tmp11 = acc[mi][ni].elt(3);
float tmp12 = acc[mi][ni].elt(6);
float tmp13 = acc[mi][ni].elt(7);
uint32_t x = fmha::float2_to_half2(tmp00, tmp01);
uint32_t y = fmha::float2_to_half2(tmp02, tmp03);
uint32_t z = fmha::float2_to_half2(tmp10, tmp11);
uint32_t w = fmha::float2_to_half2(tmp12, tmp13);
size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z);
offset ^= 4 * Base::BYTES_PER_STS;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
}
}
}
template<int M, int N> template<int M, int N>
inline __device__ void store(const uint4 (&regs)[M][N]) { inline __device__ void store(const uint4 (&regs)[M][N]) {
for( int mi = 0; mi < M; mi++ ) { for( int mi = 0; mi < M; mi++ ) {
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h" #include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload_nl.h"
using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>; using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;
...@@ -35,6 +36,13 @@ extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_at ...@@ -35,6 +36,13 @@ extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_at
fmha::compute_dq_dk_1xN<Kernel_traits>(params); fmha::compute_dq_dk_1xN<Kernel_traits>(params);
} }
template<int CHUNKS>
__global__
void fmha_dgrad_fp16_512_64_sm80_nl_kernel(Fused_multihead_attention_fprop_params params){
fmha::compute_dv_1xN_nl<CHUNKS, Kernel_traits>(params);
fmha::compute_dq_dk_1xN_nl<CHUNKS, Kernel_traits>(params);
}
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) { void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
...@@ -58,3 +66,40 @@ void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_param ...@@ -58,3 +66,40 @@ void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_param
dim3 grid(params.h, params.b); dim3 grid(params.h, params.b);
fmha_dgrad_fp16_512_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params); fmha_dgrad_fp16_512_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
} }
void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const int num_chunks, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed<Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 512 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
auto kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;
if( num_chunks == 2 ) {
kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;
}else if( num_chunks == 3 ) {
kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<3>;
} else {
assert(false && "Unsupperted number of chunks");
}
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b, num_chunks);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
...@@ -156,8 +156,10 @@ inline __device__ void compute_dv_1xN(const Params &params) { ...@@ -156,8 +156,10 @@ inline __device__ void compute_dv_1xN(const Params &params) {
fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N]; fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dv::WARPS_K>::apply(acc_dv); fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dv::WARPS_K>::apply(acc_dv);
enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };
// Load over the entire sequence length. // Load over the entire sequence length.
for( int loop = 0, outer = 0; loop < Cta_tile_p::N; loop += Cta_tile_p::M, outer++ ) { for( int l = 0; l < STEPS; l++ ) {
const int loop = l * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen ) if( loop >= binfo.actual_seqlen )
break; break;
...@@ -185,6 +187,13 @@ inline __device__ void compute_dv_1xN(const Params &params) { ...@@ -185,6 +187,13 @@ inline __device__ void compute_dv_1xN(const Params &params) {
int ki = Mma_tile_p::MMAS_K; int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
} }
// Trigger the load for the next Q values. We're using double buffering, so reading qt is safe
if( l < STEPS - 1) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
// Convert from the accumulator type to FP32 for Softmax. // Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p); softmax.unpack(acc_p);
...@@ -203,8 +212,6 @@ inline __device__ void compute_dv_1xN(const Params &params) { ...@@ -203,8 +212,6 @@ inline __device__ void compute_dv_1xN(const Params &params) {
} }
} }
float d_s[2 * M][4 * N];
#pragma unroll #pragma unroll
for( int mi = 0; mi < M; mi++ ) { for( int mi = 0; mi < M; mi++ ) {
#pragma unroll #pragma unroll
...@@ -213,10 +220,11 @@ inline __device__ void compute_dv_1xN(const Params &params) { ...@@ -213,10 +220,11 @@ inline __device__ void compute_dv_1xN(const Params &params) {
for( int ni = 0; ni < N; ni++ ) { for( int ni = 0; ni < N; ni++ ) {
#pragma unroll #pragma unroll
for( int jj = 0; jj < 4; jj++ ) { for( int jj = 0; jj < 4; jj++ ) {
const float s_dmask = s_mat[2 * mi + ii][4 * ni + jj]; float & s_dmask = s_mat[2 * mi + ii][4 * ni + jj];
const bool drop = reinterpret_cast<const uint32_t &>(s_dmask) & 0x80000000; const bool drop = reinterpret_cast<const uint32_t &>(s_dmask) & 0x80000000;
d_s[2 * mi + ii][4 * ni + jj] = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout; const float d_s = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;
softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s[2 * mi + ii][4 * ni + jj] * fabsf(s_dmask); s_dmask = fabsf(s_dmask);
softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * fabsf(s_dmask);
} }
} }
} }
...@@ -225,6 +233,7 @@ inline __device__ void compute_dv_1xN(const Params &params) { ...@@ -225,6 +233,7 @@ inline __device__ void compute_dv_1xN(const Params &params) {
float p_sum[2 * M]; float p_sum[2 * M];
softmax.template reduce<fmha::Sum_>(p_sum); softmax.template reduce<fmha::Sum_>(p_sum);
const float scalef = reinterpret_cast<const float &>(params.scale_softmax);
#pragma unroll #pragma unroll
for( int mi = 0; mi < M; mi++ ) { for( int mi = 0; mi < M; mi++ ) {
#pragma unroll #pragma unroll
...@@ -233,20 +242,12 @@ inline __device__ void compute_dv_1xN(const Params &params) { ...@@ -233,20 +242,12 @@ inline __device__ void compute_dv_1xN(const Params &params) {
for( int ni = 0; ni < N; ni++ ) { for( int ni = 0; ni < N; ni++ ) {
#pragma unroll #pragma unroll
for( int jj = 0; jj < 4; jj++ ) { for( int jj = 0; jj < 4; jj++ ) {
const float scalef = reinterpret_cast<const float &>(params.scale_softmax); softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]) ;
softmax.elt_[2 * mi + ii][4 * ni + jj] = (d_s[2 * mi + ii][4 * ni + jj] - p_sum[2 * mi + ii]) * softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef;
fabsf(s_mat[2 * mi + ii][4 * ni + jj]) * scalef;
} }
} }
} }
} }
// Trigger the load for the next Q values. We're using double buffering, so reading qt is safe
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M]; typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M];
smem_s.load(frag_s); smem_s.load(frag_s);
for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) { for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) {
...@@ -275,7 +276,7 @@ inline __device__ void compute_dv_1xN(const Params &params) { ...@@ -275,7 +276,7 @@ inline __device__ void compute_dv_1xN(const Params &params) {
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
} }
// Commit the values for Q into shared memory. // Commit the values for Q into shared memory.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) { if(l < STEPS - 1) {
gmem_q.commit(smem_q); gmem_q.commit(smem_q);
} }
...@@ -295,36 +296,15 @@ inline __device__ void compute_dv_1xN(const Params &params) { ...@@ -295,36 +296,15 @@ inline __device__ void compute_dv_1xN(const Params &params) {
// Epilogue swizzle for dV // Epilogue swizzle for dV
Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx); Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx);
uint4 dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N]; smem_dv.store(acc_dv);
#pragma unroll
for( int mi = 0; mi < Mma_tile_dv::MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_dv::MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
float tmp00 = acc_dv[mi][ni].elt(0);
float tmp01 = acc_dv[mi][ni].elt(1);
float tmp02 = acc_dv[mi][ni].elt(4);
float tmp03 = acc_dv[mi][ni].elt(5);
// 2nd row - 4 elements per row.
float tmp10 = acc_dv[mi][ni].elt(2);
float tmp11 = acc_dv[mi][ni].elt(3);
float tmp12 = acc_dv[mi][ni].elt(6);
float tmp13 = acc_dv[mi][ni].elt(7);
dv[mi][ni].x = fmha::float2_to_half2(tmp00, tmp01);
dv[mi][ni].y = fmha::float2_to_half2(tmp02, tmp03);
dv[mi][ni].z = fmha::float2_to_half2(tmp10, tmp11);
dv[mi][ni].w = fmha::float2_to_half2(tmp12, tmp13);
}
}
smem_dv.store(dv);
__syncthreads(); __syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS]; uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out); smem_dv.load(dv_out);
Qkv_params dv_params; Qkv_params dv_params;
dv_params.qkv_ptr = params.dqkv_ptr; dv_params.qkv_ptr = params.dqkv_ptr;
dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes; dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dv_params.h = params.h;
Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx); Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx);
gmem_dv.store(dv_out); gmem_dv.store(dv_out);
} }
...@@ -447,13 +427,15 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) { ...@@ -447,13 +427,15 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) {
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
enum { THREADS_PER_ROW = 32 }; enum { THREADS_PER_ROW = 32 };
enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };
// Declare the accumulators for the 2nd gemm. // Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N]; fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dk::WARPS_K>::apply(acc_dk); fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dk::WARPS_K>::apply(acc_dk);
// Load over the entire sequence length. // Load over the entire sequence length.
for( int loop = 0, outer = 0; loop < Cta_tile_p::N; loop += Cta_tile_p::M, outer++ ) { for( int l=0;l<STEPS;l++) {
const int loop = l * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen ) if( loop >= binfo.actual_seqlen )
break; break;
...@@ -492,7 +474,7 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) { ...@@ -492,7 +474,7 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) {
// Store dP to smem for transpose // Store dP to smem for transpose
smem_s.store(s_regs); smem_s.store(s_regs);
if( loop + Cta_tile_p::M < Cta_tile_p::N ) { if(l < STEPS - 1) {
// Load next part of S // Load next part of S
gmem_s.load(s_regs, mask); gmem_s.load(s_regs, mask);
gmem_s.move(); gmem_s.move();
...@@ -544,7 +526,7 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) { ...@@ -544,7 +526,7 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) {
} }
// Commit the values for Q into shared memory. // Commit the values for Q into shared memory.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) { if( l < STEPS - 1) {
gmem_q.commit(smem_q); gmem_q.commit(smem_q);
} }
...@@ -559,37 +541,14 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) { ...@@ -559,37 +541,14 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) {
// Epilogue swizzle for dK // Epilogue swizzle for dK
Smem_tile_dk smem_dk(&smem_[0], tidx); Smem_tile_dk smem_dk(&smem_[0], tidx);
uint4 dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N]; smem_dk.store(acc_dk);
#pragma unroll
for( int mi = 0; mi < Mma_tile_dk::MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_dk::MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
float tmp00 = acc_dk[mi][ni].elt(0);
float tmp01 = acc_dk[mi][ni].elt(1);
float tmp02 = acc_dk[mi][ni].elt(4);
float tmp03 = acc_dk[mi][ni].elt(5);
// 2nd row - 4 elements per row.
float tmp10 = acc_dk[mi][ni].elt(2);
float tmp11 = acc_dk[mi][ni].elt(3);
float tmp12 = acc_dk[mi][ni].elt(6);
float tmp13 = acc_dk[mi][ni].elt(7);
dk[mi][ni].x = fmha::float2_to_half2(tmp00, tmp01);
dk[mi][ni].y = fmha::float2_to_half2(tmp02, tmp03);
dk[mi][ni].z = fmha::float2_to_half2(tmp10, tmp11);
dk[mi][ni].w = fmha::float2_to_half2(tmp12, tmp13);
}
}
smem_dk.store(dk);
__syncthreads(); __syncthreads();
uint4 dk_out[Smem_tile_dk::NUM_LDS]; uint4 dk_out[Smem_tile_dk::NUM_LDS];
smem_dk.load(dk_out); smem_dk.load(dk_out);
Qkv_params dk_params; Qkv_params dk_params;
dk_params.qkv_ptr = params.dqkv_ptr; dk_params.qkv_ptr = params.dqkv_ptr;
dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes; dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dk_params.h = params.h;
Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx); Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx);
gmem_dk.store(dk_out); gmem_dk.store(dk_out);
} }
......
This diff is collapsed.
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_fprop_kernel_1xN.h" #include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN_nl.h"
using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>; using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;
...@@ -38,6 +39,17 @@ extern "C" __global__ void fmha_fprop_fp16_512_64_sm80_predict_kernel(Fused_mult ...@@ -38,6 +39,17 @@ extern "C" __global__ void fmha_fprop_fp16_512_64_sm80_predict_kernel(Fused_mult
fmha::device_1xN<Kernel_traits, false>(params); fmha::device_1xN<Kernel_traits, false>(params);
} }
template<int CHUNKS>
__global__ void fmha_fprop_fp16_512_64_sm80_train_nl_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN_nl<CHUNKS,Kernel_traits, true>(params);
}
template<int CHUNKS>
__global__ void fmha_fprop_fp16_512_64_sm80_predict_nl_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN_nl<CHUNKS, Kernel_traits, false>(params);
}
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) { void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_kernel : &fmha_fprop_fp16_512_64_sm80_predict_kernel; auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_kernel : &fmha_fprop_fp16_512_64_sm80_predict_kernel;
...@@ -54,3 +66,33 @@ void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &par ...@@ -54,3 +66,33 @@ void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &par
dim3 grid(params.h, params.b); dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params); kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
} }
void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const bool is_training, const int num_chunks, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<2> : &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<2>;
if( num_chunks == 2 ) {
kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<2>
: &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<2>;
} else if( num_chunks == 3 ) {
kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<3>
: &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<3>;
} else if( num_chunks == 4 ) {
kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<4>
: &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<4>;
} else {
assert(false && "Unsupported num_chunks");
}
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b, num_chunks);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
...@@ -174,9 +174,11 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de ...@@ -174,9 +174,11 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx); Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);
enum { THREADS_PER_ROW = 32 }; enum { THREADS_PER_ROW = 32 };
enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };
// Load over the entire sequence length. // Load over the entire sequence length.
for( int loop = 0, outer = 0; loop < Cta_tile_p::N; loop += Cta_tile_p::M, outer++ ) { for( int l = 0; l < STEPS; l++ ) {
const int loop = l * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen ) if( loop >= binfo.actual_seqlen )
break; break;
...@@ -200,12 +202,8 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de ...@@ -200,12 +202,8 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
} }
// Store the P matrix.
#if defined(STORE_P)
gmem_p.store(acc_p);
#endif
// Load the mask for that iteration. // Load the mask for that iteration.
mask.load(outer); mask.load(l);
// Convert from the accumulator type to FP32 for Softmax. // Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p); softmax.unpack(acc_p);
...@@ -213,7 +211,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de ...@@ -213,7 +211,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
// Apply the mask. // Apply the mask.
softmax.apply_mask(mask); softmax.apply_mask(mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && loop == 0 ) { if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads(); __syncthreads();
} }
...@@ -261,7 +259,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de ...@@ -261,7 +259,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
} }
// Trigger the load for the next Q values. // Trigger the load for the next Q values.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) { if(l < STEPS - 1) {
smem_q.move_to_next_write_buffer(); smem_q.move_to_next_write_buffer();
gmem_q.move(); gmem_q.move();
gmem_q.load(smem_q); gmem_q.load(smem_q);
...@@ -320,7 +318,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de ...@@ -320,7 +318,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
gmem_o.move(); gmem_o.move();
// Commit the values for Q into shared memory. // Commit the values for Q into shared memory.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) { if(l < STEPS - 1) {
gmem_q.commit(smem_q); gmem_q.commit(smem_q);
} }
......
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int CHUNKS, typename Kernel_traits, bool Is_training, typename Params>
inline __device__ void device_1xN_nl(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
// The global memory tile to store S/D.
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;
// Shared memory.
extern __shared__ char smem_[];
const int bidc = blockIdx.z;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
Noloop nl_traits(bidc);
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));
fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[0], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = nullptr;
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE];
} else {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];
}
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
nl_traits.move_all(gmem_q, gmem_o, gmem_s);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
// Trigger the loads for K.
gmem_v.load(smem_v);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Commit the data for V to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_v.commit(smem_v);
}
// Make sure the data is in shared memory.
__syncthreads();
// Load the fragments for Q.
typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];
smem_q.load(frag_q[0], 0);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
smem_k.load(frag_k[ki], ki);
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_v.commit(smem_v);
// Make sure the data is in shared memory.
__syncthreads();
}
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
smem_v.load(frag_v[ki], ki);
}
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
// Create the object to do the softmax.
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);
// The number of threads per row.
enum { THREADS_PER_ROW = 32 };
// Load over the entire sequence length.
for(int l = 0; l < nl_traits.num_steps_;l++) {
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_q.load(frag_q[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Trigger the load for the next Q values.
if( l < nl_traits.num_steps_- 1) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
// Load the mask for that iteration.
mask.load(nl_traits.loop_offset_ + l);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads();
}
// Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Max_>(p_max);
// Make sure we are done reading shared memory.
__syncthreads();
// Compute the exponential value.
softmax.apply_exp(p_max);
// Compute the sum.
float p_sum[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Sum_>(p_sum);
// Finalize softmax on the accumulators of P^T.
softmax.scale(p_sum);
if( Is_training ) {
auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
float4 tmp = uniform4(ph());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros
softmax.elt_[2 * mi + ii][4 * ni + 0] =
encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);
softmax.elt_[2 * mi + ii][4 * ni + 1] =
encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);
softmax.elt_[2 * mi + ii][4 * ni + 2] =
encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);
softmax.elt_[2 * mi + ii][4 * ni + 3] =
encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);
}
}
}
gmem_s.store(softmax.elt_, mask);
gmem_s.move();
}
using Frag_p = fmha::Fragment_a<fmha::Row>;
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
softmax.pack(frag_p);
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {
//"Apply" the dropout.
frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);
frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);
}
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out);
// Make sure the data was read from shared memory.
if( ii < Gmem_tile_o::LOOPS - 1 ) {
__syncthreads();
}
// Output the values.
gmem_o.store(out, ii);
}
// Move to the next part of the output.
gmem_o.move();
// Commit the values for Q into shared memory.
if( l < nl_traits.num_steps_- 1) {
gmem_q.commit(smem_q);
__syncthreads();
smem_q.load(frag_q[0], 0);
}
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
...@@ -40,94 +40,130 @@ namespace fmha { ...@@ -40,94 +40,130 @@ namespace fmha {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <int FMHA_VERSION> struct BlockInfo {}; template<int THREADS_PER_CTA>
struct BlockInfoPadded {
template <> struct BlockInfo<1> {
int actual_seqlen;
int bidx;
int sum_s;
int bidh;
int bidb;
template<typename Params> template<typename Params>
__device__ BlockInfo( const Params &params, __device__ BlockInfoPadded(const Params &params,
const int bidb, const int bidb,
const int bidh, const int bidh,
const int tidx ) const int tidx)
: bidb( bidb ), bidh( bidh ) { : bidb(bidb), bidh(bidh), h(params.h) {
// The block index. // The block index.
sum_s = params.b * params.s; sum_s = params.cu_seqlens[bidb];
actual_seqlen = params.s; actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s;
bidx = bidb * params.h + bidh; bidx = sum_s * params.h + bidh;
tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;
} }
__device__ bool stop_early() const { __device__ bool stop_early() const {
return false; return actual_seqlen == 0;
} }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <> struct BlockInfo<2> {
int actual_seqlen; int actual_seqlen;
int bidx; int bidx;
int sum_s; int sum_s;
int bidh; int bidh;
int bidb; int bidb;
int tidx_global;
int h;
};
template<typename Params> ////////////////////////////////////////////////////////////////////////////////////////////////////
__device__ BlockInfo( const Params &params,
const int bidb,
const int bidh,
const int tidx )
: bidb( bidb ), bidh( bidh ) {
// The block index. template<int CHUNKS, typename Cta_tile>
sum_s = params.cu_seqlens[bidb]; struct Noloop_traits{
actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s; // Interpretation of Cta_tile dims, i.e. Cta_tile_p:
bidx = sum_s * params.h + bidh; enum{ STEP = Cta_tile::M };
enum{ SEQLEN = Cta_tile::N };
// The size of the subsequence this CTA is processing
enum { SUBSEQ = SEQLEN / CHUNKS };
static_assert(SUBSEQ * CHUNKS == SEQLEN);
// The number of steps to process the subsequence
enum { NUM_STEPS = SUBSEQ / STEP };
static_assert(NUM_STEPS * Cta_tile::M == SUBSEQ);
inline __device__ Noloop_traits(const int bidc)
: loop_offset_(NUM_STEPS * bidc)
, bidc_(bidc) {
} }
__device__ bool stop_early() const { template<typename ... Tiles>
return actual_seqlen == 0; inline __device__ void move_all(Tiles & ... tiles) const {
using expand_type = int[];
for( int s = 0; s < loop_offset_; s++ ) {
expand_type{ (tiles.move(), 0)... };
}
}
inline __device__ int get_idx_dk() const {
//return bidc_;
return bidc_ * 2 + 0;
}
inline __device__ int get_idx_dv() const {
//return CHUNKS + bidc_;
return bidc_ * 2 + 1;
}
inline __device__ int offset_loop_count(const int l) {
// convert loop counter to position in the outer sequence
return (loop_offset_ + l) * STEP;
} }
const int loop_offset_;
const uint32_t bidc_;
const int num_steps_ = NUM_STEPS;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS_PER_CTA> template<typename Cta_tile>
struct BlockInfoPadded { struct Noloop_traits<3, Cta_tile>{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum{ STEP = Cta_tile::M };
enum{ SEQLEN = Cta_tile::N };
template<typename Params> static_assert(STEP == 16 && SEQLEN == 512);
__device__ BlockInfoPadded( const Params &params,
const int bidb,
const int bidh,
const int tidx )
: bidb( bidb ), bidh( bidh ), h(params.h) {
// The block index. inline __device__ Noloop_traits(const int bidc)
sum_s = params.cu_seqlens[bidb]; : bidc_(bidc)
actual_seqlen = params.seqlens[bidb]; , num_steps_(bidc < 2 ? 11 : 10)
bidx = sum_s * params.h + bidh; , loop_offset_(bidc * 11) {
}
tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx; template<typename ... Tiles>
inline __device__ void move_all(Tiles & ... tiles) const {
using expand_type = int[];
for( int s = 0; s < loop_offset_; s++ ) {
expand_type{ (tiles.move(), 0)... };
}
} }
__device__ bool stop_early() const { inline __device__ int get_idx_dk() const {
return actual_seqlen == 0; //return bidc_;
return bidc_ * 2 + 0;
} }
int actual_seqlen; inline __device__ int get_idx_dv() const {
int bidx; //return CHUNKS + bidc_;
int sum_s; return bidc_ * 2 + 1;
int bidh; }
int bidb;
int tidx_global; inline __device__ int offset_loop_count(const int l) {
int h; // convert loop counter to position in the outer sequence
return (loop_offset_ + l) * STEP;
}
const int loop_offset_;
const uint32_t bidc_;
const int num_steps_;
}; };
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha } // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
inline __device__ float4 ldg128(const void *ptr) {
return *static_cast<const float4 *>(ptr);
}
inline __device__ void stg128(void *ptr, const float4 &data) {
*static_cast<float4 *>(ptr) = data;
}
template<typename T, int THREADS, int HIDDEN_SIZE, int CHUNKS>
__global__ __launch_bounds__(THREADS) void fmha_noloop_reduce_kernel(void *__restrict__ out,
const void *__restrict__ in,
const int *__restrict__ cu_seqlens,
const int batch_size) {
enum { BYTES_PER_LDG = 16 };
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(T) };
// One CTA hidden vector for K and V
enum { BYTES_PER_ROW = HIDDEN_SIZE * sizeof(T) * 2 };
// The stride in bytes in dQKV
enum { OUT_STRIDE_BYTES = 3 * HIDDEN_SIZE * sizeof(T) };
// The offset in bytes in dQKV to the dKV part for non-interleaved heads
enum { OUT_OFFSET_KV_BYTES = HIDDEN_SIZE * sizeof(T) };
static_assert(BYTES_PER_ROW == HIDDEN_SIZE * 2 * sizeof(T));
// Size in bytes of the input tile
enum { BYTES_PER_TILE = CHUNKS * BYTES_PER_ROW };
enum { BYTES_PER_CTA = THREADS * BYTES_PER_LDG };
enum { LDGS = BYTES_PER_ROW / BYTES_PER_CTA };
static_assert(BYTES_PER_CTA * LDGS == BYTES_PER_ROW);
union Vec_t {
float4 raw;
T elt[NUM_ELTS];
};
// ZERO-OUT invalid positions in dQKV
const int total = cu_seqlens[batch_size];
if(blockIdx.x >= total){
enum { BYTES_PER_QKV_ROW = 3 * HIDDEN_SIZE * sizeof(T) };
enum { STGS = BYTES_PER_QKV_ROW / BYTES_PER_LDG };
const float4 zeros = make_float4(0.f, 0.f, 0.f, 0.f);
char *base_ptr = static_cast<char *>(out) + blockIdx.x * OUT_STRIDE_BYTES;
for(int tidx = threadIdx.x; tidx < STGS; tidx += THREADS){
stg128(base_ptr + tidx * BYTES_PER_LDG, zeros);
}
return;
}
// SETUP
const int offset_in = blockIdx.x * BYTES_PER_TILE + threadIdx.x * BYTES_PER_LDG;
const char *ptr_in = static_cast<const char *>(in) + offset_in;
const int offset_out = blockIdx.x * OUT_STRIDE_BYTES + threadIdx.x * BYTES_PER_LDG;
char *ptr_out = static_cast<char *>(out) + OUT_OFFSET_KV_BYTES + offset_out;
// LOAD
Vec_t local_in[CHUNKS][LDGS];
#pragma unroll
for( int c = 0; c < CHUNKS; c++ ) {
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
int offset = c * BYTES_PER_ROW + l * BYTES_PER_CTA;
local_in[c][l].raw = ldg128(ptr_in + offset);
}
}
// UNPACK
float acc[LDGS][NUM_ELTS];
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
#pragma unroll
for( int e = 0; e < NUM_ELTS; e++ ) {
acc[l][e] = float(local_in[0][l].elt[e]);
}
}
// COMPUTE
#pragma unroll
for( int c = 1; c < CHUNKS; c++ ) {
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
#pragma unroll
for( int e = 0; e < NUM_ELTS; e++ ) {
acc[l][e] += float(local_in[c][l].elt[e]);
}
}
}
// PACK
Vec_t local_out[LDGS];
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
#pragma unroll
for( int e = 0; e < NUM_ELTS; e++ ) {
local_out[l].elt[e] = T(acc[l][e]);
}
}
// STORE
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
const int offset = l * BYTES_PER_CTA;
stg128(ptr_out + offset, local_out[l].raw);
}
}
void fmha_run_noloop_reduce(void *out,
const void *in,
const int *cu_seqlens,
const int hidden_size,
const int batch_size,
const int total,
const int num_chunks,
cudaStream_t stream) {
const int blocks = total;
if(hidden_size == 1024){
constexpr int HIDDEN_SIZE = 1024;
constexpr int THREADS = 256;
if( num_chunks == 2 ) {
fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 2><<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);
} else if( num_chunks == 3 ) {
fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 3><<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);
} else {
assert(false && "Unsupported num_chunks");
}
}else{
assert(false && "Unsupported hidden_size");
}
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
...@@ -32,11 +32,14 @@ import fmhalib as mha ...@@ -32,11 +32,14 @@ import fmhalib as mha
class FMHAFun(torch.autograd.Function): class FMHAFun(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, seqlens, p_dropout, max_s, is_training): def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training):
context, S_dmask = mha.fwd(qkv, cu_seqlens, seqlens, p_dropout, max_s, is_training, None) batch_size = cu_seqlens.numel() - 1
if batch_size < 4:
context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, None)
else:
context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, None)
ctx.save_for_backward(qkv, S_dmask) ctx.save_for_backward(qkv, S_dmask)
ctx.cu_seqlens = cu_seqlens ctx.cu_seqlens = cu_seqlens
ctx.seqlens = seqlens
ctx.p_dropout = p_dropout ctx.p_dropout = p_dropout
ctx.max_s = max_s ctx.max_s = max_s
return context return context
...@@ -44,7 +47,11 @@ class FMHAFun(torch.autograd.Function): ...@@ -44,7 +47,11 @@ class FMHAFun(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dout): def backward(ctx, dout):
qkv, S_dmask = ctx.saved_tensors qkv, S_dmask = ctx.saved_tensors
dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.seqlens, ctx.p_dropout, ctx.max_s) batch_size = ctx.cu_seqlens.numel() - 1
if batch_size < 4:
dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)
else:
dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)
return dqkv, None, None, None, None, None, None return dqkv, None, None, None, None, None, None
...@@ -60,8 +67,8 @@ class FMHA(torch.nn.Module): ...@@ -60,8 +67,8 @@ class FMHA(torch.nn.Module):
self.d = self.hidden_size // self.h self.d = self.hidden_size // self.h
assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads" assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads"
def forward(self, qkv, cu_seqlens, seqlens, max_s, is_training=True): def forward(self, qkv, cu_seqlens, max_s, is_training=True):
ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, seqlens, self.p_dropout, max_s, is_training) ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training)
return ctx.view(-1, self.hidden_size) return ctx.view(-1, self.hidden_size)
...@@ -51,7 +51,8 @@ def py_mha(qkv, amask, b, s, h, d): ...@@ -51,7 +51,8 @@ def py_mha(qkv, amask, b, s, h, d):
class TestFMHA(unittest.TestCase): class TestFMHA(unittest.TestCase):
def run_test(self, s): def run_test(self, s, b):
print(f'Test s={s} b={b}')
torch.manual_seed(1234) torch.manual_seed(1234)
torch.cuda.manual_seed(1234) torch.cuda.manual_seed(1234)
...@@ -59,7 +60,6 @@ class TestFMHA(unittest.TestCase): ...@@ -59,7 +60,6 @@ class TestFMHA(unittest.TestCase):
dtype = torch.float16 dtype = torch.float16
device = torch.device('cuda') device = torch.device('cuda')
b = 32
h = 16 h = 16
d = 64 d = 64
...@@ -76,7 +76,10 @@ class TestFMHA(unittest.TestCase): ...@@ -76,7 +76,10 @@ class TestFMHA(unittest.TestCase):
qkv.requires_grad = True qkv.requires_grad = True
ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, seqlens, 0.0, s, True, None) if b < 4:
ctx, S_ = mha.fwd_nl(qkv_vs, cu_seqlens, 0.0, s, True, None)
else:
ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, None)
ctx = ctx.view(b,s,h,d) ctx = ctx.view(b,s,h,d)
ctx_ref = py_mha(qkv, amask, b,s,h,d) ctx_ref = py_mha(qkv, amask, b,s,h,d)
...@@ -91,23 +94,28 @@ class TestFMHA(unittest.TestCase): ...@@ -91,23 +94,28 @@ class TestFMHA(unittest.TestCase):
dw2 = dw.permute(0,2,1,3).clone().detach().contiguous() dw2 = dw.permute(0,2,1,3).clone().detach().contiguous()
dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, seqlens, 0.0, s) if b < 4:
dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)
else:
dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)
dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d) dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d)
self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3)) self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))
def test_128(self): def test_128(self):
self.run_test(128) self.run_test(128, 32)
def test_256(self): def test_256(self):
self.run_test(256) self.run_test(256, 32)
def test_384(self): def test_384(self):
self.run_test(384) self.run_test(384, 32)
def test_512(self): def test_512(self):
self.run_test(512) self.run_test(512, 32)
self.run_test(512, 2)
self.run_test(512, 3)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -349,6 +349,7 @@ if "--fmha" in sys.argv: ...@@ -349,6 +349,7 @@ if "--fmha" in sys.argv:
CUDAExtension(name='fmhalib', CUDAExtension(name='fmhalib',
sources=[ sources=[
'apex/contrib/csrc/fmha/fmha_api.cpp', 'apex/contrib/csrc/fmha/fmha_api.cpp',
'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu', 'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu', 'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu', 'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment