Commit 05087332 authored by Tri Dao's avatar Tri Dao
Browse files

Remove softmax fp16 max

parent 14dc326e
...@@ -58,12 +58,6 @@ inline __device__ float apply_exp_(float x, float max) { ...@@ -58,12 +58,6 @@ inline __device__ float apply_exp_(float x, float max) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ __half2 apply_exp_(__half2 x, __half2 max) {
return h2exp(__hsub2(x, max));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float apply_exp2_(float x, float max) { inline __device__ float apply_exp2_(float x, float max) {
return exp2f(x - max); return exp2f(x - max);
// With fast-math, this produces the same PTX instruction as the assembly below // With fast-math, this produces the same PTX instruction as the assembly below
...@@ -75,17 +69,9 @@ inline __device__ float apply_exp2_(float x, float max) { ...@@ -75,17 +69,9 @@ inline __device__ float apply_exp2_(float x, float max) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ __half2 apply_exp2_(__half2 x, __half2 max) { template<int COLS> struct ReadType {};
return h2exp2(__hsub2(x, max)); template<> struct ReadType<4> { using T = float;};
} template<> struct ReadType<8> { using T = float2;};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int COLS, bool half> struct ReadType {};
template<> struct ReadType<4, false> { using T = float;};
template<> struct ReadType<8, false> { using T = float2;};
template<> struct ReadType<4, true> { using T = __half2;};
template<> struct ReadType<8, true> { using T = float2;};
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -118,8 +104,7 @@ struct Smem_tile_reduce { ...@@ -118,8 +104,7 @@ struct Smem_tile_reduce {
static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS; static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS;
static_assert(LOOPS == 1); static_assert(LOOPS == 1);
using read_t = typename ReadType<COLS, /*half=*/false>::T; using read_t = typename ReadType<COLS>::T;
using read_half_t = typename ReadType<COLS, /*half=*/true>::T;
__device__ inline Smem_tile_reduce(float *smem_, const int tidx) { __device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
...@@ -152,17 +137,6 @@ struct Smem_tile_reduce { ...@@ -152,17 +137,6 @@ struct Smem_tile_reduce {
} }
} }
__device__ inline void store(__half2 (&frag)[MMAS_M]) {
__half2 *smem_write_half_ = reinterpret_cast<__half2 *>(smem_write_);
if( qid_ == 0 ) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * WARPS_N;
smem_write_half_[offset + 0 * 8 * WARPS_N] = frag[mi];
}
}
}
__device__ inline void load(read_t (&frag)[2 * MMAS_M]) { __device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
#pragma unroll #pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) { for( int mi = 0; mi < MMAS_M; mi++ ) {
...@@ -172,15 +146,6 @@ struct Smem_tile_reduce { ...@@ -172,15 +146,6 @@ struct Smem_tile_reduce {
} }
} }
__device__ inline void load(read_half_t (&frag)[MMAS_M]) {
read_half_t *smem_read_half_ = reinterpret_cast<read_half_t *>(smem_read_);
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * 4;
frag[mi] = smem_read_half_[offset + 0 * 8 * 4];
}
}
__device__ inline void load_row(read_t (&frag)[MMAS_M], int row) { __device__ inline void load_row(read_t (&frag)[MMAS_M], int row) {
#pragma unroll #pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) { for( int mi = 0; mi < MMAS_M; mi++ ) {
...@@ -304,29 +269,6 @@ struct Softmax_base { ...@@ -304,29 +269,6 @@ struct Softmax_base {
} }
} }
// Apply the exp to all the elements.
inline __device__ void apply_exp(const __half2 (&max)[MMAS_M]) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
constexpr float kLog2e = M_LOG2E;
const float2 max_f = __half22float2(max[mi]);
const float max0_log2e = max_f.x * kLog2e, max1_log2e = max_f.y * kLog2e;
#pragma unroll
for (int ni = 0; ni < MMAS_N * 4; ++ni) {
float2 elt = __half22float2(elt_half_[mi][ni]);
elt_[mi * 2 + 0][ni] = apply_exp2_(elt.x * kLog2e, max0_log2e);
elt_[mi * 2 + 1][ni] = apply_exp2_(elt.y * kLog2e, max1_log2e);
// __half2 out = apply_exp_(elt_half_[mi][ni], max[mi]);
// float2 outf = __half22float2(out);
// elt_[mi * 2 + 0][ni] = outf.x;
// elt_[mi * 2 + 1][ni] = outf.y;
}
}
}
// Apply the exp to all the elements. // Apply the exp to all the elements.
template <bool max_in_base2=false> template <bool max_in_base2=false>
inline __device__ void apply_exp_col(const float (&max)[MMAS_N * 4]) { inline __device__ void apply_exp_col(const float (&max)[MMAS_N * 4]) {
...@@ -527,7 +469,6 @@ struct Softmax_base { ...@@ -527,7 +469,6 @@ struct Softmax_base {
int tidx_; int tidx_;
// The elements. // The elements.
float elt_[MMAS_M * 2][MMAS_N * 4]; float elt_[MMAS_M * 2][MMAS_N * 4];
__half2 elt_half_[MMAS_M][MMAS_N * 4];
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -638,34 +579,6 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> { ...@@ -638,34 +579,6 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
} }
} }
// Scale FP32 fragments
template <typename Mask>
inline __device__ void unpack_noscale_half_and_apply_mask(const Accumulator (&acc)[MMAS_M][MMAS_N],
const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
float tmp[2][4];
// 1st row - 4 elements per row.
tmp[0][0] = mask.is_valid(mi, ni, 0, 0) ? acc[mi][ni].elt(0) : -INFINITY;
tmp[0][1] = mask.is_valid(mi, ni, 0, 1) ? acc[mi][ni].elt(1) : -INFINITY;
tmp[0][2] = mask.is_valid(mi, ni, 0, 2) ? acc[mi][ni].elt(4) : -INFINITY;
tmp[0][3] = mask.is_valid(mi, ni, 0, 3) ? acc[mi][ni].elt(5) : -INFINITY;
// 2nd row - 4 elements per row.
tmp[1][0] = mask.is_valid(mi, ni, 1, 0) ? acc[mi][ni].elt(2) : -INFINITY;
tmp[1][1] = mask.is_valid(mi, ni, 1, 1) ? acc[mi][ni].elt(3) : -INFINITY;
tmp[1][2] = mask.is_valid(mi, ni, 1, 2) ? acc[mi][ni].elt(6) : -INFINITY;
tmp[1][3] = mask.is_valid(mi, ni, 1, 3) ? acc[mi][ni].elt(7) : -INFINITY;
this->elt_half_[mi][4 * ni + 0] = __floats2half2_rn(tmp[0][0], tmp[1][0]);
this->elt_half_[mi][4 * ni + 1] = __floats2half2_rn(tmp[0][1], tmp[1][1]);
this->elt_half_[mi][4 * ni + 2] = __floats2half2_rn(tmp[0][2], tmp[1][2]);
this->elt_half_[mi][4 * ni + 3] = __floats2half2_rn(tmp[0][3], tmp[1][3]);
}
}
}
template<bool zero_init=true, typename Operator> template<bool zero_init=true, typename Operator>
__device__ inline void thread_reduce_(float (&frag)[2 * MMAS_M], Operator &op) { __device__ inline void thread_reduce_(float (&frag)[2 * MMAS_M], Operator &op) {
#pragma unroll #pragma unroll
...@@ -678,18 +591,6 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> { ...@@ -678,18 +591,6 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
} }
} }
template<typename Operator>
__device__ inline void thread_reduce_(__half2 (&frag)[MMAS_M], Operator &op) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
frag[mi] = this->elt_half_[mi][0];
#pragma unroll
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
frag[mi] = op(frag[mi], this->elt_half_[mi][ni]);
}
}
}
template<bool zero_init=true, typename Operator> template<bool zero_init=true, typename Operator>
__device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) { __device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) {
thread_reduce_<zero_init>(frag, op); thread_reduce_<zero_init>(frag, op);
...@@ -701,29 +602,13 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> { ...@@ -701,29 +602,13 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
quad_allreduce(frag, tmp, op); quad_allreduce(frag, tmp, op);
} }
template<typename Operator>
__device__ inline void reduce_(__half2 (&frag)[MMAS_M], Operator &op, Smem_tile_red & smem_red) {
thread_reduce_(frag, op);
quad_reduce(frag, frag, op);
smem_red.store(frag);
__syncthreads();
typename Smem_tile_red::read_half_t tmp[MMAS_M];
smem_red.load(tmp);
quad_allreduce(frag, tmp, op);
}
template<bool zero_init=true> template<bool zero_init=true>
__device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){ __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){
MaxOp<float> max; MaxOp<float> max;
reduce_<zero_init>(frag, max, smem_max_); reduce_<zero_init>(frag, max, smem_max_);
} }
__device__ inline void reduce_max(__half2 (&frag)[MMAS_M]){ __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
MaxOp<__half2> max;
reduce_(frag, max, smem_max_);
}
__device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
SumOp<float> sum; SumOp<float> sum;
reduce_(frag, sum, smem_sum_); reduce_(frag, sum, smem_sum_);
} }
......
...@@ -1024,11 +1024,6 @@ struct MaxOp<float> { ...@@ -1024,11 +1024,6 @@ struct MaxOp<float> {
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); } __device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
}; };
template <>
struct MaxOp<__half2> {
__device__ inline __half2 operator()(__half2 const &x, __half2 const &y) { return __hmax2(x, y); }
};
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment