Commit 7eed2594 authored by flyingdown's avatar flyingdown
Browse files

128/256前向使用mmac指令重写gemm

parent 0816a70e
......@@ -80,7 +80,11 @@ void set_params(Fused_multihead_attention_fprop_params &params,
params.p_dropout = 1.f - p_dropout;
params.rp_dropout = 1.f / params.p_dropout;
TORCH_CHECK(p_dropout < 1.f);
#if defined (__HIP_PLATFORM_HCC__)
set_alpha(params.scale_dropout, params.rp_dropout, acc_type);
#else
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
#endif
}
std::vector<at::Tensor>
......@@ -94,24 +98,38 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties();
#if not defined(__HIP_PLATFORM_HCC__)
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
#endif
auto stream = at::cuda::getCurrentCUDAStream().stream();
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_training, is_nl);
int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80;
// int seq_len = 512;
// auto launch = &run_fmha_fp16_512_64_sm80;
// if( max_seq_len <= 128 ) {
// seq_len = 128;
// launch = &run_fmha_fp16_128_64_sm80;
// } else if( max_seq_len <= 256 ) {
// seq_len = 256;
// launch = &run_fmha_fp16_256_64_sm80;
// } else if( max_seq_len <= 384 ) {
// seq_len = 384;
// launch = &run_fmha_fp16_384_64_sm80;
// } else if( max_seq_len <= 512 ) {
// seq_len = 512;
// launch = &run_fmha_fp16_512_64_sm80;
// } else {
// TORCH_CHECK(false);
// }
int seq_len = 256;
auto launch = &run_fmha_fp16_256_64_sm80;
if( max_seq_len <= 128 ) {
seq_len = 128;
launch = &run_fmha_fp16_128_64_sm80;
} else if( max_seq_len <= 256 ) {
seq_len = 256;
launch = &run_fmha_fp16_256_64_sm80;
} else if( max_seq_len <= 384 ) {
seq_len = 384;
launch = &run_fmha_fp16_384_64_sm80;
} else if( max_seq_len <= 512 ) {
seq_len = 512;
launch = &run_fmha_fp16_512_64_sm80;
} else {
TORCH_CHECK(false);
}
......@@ -178,7 +196,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
return { ctx, s };
}
/*
std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
......@@ -351,11 +369,11 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num
dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total, num_chunks, stream);
return { dqkv, softmax, dkv };
}
}*/
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Fused Multi-head Self-attention for BERT";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)");
// m.def("bwd", &mha_bwd, "Backward pass");
// m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)");
}
......@@ -145,85 +145,57 @@ struct Fragment_b : public Fragment<uint16_t, 8> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined (__HIP_PLATFORM_HCC__)
__device__ inline void f16mulf16addf32(uint32_t & a, uint32_t & b, const float * c, float * d){
// uint32_t res = 0;
// asm volatile("v_pk_fma_f16 %0, %1,%2,%3" : "=v"(res) : "v"(a), "v"(b), "v"(res));
// __half * h = reinterpret_cast<__half*>(&res);
__half * ha = reinterpret_cast<__half*>(&a);
__half * hb = reinterpret_cast<__half*>(&b);
float C = *c, D = *d;
*d = *c + __half2float(ha[0])*__half2float(hb[0]) + __half2float(ha[1])*__half2float(hb[1]);
// if (threadIdx.x == 15) {
// printf("f16mulf16addf32 %i: A %f, %f, B %f, %f, RES %f, %f, %f, C %f, %f, D %f, %f \n", threadIdx.x,
// __half2float(ha[0]), __half2float(ha[1]),
// __half2float(hb[0]), __half2float(hb[1]),
// __half2float(ha[0])*__half2float(hb[0]),
// __half2float(ha[1])*__half2float(hb[1]),
// __half2float(ha[0])*__half2float(hb[0]) + __half2float(ha[1])*__half2float(hb[1]),
// C, *c, D, *d
// );
// }
}
struct Fragment_accumulator : public Fragment<float, 4> {
// row 8 col 4
__device__ inline void m16n8k16(const uint32_t * A, const uint32_t * B, /*const float * C,*/ float * D) {
int tid = threadIdx.x;
int baseId = tid / 32 * 32;
__shared__ uint32_t smem[256*6];
int base = tid*6;
__builtin_memcpy(smem+base, A, sizeof(uint32_t));
__builtin_memcpy(smem+(base+1), A+1, sizeof(uint32_t));
__builtin_memcpy(smem+(base+2), A+2, sizeof(uint32_t));
__builtin_memcpy(smem+(base+3), A+3, sizeof(uint32_t));
__builtin_memcpy(smem+(base+4), B, sizeof(uint32_t));
__builtin_memcpy(smem+(base+5), B+1, sizeof(uint32_t));
__syncthreads();
/* 站在D的视角,每个进程负责D数据的计算,从0线程开始循环,获取一行A和两列B
s为B矩阵的线程号
baseA为A的线程号
baseB0为当前线程获取B的第一列,baseB1为当前线程获取B的第二列
*/
int s = baseId+(tid%4)*8, e = s+4;
for (int i = s; i < e; ++i) {
// A[0]->i A[1]->i+1 A[2]->i+2 A[3]->i+3 B[0]->i+4 B[1]->i+5
int baseA = (tid-tid%4+i-s)*6; // 当前tid所处行的第一列的进程号+stride 再*6
int baseB0 = i*6, baseB1 = (i+4)*6;
f16mulf16addf32(smem[baseA], smem[baseB0+4], D, D);
f16mulf16addf32(smem[baseA+2], smem[baseB0+5], D, D);
f16mulf16addf32(smem[baseA], smem[baseB1+4], D+1, D+1);
f16mulf16addf32(smem[baseA+2], smem[baseB1+5], D+1, D+1);
f16mulf16addf32(smem[baseA+1], smem[baseB0+4], D+2, D+2);
f16mulf16addf32(smem[baseA+3], smem[baseB0+5], D+2, D+2);
f16mulf16addf32(smem[baseA+1], smem[baseB1+4], D+3, D+3);
f16mulf16addf32(smem[baseA+3], smem[baseB1+5], D+3, D+3);
}
// The base class.
using Base = Fragment<float, 8>;
// __half * a0 = reinterpret_cast<__half*>(smem+base);
// __half * a1 = reinterpret_cast<__half*>(smem+base+1);
// __half * a2 = reinterpret_cast<__half*>(smem+base+2);
// __half * a3 = reinterpret_cast<__half*>(smem+base+3);
// __half * b0 = reinterpret_cast<__half*>(smem+base+4);
// __half * b1 = reinterpret_cast<__half*>(smem+base+5);
// printf("m16n8k16 %i: \n A %f, %f, %f, %f, %f, %f, %f, %f \n B %f, %f, %f, %f \n D %f, %f, %f, %f \n", threadIdx.x,
// __half2float(a0[0]), __half2float(a0[1]),
// __half2float(a1[0]), __half2float(a1[1]),
// __half2float(a2[0]), __half2float(a2[1]),
// __half2float(a3[0]), __half2float(a3[1]),
// __half2float(b0[0]), __half2float(b0[1]),
// __half2float(b1[0]), __half2float(b1[1]),
// D[0], D[1], D[2], D[3]
// );
}
#endif
// Add two fragments.
template< typename Other_fragment_ >
inline __device__ void add(const Other_fragment_ &other) {
for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
this->elt(ii) = this->elt(ii) + other.elt(ii);
}
}
// Do the HMMA.
template< typename Layout_a, typename Layout_b >
inline __device__ void mma(const Fragment_a<Layout_a> &a,
const Fragment_b<Layout_b> &b) {
// const uint32_t * A = reinterpret_cast<const uint32_t*>(a.regs_);
// const uint32_t * B = reinterpret_cast<const uint32_t*>(b.regs_);
// float * D = reinterpret_cast<float*>(regs_);
// float regs[8];
// __builtin_memcpy(regs, D, sizeof(float)*8);
// m16n8k16(A, B, D);
// m16n8k16(A, B+2, D+4);
using v4f = __attribute__( (__vector_size__(4 * sizeof(float)) )) float;
v4f * rC = reinterpret_cast<v4f*>(regs_);
// float rA = reinterpret_cast<const float&>(a.reg(0));
// float rB = reinterpret_cast<const float&>(b.reg(0));
float rA0 = a.template elt_as<float>(0);
float rB0 = b.template elt_as<float>(0);
*rC = __builtin_amdgcn_mmac_f32_16x16x4f32(rA0, rB0, *rC, 0);
float rA1 = a.template elt_as<float>(1);
float rB1 = b.template elt_as<float>(1);
*rC = __builtin_amdgcn_mmac_f32_16x16x4f32(rA1, rB1, *rC, 0);
float rA2 = a.template elt_as<float>(2);
float rB2 = b.template elt_as<float>(2);
*rC = __builtin_amdgcn_mmac_f32_16x16x4f32(rA2, rB2, *rC, 0);
float rA3 = a.template elt_as<float>(3);
float rB3 = b.template elt_as<float>(3);
*rC = __builtin_amdgcn_mmac_f32_16x16x4f32(rA3, rB3, *rC, 0);
// if (blockIdx.x == 0) {
// printf("tid:%d rA0:%6.4f rB0:%6.4f rA1:%6.4f rB1:%6.4f rA2:%6.4f rB2:%6.4f rA3:%6.4f rB3:%6.4f c0:%6.4f c1:%6.4f c2:%6.4f c3:%6.4f\n", threadIdx.x,
// rA0, rB0, rA1, rB1, rA2, rB2, rA3, rB3, elt(0), elt(1), elt(2), elt(3));
// }
}
};
#else
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Fragment_accumulator : public Fragment<float, 8> {
......@@ -242,15 +214,6 @@ struct Fragment_accumulator : public Fragment<float, 8> {
template< typename Layout_a, typename Layout_b >
inline __device__ void mma(const Fragment_a<Layout_a> &a,
const Fragment_b<Layout_b> &b) {
#if defined (__HIP_PLATFORM_HCC__)
const uint32_t * A = reinterpret_cast<const uint32_t*>(a.regs_);
const uint32_t * B = reinterpret_cast<const uint32_t*>(b.regs_);
float * D = reinterpret_cast<float*>(regs_);
float regs[8];
__builtin_memcpy(regs, D, sizeof(float)*8);
m16n8k16(A, B, D);
m16n8k16(A, B+2, D+4);
#else
asm volatile( \
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
" {%0, %1, %2, %3}, \n" \
......@@ -269,11 +232,10 @@ struct Fragment_accumulator : public Fragment<float, 8> {
: "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7))
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
, "r"(b.reg(2)), "r"(b.reg(3)));
#endif
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Fragment, int M, int N >
......@@ -310,8 +272,24 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N])
#pragma unroll
for( int mi = 0; mi < M; ++mi ) {
// wangaq debug
// if (blockIdx.x == 0) {
// printf("a tid:%d mi:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, mi,
// a[mi].template elt_as<float>(0),
// a[mi].template elt_as<float>(1),
// a[mi].template elt_as<float>(2),
// a[mi].template elt_as<float>(3));
// }
#pragma unroll
for( int ni = 0; ni < N; ++ni ) {
// wangaq debug
// if (blockIdx.x == 0) {
// printf("b tid:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, ni,
// b[ni].template elt_as<float>(0),
// b[ni].template elt_as<float>(1),
// b[ni].template elt_as<float>(2),
// b[ni].template elt_as<float>(3));
// }
acc[mi][ni].mma(a[mi], b[ni]);
}
}
......@@ -340,7 +318,11 @@ struct Cta_tile_ {
// The number of warps per CTA.
enum { WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K };
// The number of threads per warp.
#if defined(__HIP_PLATFORM_HCC__)
enum { THREADS_PER_WARP = 64 };
#else
enum { THREADS_PER_WARP = 32 };
#endif
// The number of threads per CTA.
enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP };
};
......@@ -350,7 +332,11 @@ struct Cta_tile_ {
template<typename Cta_tile>
struct Hmma_tile {
// The number of elements computed with a single warp-MMA.
// #if defined(__HIP_PLATFORM_HCC__)
// enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 4 };
// #else
enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16 };
// #endif
// The number of elements computed with a single CTA-MMA.
enum {
......
......@@ -85,6 +85,20 @@ struct Gmem_tile_qkv {
// Store data to shared memory.
template< typename Smem_tile >
inline __device__ void commit(Smem_tile &smem_tile) {
// wangaq debug
// for( int ii = 0; ii < LDGS; ++ii ) {
// if (blockIdx.x == 0) {
// printf("commit tid:%d LDGS:%d ii:%d %f %f %f %f %f %f %f %f\n", threadIdx.x, LDGS, ii,
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[0]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[1]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[2]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[3]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[4]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[5]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[6]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[7]));
// }
// }
smem_tile.store(fetch_);
}
......@@ -105,6 +119,18 @@ struct Gmem_tile_qkv {
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
fct.load(ii, preds[ii]);
// wangaq debug
// if (blockIdx.x == 0) {
// printf("load tid:%d LDGS:%d ii:%d %f %f %f %f %f %f %f %f\n", threadIdx.x, LDGS, ii,
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[0]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[1]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[2]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[3]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[4]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[5]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[6]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[7]));
// }
}
}
......@@ -254,8 +280,13 @@ struct Gmem_tile_mma_sd {
// The mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
#if defined(__HIP_PLATFORM_HCC__)
// Each STG stores 16 elements.
enum { BYTES_PER_STG = BYTES_PER_ELEMENT * 4 };
#else
// Each STG stores 8 elements.
enum { BYTES_PER_STG = BYTES_PER_ELEMENT * 8 };
#endif
// The number of MMAs in the M dimension.
enum { MMAS_M = Mma_tile::MMAS_M };
// The number of MMAs in the N dimension.
......@@ -369,6 +400,14 @@ struct Gmem_tile_mma_s : public Base {
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
#if defined(__HIP_PLATFORM_HCC__)
uint2 dst;
dst.x = float2_to_half2(frag[ni][mi].reg(0), frag[ni][mi].reg(1));
dst.y = float2_to_half2(frag[ni][mi].reg(2), frag[ni][mi].reg(3));
if( mask.any_valid(mi, ni) ) {
Base::store(dst, mi, ni);
}
#else
uint4 dst;
dst.x = frag[ni][mi].reg(0);
dst.y = frag[ni][mi].reg(2);
......@@ -377,6 +416,7 @@ struct Gmem_tile_mma_s : public Base {
if( mask.any_valid(mi, ni) ) {
Base::store(dst, mi, ni);
}
#endif
}
}
}
......
......@@ -47,11 +47,28 @@ struct Mask {
// find the warp in the Cta tile
const int warp_n = (warp / Cta_tile::WARPS_M);
const int warp_m = (warp % Cta_tile::WARPS_M);
#if defined(__HIP_PLATFORM_HCC__)
// decompose warp into 16x16 tile
const int quad = lane % 16;
const int tid = lane / 16;
row = warp_m * 16 + quad;
col = warp_n * 16 + tid;
#else
// decompose warp into 8x4 tile
const int quad = lane / 4;
const int tid = (lane % 4) * 2;
row = warp_m * 16 + quad;
col = warp_n * 16 + tid;
#endif
}
inline __device__ bool is_valid(const int mi, const int ni, const int jj) const {
// jj iterate over the 1x4 fragment
const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + 4 * jj) < actual_seqlen;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen;
return col_valid;
// return row_valid && col_valid;
}
inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const {
......@@ -65,7 +82,11 @@ struct Mask {
//BERT Mask: if upper left is invalid, none are valid
inline __device__ bool any_valid(int mi, int ni) const {
#if defined(__HIP_PLATFORM_HCC__)
return is_valid(mi, ni, 0);
#else
return is_valid(mi, ni, 0, 0);
#endif
}
inline __device__ void load(int it) {
......
......@@ -69,7 +69,7 @@ struct Smem_tile_without_skews {
// The number of bytes per row without packing of rows.
enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 };
// The number of bytes per row -- we want at least 128B per row.
enum { BYTES_PER_ROW = Max<BYTES_PER_ROW_BEFORE_PACKING, 128>::VALUE };
enum { BYTES_PER_ROW = Max<BYTES_PER_ROW_BEFORE_PACKING, 128>::VALUE + 4 };
// The number of rows in shared memory (two rows may be packed into a single one).
enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW };
......@@ -117,6 +117,18 @@ struct Smem_tile_without_skews {
inline __device__ Smem_tile_without_skews(void *smem, int tidx)
: smem_(__nvvm_get_smem_pointer(smem)) {
#if defined (__HIP_PLATFORM_HCC__)
int smem_write_row = tidx / THREADS_PER_ROW;
int smem_write_col = tidx % THREADS_PER_ROW;
// The offset.
this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS;
// TODO: Why not merge it with the read offset?
this->smem_read_buffer_ = __shfl(0, 0);
this->smem_write_buffer_ = __shfl(0, 0);
#else
// The row written by a thread. See doc/mma_smem_layout.xlsx.
int smem_write_row = tidx / THREADS_PER_ROW;
......@@ -129,10 +141,6 @@ struct Smem_tile_without_skews {
this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS;
// TODO: Why not merge it with the read offset?
#if defined (__HIP_PLATFORM_HCC__)
this->smem_read_buffer_ = __shfl(0, 0);
this->smem_write_buffer_ = __shfl(0, 0);
#else
this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0);
this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0);
#endif
......@@ -259,6 +267,32 @@ struct Smem_tile_without_skews {
uint32_t smem_ptrs[N];
this->compute_store_pointers(smem_ptrs);
sts(smem_ptrs, data);
// wangaq debug
// if (blockIdx.x == 0) {
// extern __shared__ char smem[];
// uint32_t base = __nvvm_get_smem_pointer(smem);
// for (int ii = 0; ii < N; ++ii) {
// printf("data tid:%d N:%d ii:%d %f %f %f %f %f %f %f %f\n", threadIdx.x, N, ii,
// __half2float(reinterpret_cast<const __half*>(&data[ii])[0]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[1]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[2]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[3]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[4]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[5]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[6]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[7]));
// __half * smem_ptr = reinterpret_cast<__half*>(smem-base+smem_ptrs[ii]);
// printf("smem_ptrs tid:%d N:%d ii:%d %f %f %f %f %f %f %f %f\n", threadIdx.x, N, ii,
// __half2float(smem_ptr[0]),
// __half2float(smem_ptr[1]),
// __half2float(smem_ptr[2]),
// __half2float(smem_ptr[3]),
// __half2float(smem_ptr[4]),
// __half2float(smem_ptr[5]),
// __half2float(smem_ptr[6]),
// __half2float(smem_ptr[7]));
// }
// }
}
// Store to the tile in shared memory.
......@@ -408,17 +442,28 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(WARPS_M == 1);
#if defined (__HIP_PLATFORM_HCC__)
static_assert(WARPS_N == 2 || WARPS_N == 4);
#else
static_assert(WARPS_N == 4 || WARPS_N == 8);
#endif
static_assert(WARPS_K == 1);
static_assert(Base::ROWS_PER_XOR_PATTERN == 8);
// The row and column read by the thread.
#if defined(__HIP_PLATFORM_HCC__)
const int M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA;
int smem_read_row = (tidx % M_PER_MMA_PER_CTA);
int smem_read_col = ((tidx & 0x3f) / M_PER_MMA_PER_CTA);
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*(Base::BITS_PER_ELEMENT/8);
#else
int smem_read_row = (tidx & 0x0f);
int smem_read_col = (tidx & 0x07);
smem_read_col ^= (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;
#endif
}
// Rewind smem_read_offset for last LDS phase in main loop.
......@@ -437,6 +482,21 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
#if defined(__HIP_PLATFORM_HCC__)
int k_offset = 4 /* 指令的k维度为4 */ * (Base::BITS_PER_ELEMENT/8);
int ki_offset = ki * Mma_tile::K_PER_MMA * (Base::BITS_PER_ELEMENT/8);
ldsm(a[mi].reg(0), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 0 * k_offset + ki_offset);
ldsm(a[mi].reg(1), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 1 * k_offset + ki_offset);
ldsm(a[mi].reg(2), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 2 * k_offset + ki_offset);
ldsm(a[mi].reg(3), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 3 * k_offset + ki_offset);
// if (blockIdx.x == 0) {
// printf("smem a load tid:%d %f %f %f %f\n", threadIdx.x,
// a[mi].template elt_as<float>(0),
// a[mi].template elt_as<float>(1),
// a[mi].template elt_as<float>(2),
// a[mi].template elt_as<float>(3));
// }
#else
// Load using LDSM.M88.4.
uint4 tmp;
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
......@@ -446,8 +506,12 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
a[mi].reg(1) = tmp.y;
a[mi].reg(2) = tmp.z;
a[mi].reg(3) = tmp.w;
#endif
}
#if defined(__HIP_PLATFORM_HCC__)
// this->smem_read_offset_ = (ki+1) * Mma_tile::K_PER_MMA * (Base::BITS_PER_ELEMENT/8);
#else
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
......@@ -461,6 +525,7 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
} else if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
}
#endif
}
// Reset the read offset.
......@@ -593,9 +658,15 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(Base::ROWS_PER_XOR_PATTERN == 8);
static_assert(WARPS_M == 1);
static_assert(WARPS_N == 4 || WARPS_N == 8);
static_assert(WARPS_N == 2 || WARPS_N == 4 || WARPS_N == 8);
static_assert(WARPS_K == 1);
#if defined(__HIP_PLATFORM_HCC__)
const int N_PER_MMA = Mma_tile::N_PER_MMA;
int smem_read_row = (tidx % N_PER_MMA) + tidx / Cta_tile::THREADS_PER_WARP * N_PER_MMA;
int smem_read_col = (tidx / N_PER_MMA) % 4; // 指令的k维度为4
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*(Base::BITS_PER_ELEMENT/8);
#else
// The masks to select the warps.
const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
......@@ -610,6 +681,7 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
smem_read_col ^= (tidx & 0x08) / 8;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;
#endif
}
// Rewind smem_read_offset for last LDS phase in main loop.
......@@ -628,6 +700,21 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
#if defined(__HIP_PLATFORM_HCC__)
int k_offset = 4 /* 指令的k维度为4 */ * (Base::BITS_PER_ELEMENT/8);
int ki_offset = ki * Mma_tile::K_PER_MMA_PER_CTA * (Base::BITS_PER_ELEMENT/8);
ldsm(b[ni].reg(0), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 0 * k_offset + ki_offset);
ldsm(b[ni].reg(1), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 1 * k_offset + ki_offset);
ldsm(b[ni].reg(2), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 2 * k_offset + ki_offset);
ldsm(b[ni].reg(3), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 3 * k_offset + ki_offset);
// if (blockIdx.x == 0) {
// printf("smem b load tid:%d %f %f %f %f\n", threadIdx.x,
// b[ni].template elt_as<float>(0),
// b[ni].template elt_as<float>(1),
// b[ni].template elt_as<float>(2),
// b[ni].template elt_as<float>(3));
// }
#else
// Load using LDSM.M88.4.
uint4 tmp;
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
......@@ -637,8 +724,12 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
#endif
}
#if defined(__HIP_PLATFORM_HCC__)
// this->smem_read_offset_ = (ki+1) * Mma_tile::K_PER_MMA_PER_CTA * (Base::BITS_PER_ELEMENT/8);
#else
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
......@@ -652,6 +743,7 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
} else if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
}
#endif
}
// Reset the read offset.
......@@ -919,20 +1011,41 @@ struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K,
// The row/col read by the thread.
int read_row, read_col;
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 2 || Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
#if defined(__HIP_PLATFORM_HCC__)
const int K_PER_MMA = Mma_tile::K_PER_MMA;
read_row = (tidx / 16) % 4 + (tidx / Cta_tile::THREADS_PER_WARP) * K_PER_MMA;
read_col = tidx % 16;
// The shared memory offset.
this->smem_read_offset_ = read_row*Base::BYTES_PER_ROW + read_col*(Base::BITS_PER_ELEMENT/8);
#else
read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f);
read_col = (tidx & 0x07);
read_col ^= (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS;
#endif
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
#if defined(__HIP_PLATFORM_HCC__)
// Jump by 16 * #warps row.
// int row = ki * Cta_tile::K_PER_MMA_PER_CTA;
int col_offset = ni * Mma_tile::N_PER_MMA * (Base::BITS_PER_ELEMENT/8) ;
int k_offset = 4 /* 指令的k维度为4 */ * Base::BYTES_PER_ROW;
int ki_offset = ki * Mma_tile::K_PER_MMA_PER_CTA * Base::BYTES_PER_ROW;
ldsm(b[ni].reg(0), this->smem_ + this->smem_read_offset_ + col_offset + 0 * k_offset + ki_offset);
ldsm(b[ni].reg(1), this->smem_ + this->smem_read_offset_ + col_offset + 1 * k_offset + ki_offset);
ldsm(b[ni].reg(2), this->smem_ + this->smem_read_offset_ + col_offset + 2 * k_offset + ki_offset);
ldsm(b[ni].reg(3), this->smem_ + this->smem_read_offset_ + col_offset + 3 * k_offset + ki_offset);
#else
// Jump by 16 * #warps row.
int row = ki * 16 * Cta_tile::WARPS_K;
......@@ -950,6 +1063,7 @@ struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K,
} else {
assert(false); // Not implemented!
}
#endif
}
}
};
......@@ -1010,8 +1124,25 @@ struct Smem_tile_o {
// Get a 32-bit value for the shared memory address.
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 2 || Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
#if defined(__HIP_PLATFORM_HCC__)
int write_row = tidx % 16;
int write_col = (tidx / 16) % 4 + (tidx / 64) * Mma_tile::K_PER_MMA * Mma_tile::MMAS_N;
// Assemble the write pointer.
smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_ELEMENT;
// The element read by each thread.
int read_row = tidx / THREADS_PER_ROW;
int read_col = tidx % THREADS_PER_ROW;
// Take the XOR pattern into account for the column.
// read_col ^= 2 * (read_row & 0x7);
// Assemble the read pointer.
this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
#else
int write_row = (tidx & 0x1c) / 4;
int write_col = (tidx);
......@@ -1027,6 +1158,7 @@ struct Smem_tile_o {
// Assemble the read pointer.
this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
#endif
// Is that thread active on the last LDS?
if( HAS_INCOMPLETE_LDS ) {
......@@ -1036,6 +1168,34 @@ struct Smem_tile_o {
// Load the output fragments.
inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const {
#if defined(__HIP_PLATFORM_HCC__)
for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) {
// Load the elements before the reduction (split-K).
uint4 tmp[Cta_tile::WARPS_K];
#pragma unroll
for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) {
int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;
if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) {
fmha::lds(tmp[jj], this->smem_read_ + imm);
// wangaq debug
// float * xxx = reinterpret_cast<float*>(&tmp[0]);
// xxx[0] = threadIdx.x * 10.0 + 0;
// xxx[1] = threadIdx.x * 10.0 + 1;
// xxx[2] = threadIdx.x * 10.0 + 2;
// xxx[3] = threadIdx.x * 10.0 + 3;
// fmha::sts(this->smem_read_ + imm, tmp[0]);
}
}
// Perform the reduction.
out[ii] = tmp[0];
#pragma unroll
for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) {
out[ii] = fmha::fadd4(out[ii], tmp[jj]);
}
}
#else
#pragma unroll
for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) {
......@@ -1056,10 +1216,31 @@ struct Smem_tile_o {
out[ii] = fmha::fadd4(out[ii], tmp[jj]);
}
}
#endif
}
// Store the accumulators.
template <int M, int N>
inline __device__ void store(const Accumulator (&acc)[M][N], int mi) {
#if defined(__HIP_PLATFORM_HCC__)
for (int mi = 0; mi < M; ++mi) {
for (int ni = 0; ni < N; ++ni) {
int ni_offset = Mma_tile::K_PER_MMA * ni * BYTES_PER_ELEMENT;
// uint32_t tmp[4];
// reinterpret_cast<float&>(tmp[0]) = threadIdx.x * 100.0 + ni * 10 + 0;
// reinterpret_cast<float&>(tmp[1]) = threadIdx.x * 100.0 + ni * 10 + 1;
// reinterpret_cast<float&>(tmp[2]) = threadIdx.x * 100.0 + ni * 10 + 2;
// reinterpret_cast<float&>(tmp[3]) = threadIdx.x * 100.0 + ni * 10 + 3;
// fmha::sts(this->smem_write_ + ni_offset + 0 * BYTES_PER_ELEMENT, tmp[0]);
// fmha::sts(this->smem_write_ + ni_offset + 4 * BYTES_PER_ELEMENT, tmp[1]);
// fmha::sts(this->smem_write_ + ni_offset + 8 * BYTES_PER_ELEMENT, tmp[2]);
// fmha::sts(this->smem_write_ + ni_offset + 12 * BYTES_PER_ELEMENT, tmp[3]);
fmha::sts(this->smem_write_ + ni_offset + 0 * BYTES_PER_ELEMENT, acc[mi][ni].reg(0));
fmha::sts(this->smem_write_ + ni_offset + 4 * BYTES_PER_ELEMENT, acc[mi][ni].reg(1));
fmha::sts(this->smem_write_ + ni_offset + 8 * BYTES_PER_ELEMENT, acc[mi][ni].reg(2));
fmha::sts(this->smem_write_ + ni_offset + 12 * BYTES_PER_ELEMENT, acc[mi][ni].reg(3));
}
}
#else
enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA };
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
......@@ -1109,6 +1290,7 @@ struct Smem_tile_o {
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
}
#endif
}
};
......@@ -1177,11 +1359,11 @@ struct Smem_tile_mma_transposed : public Base {
enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };
enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N };
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
static_assert(WARPS_M == 1 && (WARPS_N == 2 || WARPS_N == 4 || WARPS_N == 8));
using Fragment = typename Base::Fragment;
inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) {
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
static_assert(WARPS_M == 1 && (WARPS_N == 2 || WARPS_N == 4 || WARPS_N == 8));
int read_row, read_col;
read_row = (tidx & 0x0f);
read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;
......@@ -1221,7 +1403,7 @@ struct Smem_tile_mma_epilogue : public Base {
static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M);
enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N };
static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
static_assert((WARPS_N == 2 || WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
using Acc = fmha::Fragment_accumulator;
......
......@@ -56,8 +56,13 @@ inline __device__ float apply_exp_(float x, float max) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int COLS> struct ReadType {};
#if defined(__HIP_PLATFORM_HCC__)
template<> struct ReadType<2> { using T = float;};
template<> struct ReadType<4> { using T = float2;};
#else
template<> struct ReadType<4> { using T = float;};
template<> struct ReadType<8> { using T = float2;};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -78,7 +83,11 @@ struct Smem_tile_reduce {
static constexpr int ROWS = WARPS_M * MMAS_M * 16;
static constexpr int COLS = WARPS_N;
#if defined(__HIP_PLATFORM_HCC__)
static_assert(COLS == 2 || COLS == 4);
#else
static_assert(COLS == 4 || COLS == 8);
#endif
static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
static constexpr int ELTS_PER_TILE = ROWS * COLS;
......@@ -93,6 +102,20 @@ struct Smem_tile_reduce {
__device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
#if defined(__HIP_PLATFORM_HCC__)
int lane = tidx % Cta_tile::THREADS_PER_WARP;
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int warp_m = warp % WARPS_M;
int warp_n = warp / WARPS_M;
qid_ = lane / 16; // 前16个线程才能写入
int qp = lane % 16;
const int col = warp_n;
smem_write_ = &smem_[warp_m * ELTS_PER_TILE + qp * WARPS_N + col];
smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * ELTS_PER_TILE + qp * 2 + qid_ / WARPS_N];
#else
int lane = tidx % 32;
int warp = tidx / 32;
......@@ -107,9 +130,19 @@ struct Smem_tile_reduce {
const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN);
smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col];
smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_];
#endif
}
__device__ inline void store(float (&frag)[MMAS_M]) {
if( qid_ == 0 ) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
*smem_write_ = frag[mi];
}
}
}
__device__ inline void store(float (&frag)[2 * MMAS_M]) {
if( qid_ == 0 ) {
#pragma unroll
......@@ -121,6 +154,13 @@ struct Smem_tile_reduce {
}
}
__device__ inline void load(read_t (&frag)[MMAS_M]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
frag[mi] = *smem_read_;
}
}
__device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
......@@ -188,6 +228,29 @@ struct Softmax_base {
smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];
}
#if defined(__HIP_PLATFORM_HCC__)
template<typename Mask>
inline __device__ void apply_mask(const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
#pragma unroll
for( int jj = 0; jj < 4; ++jj ) {
if( !mask.is_valid(mi, ni, jj) ) {
elt_[mi][4 * ni + jj] = -INFINITY;
}
}
// wangaq debug
// if (blockIdx.x == 0) {
// printf("apply_mask tid:%d mi:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, mi, ni,
// this->elt_[mi][4 * ni + 0], this->elt_[mi][4 * ni + 1], this->elt_[mi][4 * ni + 2], this->elt_[mi][4 * ni + 3]);
// }
}
}
}
#else
template<typename Mask>
inline __device__ void apply_mask(const Mask &mask) {
#pragma unroll
......@@ -206,6 +269,26 @@ struct Softmax_base {
}
}
}
#endif
// Apply the exp to all the elements.
inline __device__ void apply_exp(const float (&max)[MMAS_M]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]);
}
// wangaq debug
// if (blockIdx.x == 0) {
// for( int ni = 0; ni < MMAS_N; ++ni ) {
// printf("apply_exp tid:%d mi:%d ni:%d max:%6.4f %f %f %f %f\n", threadIdx.x, mi, ni, max[mi],
// this->elt_[mi][4 * ni + 0], this->elt_[mi][4 * ni + 1], this->elt_[mi][4 * ni + 2], this->elt_[mi][4 * ni + 3]);
// }
// }
}
}
// Apply the exp to all the elements.
inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) {
......@@ -218,6 +301,33 @@ struct Softmax_base {
}
}
// Scale all the elements.
inline __device__ void scale(const float (&sum)[MMAS_M]) {
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
float inv_sum[MMAS_M];
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi];
}
// Update the values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
elt_[mi][ni] *= inv_sum[mi];
}
// wangaq debug
// if (blockIdx.x == 0) {
// for( int ni = 0; ni < MMAS_N; ++ni ) {
// printf("scale tid:%d mi:%d ni:%d sum:%6.4f inv_sum:%6.4f %f %f %f %f\n", threadIdx.x, mi, ni, sum[mi], inv_sum[mi],
// this->elt_[mi][4 * ni + 0], this->elt_[mi][4 * ni + 1], this->elt_[mi][4 * ni + 2], this->elt_[mi][4 * ni + 3]);
// }
// }
}
}
// Scale all the elements.
inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
......@@ -244,7 +354,11 @@ struct Softmax_base {
// The current thread index.
int tidx_;
// The elements.
#if defined(__HIP_PLATFORM_HCC__)
float elt_[MMAS_M][MMAS_N * 4];
#else
float elt_[MMAS_M * 2][MMAS_N * 4];
#endif
};
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -290,7 +404,20 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
for( int mi = 0; mi < M; ++mi ) {
#pragma unroll
for( int ki = 0; ki < K; ++ki ) {
#if defined(__HIP_PLATFORM_HCC__)
dst[ki][mi].template elt_as<float>(0) = this->elt_[mi][4 * ki + 0];
dst[ki][mi].template elt_as<float>(1) = this->elt_[mi][4 * ki + 1];
dst[ki][mi].template elt_as<float>(2) = this->elt_[mi][4 * ki + 2];
dst[ki][mi].template elt_as<float>(3) = this->elt_[mi][4 * ki + 3];
// wangaq debug
// if (blockIdx.x == 0) {
// printf("pack tid:%d mi:%d ki:%d %6.4f %6.4f %6.4f %6.4f -> %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, mi, ki,
// this->elt_[mi][4 * ki + 0], this->elt_[mi][4 * ki + 1], this->elt_[mi][4 * ki + 2], this->elt_[mi][4 * ki + 3],
// dst[ki][mi].template elt_as<float>(0), dst[ki][mi].template elt_as<float>(1), dst[ki][mi].template elt_as<float>(2), dst[ki][mi].template elt_as<float>(3));
// // printf("pack tid:%d mi:%d ki:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, mi, ki,
// // dst[ki][mi].template elt_as<float>(0), dst[ki][mi].template elt_as<float>(1), dst[ki][mi].template elt_as<float>(2), dst[ki][mi].template elt_as<float>(3));
// }
#else
// 1st row - 4 elements per row.
float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0];
float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1];
......@@ -308,6 +435,7 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);
dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);
dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);
#endif
}
}
}
......@@ -340,6 +468,19 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
#if defined(__HIP_PLATFORM_HCC__)
this->elt_[mi][4 * ni + 0] = acc[mi][ni].elt(0);
this->elt_[mi][4 * ni + 1] = acc[mi][ni].elt(1);
this->elt_[mi][4 * ni + 2] = acc[mi][ni].elt(2);
this->elt_[mi][4 * ni + 3] = acc[mi][ni].elt(3);
// wangaq debug
// if (blockIdx.x == 0) {
// printf("unpack_noscale tid:%d mi:%d ni:%d %6.4f %6.4f %6.4f %6.4f -> %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, mi, ni,
// acc[mi][ni].template elt_as<float>(0), acc[mi][ni].template elt_as<float>(1), acc[mi][ni].template elt_as<float>(2), acc[mi][ni].template elt_as<float>(3),
// this->elt_[mi][4 * ni + 0], this->elt_[mi][4 * ni + 1], this->elt_[mi][4 * ni + 2], this->elt_[mi][4 * ni + 3]);
// }
#else
// 1st row - 4 elements per row.
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0);
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1);
......@@ -350,11 +491,39 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3);
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6);
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7);
#endif
}
}
}
template<typename Operator>
__device__ inline void reduce_(float (&frag)[MMAS_M], Operator &op, Smem_tile_red & smem_red) {
for( int mi = 0; mi < MMAS_M; mi++ ) {
frag[mi] = this->elt_[mi][0];
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
frag[mi] = op(frag[mi], this->elt_[mi][ni]);
}
}
quad_reduce(frag, frag, op);
smem_red.store(frag);
__syncthreads();
typename Smem_tile_red::read_t tmp[MMAS_M];
smem_red.load(tmp);
binary_allreduce(frag, tmp, op);
}
__device__ inline void reduce_max(float (&frag)[MMAS_M]){
MaxOp<float> max;
reduce_(frag, max, smem_max_);
}
__device__ inline void reduce_sum(float (&frag)[MMAS_M]){
SumOp<float> sum;
reduce_(frag, sum, smem_sum_);
}
template<typename Operator>
__device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) {
......
......@@ -294,6 +294,13 @@ static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) {
return c;
}
static inline __device__ uint32_t fmul(uint32_t a, uint32_t b) {
uint32_t c;
float tmp = reinterpret_cast<float&>(a) * reinterpret_cast<float&>(b);
c = reinterpret_cast<uint32_t&>(tmp);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hmul4(uint2 a, uint2 b) {
......@@ -346,6 +353,15 @@ static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {
#endif
return res;
}
static inline __device__ uint32_t frelu(uint32_t x, uint32_t lb = 0) {
uint32_t res;
float tmp_x = reinterpret_cast<float&>(x);
tmp_x = tmp_x > lb ? tmp_x : lb;
__builtin_memcpy(&res, &tmp_x, sizeof(uint32_t));
return res;
}
static inline __device__ uint32_t habs2(uint32_t x) {
uint32_t res;
#if defined (__HIP_PLATFORM_HCC__)
......@@ -905,10 +921,19 @@ inline __device__ void lds(uint4 &dst, uint32_t ptr) {
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint32_t &dst, uint32_t ptr) {
#if defined (__HIP_PLATFORM_HCC__)
extern __shared__ char smem[];
uint32_t base = __nvvm_get_smem_pointer(smem);
float tmp = __half2float(*(__half*)(smem-base+ptr));
__builtin_memcpy(&dst, &tmp, sizeof(uint32_t));
// if (blockIdx.x == 0)
// printf("ldsm tid:%d tmp:%f\n", threadIdx.x, tmp);
#else
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n"
: "=r"(dst) : "r"(ptr));
#endif
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -1114,10 +1139,32 @@ inline __device__ void sts(uint32_t ptr, uint2 val) {
inline __device__ void sts(uint32_t ptr, uint4 val) {
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr) , "v"(val.x));
asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+4) , "v"(val.y));
asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+8) , "v"(val.z));
asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+12) , "v"(val.w));
// asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr) , "v"(val.x));
// asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+4) , "v"(val.y));
// asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+8) , "v"(val.z));
// asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+12) , "v"(val.w));
extern __shared__ char smem[];
uint32_t base = __nvvm_get_smem_pointer(smem);
__builtin_memcpy(smem-base+ptr, &val, sizeof(uint4));
// if (blockIdx.x == 0) {
// printf("sts tid:%d %f %f %f %f %f %f %f %f -> %f %f %f %f %f %f %f %f\n", threadIdx.x,
// __half2float(reinterpret_cast<half*>(&val)[0]),
// __half2float(reinterpret_cast<half*>(&val)[1]),
// __half2float(reinterpret_cast<half*>(&val)[2]),
// __half2float(reinterpret_cast<half*>(&val)[3]),
// __half2float(reinterpret_cast<half*>(&val)[4]),
// __half2float(reinterpret_cast<half*>(&val)[5]),
// __half2float(reinterpret_cast<half*>(&val)[6]),
// __half2float(reinterpret_cast<half*>(&val)[7]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[0]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[1]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[2]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[3]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[4]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[5]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[6]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[7]));
// }
#else
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n"
:
......@@ -1190,7 +1237,7 @@ struct Allreduce {
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
#if defined (__HIP_PLATFORM_HCC__)
x = op(x, __shfl_xor(uint32_t(-1), x, OFFSET));
x = op(x, __shfl_xor(x, OFFSET));
#else
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
#endif
......@@ -1221,8 +1268,8 @@ __device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
#if defined (__HIP_PLATFORM_HCC__)
dst[mi] = op(dst[mi], __shfl_down(dst[mi], 2));
dst[mi] = op(dst[mi], __shfl_down(dst[mi], 1));
dst[mi] = op(dst[mi], __shfl_down(dst[mi], 32));
dst[mi] = op(dst[mi], __shfl_down(dst[mi], 16));
#else
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));
......@@ -1267,4 +1314,34 @@ __device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operato
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined (__HIP_PLATFORM_HCC__)
template<int THREADS>
struct Allreduce32 {
static_assert(THREADS == 64);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
return op(x, __shfl_xor(x, 32));
}
};
template<typename Operator, int M>
__device__ inline void binary_allreduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = Allreduce32<64>::run(dst[mi], op);
}
}
template<typename Operator, int M>
__device__ inline void binary_allreduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
binary_allreduce(dst, tmp, op);
}
#endif
} // namespace fmha
......@@ -28,7 +28,11 @@
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#if defined(__HIP_PLATFORM_HCC__)
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 2, 0x08u>;
#else
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
#endif
extern "C" __global__ void fmha_dgrad_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
......
......@@ -28,7 +28,11 @@
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#if defined(__HIP_PLATFORM_HCC__)
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 2, 0x08u>;
#else
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
#endif
template<bool Is_training>
__global__
......
......@@ -28,7 +28,11 @@
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#if defined(__HIP_PLATFORM_HCC__)
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 2, 0x08u>;
#else
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
#endif
template<bool Is_training>
__global__
......
......@@ -28,7 +28,11 @@
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#if defined(__HIP_PLATFORM_HCC__)
using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 2, 0x18u>;
#else
using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 4, 0x18u>;
#endif
template<bool Is_training>
__global__
......
......@@ -28,7 +28,11 @@
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#if defined(__HIP_PLATFORM_HCC__)
using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 4, 0x00u>;
#else
using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x00u>;
#else
template<bool Is_training>
__global__
......
......@@ -131,6 +131,8 @@ std::tuple<int , int, int, int, int, int> work_dist(const int total_ctas, const
const int num_full_heads = heads_total / total_ctas;
const int heads_last_wave = heads_total % total_ctas;
// printf("total_ctas:%d heads_total:%d num_full_heads:%d heads_last_wave:%d N:%d M:%d steps:%d\n",
// total_ctas, heads_total, num_full_heads, heads_last_wave, Kernel_traits::Cta_tile_p::N, Kernel_traits::Cta_tile_p::M, STEPS_PER_HEAD);
int num_main_groups = 0;
int main_steps = 0;
......
......@@ -512,15 +512,15 @@ if "--fmha" in sys.argv:
CUDAExtension(name='fmhalib',
sources=[
'apex/contrib/csrc/fmha/fmha_api.cpp',
'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu',
# 'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu',
# 'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu',
# 'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu',
# 'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu',
# 'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu',
# 'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu',
],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment