Unverified Commit ed719967 authored by yjk21's avatar yjk21 Committed by GitHub
Browse files

Adds small-batch kernels (#1126)

parent c1378e6f
......@@ -30,28 +30,6 @@
#include "fmha.h"
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void set_params(Fused_multihead_attention_fprop_params &params,
// sizes
const size_t b,
......@@ -61,7 +39,6 @@ void set_params(Fused_multihead_attention_fprop_params &params,
// device pointers
void *qkv_packed_d,
void *cu_seqlens_d,
void *seqlens_d,
void *o_packed_d,
void *s_d,
float p_dropout) {
......@@ -79,7 +56,6 @@ void set_params(Fused_multihead_attention_fprop_params &params,
params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type);
params.cu_seqlens = static_cast<int *>(cu_seqlens_d);
params.seqlens = static_cast<int *>(seqlens_d);
// S = softmax(P)
params.s_ptr = s_d;
......@@ -107,13 +83,9 @@ void set_params(Fused_multihead_attention_fprop_params &params,
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
}
constexpr uint32_t NUM_HEADS_DIM = 2;
constexpr uint32_t THREE_DIM = 1;
std::vector<at::Tensor>
mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const at::Tensor &seqlens, // b
const float p_dropout,
const int max_seq_len,
const bool is_training,
......@@ -149,17 +121,14 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(seqlens.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
......@@ -167,10 +136,9 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
TORCH_CHECK(seqlens.numel() == batch_size);
const int total = sizes[0];
const int num_heads = sizes[NUM_HEADS_DIM];
const int head_size = sizes[3];
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
auto opts = qkv.options();
......@@ -191,7 +159,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
seqlens.data_ptr(),
ctx.data_ptr(),
s.data_ptr(),
p_dropout);
......@@ -217,7 +184,6 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1
const at::Tensor &seqlens, // b
const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel
) {
......@@ -247,17 +213,14 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
TORCH_CHECK(dout.dtype() == torch::kFloat16);
TORCH_CHECK(softmax.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(seqlens.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda());
TORCH_CHECK(cu_seqlens.is_cuda());
TORCH_CHECK(qkv.is_contiguous());
TORCH_CHECK(cu_seqlens.is_contiguous());
TORCH_CHECK(seqlens.is_contiguous());
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
......@@ -265,9 +228,8 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
TORCH_CHECK(seqlens.numel() == batch_size);
const int num_heads = sizes[NUM_HEADS_DIM];
const int head_size = sizes[3];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
......@@ -282,12 +244,11 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
seqlens.data_ptr(),
dout.data_ptr(), // we set o_ptr to dout
softmax.data_ptr(), // softmax gets overwritten by dP!
p_dropout);
// we're re-using these scales scales
// we're re-using these scales
Data_type acc_type = DATA_TYPE_FP32;
set_alpha(params.scale_bmm1, 1.f, acc_type);
set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
......@@ -298,8 +259,174 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
return { dqkv, softmax };
}
std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const float p_dropout,
const int max_seq_len,
const bool is_training,
c10::optional<at::Generator> gen_) {
int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80_nl;
TORCH_CHECK(max_seq_len == seq_len);
constexpr int warps_m = 1;
constexpr int warps_n = 4; // this leads to an upper bound
const int mmas_m = seq_len / 16 / warps_m;
const int mmas_n = seq_len / 16 / warps_n;
// static_assert( mmas_m == 32 );
// static_assert( mmas_n == 4 );
const int elts_per_thread = 8 * mmas_m * mmas_n;
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
auto opts = qkv.options();
auto ctx = torch::empty({ total, num_heads, head_size }, opts);
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
Fused_multihead_attention_fprop_params params;
set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
ctx.data_ptr(),
s.data_ptr(),
p_dropout);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t counter_offset = elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;
if( is_training ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}
int num_chunks = 3;
if(batch_size == 3) {
num_chunks = 2;
}
launch(params, is_training, num_chunks, stream);
return { ctx, s };
}
std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel
) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
int seq_len = 512;
auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl;
auto opts = qkv.options();
auto dqkv = torch::empty_like(qkv);
int num_chunks = 2;
if( batch_size == 1 ) {
num_chunks = 4;
}else if( batch_size == 2 ) {
num_chunks = 3;
}
auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts);
Fused_multihead_attention_fprop_params params;
set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
dout.data_ptr(), // o_ptr = dout
softmax.data_ptr(), // softmax gets overwritten by dP!
p_dropout);
params.dkv_ptr = dkv.data_ptr();
Data_type acc_type = DATA_TYPE_FP32;
set_alpha(params.scale_bmm1, 1.f, acc_type);
set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);
params.dqkv_ptr = dqkv.data_ptr();
launch(params, num_chunks, stream);
//SPLIT-K reduction of num_chunks dK, dV parts
// The equivalent of the following Pytorch code:
// using namespace torch::indexing;
// at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)});
// torch::sum_out(view_out, dkv, 1);
const int hidden_size = num_heads * head_size;
fmha_run_noloop_reduce(
dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total, num_chunks, stream);
return { dqkv, softmax, dkv };
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Fused Multi-head Self-attention for BERT";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("fwd_nl", &mha_fwd_nl, "Forward pass (small-batch)");
m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)");
}
......@@ -35,6 +35,12 @@
#include <fmha_utils.h>
constexpr int TOTAL_DIM = 0;
constexpr int THREE_DIM = 1;
constexpr int H_DIM = 2;
constexpr int D_DIM = 3;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
......@@ -43,6 +49,9 @@ struct Qkv_params {
// The stride between rows of the Q, K and V matrices.
size_t qkv_stride_in_bytes;
// The number of heads.
int h;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -52,6 +61,9 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
// The dQKV matrices.
void *dqkv_ptr;
// Temporary for dKV.
void *dkv_ptr;
// The O matrix (output).
void *o_ptr;
......@@ -64,7 +76,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
int64_t s_stride_in_bytes;
// The dimensions.
int b, h, s, d;
int b, s, d;
// The scaling factors for the kernel.
uint32_t scale_bmm1, scale_softmax, scale_bmm2;
......@@ -72,9 +84,6 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
// array of length b+1 holding starting offset of each sequence.
int *cu_seqlens;
// array of length b holding the actual sequence lenghts.
int *seqlens;
// The dropout probability (probability of keeping an activation).
float p_dropout;
......@@ -90,3 +99,27 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const bool is_training, const int num_chunks, cudaStream_t stream);
void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const int num_chunks, cudaStream_t stream);
void fmha_run_noloop_reduce(void *out,
const void *in,
const int *cu_seqlens,
const int hidden_size,
const int batch_size,
const int total,
const int num_chunks,
cudaStream_t stream);
......@@ -39,7 +39,9 @@ template<
// The number of rows of Q, K or V loaded by this tile.
int ROWS,
// The number of columns.
int COLS
int COLS,
// The number of matrics.
int NUM_MATS = 3
>
struct Gmem_tile_qkv {
......@@ -74,7 +76,7 @@ struct Gmem_tile_qkv {
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
// Add the block index.
row_offset += (int64_t)((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
// Assemble the final pointer.
qkv_ptr_ += row_offset + col * BYTES_PER_LDG;
......
......@@ -1217,7 +1217,8 @@ struct Smem_tile_mma_epilogue : public Base {
enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N };
static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
using Fragment = typename Base::Fragment;
using Acc = fmha::Fragment_accumulator;
inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) {
const int read_row = tidx / THREADS_PER_ROW;
......@@ -1233,6 +1234,40 @@ struct Smem_tile_mma_epilogue : public Base {
}
}
template<int M, int N>
inline __device__ void store(const Acc (&acc)[M][N]){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// 1st row - 4 elements per row.
float tmp00 = acc[mi][ni].elt(0);
float tmp01 = acc[mi][ni].elt(1);
float tmp02 = acc[mi][ni].elt(4);
float tmp03 = acc[mi][ni].elt(5);
// 2nd row - 4 elements per row.
float tmp10 = acc[mi][ni].elt(2);
float tmp11 = acc[mi][ni].elt(3);
float tmp12 = acc[mi][ni].elt(6);
float tmp13 = acc[mi][ni].elt(7);
uint32_t x = fmha::float2_to_half2(tmp00, tmp01);
uint32_t y = fmha::float2_to_half2(tmp02, tmp03);
uint32_t z = fmha::float2_to_half2(tmp10, tmp11);
uint32_t w = fmha::float2_to_half2(tmp12, tmp13);
size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z);
offset ^= 4 * Base::BYTES_PER_STS;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
}
}
}
template<int M, int N>
inline __device__ void store(const uint4 (&regs)[M][N]) {
for( int mi = 0; mi < M; mi++ ) {
......
......@@ -27,6 +27,7 @@
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload_nl.h"
using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;
......@@ -35,6 +36,13 @@ extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_at
fmha::compute_dq_dk_1xN<Kernel_traits>(params);
}
template<int CHUNKS>
__global__
void fmha_dgrad_fp16_512_64_sm80_nl_kernel(Fused_multihead_attention_fprop_params params){
fmha::compute_dv_1xN_nl<CHUNKS, Kernel_traits>(params);
fmha::compute_dq_dk_1xN_nl<CHUNKS, Kernel_traits>(params);
}
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
......@@ -58,3 +66,40 @@ void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_param
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_512_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const int num_chunks, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed<Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 512 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
auto kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;
if( num_chunks == 2 ) {
kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;
}else if( num_chunks == 3 ) {
kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<3>;
} else {
assert(false && "Unsupperted number of chunks");
}
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b, num_chunks);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
......@@ -156,8 +156,10 @@ inline __device__ void compute_dv_1xN(const Params &params) {
fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dv::WARPS_K>::apply(acc_dv);
enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };
// Load over the entire sequence length.
for( int loop = 0, outer = 0; loop < Cta_tile_p::N; loop += Cta_tile_p::M, outer++ ) {
for( int l = 0; l < STEPS; l++ ) {
const int loop = l * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen )
break;
......@@ -185,6 +187,13 @@ inline __device__ void compute_dv_1xN(const Params &params) {
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Trigger the load for the next Q values. We're using double buffering, so reading qt is safe
if( l < STEPS - 1) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p);
......@@ -203,8 +212,6 @@ inline __device__ void compute_dv_1xN(const Params &params) {
}
}
float d_s[2 * M][4 * N];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
......@@ -213,10 +220,11 @@ inline __device__ void compute_dv_1xN(const Params &params) {
for( int ni = 0; ni < N; ni++ ) {
#pragma unroll
for( int jj = 0; jj < 4; jj++ ) {
const float s_dmask = s_mat[2 * mi + ii][4 * ni + jj];
float & s_dmask = s_mat[2 * mi + ii][4 * ni + jj];
const bool drop = reinterpret_cast<const uint32_t &>(s_dmask) & 0x80000000;
d_s[2 * mi + ii][4 * ni + jj] = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;
softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s[2 * mi + ii][4 * ni + jj] * fabsf(s_dmask);
const float d_s = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;
s_dmask = fabsf(s_dmask);
softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * fabsf(s_dmask);
}
}
}
......@@ -225,6 +233,7 @@ inline __device__ void compute_dv_1xN(const Params &params) {
float p_sum[2 * M];
softmax.template reduce<fmha::Sum_>(p_sum);
const float scalef = reinterpret_cast<const float &>(params.scale_softmax);
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
......@@ -233,20 +242,12 @@ inline __device__ void compute_dv_1xN(const Params &params) {
for( int ni = 0; ni < N; ni++ ) {
#pragma unroll
for( int jj = 0; jj < 4; jj++ ) {
const float scalef = reinterpret_cast<const float &>(params.scale_softmax);
softmax.elt_[2 * mi + ii][4 * ni + jj] = (d_s[2 * mi + ii][4 * ni + jj] - p_sum[2 * mi + ii]) *
fabsf(s_mat[2 * mi + ii][4 * ni + jj]) * scalef;
softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]) ;
softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef;
}
}
}
}
// Trigger the load for the next Q values. We're using double buffering, so reading qt is safe
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M];
smem_s.load(frag_s);
for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) {
......@@ -275,7 +276,7 @@ inline __device__ void compute_dv_1xN(const Params &params) {
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Commit the values for Q into shared memory.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
if(l < STEPS - 1) {
gmem_q.commit(smem_q);
}
......@@ -295,36 +296,15 @@ inline __device__ void compute_dv_1xN(const Params &params) {
// Epilogue swizzle for dV
Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx);
uint4 dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];
#pragma unroll
for( int mi = 0; mi < Mma_tile_dv::MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_dv::MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
float tmp00 = acc_dv[mi][ni].elt(0);
float tmp01 = acc_dv[mi][ni].elt(1);
float tmp02 = acc_dv[mi][ni].elt(4);
float tmp03 = acc_dv[mi][ni].elt(5);
// 2nd row - 4 elements per row.
float tmp10 = acc_dv[mi][ni].elt(2);
float tmp11 = acc_dv[mi][ni].elt(3);
float tmp12 = acc_dv[mi][ni].elt(6);
float tmp13 = acc_dv[mi][ni].elt(7);
dv[mi][ni].x = fmha::float2_to_half2(tmp00, tmp01);
dv[mi][ni].y = fmha::float2_to_half2(tmp02, tmp03);
dv[mi][ni].z = fmha::float2_to_half2(tmp10, tmp11);
dv[mi][ni].w = fmha::float2_to_half2(tmp12, tmp13);
}
}
smem_dv.store(acc_dv);
smem_dv.store(dv);
__syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out);
Qkv_params dv_params;
dv_params.qkv_ptr = params.dqkv_ptr;
dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dv_params.h = params.h;
Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx);
gmem_dv.store(dv_out);
}
......@@ -447,13 +427,15 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) {
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
enum { THREADS_PER_ROW = 32 };
enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dk::WARPS_K>::apply(acc_dk);
// Load over the entire sequence length.
for( int loop = 0, outer = 0; loop < Cta_tile_p::N; loop += Cta_tile_p::M, outer++ ) {
for( int l=0;l<STEPS;l++) {
const int loop = l * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen )
break;
......@@ -492,7 +474,7 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) {
// Store dP to smem for transpose
smem_s.store(s_regs);
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
if(l < STEPS - 1) {
// Load next part of S
gmem_s.load(s_regs, mask);
gmem_s.move();
......@@ -544,7 +526,7 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) {
}
// Commit the values for Q into shared memory.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
if( l < STEPS - 1) {
gmem_q.commit(smem_q);
}
......@@ -559,37 +541,14 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) {
// Epilogue swizzle for dK
Smem_tile_dk smem_dk(&smem_[0], tidx);
uint4 dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];
#pragma unroll
for( int mi = 0; mi < Mma_tile_dk::MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_dk::MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
float tmp00 = acc_dk[mi][ni].elt(0);
float tmp01 = acc_dk[mi][ni].elt(1);
float tmp02 = acc_dk[mi][ni].elt(4);
float tmp03 = acc_dk[mi][ni].elt(5);
// 2nd row - 4 elements per row.
float tmp10 = acc_dk[mi][ni].elt(2);
float tmp11 = acc_dk[mi][ni].elt(3);
float tmp12 = acc_dk[mi][ni].elt(6);
float tmp13 = acc_dk[mi][ni].elt(7);
dk[mi][ni].x = fmha::float2_to_half2(tmp00, tmp01);
dk[mi][ni].y = fmha::float2_to_half2(tmp02, tmp03);
dk[mi][ni].z = fmha::float2_to_half2(tmp10, tmp11);
dk[mi][ni].w = fmha::float2_to_half2(tmp12, tmp13);
}
}
smem_dk.store(dk);
smem_dk.store(acc_dk);
__syncthreads();
uint4 dk_out[Smem_tile_dk::NUM_LDS];
smem_dk.load(dk_out);
Qkv_params dk_params;
dk_params.qkv_ptr = params.dqkv_ptr;
dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dk_params.h = params.h;
Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx);
gmem_dk.store(dk_out);
}
......
This diff is collapsed.
......@@ -27,6 +27,7 @@
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN_nl.h"
using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;
......@@ -38,6 +39,17 @@ extern "C" __global__ void fmha_fprop_fp16_512_64_sm80_predict_kernel(Fused_mult
fmha::device_1xN<Kernel_traits, false>(params);
}
template<int CHUNKS>
__global__ void fmha_fprop_fp16_512_64_sm80_train_nl_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN_nl<CHUNKS,Kernel_traits, true>(params);
}
template<int CHUNKS>
__global__ void fmha_fprop_fp16_512_64_sm80_predict_nl_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN_nl<CHUNKS, Kernel_traits, false>(params);
}
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_kernel : &fmha_fprop_fp16_512_64_sm80_predict_kernel;
......@@ -54,3 +66,33 @@ void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &par
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const bool is_training, const int num_chunks, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<2> : &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<2>;
if( num_chunks == 2 ) {
kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<2>
: &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<2>;
} else if( num_chunks == 3 ) {
kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<3>
: &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<3>;
} else if( num_chunks == 4 ) {
kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<4>
: &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<4>;
} else {
assert(false && "Unsupported num_chunks");
}
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b, num_chunks);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
......@@ -174,9 +174,11 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);
enum { THREADS_PER_ROW = 32 };
enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };
// Load over the entire sequence length.
for( int loop = 0, outer = 0; loop < Cta_tile_p::N; loop += Cta_tile_p::M, outer++ ) {
for( int l = 0; l < STEPS; l++ ) {
const int loop = l * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen )
break;
......@@ -200,12 +202,8 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Store the P matrix.
#if defined(STORE_P)
gmem_p.store(acc_p);
#endif
// Load the mask for that iteration.
mask.load(outer);
mask.load(l);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p);
......@@ -213,7 +211,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
// Apply the mask.
softmax.apply_mask(mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && loop == 0 ) {
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads();
}
......@@ -261,7 +259,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
}
// Trigger the load for the next Q values.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
if(l < STEPS - 1) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
......@@ -320,7 +318,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
gmem_o.move();
// Commit the values for Q into shared memory.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
if(l < STEPS - 1) {
gmem_q.commit(smem_q);
}
......
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int CHUNKS, typename Kernel_traits, bool Is_training, typename Params>
inline __device__ void device_1xN_nl(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
// The global memory tile to store S/D.
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;
// Shared memory.
extern __shared__ char smem_[];
const int bidc = blockIdx.z;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
Noloop nl_traits(bidc);
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));
fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[0], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = nullptr;
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE];
} else {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];
}
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
nl_traits.move_all(gmem_q, gmem_o, gmem_s);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
// Trigger the loads for K.
gmem_v.load(smem_v);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Commit the data for V to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_v.commit(smem_v);
}
// Make sure the data is in shared memory.
__syncthreads();
// Load the fragments for Q.
typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];
smem_q.load(frag_q[0], 0);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
smem_k.load(frag_k[ki], ki);
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_v.commit(smem_v);
// Make sure the data is in shared memory.
__syncthreads();
}
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
smem_v.load(frag_v[ki], ki);
}
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
// Create the object to do the softmax.
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);
// The number of threads per row.
enum { THREADS_PER_ROW = 32 };
// Load over the entire sequence length.
for(int l = 0; l < nl_traits.num_steps_;l++) {
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_q.load(frag_q[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Trigger the load for the next Q values.
if( l < nl_traits.num_steps_- 1) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
// Load the mask for that iteration.
mask.load(nl_traits.loop_offset_ + l);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads();
}
// Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Max_>(p_max);
// Make sure we are done reading shared memory.
__syncthreads();
// Compute the exponential value.
softmax.apply_exp(p_max);
// Compute the sum.
float p_sum[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Sum_>(p_sum);
// Finalize softmax on the accumulators of P^T.
softmax.scale(p_sum);
if( Is_training ) {
auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
float4 tmp = uniform4(ph());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros
softmax.elt_[2 * mi + ii][4 * ni + 0] =
encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);
softmax.elt_[2 * mi + ii][4 * ni + 1] =
encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);
softmax.elt_[2 * mi + ii][4 * ni + 2] =
encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);
softmax.elt_[2 * mi + ii][4 * ni + 3] =
encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);
}
}
}
gmem_s.store(softmax.elt_, mask);
gmem_s.move();
}
using Frag_p = fmha::Fragment_a<fmha::Row>;
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
softmax.pack(frag_p);
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {
//"Apply" the dropout.
frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);
frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);
}
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out);
// Make sure the data was read from shared memory.
if( ii < Gmem_tile_o::LOOPS - 1 ) {
__syncthreads();
}
// Output the values.
gmem_o.store(out, ii);
}
// Move to the next part of the output.
gmem_o.move();
// Commit the values for Q into shared memory.
if( l < nl_traits.num_steps_- 1) {
gmem_q.commit(smem_q);
__syncthreads();
smem_q.load(frag_q[0], 0);
}
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
......@@ -40,94 +40,130 @@ namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int FMHA_VERSION> struct BlockInfo {};
template <> struct BlockInfo<1> {
int actual_seqlen;
int bidx;
int sum_s;
int bidh;
int bidb;
template<int THREADS_PER_CTA>
struct BlockInfoPadded {
template<typename Params>
__device__ BlockInfo( const Params &params,
__device__ BlockInfoPadded(const Params &params,
const int bidb,
const int bidh,
const int tidx )
: bidb( bidb ), bidh( bidh ) {
const int tidx)
: bidb(bidb), bidh(bidh), h(params.h) {
// The block index.
sum_s = params.b * params.s;
actual_seqlen = params.s;
bidx = bidb * params.h + bidh;
sum_s = params.cu_seqlens[bidb];
actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s;
bidx = sum_s * params.h + bidh;
tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;
}
__device__ bool stop_early() const {
return false;
return actual_seqlen == 0;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <> struct BlockInfo<2> {
int actual_seqlen;
int bidx;
int sum_s;
int bidh;
int bidb;
int tidx_global;
int h;
};
template<typename Params>
__device__ BlockInfo( const Params &params,
const int bidb,
const int bidh,
const int tidx )
: bidb( bidb ), bidh( bidh ) {
////////////////////////////////////////////////////////////////////////////////////////////////////
// The block index.
sum_s = params.cu_seqlens[bidb];
actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s;
bidx = sum_s * params.h + bidh;
template<int CHUNKS, typename Cta_tile>
struct Noloop_traits{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum{ STEP = Cta_tile::M };
enum{ SEQLEN = Cta_tile::N };
// The size of the subsequence this CTA is processing
enum { SUBSEQ = SEQLEN / CHUNKS };
static_assert(SUBSEQ * CHUNKS == SEQLEN);
// The number of steps to process the subsequence
enum { NUM_STEPS = SUBSEQ / STEP };
static_assert(NUM_STEPS * Cta_tile::M == SUBSEQ);
inline __device__ Noloop_traits(const int bidc)
: loop_offset_(NUM_STEPS * bidc)
, bidc_(bidc) {
}
__device__ bool stop_early() const {
return actual_seqlen == 0;
template<typename ... Tiles>
inline __device__ void move_all(Tiles & ... tiles) const {
using expand_type = int[];
for( int s = 0; s < loop_offset_; s++ ) {
expand_type{ (tiles.move(), 0)... };
}
}
inline __device__ int get_idx_dk() const {
//return bidc_;
return bidc_ * 2 + 0;
}
inline __device__ int get_idx_dv() const {
//return CHUNKS + bidc_;
return bidc_ * 2 + 1;
}
inline __device__ int offset_loop_count(const int l) {
// convert loop counter to position in the outer sequence
return (loop_offset_ + l) * STEP;
}
const int loop_offset_;
const uint32_t bidc_;
const int num_steps_ = NUM_STEPS;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS_PER_CTA>
struct BlockInfoPadded {
template<typename Cta_tile>
struct Noloop_traits<3, Cta_tile>{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum{ STEP = Cta_tile::M };
enum{ SEQLEN = Cta_tile::N };
template<typename Params>
__device__ BlockInfoPadded( const Params &params,
const int bidb,
const int bidh,
const int tidx )
: bidb( bidb ), bidh( bidh ), h(params.h) {
static_assert(STEP == 16 && SEQLEN == 512);
// The block index.
sum_s = params.cu_seqlens[bidb];
actual_seqlen = params.seqlens[bidb];
bidx = sum_s * params.h + bidh;
inline __device__ Noloop_traits(const int bidc)
: bidc_(bidc)
, num_steps_(bidc < 2 ? 11 : 10)
, loop_offset_(bidc * 11) {
}
tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;
template<typename ... Tiles>
inline __device__ void move_all(Tiles & ... tiles) const {
using expand_type = int[];
for( int s = 0; s < loop_offset_; s++ ) {
expand_type{ (tiles.move(), 0)... };
}
}
__device__ bool stop_early() const {
return actual_seqlen == 0;
inline __device__ int get_idx_dk() const {
//return bidc_;
return bidc_ * 2 + 0;
}
int actual_seqlen;
int bidx;
int sum_s;
int bidh;
int bidb;
int tidx_global;
int h;
inline __device__ int get_idx_dv() const {
//return CHUNKS + bidc_;
return bidc_ * 2 + 1;
}
inline __device__ int offset_loop_count(const int l) {
// convert loop counter to position in the outer sequence
return (loop_offset_ + l) * STEP;
}
const int loop_offset_;
const uint32_t bidc_;
const int num_steps_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
inline __device__ float4 ldg128(const void *ptr) {
return *static_cast<const float4 *>(ptr);
}
inline __device__ void stg128(void *ptr, const float4 &data) {
*static_cast<float4 *>(ptr) = data;
}
template<typename T, int THREADS, int HIDDEN_SIZE, int CHUNKS>
__global__ __launch_bounds__(THREADS) void fmha_noloop_reduce_kernel(void *__restrict__ out,
const void *__restrict__ in,
const int *__restrict__ cu_seqlens,
const int batch_size) {
enum { BYTES_PER_LDG = 16 };
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(T) };
// One CTA hidden vector for K and V
enum { BYTES_PER_ROW = HIDDEN_SIZE * sizeof(T) * 2 };
// The stride in bytes in dQKV
enum { OUT_STRIDE_BYTES = 3 * HIDDEN_SIZE * sizeof(T) };
// The offset in bytes in dQKV to the dKV part for non-interleaved heads
enum { OUT_OFFSET_KV_BYTES = HIDDEN_SIZE * sizeof(T) };
static_assert(BYTES_PER_ROW == HIDDEN_SIZE * 2 * sizeof(T));
// Size in bytes of the input tile
enum { BYTES_PER_TILE = CHUNKS * BYTES_PER_ROW };
enum { BYTES_PER_CTA = THREADS * BYTES_PER_LDG };
enum { LDGS = BYTES_PER_ROW / BYTES_PER_CTA };
static_assert(BYTES_PER_CTA * LDGS == BYTES_PER_ROW);
union Vec_t {
float4 raw;
T elt[NUM_ELTS];
};
// ZERO-OUT invalid positions in dQKV
const int total = cu_seqlens[batch_size];
if(blockIdx.x >= total){
enum { BYTES_PER_QKV_ROW = 3 * HIDDEN_SIZE * sizeof(T) };
enum { STGS = BYTES_PER_QKV_ROW / BYTES_PER_LDG };
const float4 zeros = make_float4(0.f, 0.f, 0.f, 0.f);
char *base_ptr = static_cast<char *>(out) + blockIdx.x * OUT_STRIDE_BYTES;
for(int tidx = threadIdx.x; tidx < STGS; tidx += THREADS){
stg128(base_ptr + tidx * BYTES_PER_LDG, zeros);
}
return;
}
// SETUP
const int offset_in = blockIdx.x * BYTES_PER_TILE + threadIdx.x * BYTES_PER_LDG;
const char *ptr_in = static_cast<const char *>(in) + offset_in;
const int offset_out = blockIdx.x * OUT_STRIDE_BYTES + threadIdx.x * BYTES_PER_LDG;
char *ptr_out = static_cast<char *>(out) + OUT_OFFSET_KV_BYTES + offset_out;
// LOAD
Vec_t local_in[CHUNKS][LDGS];
#pragma unroll
for( int c = 0; c < CHUNKS; c++ ) {
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
int offset = c * BYTES_PER_ROW + l * BYTES_PER_CTA;
local_in[c][l].raw = ldg128(ptr_in + offset);
}
}
// UNPACK
float acc[LDGS][NUM_ELTS];
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
#pragma unroll
for( int e = 0; e < NUM_ELTS; e++ ) {
acc[l][e] = float(local_in[0][l].elt[e]);
}
}
// COMPUTE
#pragma unroll
for( int c = 1; c < CHUNKS; c++ ) {
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
#pragma unroll
for( int e = 0; e < NUM_ELTS; e++ ) {
acc[l][e] += float(local_in[c][l].elt[e]);
}
}
}
// PACK
Vec_t local_out[LDGS];
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
#pragma unroll
for( int e = 0; e < NUM_ELTS; e++ ) {
local_out[l].elt[e] = T(acc[l][e]);
}
}
// STORE
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
const int offset = l * BYTES_PER_CTA;
stg128(ptr_out + offset, local_out[l].raw);
}
}
void fmha_run_noloop_reduce(void *out,
const void *in,
const int *cu_seqlens,
const int hidden_size,
const int batch_size,
const int total,
const int num_chunks,
cudaStream_t stream) {
const int blocks = total;
if(hidden_size == 1024){
constexpr int HIDDEN_SIZE = 1024;
constexpr int THREADS = 256;
if( num_chunks == 2 ) {
fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 2><<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);
} else if( num_chunks == 3 ) {
fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 3><<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);
} else {
assert(false && "Unsupported num_chunks");
}
}else{
assert(false && "Unsupported hidden_size");
}
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
......@@ -32,11 +32,14 @@ import fmhalib as mha
class FMHAFun(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, seqlens, p_dropout, max_s, is_training):
context, S_dmask = mha.fwd(qkv, cu_seqlens, seqlens, p_dropout, max_s, is_training, None)
def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training):
batch_size = cu_seqlens.numel() - 1
if batch_size < 4:
context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, None)
else:
context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, None)
ctx.save_for_backward(qkv, S_dmask)
ctx.cu_seqlens = cu_seqlens
ctx.seqlens = seqlens
ctx.p_dropout = p_dropout
ctx.max_s = max_s
return context
......@@ -44,7 +47,11 @@ class FMHAFun(torch.autograd.Function):
@staticmethod
def backward(ctx, dout):
qkv, S_dmask = ctx.saved_tensors
dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.seqlens, ctx.p_dropout, ctx.max_s)
batch_size = ctx.cu_seqlens.numel() - 1
if batch_size < 4:
dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)
else:
dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)
return dqkv, None, None, None, None, None, None
......@@ -60,8 +67,8 @@ class FMHA(torch.nn.Module):
self.d = self.hidden_size // self.h
assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads"
def forward(self, qkv, cu_seqlens, seqlens, max_s, is_training=True):
def forward(self, qkv, cu_seqlens, max_s, is_training=True):
ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, seqlens, self.p_dropout, max_s, is_training)
ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training)
return ctx.view(-1, self.hidden_size)
......@@ -51,7 +51,8 @@ def py_mha(qkv, amask, b, s, h, d):
class TestFMHA(unittest.TestCase):
def run_test(self, s):
def run_test(self, s, b):
print(f'Test s={s} b={b}')
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
......@@ -59,7 +60,6 @@ class TestFMHA(unittest.TestCase):
dtype = torch.float16
device = torch.device('cuda')
b = 32
h = 16
d = 64
......@@ -76,7 +76,10 @@ class TestFMHA(unittest.TestCase):
qkv.requires_grad = True
ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, seqlens, 0.0, s, True, None)
if b < 4:
ctx, S_ = mha.fwd_nl(qkv_vs, cu_seqlens, 0.0, s, True, None)
else:
ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, None)
ctx = ctx.view(b,s,h,d)
ctx_ref = py_mha(qkv, amask, b,s,h,d)
......@@ -91,23 +94,28 @@ class TestFMHA(unittest.TestCase):
dw2 = dw.permute(0,2,1,3).clone().detach().contiguous()
dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, seqlens, 0.0, s)
if b < 4:
dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)
else:
dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)
dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d)
self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))
def test_128(self):
self.run_test(128)
self.run_test(128, 32)
def test_256(self):
self.run_test(256)
self.run_test(256, 32)
def test_384(self):
self.run_test(384)
self.run_test(384, 32)
def test_512(self):
self.run_test(512)
self.run_test(512, 32)
self.run_test(512, 2)
self.run_test(512, 3)
if __name__ == '__main__':
unittest.main()
......@@ -349,6 +349,7 @@ if "--fmha" in sys.argv:
CUDAExtension(name='fmhalib',
sources=[
'apex/contrib/csrc/fmha/fmha_api.cpp',
'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment