Unverified Commit 96850dfa authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #80 from ROCmSoftwarePlatform/IFU-master-2022-07-29

IFU-master-2022-07-29
parents 87fc4125 cc5f83b5
...@@ -63,6 +63,11 @@ struct Mask { ...@@ -63,6 +63,11 @@ struct Mask {
// return row_valid && col_valid; // return row_valid && col_valid;
} }
//BERT Mask: if upper left is invalid, none are valid
inline __device__ bool any_valid(int mi, int ni) const {
return is_valid(mi, ni, 0, 0);
}
inline __device__ void load(int it) { inline __device__ void load(int it) {
row_offset = it * Cta_tile::M + row; row_offset = it * Cta_tile::M + row;
} }
......
...@@ -1266,8 +1266,6 @@ struct Smem_tile_mma_epilogue : public Base { ...@@ -1266,8 +1266,6 @@ struct Smem_tile_mma_epilogue : public Base {
} }
} }
template<int M, int N> template<int M, int N>
inline __device__ void store(const uint4 (&regs)[M][N]) { inline __device__ void store(const uint4 (&regs)[M][N]) {
for( int mi = 0; mi < M; mi++ ) { for( int mi = 0; mi < M; mi++ ) {
......
...@@ -55,6 +55,88 @@ inline __device__ float apply_exp_(float x, float max) { ...@@ -55,6 +55,88 @@ inline __device__ float apply_exp_(float x, float max) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<int COLS> struct ReadType {};
template<> struct ReadType<4> { using T = float;};
template<> struct ReadType<8> { using T = float2;};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile, typename Kernel_traits>
struct Smem_tile_reduce {
// Helper class to distribute MMA tiles reduced over rows per warp over quads.
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
enum { MMAS_M = Mma_tile::MMAS_M };
enum { MMAS_N = Mma_tile::MMAS_N };
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
static constexpr int ROWS = WARPS_M * MMAS_M * 16;
static constexpr int COLS = WARPS_N;
static_assert(COLS == 4 || COLS == 8);
static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
static constexpr int ELTS_PER_TILE = ROWS * COLS;
static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW;
static_assert(THREADS_PER_GROUP == 16); // DEBUG
static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP;
static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS;
static_assert(LOOPS == 1);
using read_t = typename ReadType<COLS>::T;
__device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
int lane = tidx % 32;
int warp = tidx / 32;
int warp_m = warp % WARPS_M;
int warp_n = warp / WARPS_M;
qid_ = lane % 4;
int qp = lane / 4;
// Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.
// This won't affect reading as we assume commutative reduction ops.
const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN);
smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col];
smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_];
}
__device__ inline void store(float (&frag)[2 * MMAS_M]) {
if( qid_ == 0 ) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * WARPS_N;
smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0];
smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1];
}
}
}
__device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * 4;
frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4];
frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4];
}
}
int qid_;
float *smem_write_;
read_t *smem_read_;
};
template<typename Cta_tile, typename Kernel_traits> template<typename Cta_tile, typename Kernel_traits>
struct Softmax_base { struct Softmax_base {
...@@ -136,201 +218,6 @@ struct Softmax_base { ...@@ -136,201 +218,6 @@ struct Softmax_base {
} }
} }
// Do a CTA-wide reduction.
template<typename Functor>
inline __device__ void reduce_1x4(float (&dst)[MMAS_M * 2]) {
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if( Functor::IS_SUM ) {
// Apply the summation inside the thread.
float tmp[MMAS_M * 2][2];
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
tmp[mi][0] = 0.f;
tmp[mi][1] = 0.f;
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
tmp[mi][0] += elt_[mi][4 * ni + 0];
tmp[mi][0] += elt_[mi][4 * ni + 1];
tmp[mi][1] += elt_[mi][4 * ni + 2];
tmp[mi][1] += elt_[mi][4 * ni + 3];
}
dst[mi] = tmp[mi][0] + tmp[mi][1];
}
} else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = elt_[mi][0];
#pragma unroll
for( int ni = 1; ni < MMAS_N * 4; ++ni ) {
dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1));
__syncwarp();
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2));
__syncwarp();
}
// Store the different values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
if( tidx_ % 4 == 0 ) {
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0];
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1];
}
}
// Make sure the values are in shared memory.
__syncthreads();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4 tmp[1];
if( tidx_ < Cta_tile::M ) {
tmp[0] = reinterpret_cast<const float4 *>(&smem_[0 * ELEMENTS / 2])[tidx_];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y);
tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w);
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z);
// Make sure we can write to shared memory.
__syncthreads();
// Store the value back to shared memory.
if( tidx_ < Cta_tile::M ) {
smem_[tidx_] = tmp[0].x;
}
// Make sure the data is in shared memory.
__syncthreads();
// Finally read the values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
dst[2 * mi + 0] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 0];
dst[2 * mi + 1] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 8];
}
}
// Do a CTA-wide reduction.
template<typename Functor>
inline __device__ void reduce_1x8(float (&dst)[MMAS_M * 2]) {
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if( Functor::IS_SUM ) {
// Apply the summation inside the thread.
float tmp[MMAS_M * 2][2];
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
tmp[mi][0] = 0.f;
tmp[mi][1] = 0.f;
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
tmp[mi][0] += elt_[mi][4 * ni + 0];
tmp[mi][0] += elt_[mi][4 * ni + 1];
tmp[mi][1] += elt_[mi][4 * ni + 2];
tmp[mi][1] += elt_[mi][4 * ni + 3];
}
dst[mi] = tmp[mi][0] + tmp[mi][1];
}
} else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = elt_[mi][0];
#pragma unroll
for( int ni = 1; ni < MMAS_N * 4; ++ni ) {
dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1));
__syncwarp();
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2));
__syncwarp();
}
// Store the different values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
if( tidx_ % 4 == 0 ) {
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0];
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1];
}
}
// Make sure the values are in shared memory.
__syncthreads();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4 tmp[2];
if( tidx_ < Cta_tile::M ) {
tmp[0] = reinterpret_cast<const float4 *>(&smem_[0 * ELEMENTS / 2])[tidx_];
tmp[1] = reinterpret_cast<const float4 *>(&smem_[1 * ELEMENTS / 2])[tidx_];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y);
tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w);
tmp[1].x = Functor::apply(tmp[1].x, tmp[1].y);
tmp[1].z = Functor::apply(tmp[1].z, tmp[1].w);
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z);
tmp[1].x = Functor::apply(tmp[1].x, tmp[1].z);
tmp[0].x = Functor::apply(tmp[0].x, tmp[1].x);
// Make sure we can write to shared memory.
__syncthreads();
// Store the value back to shared memory.
if( tidx_ < Cta_tile::M ) {
smem_[tidx_] = tmp[0].x;
}
// Make sure the data is in shared memory.
__syncthreads();
// Finally read the values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
dst[2 * mi + 0] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 0];
dst[2 * mi + 1] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 8];
}
}
// Do a CTA-wide reduction.
template<typename Functor>
inline __device__ void reduce(float (&dst)[MMAS_M * 2]) {
static_assert(Cta_tile::WARPS_M == 1 && (Cta_tile::WARPS_N == 4 || Cta_tile::WARPS_N == 8));
if( Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 4 ) {
reduce_1x4<Functor>(dst);
} else if( Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 8 ) {
reduce_1x8<Functor>(dst);
} else {
assert(false);
}
// Make sure we are done reading from shared memory.
__syncthreads();
}
// Scale all the elements. // Scale all the elements.
inline __device__ void scale(const float (&sum)[MMAS_M * 2]) { inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal. // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
...@@ -372,6 +259,8 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> { ...@@ -372,6 +259,8 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
static_assert(Fragment_a::NUM_REGS == 4); static_assert(Fragment_a::NUM_REGS == 4);
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
// The MMAs. // The MMAs.
enum { MMAS_M = Base::MMAS_M }; enum { MMAS_M = Base::MMAS_M };
enum { MMAS_N = Base::MMAS_N }; enum { MMAS_N = Base::MMAS_N };
...@@ -383,41 +272,15 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> { ...@@ -383,41 +272,15 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
static_assert(std::is_same<Accumulator::Data_type, float>::value); static_assert(std::is_same<Accumulator::Data_type, float>::value);
using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);
// Ctor. // Ctor.
template<typename Params> template<typename Params>
inline __device__ Softmax(const Params &params, void *smem, int bidb, int tidx) inline __device__ Softmax(const Params &params, void *smem, int bidb, int tidx)
: Base(params, smem, bidb, tidx), params_scale_bmm1_(params.scale_bmm1) { : Base(params, smem, bidb, tidx)
} , params_scale_bmm1_(params.scale_bmm1)
, smem_sum_(static_cast<float*>(smem), tidx)
// Store the tile after softmax. , smem_max_(static_cast<float*>(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {
template<typename Gmem_tile>
inline __device__ void store(Gmem_tile &gmem_tile) {
Accumulator_out acc[MMAS_M][MMAS_N];
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
// The elements.
float tmp_00 = this->elt_[2 * mi + 0][4 * ni + 0];
float tmp_01 = this->elt_[2 * mi + 0][4 * ni + 1];
float tmp_02 = this->elt_[2 * mi + 0][4 * ni + 2];
float tmp_03 = this->elt_[2 * mi + 0][4 * ni + 3];
float tmp_10 = this->elt_[2 * mi + 1][4 * ni + 0];
float tmp_11 = this->elt_[2 * mi + 1][4 * ni + 1];
float tmp_12 = this->elt_[2 * mi + 1][4 * ni + 2];
float tmp_13 = this->elt_[2 * mi + 1][4 * ni + 3];
// Transform to accumulators.
acc[mi][ni].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01);
acc[mi][ni].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);
acc[mi][ni].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);
acc[mi][ni].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);
}
}
// Delegate to the gmem tile to store.
gmem_tile.store(acc);
} }
// Pack the data to a fragment for the next GEMM. // Pack the data to a fragment for the next GEMM.
...@@ -470,7 +333,61 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> { ...@@ -470,7 +333,61 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
} }
} }
} }
// Scale FP32 fragments
inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0);
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1);
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4);
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5);
// 2nd row - 4 elements per row.
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2);
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3);
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6);
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7);
}
}
}
template<typename Operator>
__device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) {
for( int mi = 0; mi < 2 * MMAS_M; mi++ ) {
frag[mi] = this->elt_[mi][0];
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
frag[mi] = op(frag[mi], this->elt_[mi][ni]);
}
}
quad_reduce(frag, frag, op);
smem_red.store(frag);
__syncthreads();
typename Smem_tile_red::read_t tmp[2 * MMAS_M];
smem_red.load(tmp);
quad_allreduce(frag, tmp, op);
}
__device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){
MaxOp<float> max;
reduce_(frag, max, smem_max_);
}
__device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
SumOp<float> sum;
reduce_(frag, sum, smem_sum_);
}
const uint32_t params_scale_bmm1_; const uint32_t params_scale_bmm1_;
Smem_tile_red smem_max_;
Smem_tile_red smem_sum_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -950,4 +950,89 @@ inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) { ...@@ -950,4 +950,89 @@ inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_reduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = Allreduce<4>::run(dst[mi], op);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_allreduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha } // namespace fmha
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h" #include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits< 128, 64, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { extern "C" __global__ void fmha_dgrad_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params); fmha::compute_dv_1xN<Kernel_traits>(params);
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h" #include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits< 256, 64, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { extern "C" __global__ void fmha_dgrad_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params); fmha::compute_dv_1xN<Kernel_traits>(params);
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h" #include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits< 384, 64, 16, 1, 8, 0x08u>; using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 8, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { extern "C" __global__ void fmha_dgrad_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params); fmha::compute_dv_1xN<Kernel_traits>(params);
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include "fmha_dgrad_kernel_1xN_reload.h" #include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload_nl.h" #include "fmha_dgrad_kernel_1xN_reload_nl.h"
using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>; using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params); fmha::compute_dv_1xN<Kernel_traits>(params);
......
...@@ -141,7 +141,7 @@ inline __device__ void compute_dv_1xN(const Params &params) { ...@@ -141,7 +141,7 @@ inline __device__ void compute_dv_1xN(const Params &params) {
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
Gmem_tile_s gmem_s(params.s_ptr, params, tidx); Gmem_tile_s gmem_s(params, binfo, tidx);
// Create the object to do the softmax. // Create the object to do the softmax.
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>; using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
...@@ -231,7 +231,7 @@ inline __device__ void compute_dv_1xN(const Params &params) { ...@@ -231,7 +231,7 @@ inline __device__ void compute_dv_1xN(const Params &params) {
} }
float p_sum[2 * M]; float p_sum[2 * M];
softmax.template reduce<fmha::Sum_>(p_sum); softmax.reduce_sum(p_sum);
const float scalef = reinterpret_cast<const float &>(params.scale_softmax); const float scalef = reinterpret_cast<const float &>(params.scale_softmax);
#pragma unroll #pragma unroll
...@@ -406,7 +406,7 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) { ...@@ -406,7 +406,7 @@ inline __device__ void compute_dq_dk_1xN(const Params &params) {
// Trigger the loads for K. // Trigger the loads for K.
gmem_k.load(smem_k); gmem_k.load(smem_k);
Gmem_tile_s gmem_s(params.s_ptr, params, tidx); Gmem_tile_s gmem_s(params, binfo, tidx);
// Load dP // Load dP
uint4 s_regs[M][N]; uint4 s_regs[M][N];
gmem_s.load(s_regs, mask); gmem_s.load(s_regs, mask);
......
...@@ -114,11 +114,11 @@ inline __device__ void compute_dv_1xN_nl(const Params &params) { ...@@ -114,11 +114,11 @@ inline __device__ void compute_dv_1xN_nl(const Params &params) {
// Allocate the shared memory tile loader for K. // Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
Gmem_tile_s gmem_s(params.s_ptr, params, tidx); Gmem_tile_s gmem_s(params, binfo, tidx);
using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>; using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;
Noloop nl_traits(bidc); Noloop nl_traits(bidc, binfo);
nl_traits.move_all(gmem_q, gmem_s); nl_traits.move_all(gmem_q, gmem_s);
// Trigger the loads for Q. // Trigger the loads for Q.
...@@ -163,8 +163,6 @@ inline __device__ void compute_dv_1xN_nl(const Params &params) { ...@@ -163,8 +163,6 @@ inline __device__ void compute_dv_1xN_nl(const Params &params) {
// Load over the entire sequence length. // Load over the entire sequence length.
for(int l = 0; l < nl_traits.num_steps_;l++) { for(int l = 0; l < nl_traits.num_steps_;l++) {
const int loop = nl_traits.offset_loop_count(l);
if( loop >= binfo.actual_seqlen ) break;
uint4 s_regs[M][N]; uint4 s_regs[M][N];
gmem_s.load(s_regs, mask); gmem_s.load(s_regs, mask);
...@@ -230,7 +228,7 @@ inline __device__ void compute_dv_1xN_nl(const Params &params) { ...@@ -230,7 +228,7 @@ inline __device__ void compute_dv_1xN_nl(const Params &params) {
} }
float p_sum[2 * M]; float p_sum[2 * M];
softmax.template reduce<fmha::Sum_>(p_sum); softmax.reduce_sum(p_sum);
const float scalef = reinterpret_cast<const float &>(params.scale_softmax); const float scalef = reinterpret_cast<const float &>(params.scale_softmax);
#pragma unroll #pragma unroll
...@@ -400,7 +398,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params &params) { ...@@ -400,7 +398,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params &params) {
// Allocate the shared memory tile loader for Q (as B). // Allocate the shared memory tile loader for Q (as B).
Smem_tile_qt smem_qt(&smem_[0], tidx); Smem_tile_qt smem_qt(&smem_[0], tidx);
// Allocate the global memory tile loader for dP. // Allocate the global memory tile loader for dP.
Gmem_tile_s gmem_s(params.s_ptr, params, tidx); Gmem_tile_s gmem_s(params, binfo, tidx);
// Allocate the shared memory tile loader for dP. // Allocate the shared memory tile loader for dP.
Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx); Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx);
...@@ -414,7 +412,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params &params) { ...@@ -414,7 +412,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params &params) {
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!! // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);
Noloop nl_traits(bidc); Noloop nl_traits(bidc, binfo);
nl_traits.move_all(gmem_q, gmem_o, gmem_s); nl_traits.move_all(gmem_q, gmem_o, gmem_s);
......
...@@ -28,31 +28,57 @@ ...@@ -28,31 +28,57 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_fprop_kernel_1xN.h" #include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits< 128, 64, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_fprop_fp16_128_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) { template<bool Is_training>
fmha::device_1xN<Kernel_traits, true>(params); __global__
} void fmha_fprop_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
extern "C" __global__ void fmha_fprop_fp16_128_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) { fmha::device_1xN<Kernel_traits, Is_training>(
fmha::device_1xN<Kernel_traits, false>(params); params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
} }
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) { void run_fmha_fp16_128_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
auto kernel = is_training ? &fmha_fprop_fp16_128_64_sm80_train_kernel : &fmha_fprop_fp16_128_64_sm80_predict_kernel;
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); auto kernel = launch_params.is_training ? &fmha_fprop_fp16_128_64_sm80_kernel<true> : &fmha_fprop_fp16_128_64_sm80_kernel<false>;
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax); constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) { if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
} }
dim3 grid(params.h, params.b); const int sm_count = launch_params.props->multiProcessorCount;
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params); int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
if(configure) {
const int heads_total = launch_params.params.b * launch_params.params.h;
std::tie(launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps,
launch_params.elts_per_thread) = fmha::work_dist<Kernel_traits>(total_ctas, heads_total);
return;
}
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
} }
...@@ -28,31 +28,57 @@ ...@@ -28,31 +28,57 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_fprop_kernel_1xN.h" #include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits< 256, 64, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_fprop_fp16_256_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) { template<bool Is_training>
fmha::device_1xN<Kernel_traits, true>(params); __global__
} void fmha_fprop_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
extern "C" __global__ void fmha_fprop_fp16_256_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) { fmha::device_1xN<Kernel_traits, Is_training>(
fmha::device_1xN<Kernel_traits, false>(params); params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
} }
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) { void run_fmha_fp16_256_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
auto kernel = is_training ? &fmha_fprop_fp16_256_64_sm80_train_kernel : &fmha_fprop_fp16_256_64_sm80_predict_kernel;
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); auto kernel = launch_params.is_training ? &fmha_fprop_fp16_256_64_sm80_kernel<true> : &fmha_fprop_fp16_256_64_sm80_kernel<false>;
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax); constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) { if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
} }
dim3 grid(params.h, params.b); const int sm_count = launch_params.props->multiProcessorCount;
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params); int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
if(configure) {
const int heads_total = launch_params.params.b * launch_params.params.h;
std::tie(launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps,
launch_params.elts_per_thread) = fmha::work_dist<Kernel_traits>(total_ctas, heads_total);
return;
}
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
} }
...@@ -26,32 +26,59 @@ ...@@ -26,32 +26,59 @@
******************************************************************************/ ******************************************************************************/
#include "fmha.h" #include "fmha.h"
#include "fmha_fprop_kernel_1xN_reload_v.h" #include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits< 384, 64, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 4, 0x18u>;
extern "C" __global__ void fmha_fprop_fp16_384_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) { template<bool Is_training>
fmha::device_1xN<Kernel_traits, true>(params); __global__
} void fmha_fprop_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
extern "C" __global__ void fmha_fprop_fp16_384_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) { fmha::device_1xN<Kernel_traits, Is_training>(
fmha::device_1xN<Kernel_traits, false>(params); params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
} }
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) { void run_fmha_fp16_384_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
auto kernel = is_training ? &fmha_fprop_fp16_384_64_sm80_train_kernel : &fmha_fprop_fp16_384_64_sm80_predict_kernel;
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); auto kernel = launch_params.is_training ? &fmha_fprop_fp16_384_64_sm80_kernel<true> : &fmha_fprop_fp16_384_64_sm80_kernel<false>;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_v + smem_size_o + smem_size_softmax; constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) { if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
} }
dim3 grid(params.h, params.b); const int sm_count = launch_params.props->multiProcessorCount;
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params); int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
if(configure) {
const int heads_total = launch_params.params.b * launch_params.params.h;
std::tie(launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps,
launch_params.elts_per_thread) = fmha::work_dist<Kernel_traits>(total_ctas, heads_total);
return;
}
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
} }
...@@ -27,72 +27,111 @@ ...@@ -27,72 +27,111 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_fprop_kernel_1xN.h" #include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN_nl.h"
using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>; using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x00u>;
extern "C" __global__ void fmha_fprop_fp16_512_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) { template<bool Is_training>
fmha::device_1xN<Kernel_traits, true>(params); __global__
} void fmha_fprop_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params,
const int total_heads) {
extern "C" __global__ void fmha_fprop_fp16_512_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, false>(params);
}
template<int CHUNKS> fmha::device_1xN<Kernel_traits, Is_training>(params, total_heads);
__global__ void fmha_fprop_fp16_512_64_sm80_train_nl_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN_nl<CHUNKS,Kernel_traits, true>(params);
} }
template<int CHUNKS> template<bool Is_training>
__global__ void fmha_fprop_fp16_512_64_sm80_predict_nl_kernel(Fused_multihead_attention_fprop_params params) { __global__
fmha::device_1xN_nl<CHUNKS, Kernel_traits, false>(params); void fmha_fprop_fp16_512_64_sm80_kernel_nl(Fused_multihead_attention_fprop_params params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
fmha::device_1xN<Kernel_traits, Is_training>(
params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
} }
void run_fmha_fp16_512_64_sm80_(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) { auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel<true> : &fmha_fprop_fp16_512_64_sm80_kernel<false>;
auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_kernel : &fmha_fprop_fp16_512_64_sm80_predict_kernel; constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
if( smem_size >= 48 * 1024 ) { if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
} }
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const bool is_training, const int num_chunks, cudaStream_t stream) { const int sm_count = launch_params.props->multiProcessorCount;
int ctas_per_sm;
auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<2> : &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<2>; FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
if( num_chunks == 2 ) { int total_ctas = sm_count * ctas_per_sm;
kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<2>
: &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<2>; const int heads_total = launch_params.params.b * launch_params.params.h;
} else if( num_chunks == 3 ) { if(configure) {
kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<3>
: &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<3>; using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
} else if( num_chunks == 4 ) { constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;
kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<4> constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
: &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<4>; constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
} else {
assert(false && "Unsupported num_chunks"); size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas);
size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8;
launch_params.elts_per_thread = heads_per_cta * elts_per_head;
return;
} }
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); dim3 grid(total_ctas);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; launch_params.params,
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; heads_total);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
void run_fmha_fp16_512_64_sm80_nl_(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel_nl<true> : &fmha_fprop_fp16_512_64_sm80_kernel_nl<false>;
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
if( smem_size >= 48 * 1024 ) { if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
} }
dim3 grid(params.h, params.b, num_chunks); const int sm_count = launch_params.props->multiProcessorCount;
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params); int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
if(configure) {
const int heads_total = launch_params.params.b * launch_params.params.h;
std::tie(launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps,
launch_params.elts_per_thread) = fmha::work_dist<Kernel_traits>(total_ctas, heads_total);
return;
}
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
void run_fmha_fp16_512_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
if( launch_params.is_nl ) {
run_fmha_fp16_512_64_sm80_nl_(launch_params, configure);
} else {
run_fmha_fp16_512_64_sm80_(launch_params, configure);
}
} }
/****************************************************************************** /***************************************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
* *
* Redistribution and use in source and binary forms, with or without * Redistribution and use in source and binary forms, with or without
...@@ -35,7 +35,159 @@ namespace fmha { ...@@ -35,7 +35,159 @@ namespace fmha {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Kernel_traits, bool Is_training, typename Params> inline __device__ void device_1xN(const Params &params) { template<typename Kernel_traits>
struct Gemm_Q_K_base {
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
using Fragment_q = typename Smem_tile_q::Fragment;
using Fragment_k = typename Smem_tile_k::Fragment;
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2;
__device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k, const int tidx)
: smem_q(smem_ptr_q, tidx)
, smem_k(smem_ptr_k, tidx) {
}
__device__ inline void load_q() {
smem_q.load(frag_q[0], 0);
}
__device__ inline void reload_q() {
smem_q.load(frag_q[0], 0);
}
Fragment_q frag_q[2][Mma_tile_p::MMAS_M];
Smem_tile_q smem_q;
Smem_tile_k smem_k;
};
template<typename Kernel_traits, bool K_in_regs>
struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
using Base = Gemm_Q_K_base<Kernel_traits>;
using Smem_tile_o = typename Base::Smem_tile_o;
using Smem_tile_q = typename Base::Smem_tile_q;
using Smem_tile_k = typename Base::Smem_tile_k;
using Fragment_k = typename Base::Fragment_k;
using Mma_tile_p = typename Base::Mma_tile_p;
enum { SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V };
enum { SMEM_OFFSET_O = Smem_tile_q::BYTES_PER_TILE };
enum { SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE) };
// Q | K / V
// | O | SOFTMAX
static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE
+ std::max((SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE,
Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX);
__device__ inline Gemm_Q_K(char * smem_, const int tidx)
: Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {
}
__device__ inline void load_k(){
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
Base::smem_k.load(frag_k[ki], ki);
}
}
template<typename Acc, int M, int N>
__device__ inline void operator()(Acc (&acc_p)[M][N]){
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
Base::smem_q.load(Base::frag_q[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
}
__device__ inline void reload_k(){
// Noop.
}
Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
};
template<typename Kernel_traits>
struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
using Base = Gemm_Q_K_base<Kernel_traits>;
using Smem_tile_o = typename Base::Smem_tile_o;
using Smem_tile_q = typename Base::Smem_tile_q;
using Smem_tile_k = typename Base::Smem_tile_k;
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
using Fragment_k = typename Base::Fragment_k;
using Mma_tile_p = typename Base::Mma_tile_p;
Fragment_k frag_k[2][Mma_tile_p::MMAS_N];
enum { SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V };
enum { SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE) };
static_assert(Smem_tile_v::BYTES_PER_TILE == (int) Smem_tile_k::BYTES_PER_TILE);
enum { SMEM_OFFSET_O = SMEM_OFFSET_V + Smem_tile_v::BYTES_PER_TILE };
// Q | K/V + O + SOFTMAX
static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE
+ (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE
+ Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX;
__device__ inline Gemm_Q_K(char * smem_, const int tidx)
: Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {
}
__device__ inline void load_k(){
Base::smem_k.load(frag_k[0], 0);
}
template<typename Acc, int M, int N>
__device__ inline void operator()(Acc (&acc_p)[M][N]){
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
Base::smem_q.load(Base::frag_q[ki & 1], ki);
Base::smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
}
__device__ inline void reload_k(){
Base::smem_k.load(frag_k[0], 0);
}
};
template<typename Kernel_traits>
constexpr size_t get_dynamic_smem_size(){
return Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>::SMEM_BYTES;
}
template<typename Kernel_traits, bool Is_training, typename Params, typename Prng>
inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, const int begin, const int steps, Prng & ph) {
// The description of the CTA tile for the 1st batched GEMM. // The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p; using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
...@@ -49,13 +201,9 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de ...@@ -49,13 +201,9 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
// The global memory tile to load Q. // The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load K. // The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load V. // The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
...@@ -69,81 +217,88 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de ...@@ -69,81 +217,88 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
// The number of threads per row.
enum { THREADS_PER_ROW = 32 };
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
// Shared memory. // Shared memory.
extern __shared__ char smem_[]; extern __shared__ char smem_[];
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index. // The thread index.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx); const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() ) if( binfo.stop_early() ) return;
return;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));
Mask<Cta_tile_p> mask(params, binfo, tidx);
Gemm1 gemm_q_k(smem_, tidx);
// Allocate the global memory tile loader for Q. // Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx); Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the shared memory tile loader for Q. // Allocate the global memory tile loader for O.
Smem_tile_q smem_q(&smem_[0], tidx); Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx);
// Wind gmem tiles to the correct position.
for( int it = 0; it < begin; it++ ) {
gmem_q.move();
gmem_s.move();
gmem_o.move();
}
fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for K. // Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx); Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for V. // Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx); Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v; // The base pointer of smem_v;
char *smem_v_ = nullptr; char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE];
} else {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];
}
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!! // Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx); Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!! // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
// Trigger the loads for K. // Trigger the loads for K.
gmem_k.load(gemm_q_k.smem_k);
// Trigger the loads for Q.
gmem_q.load(gemm_q_k.smem_q);
// Trigger the loads for V.
gmem_v.load(smem_v); gmem_v.load(smem_v);
// Commit the data for Q and K to shared memory. const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
gmem_q.commit(smem_q); #pragma unroll
gmem_k.commit(smem_k); for(int it=0;it < Gmem_tile_k::LDGS;it++){
gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
}
// Commit the data for V to shared memory.
// Commit the data for Q and V to shared memory.
gmem_q.commit(gemm_q_k.smem_q);
gmem_v.commit(smem_v);
// Commit the data for K to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_v.commit(smem_v); gmem_k.commit(gemm_q_k.smem_k);
} }
// Make sure the data is in shared memory.
__syncthreads(); __syncthreads();
// Load the fragments for Q. // Load the fragments for Q.
typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M]; gemm_q_k.load_q();
smem_q.load(frag_q[0], 0);
// Load the fragments for K. We keep the data in registers during the entire kernel. // Load the fragments for V. We keep the data in registers during the entire kernel.
typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N]; typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
#pragma unroll #pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) { for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
smem_k.load(frag_k[ki], ki); smem_v.load(frag_v[ki], ki);
} }
// Commit the data for V to shared memory if it has not been done already. // Commit the data for V to shared memory if it has not been done already.
...@@ -152,61 +307,41 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de ...@@ -152,61 +307,41 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
__syncthreads(); __syncthreads();
// Commit the data to shared memory for V. // Commit the data to shared memory for V.
gmem_v.commit(smem_v); gmem_k.commit(gemm_q_k.smem_k);
// Make sure the data is in shared memory. // Make sure the data is in shared memory.
__syncthreads(); __syncthreads();
} }
// Load the fragments for V. We keep the data in registers during the entire kernel. // Load the fragments for K.
typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N]; gemm_q_k.load_k();
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
smem_v.load(frag_v[ki], ki);
}
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
// Create the object to do the softmax. // Create the object to do the softmax.
using Softmax = fmha::Softmax< Cta_tile_p, Kernel_traits>; Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);
Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);
enum { THREADS_PER_ROW = 32 };
enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };
// Load over the entire sequence length. // Load over the entire sequence length.
for( int l = 0; l < STEPS; l++ ) { for( int l = 0; l < steps; l++ ) {
const int loop = l * Cta_tile_p::M; if(begin + l * Cta_tile_p::M >= binfo.actual_seqlen) break;
if( loop >= binfo.actual_seqlen )
break;
// Declare the accumulators for the 1st gemm. // Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p); fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T. // Do this part of P^T = (Q * K^T)^T.
#pragma unroll gemm_q_k(acc_p);
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values. // Trigger the load for the next Q values.
smem_q.load(frag_q[ki & 1], ki); if( l < steps - 1) {
// Do the math for the values already in registers. gemm_q_k.smem_q.move_to_next_write_buffer();
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); gmem_q.move();
} gmem_q.load(gemm_q_k.smem_q);
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
} }
// Load the mask for that iteration. // Load the mask for that iteration.
mask.load(l); mask.load(begin + l);
// Convert from the accumulator type to FP32 for Softmax. // Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p); softmax.unpack_noscale(acc_p);
// Apply the mask. // Apply the mask.
softmax.apply_mask(mask); softmax.apply_mask(mask);
...@@ -217,21 +352,21 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de ...@@ -217,21 +352,21 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
} }
// Compute the max. // Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2]; float p_max[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Max_>(p_max); //softmax.template reduce<fmha::Max_>(p_max);
softmax.reduce_max(p_max);
// Make sure we are done reading shared memory.
__syncthreads();
// Compute the exponential value. // Compute the exponential value.
softmax.apply_exp(p_max); softmax.apply_exp(p_max);
// Compute the sum. // Compute the sum.
float p_sum[Mma_tile_p::MMAS_M * 2]; float p_sum[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Sum_>(p_sum); softmax.reduce_sum(p_sum);
// Finalize softmax on the accumulators of P^T. // Finalize softmax on the accumulators of P^T.
softmax.scale(p_sum); softmax.scale(p_sum);
using Frag_p = fmha::Fragment_a<fmha::Row>;
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
if( Is_training ) { if( Is_training ) {
auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; }; auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };
#pragma unroll #pragma unroll
...@@ -241,8 +376,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de ...@@ -241,8 +376,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
#pragma unroll #pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
float4 tmp = uniform4(ph()); float4 tmp = uniform4(ph());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from // We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros
// pre-existing zeros
softmax.elt_[2 * mi + ii][4 * ni + 0] = softmax.elt_[2 * mi + ii][4 * ni + 0] =
encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]); encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);
softmax.elt_[2 * mi + ii][4 * ni + 1] = softmax.elt_[2 * mi + ii][4 * ni + 1] =
...@@ -254,20 +388,18 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de ...@@ -254,20 +388,18 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
} }
} }
} }
gmem_s.store(softmax.elt_, mask); softmax.pack(frag_p);
gmem_s.store(frag_p, mask);
gmem_s.move(); gmem_s.move();
} else {
softmax.pack(frag_p);
} }
// Trigger the load for the next Q values. // Commit the values for Q into shared memory.
if(l < STEPS - 1) { if(l < steps - 1) {
smem_q.move_to_next_write_buffer(); gmem_q.commit(gemm_q_k.smem_q);
gmem_q.move();
gmem_q.load(smem_q);
} }
using Frag_p = fmha::Fragment_a< fmha::Row>;
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
softmax.pack(frag_p);
#pragma unroll #pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) { for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll #pragma unroll
...@@ -316,21 +448,84 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de ...@@ -316,21 +448,84 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
// Move to the next part of the output. // Move to the next part of the output.
gmem_o.move(); gmem_o.move();
gemm_q_k.reload_k();
// Commit the values for Q into shared memory. // Commit the values for Q into shared memory.
if(l < STEPS - 1) { if(l < steps - 1) {
gmem_q.commit(smem_q); gemm_q_k.reload_q();
} }
// Make sure the data is in shared memory. } // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_training, typename Params>
inline __device__ void device_1xN(const Params &params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
constexpr int STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;
const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
for( int it = 0; it < num_full_heads; it++ ) {
const int bidx = it * gridDim.x + blockIdx.x;
const int bidh = bidx % params.h;
const int bidb = bidx / params.h;
fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, 0, STEPS, ph);
__syncthreads(); __syncthreads();
}
if( main_group_size == 0 )
return;
const int head_offset = num_full_heads * gridDim.x;
if( blockIdx.x < main_group_size * num_main_groups ) {
// process within heads
const int group = blockIdx.x % num_main_groups;
const int bidx = blockIdx.x / num_main_groups;
const int bidh = (head_offset + bidx) % params.h;
const int bidb = (head_offset + bidx) / params.h;
const int offset = group * main_steps;
fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, offset, main_steps, ph);
} else {
if(rest_steps == 0 ) return;
// process across heads
const int bidx = blockIdx.x - main_group_size * num_main_groups;
const int offset = num_main_groups * main_steps;
const int total_heads = params.b * params.h;
const int rest_ctas = gridDim.x - main_group_size * num_main_groups;
for( int it = head_offset + bidx; it < total_heads; it += rest_ctas ) {
const int bidh = it % params.h;
const int bidb = it / params.h;
fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, offset, rest_steps, ph);
__syncthreads();
}
}
}
// Trigger the loads for the values of Q for the next iteration. ////////////////////////////////////////////////////////////////////////////////////////////////////
smem_q.load(frag_q[0], 0);
} // Outer loop over the sequence length. template<typename Kernel_traits, bool Is_training, typename Params>
inline __device__ void device_1xN(const Params &params, const int total_heads) {
const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
constexpr int STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;
for(int bidx = blockIdx.x; bidx < total_heads; bidx += gridDim.x){
const int bidh = bidx % params.h;
const int bidb = bidx / params.h;
fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, 0, STEPS, ph);
__syncthreads();
}
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha } // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int CHUNKS, typename Kernel_traits, bool Is_training, typename Params>
inline __device__ void device_1xN_nl(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
// The global memory tile to store S/D.
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;
// Shared memory.
extern __shared__ char smem_[];
const int bidc = blockIdx.z;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
Noloop nl_traits(bidc);
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));
fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[0], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = nullptr;
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE];
} else {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];
}
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
nl_traits.move_all(gmem_q, gmem_o, gmem_s);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
// Trigger the loads for K.
gmem_v.load(smem_v);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Commit the data for V to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_v.commit(smem_v);
}
// Make sure the data is in shared memory.
__syncthreads();
// Load the fragments for Q.
typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];
smem_q.load(frag_q[0], 0);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
smem_k.load(frag_k[ki], ki);
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_v.commit(smem_v);
// Make sure the data is in shared memory.
__syncthreads();
}
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
smem_v.load(frag_v[ki], ki);
}
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
// Create the object to do the softmax.
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);
// The number of threads per row.
enum { THREADS_PER_ROW = 32 };
// Load over the entire sequence length.
for(int l = 0; l < nl_traits.num_steps_;l++) {
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_q.load(frag_q[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Trigger the load for the next Q values.
if( l < nl_traits.num_steps_- 1) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
// Load the mask for that iteration.
mask.load(nl_traits.loop_offset_ + l);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads();
}
// Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Max_>(p_max);
// Make sure we are done reading shared memory.
__syncthreads();
// Compute the exponential value.
softmax.apply_exp(p_max);
// Compute the sum.
float p_sum[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Sum_>(p_sum);
// Finalize softmax on the accumulators of P^T.
softmax.scale(p_sum);
if( Is_training ) {
auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
float4 tmp = uniform4(ph());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros
softmax.elt_[2 * mi + ii][4 * ni + 0] =
encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);
softmax.elt_[2 * mi + ii][4 * ni + 1] =
encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);
softmax.elt_[2 * mi + ii][4 * ni + 2] =
encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);
softmax.elt_[2 * mi + ii][4 * ni + 3] =
encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);
}
}
}
gmem_s.store(softmax.elt_, mask);
gmem_s.move();
}
using Frag_p = fmha::Fragment_a<fmha::Row>;
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
softmax.pack(frag_p);
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {
//"Apply" the dropout.
frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);
frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);
}
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out);
// Make sure the data was read from shared memory.
if( ii < Gmem_tile_o::LOOPS - 1 ) {
__syncthreads();
}
// Output the values.
gmem_o.store(out, ii);
}
// Move to the next part of the output.
gmem_o.move();
// Commit the values for Q into shared memory.
if( l < nl_traits.num_steps_- 1) {
gmem_q.commit(smem_q);
__syncthreads();
smem_q.load(frag_q[0], 0);
}
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Kernel_traits, bool Is_training, typename Params> inline __device__ void device_1xN(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
// Shared memory.
extern __shared__ char smem_[];
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
Mask<Cta_tile_p> mask(params, binfo, tidx);
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));
static_assert(2 * Mma_tile_p::MMAS_M * 4 * Mma_tile_p::MMAS_N <= 64);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[0], tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = nullptr;
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
smem_v_ = &smem_[0];
} else {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];
}
static_assert(Kernel_traits::SHARE_SMEM_FOR_K_AND_V);
static_assert(Smem_tile_k::BYTES_PER_TILE == Smem_tile_v::BYTES_PER_TILE);
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[Smem_tile_v::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_v::BYTES_PER_TILE], tidx);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
// Trigger the loads for K.
gmem_v.load(smem_v);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Commit the data for V to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_v.commit(smem_v);
}
// Make sure the data is in shared memory.
__syncthreads();
// Load the fragments for Q.
typename Smem_tile_q::Fragment frag_q[1][Mma_tile_p::MMAS_M];
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
smem_k.load(frag_k[ki], ki);
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_v.commit(smem_v);
}
enum { BITS_PER_ELT_S = sizeof(typename fmha::A_type) * 8 };
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
// Create the object to do the softmax.
using Softmax = fmha::Softmax< Cta_tile_p, Kernel_traits>;
Softmax softmax(params, &smem_[Smem_tile_v::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);
constexpr int SMEM_BYTES_SOFTMAX = Softmax::ELEMENTS * sizeof(float);
static_assert(SMEM_BYTES_SOFTMAX == Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float));
enum { THREADS_PER_ROW = 32 };
const float pinv = 1.f / params.p_dropout;
// Load over the entire sequence length.
for( int loop = 0, outer = 0; loop < Cta_tile_p::N; loop += Cta_tile_p::M, outer++ ) {
if( loop >= binfo.actual_seqlen )
break;
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_q.load(frag_q[0], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, frag_q[0], frag_k[ki]);
}
// Load the mask for that iteration.
mask.load(outer);
// Convert from the accumulator typ e to FP32 for Softmax.
softmax.unpack(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
static_assert(2 * Mma_tile_p::MMAS_M * 4 * Mma_tile_p::MMAS_N <= 64);
// Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Max_>(p_max);
// Make sure we are done reading shared memory.
__syncthreads();
// Compute the exponential value.
softmax.apply_exp(p_max);
// Compute the sum.
float p_sum[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Sum_>(p_sum);
// Finalize softmax on the accumulators of P^T.
softmax.scale(p_sum);
__syncthreads();
if( Is_training ) {
auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
float4 tmp = uniform4(ph());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from
// pre-existing zeros
softmax.elt_[2 * mi + ii][4 * ni + 0] =
encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);
softmax.elt_[2 * mi + ii][4 * ni + 1] =
encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);
softmax.elt_[2 * mi + ii][4 * ni + 2] =
encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);
softmax.elt_[2 * mi + ii][4 * ni + 3] =
encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);
}
}
}
gmem_s.store(softmax.elt_, mask);
gmem_s.move();
}
// Trigger the load for the next Q values.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
typename Smem_tile_v::Fragment frag_v[1][Mma_tile_o::MMAS_N];
using Frag_p = fmha::Fragment_a< fmha::Row>;
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
softmax.pack(frag_p);
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {
//"Apply" the dropout.
frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);
frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of V values.
smem_v.load(frag_v[0], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_o, frag_p[ki], frag_v[0]);
}
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out);
// Always sync after last iter: shared smem_q and smem_o!
__syncthreads();
// Output the values.
gmem_o.store(out, ii);
}
// same smem as o
// Move to the next part of the output.
gmem_o.move();
// Commit the values for Q into shared memory.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
gmem_q.commit(smem_q);
}
// Make sure the data is in shared memory.
__syncthreads();
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
...@@ -79,17 +79,19 @@ struct Noloop_traits{ ...@@ -79,17 +79,19 @@ struct Noloop_traits{
enum{ STEP = Cta_tile::M }; enum{ STEP = Cta_tile::M };
enum{ SEQLEN = Cta_tile::N }; enum{ SEQLEN = Cta_tile::N };
// The size of the subsequence this CTA is processing template<typename Block_info>
enum { SUBSEQ = SEQLEN / CHUNKS }; inline __device__ Noloop_traits(const int bidc, const Block_info& binfo)
static_assert(SUBSEQ * CHUNKS == SEQLEN); : bidc_(bidc) {
const int seqlen = binfo.actual_seqlen;
const int steps = (seqlen + STEP - 1) / STEP;
const int steps_per_chunk = (steps + CHUNKS - 1) / CHUNKS;
const int step_begin = bidc_ * steps_per_chunk;
const int step_end = min(steps, (bidc_ + 1) * steps_per_chunk);
const int actual_steps = max(0, step_end - step_begin);
loop_offset_ = step_begin;
num_steps_ = actual_steps;
// The number of steps to process the subsequence
enum { NUM_STEPS = SUBSEQ / STEP };
static_assert(NUM_STEPS * Cta_tile::M == SUBSEQ);
inline __device__ Noloop_traits(const int bidc)
: loop_offset_(NUM_STEPS * bidc)
, bidc_(bidc) {
} }
template<typename ... Tiles> template<typename ... Tiles>
...@@ -115,54 +117,62 @@ struct Noloop_traits{ ...@@ -115,54 +117,62 @@ struct Noloop_traits{
return (loop_offset_ + l) * STEP; return (loop_offset_ + l) * STEP;
} }
const int loop_offset_;
const uint32_t bidc_; const uint32_t bidc_;
const int num_steps_ = NUM_STEPS; int loop_offset_;
int num_steps_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile> template<typename Kernel_traits>
struct Noloop_traits<3, Cta_tile>{ std::tuple<int , int, int, int, int, int> work_dist(const int total_ctas, const int heads_total) {
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum{ STEP = Cta_tile::M }; constexpr int STEPS_PER_HEAD = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;
enum{ SEQLEN = Cta_tile::N };
const int num_full_heads = heads_total / total_ctas;
static_assert(STEP == 16 && SEQLEN == 512); const int heads_last_wave = heads_total % total_ctas;
inline __device__ Noloop_traits(const int bidc) int num_main_groups = 0;
: bidc_(bidc) int main_steps = 0;
, num_steps_(bidc < 2 ? 11 : 10) int rest_steps = 0;
, loop_offset_(bidc * 11) { if( heads_last_wave > 0 ) {
} // Number of CTA groups that process within heads.
num_main_groups = total_ctas / heads_last_wave;
template<typename ... Tiles> // Remaining CTAs that process between heads.
inline __device__ void move_all(Tiles & ... tiles) const { const int rest_ctas = total_ctas - (heads_last_wave * num_main_groups);
using expand_type = int[]; if(rest_ctas == 0) {
for( int s = 0; s < loop_offset_; s++ ) { // We have exactly "num_main_groups" CTAs to process each of the remaining heads.
expand_type{ (tiles.move(), 0)... }; main_steps = (STEPS_PER_HEAD + num_main_groups - 1) / num_main_groups;
num_main_groups = STEPS_PER_HEAD / main_steps; // Here: main_step > 0
rest_steps = STEPS_PER_HEAD % main_steps;
} else {
// Ideal number of steps if we could load-balance as evenly as possible.
const int steps_ideal = (heads_last_wave * STEPS_PER_HEAD + total_ctas - 1) / total_ctas;
// Iterations that a "rest" CTA has to do at most.
const int max_rest_iters = (heads_last_wave + rest_ctas - 1) / rest_ctas;
// Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main CTAs.
main_steps = steps_ideal;
rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
for( ; main_steps * num_main_groups < STEPS_PER_HEAD; main_steps++ ) {
rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
const int max_rest_total_steps = rest_steps * max_rest_iters;
if( max_rest_total_steps < main_steps )
break;
}
rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
} }
} }
inline __device__ int get_idx_dk() const { using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
//return bidc_; using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
return bidc_ * 2 + 0;
}
inline __device__ int get_idx_dv() const {
//return CHUNKS + bidc_;
return bidc_ * 2 + 1;
}
inline __device__ int offset_loop_count(const int l) { const int max_steps = STEPS_PER_HEAD * num_full_heads + std::max(main_steps, rest_steps);
// convert loop counter to position in the outer sequence const int elts_per_thread_per_step = Mma_tile_p::MMAS_M * Mma_tile_p::MMAS_N * 8;
return (loop_offset_ + l) * STEP; const int elts_per_thread = max_steps * elts_per_thread_per_step;
}
const int loop_offset_; return {num_full_heads, num_main_groups, heads_last_wave, main_steps, rest_steps, elts_per_thread};
const uint32_t bidc_; }
const int num_steps_;
};
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
#include <torch/torch.h>
#include <vector>
#include <cstdint>
// CUDA forward declarations
std::vector<at::Tensor> focal_loss_forward_cuda(
const at::Tensor &cls_output,
const at::Tensor &cls_targets_at_level,
const at::Tensor &num_positives_sum,
const int64_t num_real_classes,
const float alpha,
const float gamma,
const float smoothing_factor);
at::Tensor focal_loss_backward_cuda(
const at::Tensor &grad_output,
const at::Tensor &partial_grad,
const at::Tensor &num_positives_sum);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> focal_loss_forward(
const at::Tensor &cls_output,
const at::Tensor &cls_targets_at_level,
const at::Tensor &num_positives_sum,
const int64_t num_real_classes,
const float alpha,
const float gamma,
const float smoothing_factor
) {
CHECK_INPUT(cls_output);
CHECK_INPUT(cls_targets_at_level);
CHECK_INPUT(num_positives_sum);
return focal_loss_forward_cuda(
cls_output,
cls_targets_at_level,
num_positives_sum,
num_real_classes,
alpha,
gamma,
smoothing_factor);
}
at::Tensor focal_loss_backward(
const at::Tensor &grad_output,
const at::Tensor &partial_grad,
const at::Tensor &num_positives_sum
) {
CHECK_INPUT(grad_output);
CHECK_INPUT(partial_grad);
return focal_loss_backward_cuda(grad_output, partial_grad, num_positives_sum);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &focal_loss_forward,
"Focal loss calculation forward (CUDA)");
m.def("backward", &focal_loss_backward,
"Focal loss calculation backward (CUDA)");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#define ASSERT_UINT4_ALIGNED(PTR) \
TORCH_INTERNAL_ASSERT(is_aligned<uint4>(PTR), "Tensor " #PTR " is not uint4 aligned")
template <class T> bool is_aligned(const void *ptr) noexcept {
auto iptr = reinterpret_cast<std::uintptr_t>(ptr);
return !(iptr % alignof(T));
}
template <bool SMOOTHING, int ILP, typename scalar_t, typename labelscalar_t,
typename accscalar_t, typename outscalar_t>
__global__ void focal_loss_forward_cuda_kernel(
outscalar_t *loss, scalar_t *partial_grad,
const scalar_t *__restrict__ cls_output,
const labelscalar_t *__restrict__ cls_targets_at_level,
const float *__restrict__ num_positives_sum, const int64_t num_examples,
const int64_t num_classes, const int64_t num_real_classes,
const float alpha, const float gamma, const float smoothing_factor) {
extern __shared__ unsigned char shm[];
accscalar_t *loss_shm = reinterpret_cast<accscalar_t *>(shm);
loss_shm[threadIdx.x] = 0;
accscalar_t loss_acc = 0;
accscalar_t one = accscalar_t(1.0);
accscalar_t K = accscalar_t(2.0);
accscalar_t normalizer = one / static_cast<accscalar_t>(num_positives_sum[0]);
accscalar_t nn_norm, np_norm, pn_norm, pp_norm;
// *_norm is used for label smoothing only
if (SMOOTHING) {
nn_norm = one - smoothing_factor / K;
np_norm = smoothing_factor / K;
pn_norm = smoothing_factor - smoothing_factor / K;
pp_norm = one - smoothing_factor + smoothing_factor / K;
}
uint4 p_vec, grad_vec;
// Accumulate loss on each thread
for (int64_t i = (blockIdx.x * blockDim.x + threadIdx.x) * ILP;
i < num_examples * num_classes; i += gridDim.x * blockDim.x * ILP) {
int64_t idy = i / num_classes;
labelscalar_t y = cls_targets_at_level[idy];
int64_t base_yid = i % num_classes;
int64_t pos_idx = idy * num_classes + y;
p_vec = *(uint4 *)&cls_output[i];
// Skip ignored matches
if (y == -2) {
#pragma unroll
for (int j = 0; j < ILP; j++) {
*((scalar_t *)(&grad_vec) + j) = 0;
}
*(uint4 *)&partial_grad[i] = grad_vec;
continue;
}
#pragma unroll
for (int j = 0; j < ILP; j++) {
// Skip the pad classes
if (base_yid + j >= num_real_classes) {
*((scalar_t *)(&grad_vec) + j) = 0;
continue;
}
accscalar_t p = static_cast<accscalar_t>(*((scalar_t *)(&p_vec) + j));
accscalar_t exp_np = ::exp(-p);
accscalar_t exp_pp = ::exp(p);
accscalar_t sigma = one / (one + exp_np);
accscalar_t logee = (p >= 0) ? exp_np : exp_pp;
accscalar_t addee = (p >= 0) ? 0 : -p;
accscalar_t off_a = addee + ::log(one + logee);
// Negative matches
accscalar_t base = SMOOTHING ? nn_norm * p : p;
accscalar_t off_b = (SMOOTHING ? np_norm : 0) - sigma;
accscalar_t coeff_f1 = one - alpha;
accscalar_t coeff_f2 = sigma;
accscalar_t coeff_b1 = gamma;
accscalar_t coeff_b2 = one - sigma;
// Positive matches
if (y >= 0 && (i + j == pos_idx)) {
base = SMOOTHING ? pn_norm * p : 0;
off_b = (SMOOTHING ? pp_norm : one) - sigma;
coeff_f1 = alpha;
coeff_f2 = one - sigma;
coeff_b1 = -gamma;
coeff_b2 = sigma;
}
accscalar_t coeff_f = coeff_f1 * ::pow(coeff_f2, gamma);
accscalar_t coeff_b = coeff_b1 * coeff_b2;
accscalar_t loss_t = coeff_f * (base + off_a);
accscalar_t grad = coeff_f * (coeff_b * (base + off_a) - off_b);
// Delay the normalize of partial gradient by num_positives_sum to back
// propagation because scalar_t reduces precision. Focal loss is very
// sensitive to the small gradient. No worry on overflow here since
// gradient has relative smaller range than input.
loss_acc += loss_t;
*((scalar_t *)(&grad_vec) + j) = static_cast<scalar_t>(grad);
}
// This can't ensure to generate stg.128 and may be two stg.64.
*(uint4 *)&partial_grad[i] = grad_vec;
}
loss_shm[threadIdx.x] = loss_acc;
// Intra-CTA reduction
__syncthreads();
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
loss_shm[threadIdx.x] += loss_shm[threadIdx.x + s];
}
__syncthreads();
}
// Inter-CTA reduction
if (threadIdx.x == 0) {
loss_acc = loss_shm[0] * normalizer;
atomicAdd(loss, loss_acc);
}
}
template <int ILP, typename scalar_t, typename accscalar_t,
typename outscalar_t>
__global__ void focal_loss_backward_cuda_kernel(
scalar_t *partial_grad, const outscalar_t *__restrict__ grad_output,
const float *__restrict__ num_positives_sum, const uint64_t numel) {
int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * ILP;
accscalar_t normalizer = static_cast<accscalar_t>(grad_output[0]) /
static_cast<accscalar_t>(num_positives_sum[0]);
// The input is enforced to pad to use vector load, thus there's no need to
// check whether the last element of ILP can out of bound.
if (idx >= numel)
return;
uint4 grad_vec;
grad_vec = *(uint4 *)&partial_grad[idx];
#pragma unroll(ILP)
for (int i = 0; i < ILP; i++) {
auto grad = static_cast<accscalar_t>(*((scalar_t *)(&grad_vec) + i));
grad *= normalizer;
*((scalar_t *)(&grad_vec) + i) = static_cast<scalar_t>(grad);
}
*(uint4 *)&partial_grad[idx] = grad_vec;
}
std::vector<at::Tensor> focal_loss_forward_cuda(
const at::Tensor &cls_output, const at::Tensor &cls_targets_at_level,
const at::Tensor &num_positives_sum, const int64_t num_real_classes,
const float alpha, const float gamma, const float smoothing_factor) {
// Checks required for correctness
TORCH_INTERNAL_ASSERT(cls_output.size(-1) >= num_real_classes,
"Incorrect number of real classes.");
TORCH_INTERNAL_ASSERT(cls_targets_at_level.scalar_type() == at::kLong,
"Invalid label type.");
TORCH_INTERNAL_ASSERT(
(num_positives_sum.numel() == 1) &&
(num_positives_sum.scalar_type() == at::kFloat),
"Expect num_positives_sum to be a float32 tensor with only one element.");
TORCH_INTERNAL_ASSERT(cls_output.dim() == cls_targets_at_level.dim() + 1,
"Mis-matched dimensions between class output and label.");
for (int64_t i = 0; i < cls_targets_at_level.dim(); i++)
TORCH_INTERNAL_ASSERT(cls_output.size(i) == cls_targets_at_level.size(i),
"Mis-matched shape between class output and label.");
// Checks required for better performance
const int ILP = sizeof(uint4) / cls_output.element_size();
ASSERT_UINT4_ALIGNED(cls_output.data_ptr());
TORCH_INTERNAL_ASSERT(cls_output.size(-1) % ILP == 0,
"Pad number of classes first to take advantage of 128 bit load.");
TORCH_INTERNAL_ASSERT(num_real_classes >= ILP, "Too few classes.");
int64_t num_classes = cls_output.size(-1);
int64_t num_examples = cls_output.numel() / num_classes;
at::Tensor loss = at::zeros({}, cls_output.options().dtype(at::kFloat));
// Compute the incompelete gradient during fprop since most of the heavy
// functions of bprop are the same as fprop, thus trade memory for compute
// helps with focal loss.
at::Tensor partial_grad = at::empty_like(cls_output);
// The grid contains 2 CTA per SM, each CTA loop on input with stride till the
// last item.
cudaDeviceProp props;
cudaGetDeviceProperties(&props, at::cuda::current_device());
dim3 block(512);
dim3 grid(2 * props.multiProcessorCount);
// Specialize on label smoothing or not to reduce redundant operations
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (smoothing_factor == 0.0f) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
cls_output.scalar_type(), "focal_loss_fprop", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
using labelscalar_t = int64_t;
using outscalar_t = float;
const int ILP = sizeof(uint4) / sizeof(scalar_t);
focal_loss_forward_cuda_kernel<false, ILP, scalar_t, labelscalar_t,
accscalar_t, outscalar_t>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
loss.data_ptr<outscalar_t>(),
partial_grad.data_ptr<scalar_t>(),
cls_output.data_ptr<scalar_t>(),
cls_targets_at_level.data_ptr<labelscalar_t>(),
num_positives_sum.data_ptr<float>(), num_examples,
num_classes, num_real_classes, alpha, gamma,
smoothing_factor);
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
cls_output.scalar_type(), "focal_loss_fprop", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
using labelscalar_t = int64_t;
using outscalar_t = float;
const int ILP = sizeof(uint4) / sizeof(scalar_t);
focal_loss_forward_cuda_kernel<true, ILP, scalar_t, labelscalar_t,
accscalar_t, outscalar_t>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
loss.data_ptr<outscalar_t>(),
partial_grad.data_ptr<scalar_t>(),
cls_output.data_ptr<scalar_t>(),
cls_targets_at_level.data_ptr<labelscalar_t>(),
num_positives_sum.data_ptr<float>(), num_examples,
num_classes, num_real_classes, alpha, gamma,
smoothing_factor);
});
}
AT_CUDA_CHECK(cudaGetLastError());
return {loss, partial_grad};
}
at::Tensor focal_loss_backward_cuda(const at::Tensor &grad_output,
const at::Tensor &partial_grad,
const at::Tensor &num_positives_sum) {
// Each thread process ILP elements
const int ILP = sizeof(uint4) / partial_grad.element_size();
dim3 block(512);
dim3 grid((partial_grad.numel() + block.x * ILP - 1) / (block.x * ILP));
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
partial_grad.scalar_type(), "focal_loss_bprop", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
using outscalar_t = float;
const int ILP = sizeof(uint4) / sizeof(scalar_t);
focal_loss_backward_cuda_kernel<ILP, scalar_t, accscalar_t, outscalar_t>
<<<grid, block, 0, stream>>>(partial_grad.data_ptr<scalar_t>(),
grad_output.data_ptr<outscalar_t>(),
num_positives_sum.data_ptr<float>(),
partial_grad.numel());
});
AT_CUDA_CHECK(cudaGetLastError());
return partial_grad;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment