Unverified Commit 96850dfa authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #80 from ROCmSoftwarePlatform/IFU-master-2022-07-29

IFU-master-2022-07-29
parents 87fc4125 cc5f83b5
......@@ -63,6 +63,11 @@ struct Mask {
// return row_valid && col_valid;
}
//BERT Mask: if upper left is invalid, none are valid
inline __device__ bool any_valid(int mi, int ni) const {
return is_valid(mi, ni, 0, 0);
}
inline __device__ void load(int it) {
row_offset = it * Cta_tile::M + row;
}
......
......@@ -1266,8 +1266,6 @@ struct Smem_tile_mma_epilogue : public Base {
}
}
template<int M, int N>
inline __device__ void store(const uint4 (&regs)[M][N]) {
for( int mi = 0; mi < M; mi++ ) {
......
......@@ -55,6 +55,88 @@ inline __device__ float apply_exp_(float x, float max) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int COLS> struct ReadType {};
template<> struct ReadType<4> { using T = float;};
template<> struct ReadType<8> { using T = float2;};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile, typename Kernel_traits>
struct Smem_tile_reduce {
// Helper class to distribute MMA tiles reduced over rows per warp over quads.
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
enum { MMAS_M = Mma_tile::MMAS_M };
enum { MMAS_N = Mma_tile::MMAS_N };
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
static constexpr int ROWS = WARPS_M * MMAS_M * 16;
static constexpr int COLS = WARPS_N;
static_assert(COLS == 4 || COLS == 8);
static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
static constexpr int ELTS_PER_TILE = ROWS * COLS;
static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW;
static_assert(THREADS_PER_GROUP == 16); // DEBUG
static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP;
static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS;
static_assert(LOOPS == 1);
using read_t = typename ReadType<COLS>::T;
__device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
int lane = tidx % 32;
int warp = tidx / 32;
int warp_m = warp % WARPS_M;
int warp_n = warp / WARPS_M;
qid_ = lane % 4;
int qp = lane / 4;
// Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.
// This won't affect reading as we assume commutative reduction ops.
const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN);
smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col];
smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_];
}
__device__ inline void store(float (&frag)[2 * MMAS_M]) {
if( qid_ == 0 ) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * WARPS_N;
smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0];
smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1];
}
}
}
__device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * 4;
frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4];
frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4];
}
}
int qid_;
float *smem_write_;
read_t *smem_read_;
};
template<typename Cta_tile, typename Kernel_traits>
struct Softmax_base {
......@@ -136,201 +218,6 @@ struct Softmax_base {
}
}
// Do a CTA-wide reduction.
template<typename Functor>
inline __device__ void reduce_1x4(float (&dst)[MMAS_M * 2]) {
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if( Functor::IS_SUM ) {
// Apply the summation inside the thread.
float tmp[MMAS_M * 2][2];
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
tmp[mi][0] = 0.f;
tmp[mi][1] = 0.f;
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
tmp[mi][0] += elt_[mi][4 * ni + 0];
tmp[mi][0] += elt_[mi][4 * ni + 1];
tmp[mi][1] += elt_[mi][4 * ni + 2];
tmp[mi][1] += elt_[mi][4 * ni + 3];
}
dst[mi] = tmp[mi][0] + tmp[mi][1];
}
} else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = elt_[mi][0];
#pragma unroll
for( int ni = 1; ni < MMAS_N * 4; ++ni ) {
dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1));
__syncwarp();
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2));
__syncwarp();
}
// Store the different values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
if( tidx_ % 4 == 0 ) {
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0];
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1];
}
}
// Make sure the values are in shared memory.
__syncthreads();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4 tmp[1];
if( tidx_ < Cta_tile::M ) {
tmp[0] = reinterpret_cast<const float4 *>(&smem_[0 * ELEMENTS / 2])[tidx_];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y);
tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w);
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z);
// Make sure we can write to shared memory.
__syncthreads();
// Store the value back to shared memory.
if( tidx_ < Cta_tile::M ) {
smem_[tidx_] = tmp[0].x;
}
// Make sure the data is in shared memory.
__syncthreads();
// Finally read the values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
dst[2 * mi + 0] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 0];
dst[2 * mi + 1] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 8];
}
}
// Do a CTA-wide reduction.
template<typename Functor>
inline __device__ void reduce_1x8(float (&dst)[MMAS_M * 2]) {
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if( Functor::IS_SUM ) {
// Apply the summation inside the thread.
float tmp[MMAS_M * 2][2];
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
tmp[mi][0] = 0.f;
tmp[mi][1] = 0.f;
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
tmp[mi][0] += elt_[mi][4 * ni + 0];
tmp[mi][0] += elt_[mi][4 * ni + 1];
tmp[mi][1] += elt_[mi][4 * ni + 2];
tmp[mi][1] += elt_[mi][4 * ni + 3];
}
dst[mi] = tmp[mi][0] + tmp[mi][1];
}
} else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = elt_[mi][0];
#pragma unroll
for( int ni = 1; ni < MMAS_N * 4; ++ni ) {
dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1));
__syncwarp();
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2));
__syncwarp();
}
// Store the different values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
if( tidx_ % 4 == 0 ) {
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0];
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1];
}
}
// Make sure the values are in shared memory.
__syncthreads();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4 tmp[2];
if( tidx_ < Cta_tile::M ) {
tmp[0] = reinterpret_cast<const float4 *>(&smem_[0 * ELEMENTS / 2])[tidx_];
tmp[1] = reinterpret_cast<const float4 *>(&smem_[1 * ELEMENTS / 2])[tidx_];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y);
tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w);
tmp[1].x = Functor::apply(tmp[1].x, tmp[1].y);
tmp[1].z = Functor::apply(tmp[1].z, tmp[1].w);
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z);
tmp[1].x = Functor::apply(tmp[1].x, tmp[1].z);
tmp[0].x = Functor::apply(tmp[0].x, tmp[1].x);
// Make sure we can write to shared memory.
__syncthreads();
// Store the value back to shared memory.
if( tidx_ < Cta_tile::M ) {
smem_[tidx_] = tmp[0].x;
}
// Make sure the data is in shared memory.
__syncthreads();
// Finally read the values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
dst[2 * mi + 0] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 0];
dst[2 * mi + 1] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 8];
}
}
// Do a CTA-wide reduction.
template<typename Functor>
inline __device__ void reduce(float (&dst)[MMAS_M * 2]) {
static_assert(Cta_tile::WARPS_M == 1 && (Cta_tile::WARPS_N == 4 || Cta_tile::WARPS_N == 8));
if( Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 4 ) {
reduce_1x4<Functor>(dst);
} else if( Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 8 ) {
reduce_1x8<Functor>(dst);
} else {
assert(false);
}
// Make sure we are done reading from shared memory.
__syncthreads();
}
// Scale all the elements.
inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
......@@ -372,6 +259,8 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
static_assert(Fragment_a::NUM_REGS == 4);
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
// The MMAs.
enum { MMAS_M = Base::MMAS_M };
enum { MMAS_N = Base::MMAS_N };
......@@ -383,41 +272,15 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
static_assert(std::is_same<Accumulator::Data_type, float>::value);
using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);
// Ctor.
template<typename Params>
inline __device__ Softmax(const Params &params, void *smem, int bidb, int tidx)
: Base(params, smem, bidb, tidx), params_scale_bmm1_(params.scale_bmm1) {
}
// Store the tile after softmax.
template<typename Gmem_tile>
inline __device__ void store(Gmem_tile &gmem_tile) {
Accumulator_out acc[MMAS_M][MMAS_N];
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
// The elements.
float tmp_00 = this->elt_[2 * mi + 0][4 * ni + 0];
float tmp_01 = this->elt_[2 * mi + 0][4 * ni + 1];
float tmp_02 = this->elt_[2 * mi + 0][4 * ni + 2];
float tmp_03 = this->elt_[2 * mi + 0][4 * ni + 3];
float tmp_10 = this->elt_[2 * mi + 1][4 * ni + 0];
float tmp_11 = this->elt_[2 * mi + 1][4 * ni + 1];
float tmp_12 = this->elt_[2 * mi + 1][4 * ni + 2];
float tmp_13 = this->elt_[2 * mi + 1][4 * ni + 3];
// Transform to accumulators.
acc[mi][ni].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01);
acc[mi][ni].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);
acc[mi][ni].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);
acc[mi][ni].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);
}
}
// Delegate to the gmem tile to store.
gmem_tile.store(acc);
: Base(params, smem, bidb, tidx)
, params_scale_bmm1_(params.scale_bmm1)
, smem_sum_(static_cast<float*>(smem), tidx)
, smem_max_(static_cast<float*>(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {
}
// Pack the data to a fragment for the next GEMM.
......@@ -470,7 +333,61 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
}
}
}
// Scale FP32 fragments
inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0);
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1);
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4);
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5);
// 2nd row - 4 elements per row.
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2);
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3);
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6);
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7);
}
}
}
template<typename Operator>
__device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) {
for( int mi = 0; mi < 2 * MMAS_M; mi++ ) {
frag[mi] = this->elt_[mi][0];
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
frag[mi] = op(frag[mi], this->elt_[mi][ni]);
}
}
quad_reduce(frag, frag, op);
smem_red.store(frag);
__syncthreads();
typename Smem_tile_red::read_t tmp[2 * MMAS_M];
smem_red.load(tmp);
quad_allreduce(frag, tmp, op);
}
__device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){
MaxOp<float> max;
reduce_(frag, max, smem_max_);
}
__device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
SumOp<float> sum;
reduce_(frag, sum, smem_sum_);
}
const uint32_t params_scale_bmm1_;
Smem_tile_red smem_max_;
Smem_tile_red smem_sum_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -950,4 +950,89 @@ inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_reduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = Allreduce<4>::run(dst[mi], op);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_allreduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
......@@ -28,7 +28,7 @@
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits< 128, 64, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
......
......@@ -28,7 +28,7 @@
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits< 256, 64, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
......
......@@ -28,7 +28,7 @@
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits< 384, 64, 16, 1, 8, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 8, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
......
......@@ -29,7 +29,7 @@
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload_nl.h"
using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
......
......@@ -141,7 +141,7 @@ inline __device__ void compute_dv_1xN(const Params &params) {
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
Gmem_tile_s gmem_s(params, binfo, tidx);
// Create the object to do the softmax.
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
......@@ -231,7 +231,7 @@ inline __device__ void compute_dv_1xN(const Params &params) {
}
float p_sum[2 * M];
softmax.template reduce<fmha::Sum_>(p_sum);
softmax.reduce_sum(p_sum);
const float scalef = reinterpret_cast<const float &>(params.scale_softmax);
#pragma unroll
......@@ -406,7 +406,7 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) {
// Trigger the loads for K.
gmem_k.load(smem_k);
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
Gmem_tile_s gmem_s(params, binfo, tidx);
// Load dP
uint4 s_regs[M][N];
gmem_s.load(s_regs, mask);
......
......@@ -114,11 +114,11 @@ inline __device__ void compute_dv_1xN_nl(const Params &params) {
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
Gmem_tile_s gmem_s(params, binfo, tidx);
using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;
Noloop nl_traits(bidc);
Noloop nl_traits(bidc, binfo);
nl_traits.move_all(gmem_q, gmem_s);
// Trigger the loads for Q.
......@@ -163,8 +163,6 @@ inline __device__ void compute_dv_1xN_nl(const Params &params) {
// Load over the entire sequence length.
for(int l = 0; l < nl_traits.num_steps_;l++) {
const int loop = nl_traits.offset_loop_count(l);
if( loop >= binfo.actual_seqlen ) break;
uint4 s_regs[M][N];
gmem_s.load(s_regs, mask);
......@@ -230,7 +228,7 @@ inline __device__ void compute_dv_1xN_nl(const Params &params) {
}
float p_sum[2 * M];
softmax.template reduce<fmha::Sum_>(p_sum);
softmax.reduce_sum(p_sum);
const float scalef = reinterpret_cast<const float &>(params.scale_softmax);
#pragma unroll
......@@ -400,7 +398,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params &params) {
// Allocate the shared memory tile loader for Q (as B).
Smem_tile_qt smem_qt(&smem_[0], tidx);
// Allocate the global memory tile loader for dP.
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
Gmem_tile_s gmem_s(params, binfo, tidx);
// Allocate the shared memory tile loader for dP.
Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx);
......@@ -414,7 +412,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params &params) {
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);
Noloop nl_traits(bidc);
Noloop nl_traits(bidc, binfo);
nl_traits.move_all(gmem_q, gmem_o, gmem_s);
......
......@@ -28,31 +28,57 @@
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits< 128, 64, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_fprop_fp16_128_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, true>(params);
}
template<bool Is_training>
__global__
void fmha_fprop_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
extern "C" __global__ void fmha_fprop_fp16_128_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, false>(params);
fmha::device_1xN<Kernel_traits, Is_training>(
params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
}
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_128_64_sm80_train_kernel : &fmha_fprop_fp16_128_64_sm80_predict_kernel;
void run_fmha_fp16_128_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
auto kernel = launch_params.is_training ? &fmha_fprop_fp16_128_64_sm80_kernel<true> : &fmha_fprop_fp16_128_64_sm80_kernel<false>;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
const int sm_count = launch_params.props->multiProcessorCount;
int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
if(configure) {
const int heads_total = launch_params.params.b * launch_params.params.h;
std::tie(launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps,
launch_params.elts_per_thread) = fmha::work_dist<Kernel_traits>(total_ctas, heads_total);
return;
}
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
......@@ -28,31 +28,57 @@
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits< 256, 64, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_fprop_fp16_256_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, true>(params);
}
template<bool Is_training>
__global__
void fmha_fprop_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
extern "C" __global__ void fmha_fprop_fp16_256_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, false>(params);
fmha::device_1xN<Kernel_traits, Is_training>(
params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
}
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_256_64_sm80_train_kernel : &fmha_fprop_fp16_256_64_sm80_predict_kernel;
void run_fmha_fp16_256_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
auto kernel = launch_params.is_training ? &fmha_fprop_fp16_256_64_sm80_kernel<true> : &fmha_fprop_fp16_256_64_sm80_kernel<false>;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
const int sm_count = launch_params.props->multiProcessorCount;
int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
if(configure) {
const int heads_total = launch_params.params.b * launch_params.params.h;
std::tie(launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps,
launch_params.elts_per_thread) = fmha::work_dist<Kernel_traits>(total_ctas, heads_total);
return;
}
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment