Unverified Commit 3c88451a authored by yjk21's avatar yjk21 Committed by GitHub
Browse files

update fmha (#1344)

parent a0ed4151
......@@ -72,7 +72,7 @@ void set_params(Fused_multihead_attention_fprop_params &params,
constexpr float scale_softmax = 1.f;
constexpr float scale_bmm2 = 1.f;
set_alpha(params.scale_bmm1, scale_bmm1, acc_type);
set_alpha(params.scale_bmm1, scale_bmm1, data_type);
set_alpha(params.scale_softmax, scale_softmax, acc_type);
set_alpha(params.scale_bmm2, scale_bmm2, data_type);
......@@ -83,16 +83,21 @@ void set_params(Fused_multihead_attention_fprop_params &params,
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
}
std::vector<at::Tensor>
mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
std::vector<at::Tensor>
mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const float p_dropout,
const int max_seq_len,
const bool is_training,
const bool is_nl,
const bool zero_tensors,
c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
auto stream = at::cuda::getCurrentCUDAStream().stream();
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_training, is_nl);
int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80;
if( max_seq_len <= 128 ) {
......@@ -111,18 +116,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
TORCH_CHECK(false);
}
constexpr int warps_m = 1;
constexpr int warps_n = 4; // this leads to an upper bound
const int mmas_m = seq_len / 16 / warps_m;
const int mmas_n = seq_len / 16 / warps_n;
const int elts_per_thread = 8 * mmas_m * mmas_n;
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
......@@ -156,9 +149,8 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
Fused_multihead_attention_fprop_params params;
set_params(params,
set_params(launch_params.params,
batch_size,
seq_len,
num_heads,
......@@ -169,22 +161,24 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
s.data_ptr(),
p_dropout);
// number of times random will be generated per thread, to offset philox counter in the random
launch(launch_params, /*configure=*/ true);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t counter_offset = elts_per_thread;
int64_t counter_offset = launch_params.elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;
if( is_training ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
}
launch(params, is_training, stream);
launch(launch_params, /*configure=*/ false);
return { ctx, s };
}
std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
......@@ -270,92 +264,6 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
return { dqkv, softmax };
}
std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const float p_dropout,
const int max_seq_len,
const bool is_training,
const bool zero_tensors,
c10::optional<at::Generator> gen_) {
int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80_nl;
TORCH_CHECK(max_seq_len == seq_len);
constexpr int warps_m = 1;
constexpr int warps_n = 4; // this leads to an upper bound
const int mmas_m = seq_len / 16 / warps_m;
const int mmas_n = seq_len / 16 / warps_n;
// static_assert( mmas_m == 32 );
// static_assert( mmas_n == 4 );
const int elts_per_thread = 8 * mmas_m * mmas_n;
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
auto opts = qkv.options();
auto ctx = torch::empty({ total, num_heads, head_size }, opts);
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
if( zero_tensors ) {
ctx.zero_();
s.zero_();
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
Fused_multihead_attention_fprop_params params;
set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
ctx.data_ptr(),
s.data_ptr(),
p_dropout);
// number of times random will be generated per thread, to offset philox counter in the random
// state
int64_t counter_offset = elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;
if( is_training ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}
int num_chunks = 3;
if(batch_size == 3) {
num_chunks = 2;
}
launch(params, is_training, num_chunks, stream);
return { ctx, s };
}
std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
......@@ -449,6 +357,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Fused Multi-head Self-attention for BERT";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("fwd_nl", &mha_fwd_nl, "Forward pass (small-batch)");
m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)");
}
......@@ -50,7 +50,7 @@ constexpr int D_DIM = 3;
struct Qkv_params {
// The QKV matrices.
void *qkv_ptr;
void * __restrict__ qkv_ptr;
// The stride between rows of the Q, K and V matrices.
size_t qkv_stride_in_bytes;
......@@ -64,19 +64,19 @@ struct Qkv_params {
struct Fused_multihead_attention_fprop_params : public Qkv_params {
// The dQKV matrices.
void *dqkv_ptr;
void * __restrict__ dqkv_ptr;
// Temporary for dKV.
void *dkv_ptr;
void * __restrict__ dkv_ptr;
// The O matrix (output).
void *o_ptr;
void * __restrict__ o_ptr;
// The stride between rows of O.
int64_t o_stride_in_bytes;
// The pointer to the S matrix, overwritten by the dP matrix (bwd).
void *s_ptr;
void * __restrict__ s_ptr;
// The stride between rows of the S matrix.
int64_t s_stride_in_bytes;
......@@ -87,7 +87,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
uint32_t scale_bmm1, scale_softmax, scale_bmm2;
// array of length b+1 holding starting offset of each sequence.
int *cu_seqlens;
int * __restrict__ cu_seqlens;
// The dropout probability (probability of keeping an activation).
float p_dropout;
......@@ -104,10 +104,43 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
template<typename Kernel_params>
struct Launch_params{
Launch_params(cudaDeviceProp * props_,
cudaStream_t stream_,
bool is_training_,
bool is_nl_)
: elts_per_thread(0)
, props(props_)
, stream(stream_)
, is_training(is_training_)
, is_nl(is_nl_) {
}
size_t elts_per_thread;
cudaDeviceProp * props;
cudaStream_t stream;
bool is_training;
Kernel_params params;
int num_full_heads;
int num_main_groups;
int heads_last_wave;
int main_steps;
int rest_steps;
bool is_nl;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fp16_128_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_fp16_256_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_fp16_384_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_fp16_512_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
......
......@@ -210,9 +210,6 @@ struct Clear_accumulator<float, WARPS_K> {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Acc, typename A, typename B, int M, int N>
......
......@@ -60,7 +60,7 @@ struct Gmem_tile_qkv {
// Ctor.
template< typename Params, typename BInfo >
inline __device__ Gmem_tile_qkv(const Params &params, int qkv_offset, const BInfo &binfo, int tidx)
inline __device__ Gmem_tile_qkv(const Params &params, const int qkv_offset, const BInfo &binfo, const int tidx)
: params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes)
, actual_seqlen(binfo.actual_seqlen)
, qkv_ptr_(reinterpret_cast<char *>(params.qkv_ptr)) {
......@@ -125,6 +125,11 @@ struct Gmem_tile_qkv {
actual_seqlen -= ROWS;
}
inline __device__ void move(int steps) {
qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps;
actual_seqlen -= ROWS * steps;
}
// The stride between rows for the QKV matrice.
int64_t params_qkv_stride_in_bytes_;
// The pointer.
......@@ -224,6 +229,11 @@ struct Gmem_tile_o {
o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_;
}
inline __device__ void move(const int steps) {
row_ += ROWS * steps;
o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_ * steps;
}
// The stride between rows for the QKV matrice.
int64_t params_o_stride_in_bytes_;
// The pointer.
......@@ -270,13 +280,9 @@ struct Gmem_tile_mma_sd {
// Ctor.
template<typename Params>
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int tidx)
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int bidb, const int bidh, const int tidx)
: ptr_(static_cast<char *>(ptr)) {
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The block index.
size_t bidx = bidb * params.h + bidh;
......@@ -300,6 +306,9 @@ struct Gmem_tile_mma_sd {
inline __device__ void move() {
ptr_ += LOOP_STRIDE_BYTES;
}
inline __device__ void move(const int steps) {
ptr_ += LOOP_STRIDE_BYTES * steps;
}
// The pointer in global memory.
char *ptr_;
......@@ -318,9 +327,9 @@ struct Gmem_tile_mma_s : public Base {
using Type = typename Base::Type;
// Ctor.
template< typename Params >
inline __device__ Gmem_tile_mma_s(void *ptr, const Params &params, const int tidx)
: Base(ptr, params, tidx) {
template< typename Params, typename Block_info >
inline __device__ Gmem_tile_mma_s(const Params &params, const Block_info& binfo, const int tidx)
: Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
}
// Store to global memory.
......@@ -353,6 +362,25 @@ struct Gmem_tile_mma_s : public Base {
}
}
// Store to global memory.
template<typename Mask, typename Fragment>
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 dst;
dst.x = frag[ni][mi].reg(0);
dst.y = frag[ni][mi].reg(2);
dst.z = frag[ni][mi].reg(1);
dst.w = frag[ni][mi].reg(3);
if( mask.any_valid(mi, ni) ) {
Base::store(dst, mi, ni);
}
}
}
}
// Load from global memory.
template<typename Mask>
inline __device__ void load(uint4 (&regs)[M][N], const Mask &mask) {
......@@ -361,7 +389,7 @@ struct Gmem_tile_mma_s : public Base {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
regs[mi][ni] = make_uint4(0, 0, 0, 0);
if( mask.is_valid(mi, ni, 0, 0) ) {
if( mask.any_valid(mi, ni) ) {
Base::load(regs[mi][ni], mi, ni);
}
}
......
......@@ -29,7 +29,7 @@
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x8u>
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u>
struct FMHA_kernel_traits {
// The CTA description for the 1st GEMM.
......@@ -38,7 +38,9 @@ struct FMHA_kernel_traits {
using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>;
// Do we use one buffer for K and V.
enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u };
enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u };
// Do we keep K in registers.
enum { K_IN_REGS = (FLAGS & 0x10u) == 0u };
// The global memory tile to load Q.
using Gmem_tile_q = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
......
......@@ -63,6 +63,11 @@ struct Mask {
// return row_valid && col_valid;
}
//BERT Mask: if upper left is invalid, none are valid
inline __device__ bool any_valid(int mi, int ni) const {
return is_valid(mi, ni, 0, 0);
}
inline __device__ void load(int it) {
row_offset = it * Cta_tile::M + row;
}
......
......@@ -1266,8 +1266,6 @@ struct Smem_tile_mma_epilogue : public Base {
}
}
template<int M, int N>
inline __device__ void store(const uint4 (&regs)[M][N]) {
for( int mi = 0; mi < M; mi++ ) {
......
......@@ -55,6 +55,88 @@ inline __device__ float apply_exp_(float x, float max) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int COLS> struct ReadType {};
template<> struct ReadType<4> { using T = float;};
template<> struct ReadType<8> { using T = float2;};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile, typename Kernel_traits>
struct Smem_tile_reduce {
// Helper class to distribute MMA tiles reduced over rows per warp over quads.
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
enum { MMAS_M = Mma_tile::MMAS_M };
enum { MMAS_N = Mma_tile::MMAS_N };
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
static constexpr int ROWS = WARPS_M * MMAS_M * 16;
static constexpr int COLS = WARPS_N;
static_assert(COLS == 4 || COLS == 8);
static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
static constexpr int ELTS_PER_TILE = ROWS * COLS;
static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW;
static_assert(THREADS_PER_GROUP == 16); // DEBUG
static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP;
static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS;
static_assert(LOOPS == 1);
using read_t = typename ReadType<COLS>::T;
__device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
int lane = tidx % 32;
int warp = tidx / 32;
int warp_m = warp % WARPS_M;
int warp_n = warp / WARPS_M;
qid_ = lane % 4;
int qp = lane / 4;
// Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.
// This won't affect reading as we assume commutative reduction ops.
const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN);
smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col];
smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_];
}
__device__ inline void store(float (&frag)[2 * MMAS_M]) {
if( qid_ == 0 ) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * WARPS_N;
smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0];
smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1];
}
}
}
__device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * 4;
frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4];
frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4];
}
}
int qid_;
float *smem_write_;
read_t *smem_read_;
};
template<typename Cta_tile, typename Kernel_traits>
struct Softmax_base {
......@@ -136,201 +218,6 @@ struct Softmax_base {
}
}
// Do a CTA-wide reduction.
template<typename Functor>
inline __device__ void reduce_1x4(float (&dst)[MMAS_M * 2]) {
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if( Functor::IS_SUM ) {
// Apply the summation inside the thread.
float tmp[MMAS_M * 2][2];
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
tmp[mi][0] = 0.f;
tmp[mi][1] = 0.f;
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
tmp[mi][0] += elt_[mi][4 * ni + 0];
tmp[mi][0] += elt_[mi][4 * ni + 1];
tmp[mi][1] += elt_[mi][4 * ni + 2];
tmp[mi][1] += elt_[mi][4 * ni + 3];
}
dst[mi] = tmp[mi][0] + tmp[mi][1];
}
} else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = elt_[mi][0];
#pragma unroll
for( int ni = 1; ni < MMAS_N * 4; ++ni ) {
dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1));
__syncwarp();
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2));
__syncwarp();
}
// Store the different values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
if( tidx_ % 4 == 0 ) {
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0];
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1];
}
}
// Make sure the values are in shared memory.
__syncthreads();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4 tmp[1];
if( tidx_ < Cta_tile::M ) {
tmp[0] = reinterpret_cast<const float4 *>(&smem_[0 * ELEMENTS / 2])[tidx_];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y);
tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w);
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z);
// Make sure we can write to shared memory.
__syncthreads();
// Store the value back to shared memory.
if( tidx_ < Cta_tile::M ) {
smem_[tidx_] = tmp[0].x;
}
// Make sure the data is in shared memory.
__syncthreads();
// Finally read the values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
dst[2 * mi + 0] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 0];
dst[2 * mi + 1] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 8];
}
}
// Do a CTA-wide reduction.
template<typename Functor>
inline __device__ void reduce_1x8(float (&dst)[MMAS_M * 2]) {
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if( Functor::IS_SUM ) {
// Apply the summation inside the thread.
float tmp[MMAS_M * 2][2];
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
tmp[mi][0] = 0.f;
tmp[mi][1] = 0.f;
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
tmp[mi][0] += elt_[mi][4 * ni + 0];
tmp[mi][0] += elt_[mi][4 * ni + 1];
tmp[mi][1] += elt_[mi][4 * ni + 2];
tmp[mi][1] += elt_[mi][4 * ni + 3];
}
dst[mi] = tmp[mi][0] + tmp[mi][1];
}
} else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = elt_[mi][0];
#pragma unroll
for( int ni = 1; ni < MMAS_N * 4; ++ni ) {
dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1));
__syncwarp();
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2));
__syncwarp();
}
// Store the different values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
if( tidx_ % 4 == 0 ) {
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0];
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1];
}
}
// Make sure the values are in shared memory.
__syncthreads();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4 tmp[2];
if( tidx_ < Cta_tile::M ) {
tmp[0] = reinterpret_cast<const float4 *>(&smem_[0 * ELEMENTS / 2])[tidx_];
tmp[1] = reinterpret_cast<const float4 *>(&smem_[1 * ELEMENTS / 2])[tidx_];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y);
tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w);
tmp[1].x = Functor::apply(tmp[1].x, tmp[1].y);
tmp[1].z = Functor::apply(tmp[1].z, tmp[1].w);
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z);
tmp[1].x = Functor::apply(tmp[1].x, tmp[1].z);
tmp[0].x = Functor::apply(tmp[0].x, tmp[1].x);
// Make sure we can write to shared memory.
__syncthreads();
// Store the value back to shared memory.
if( tidx_ < Cta_tile::M ) {
smem_[tidx_] = tmp[0].x;
}
// Make sure the data is in shared memory.
__syncthreads();
// Finally read the values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
dst[2 * mi + 0] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 0];
dst[2 * mi + 1] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 8];
}
}
// Do a CTA-wide reduction.
template<typename Functor>
inline __device__ void reduce(float (&dst)[MMAS_M * 2]) {
static_assert(Cta_tile::WARPS_M == 1 && (Cta_tile::WARPS_N == 4 || Cta_tile::WARPS_N == 8));
if( Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 4 ) {
reduce_1x4<Functor>(dst);
} else if( Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 8 ) {
reduce_1x8<Functor>(dst);
} else {
assert(false);
}
// Make sure we are done reading from shared memory.
__syncthreads();
}
// Scale all the elements.
inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
......@@ -372,6 +259,8 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
static_assert(Fragment_a::NUM_REGS == 4);
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
// The MMAs.
enum { MMAS_M = Base::MMAS_M };
enum { MMAS_N = Base::MMAS_N };
......@@ -383,41 +272,15 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
static_assert(std::is_same<Accumulator::Data_type, float>::value);
using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);
// Ctor.
template<typename Params>
inline __device__ Softmax(const Params &params, void *smem, int bidb, int tidx)
: Base(params, smem, bidb, tidx), params_scale_bmm1_(params.scale_bmm1) {
}
// Store the tile after softmax.
template<typename Gmem_tile>
inline __device__ void store(Gmem_tile &gmem_tile) {
Accumulator_out acc[MMAS_M][MMAS_N];
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
// The elements.
float tmp_00 = this->elt_[2 * mi + 0][4 * ni + 0];
float tmp_01 = this->elt_[2 * mi + 0][4 * ni + 1];
float tmp_02 = this->elt_[2 * mi + 0][4 * ni + 2];
float tmp_03 = this->elt_[2 * mi + 0][4 * ni + 3];
float tmp_10 = this->elt_[2 * mi + 1][4 * ni + 0];
float tmp_11 = this->elt_[2 * mi + 1][4 * ni + 1];
float tmp_12 = this->elt_[2 * mi + 1][4 * ni + 2];
float tmp_13 = this->elt_[2 * mi + 1][4 * ni + 3];
// Transform to accumulators.
acc[mi][ni].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01);
acc[mi][ni].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);
acc[mi][ni].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);
acc[mi][ni].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);
}
}
// Delegate to the gmem tile to store.
gmem_tile.store(acc);
: Base(params, smem, bidb, tidx)
, params_scale_bmm1_(params.scale_bmm1)
, smem_sum_(static_cast<float*>(smem), tidx)
, smem_max_(static_cast<float*>(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {
}
// Pack the data to a fragment for the next GEMM.
......@@ -470,7 +333,61 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
}
}
}
// Scale FP32 fragments
inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0);
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1);
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4);
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5);
// 2nd row - 4 elements per row.
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2);
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3);
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6);
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7);
}
}
}
template<typename Operator>
__device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) {
for( int mi = 0; mi < 2 * MMAS_M; mi++ ) {
frag[mi] = this->elt_[mi][0];
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
frag[mi] = op(frag[mi], this->elt_[mi][ni]);
}
}
quad_reduce(frag, frag, op);
smem_red.store(frag);
__syncthreads();
typename Smem_tile_red::read_t tmp[2 * MMAS_M];
smem_red.load(tmp);
quad_allreduce(frag, tmp, op);
}
__device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){
MaxOp<float> max;
reduce_(frag, max, smem_max_);
}
__device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
SumOp<float> sum;
reduce_(frag, sum, smem_sum_);
}
const uint32_t params_scale_bmm1_;
Smem_tile_red smem_max_;
Smem_tile_red smem_sum_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -950,4 +950,89 @@ inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_reduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = Allreduce<4>::run(dst[mi], op);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_allreduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
......@@ -28,7 +28,7 @@
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits< 128, 64, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
......
......@@ -28,7 +28,7 @@
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits< 256, 64, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
......
......@@ -28,7 +28,7 @@
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits< 384, 64, 16, 1, 8, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 8, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
......
......@@ -29,7 +29,7 @@
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload_nl.h"
using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
......
......@@ -141,7 +141,7 @@ inline __device__ void compute_dv_1xN(const Params &params) {
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
Gmem_tile_s gmem_s(params, binfo, tidx);
// Create the object to do the softmax.
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
......@@ -231,7 +231,7 @@ inline __device__ void compute_dv_1xN(const Params &params) {
}
float p_sum[2 * M];
softmax.template reduce<fmha::Sum_>(p_sum);
softmax.reduce_sum(p_sum);
const float scalef = reinterpret_cast<const float &>(params.scale_softmax);
#pragma unroll
......@@ -406,7 +406,7 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) {
// Trigger the loads for K.
gmem_k.load(smem_k);
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
Gmem_tile_s gmem_s(params, binfo, tidx);
// Load dP
uint4 s_regs[M][N];
gmem_s.load(s_regs, mask);
......
......@@ -114,11 +114,11 @@ inline __device__ void compute_dv_1xN_nl(const Params &params) {
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
Gmem_tile_s gmem_s(params, binfo, tidx);
using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;
Noloop nl_traits(bidc);
Noloop nl_traits(bidc, binfo);
nl_traits.move_all(gmem_q, gmem_s);
// Trigger the loads for Q.
......@@ -163,8 +163,6 @@ inline __device__ void compute_dv_1xN_nl(const Params &params) {
// Load over the entire sequence length.
for(int l = 0; l < nl_traits.num_steps_;l++) {
const int loop = nl_traits.offset_loop_count(l);
if( loop >= binfo.actual_seqlen ) break;
uint4 s_regs[M][N];
gmem_s.load(s_regs, mask);
......@@ -230,7 +228,7 @@ inline __device__ void compute_dv_1xN_nl(const Params &params) {
}
float p_sum[2 * M];
softmax.template reduce<fmha::Sum_>(p_sum);
softmax.reduce_sum(p_sum);
const float scalef = reinterpret_cast<const float &>(params.scale_softmax);
#pragma unroll
......@@ -400,7 +398,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params &params) {
// Allocate the shared memory tile loader for Q (as B).
Smem_tile_qt smem_qt(&smem_[0], tidx);
// Allocate the global memory tile loader for dP.
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
Gmem_tile_s gmem_s(params, binfo, tidx);
// Allocate the shared memory tile loader for dP.
Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx);
......@@ -414,7 +412,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params &params) {
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);
Noloop nl_traits(bidc);
Noloop nl_traits(bidc, binfo);
nl_traits.move_all(gmem_q, gmem_o, gmem_s);
......
......@@ -28,31 +28,57 @@
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits< 128, 64, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_fprop_fp16_128_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, true>(params);
}
template<bool Is_training>
__global__
void fmha_fprop_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
extern "C" __global__ void fmha_fprop_fp16_128_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, false>(params);
fmha::device_1xN<Kernel_traits, Is_training>(
params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
}
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_128_64_sm80_train_kernel : &fmha_fprop_fp16_128_64_sm80_predict_kernel;
void run_fmha_fp16_128_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
auto kernel = launch_params.is_training ? &fmha_fprop_fp16_128_64_sm80_kernel<true> : &fmha_fprop_fp16_128_64_sm80_kernel<false>;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
const int sm_count = launch_params.props->multiProcessorCount;
int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
if(configure) {
const int heads_total = launch_params.params.b * launch_params.params.h;
std::tie(launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps,
launch_params.elts_per_thread) = fmha::work_dist<Kernel_traits>(total_ctas, heads_total);
return;
}
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
......@@ -28,31 +28,57 @@
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits< 256, 64, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_fprop_fp16_256_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, true>(params);
}
template<bool Is_training>
__global__
void fmha_fprop_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
extern "C" __global__ void fmha_fprop_fp16_256_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, false>(params);
fmha::device_1xN<Kernel_traits, Is_training>(
params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
}
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_256_64_sm80_train_kernel : &fmha_fprop_fp16_256_64_sm80_predict_kernel;
void run_fmha_fp16_256_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
auto kernel = launch_params.is_training ? &fmha_fprop_fp16_256_64_sm80_kernel<true> : &fmha_fprop_fp16_256_64_sm80_kernel<false>;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
const int sm_count = launch_params.props->multiProcessorCount;
int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
if(configure) {
const int heads_total = launch_params.params.b * launch_params.params.h;
std::tie(launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps,
launch_params.elts_per_thread) = fmha::work_dist<Kernel_traits>(total_ctas, heads_total);
return;
}
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
......@@ -26,32 +26,59 @@
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN_reload_v.h"
#include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits< 384, 64, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 4, 0x18u>;
extern "C" __global__ void fmha_fprop_fp16_384_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, true>(params);
}
template<bool Is_training>
__global__
void fmha_fprop_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
extern "C" __global__ void fmha_fprop_fp16_384_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, false>(params);
fmha::device_1xN<Kernel_traits, Is_training>(
params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
}
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_384_64_sm80_train_kernel : &fmha_fprop_fp16_384_64_sm80_predict_kernel;
void run_fmha_fp16_384_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
auto kernel = launch_params.is_training ? &fmha_fprop_fp16_384_64_sm80_kernel<true> : &fmha_fprop_fp16_384_64_sm80_kernel<false>;
constexpr int smem_size = smem_size_v + smem_size_o + smem_size_softmax;
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
const int sm_count = launch_params.props->multiProcessorCount;
int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
if(configure) {
const int heads_total = launch_params.params.b * launch_params.params.h;
std::tie(launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps,
launch_params.elts_per_thread) = fmha::work_dist<Kernel_traits>(total_ctas, heads_total);
return;
}
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
......@@ -27,72 +27,111 @@
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN_nl.h"
using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x00u>;
extern "C" __global__ void fmha_fprop_fp16_512_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, true>(params);
}
extern "C" __global__ void fmha_fprop_fp16_512_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, false>(params);
}
template<bool Is_training>
__global__
void fmha_fprop_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params,
const int total_heads) {
template<int CHUNKS>
__global__ void fmha_fprop_fp16_512_64_sm80_train_nl_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN_nl<CHUNKS,Kernel_traits, true>(params);
fmha::device_1xN<Kernel_traits, Is_training>(params, total_heads);
}
template<int CHUNKS>
__global__ void fmha_fprop_fp16_512_64_sm80_predict_nl_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN_nl<CHUNKS, Kernel_traits, false>(params);
template<bool Is_training>
__global__
void fmha_fprop_fp16_512_64_sm80_kernel_nl(Fused_multihead_attention_fprop_params params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
fmha::device_1xN<Kernel_traits, Is_training>(
params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
}
void run_fmha_fp16_512_64_sm80_(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel<true> : &fmha_fprop_fp16_512_64_sm80_kernel<false>;
auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_kernel : &fmha_fprop_fp16_512_64_sm80_predict_kernel;
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const bool is_training, const int num_chunks, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<2> : &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<2>;
if( num_chunks == 2 ) {
kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<2>
: &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<2>;
} else if( num_chunks == 3 ) {
kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<3>
: &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<3>;
} else if( num_chunks == 4 ) {
kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<4>
: &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<4>;
} else {
assert(false && "Unsupported num_chunks");
const int sm_count = launch_params.props->multiProcessorCount;
int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
const int heads_total = launch_params.params.b * launch_params.params.h;
if(configure) {
using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;
constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas);
size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8;
launch_params.elts_per_thread = heads_per_cta * elts_per_head;
return;
}
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
heads_total);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
void run_fmha_fp16_512_64_sm80_nl_(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel_nl<true> : &fmha_fprop_fp16_512_64_sm80_kernel_nl<false>;
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b, num_chunks);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
const int sm_count = launch_params.props->multiProcessorCount;
int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
if(configure) {
const int heads_total = launch_params.params.b * launch_params.params.h;
std::tie(launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps,
launch_params.elts_per_thread) = fmha::work_dist<Kernel_traits>(total_ctas, heads_total);
return;
}
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
void run_fmha_fp16_512_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
if( launch_params.is_nl ) {
run_fmha_fp16_512_64_sm80_nl_(launch_params, configure);
} else {
run_fmha_fp16_512_64_sm80_(launch_params, configure);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment