Commit 7eed2594 authored by flyingdown's avatar flyingdown
Browse files

128/256前向使用mmac指令重写gemm

parent 0816a70e
...@@ -80,7 +80,11 @@ void set_params(Fused_multihead_attention_fprop_params &params, ...@@ -80,7 +80,11 @@ void set_params(Fused_multihead_attention_fprop_params &params,
params.p_dropout = 1.f - p_dropout; params.p_dropout = 1.f - p_dropout;
params.rp_dropout = 1.f / params.p_dropout; params.rp_dropout = 1.f / params.p_dropout;
TORCH_CHECK(p_dropout < 1.f); TORCH_CHECK(p_dropout < 1.f);
#if defined (__HIP_PLATFORM_HCC__)
set_alpha(params.scale_dropout, params.rp_dropout, acc_type);
#else
set_alpha(params.scale_dropout, params.rp_dropout, data_type); set_alpha(params.scale_dropout, params.rp_dropout, data_type);
#endif
} }
std::vector<at::Tensor> std::vector<at::Tensor>
...@@ -94,24 +98,38 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot ...@@ -94,24 +98,38 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
#if not defined(__HIP_PLATFORM_HCC__)
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0); TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
#endif
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_training, is_nl); Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_training, is_nl);
int seq_len = 512; // int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80; // auto launch = &run_fmha_fp16_512_64_sm80;
// if( max_seq_len <= 128 ) {
// seq_len = 128;
// launch = &run_fmha_fp16_128_64_sm80;
// } else if( max_seq_len <= 256 ) {
// seq_len = 256;
// launch = &run_fmha_fp16_256_64_sm80;
// } else if( max_seq_len <= 384 ) {
// seq_len = 384;
// launch = &run_fmha_fp16_384_64_sm80;
// } else if( max_seq_len <= 512 ) {
// seq_len = 512;
// launch = &run_fmha_fp16_512_64_sm80;
// } else {
// TORCH_CHECK(false);
// }
int seq_len = 256;
auto launch = &run_fmha_fp16_256_64_sm80;
if( max_seq_len <= 128 ) { if( max_seq_len <= 128 ) {
seq_len = 128; seq_len = 128;
launch = &run_fmha_fp16_128_64_sm80; launch = &run_fmha_fp16_128_64_sm80;
} else if( max_seq_len <= 256 ) { } else if( max_seq_len <= 256 ) {
seq_len = 256; seq_len = 256;
launch = &run_fmha_fp16_256_64_sm80; launch = &run_fmha_fp16_256_64_sm80;
} else if( max_seq_len <= 384 ) {
seq_len = 384;
launch = &run_fmha_fp16_384_64_sm80;
} else if( max_seq_len <= 512 ) {
seq_len = 512;
launch = &run_fmha_fp16_512_64_sm80;
} else { } else {
TORCH_CHECK(false); TORCH_CHECK(false);
} }
...@@ -178,7 +196,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot ...@@ -178,7 +196,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
return { ctx, s }; return { ctx, s };
} }
/*
std::vector<at::Tensor> std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
...@@ -351,11 +369,11 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num ...@@ -351,11 +369,11 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num
dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total, num_chunks, stream); dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total, num_chunks, stream);
return { dqkv, softmax, dkv }; return { dqkv, softmax, dkv };
} }*/
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Fused Multi-head Self-attention for BERT"; m.doc() = "Fused Multi-head Self-attention for BERT";
m.def("fwd", &mha_fwd, "Forward pass"); m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass"); // m.def("bwd", &mha_bwd, "Backward pass");
m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)"); // m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)");
} }
...@@ -145,85 +145,57 @@ struct Fragment_b : public Fragment<uint16_t, 8> { ...@@ -145,85 +145,57 @@ struct Fragment_b : public Fragment<uint16_t, 8> {
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined (__HIP_PLATFORM_HCC__) #if defined (__HIP_PLATFORM_HCC__)
__device__ inline void f16mulf16addf32(uint32_t & a, uint32_t & b, const float * c, float * d){ struct Fragment_accumulator : public Fragment<float, 4> {
// uint32_t res = 0;
// asm volatile("v_pk_fma_f16 %0, %1,%2,%3" : "=v"(res) : "v"(a), "v"(b), "v"(res));
// __half * h = reinterpret_cast<__half*>(&res);
__half * ha = reinterpret_cast<__half*>(&a);
__half * hb = reinterpret_cast<__half*>(&b);
float C = *c, D = *d;
*d = *c + __half2float(ha[0])*__half2float(hb[0]) + __half2float(ha[1])*__half2float(hb[1]);
// if (threadIdx.x == 15) {
// printf("f16mulf16addf32 %i: A %f, %f, B %f, %f, RES %f, %f, %f, C %f, %f, D %f, %f \n", threadIdx.x,
// __half2float(ha[0]), __half2float(ha[1]),
// __half2float(hb[0]), __half2float(hb[1]),
// __half2float(ha[0])*__half2float(hb[0]),
// __half2float(ha[1])*__half2float(hb[1]),
// __half2float(ha[0])*__half2float(hb[0]) + __half2float(ha[1])*__half2float(hb[1]),
// C, *c, D, *d
// );
// }
}
// row 8 col 4 // The base class.
__device__ inline void m16n8k16(const uint32_t * A, const uint32_t * B, /*const float * C,*/ float * D) { using Base = Fragment<float, 8>;
int tid = threadIdx.x;
int baseId = tid / 32 * 32;
__shared__ uint32_t smem[256*6];
int base = tid*6;
__builtin_memcpy(smem+base, A, sizeof(uint32_t));
__builtin_memcpy(smem+(base+1), A+1, sizeof(uint32_t));
__builtin_memcpy(smem+(base+2), A+2, sizeof(uint32_t));
__builtin_memcpy(smem+(base+3), A+3, sizeof(uint32_t));
__builtin_memcpy(smem+(base+4), B, sizeof(uint32_t));
__builtin_memcpy(smem+(base+5), B+1, sizeof(uint32_t));
__syncthreads();
/* 站在D的视角,每个进程负责D数据的计算,从0线程开始循环,获取一行A和两列B
s为B矩阵的线程号
baseA为A的线程号
baseB0为当前线程获取B的第一列,baseB1为当前线程获取B的第二列
*/
int s = baseId+(tid%4)*8, e = s+4;
for (int i = s; i < e; ++i) {
// A[0]->i A[1]->i+1 A[2]->i+2 A[3]->i+3 B[0]->i+4 B[1]->i+5
int baseA = (tid-tid%4+i-s)*6; // 当前tid所处行的第一列的进程号+stride 再*6
int baseB0 = i*6, baseB1 = (i+4)*6;
f16mulf16addf32(smem[baseA], smem[baseB0+4], D, D);
f16mulf16addf32(smem[baseA+2], smem[baseB0+5], D, D);
f16mulf16addf32(smem[baseA], smem[baseB1+4], D+1, D+1);
f16mulf16addf32(smem[baseA+2], smem[baseB1+5], D+1, D+1);
f16mulf16addf32(smem[baseA+1], smem[baseB0+4], D+2, D+2);
f16mulf16addf32(smem[baseA+3], smem[baseB0+5], D+2, D+2);
f16mulf16addf32(smem[baseA+1], smem[baseB1+4], D+3, D+3);
f16mulf16addf32(smem[baseA+3], smem[baseB1+5], D+3, D+3);
}
// __half * a0 = reinterpret_cast<__half*>(smem+base); // Add two fragments.
// __half * a1 = reinterpret_cast<__half*>(smem+base+1); template< typename Other_fragment_ >
// __half * a2 = reinterpret_cast<__half*>(smem+base+2); inline __device__ void add(const Other_fragment_ &other) {
// __half * a3 = reinterpret_cast<__half*>(smem+base+3); for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
// __half * b0 = reinterpret_cast<__half*>(smem+base+4); this->elt(ii) = this->elt(ii) + other.elt(ii);
// __half * b1 = reinterpret_cast<__half*>(smem+base+5); }
// printf("m16n8k16 %i: \n A %f, %f, %f, %f, %f, %f, %f, %f \n B %f, %f, %f, %f \n D %f, %f, %f, %f \n", threadIdx.x, }
// __half2float(a0[0]), __half2float(a0[1]),
// __half2float(a1[0]), __half2float(a1[1]), // Do the HMMA.
// __half2float(a2[0]), __half2float(a2[1]), template< typename Layout_a, typename Layout_b >
// __half2float(a3[0]), __half2float(a3[1]), inline __device__ void mma(const Fragment_a<Layout_a> &a,
// __half2float(b0[0]), __half2float(b0[1]), const Fragment_b<Layout_b> &b) {
// __half2float(b1[0]), __half2float(b1[1]), // const uint32_t * A = reinterpret_cast<const uint32_t*>(a.regs_);
// D[0], D[1], D[2], D[3] // const uint32_t * B = reinterpret_cast<const uint32_t*>(b.regs_);
// ); // float * D = reinterpret_cast<float*>(regs_);
} // float regs[8];
#endif // __builtin_memcpy(regs, D, sizeof(float)*8);
// m16n8k16(A, B, D);
// m16n8k16(A, B+2, D+4);
using v4f = __attribute__( (__vector_size__(4 * sizeof(float)) )) float;
v4f * rC = reinterpret_cast<v4f*>(regs_);
// float rA = reinterpret_cast<const float&>(a.reg(0));
// float rB = reinterpret_cast<const float&>(b.reg(0));
float rA0 = a.template elt_as<float>(0);
float rB0 = b.template elt_as<float>(0);
*rC = __builtin_amdgcn_mmac_f32_16x16x4f32(rA0, rB0, *rC, 0);
float rA1 = a.template elt_as<float>(1);
float rB1 = b.template elt_as<float>(1);
*rC = __builtin_amdgcn_mmac_f32_16x16x4f32(rA1, rB1, *rC, 0);
float rA2 = a.template elt_as<float>(2);
float rB2 = b.template elt_as<float>(2);
*rC = __builtin_amdgcn_mmac_f32_16x16x4f32(rA2, rB2, *rC, 0);
float rA3 = a.template elt_as<float>(3);
float rB3 = b.template elt_as<float>(3);
*rC = __builtin_amdgcn_mmac_f32_16x16x4f32(rA3, rB3, *rC, 0);
// if (blockIdx.x == 0) {
// printf("tid:%d rA0:%6.4f rB0:%6.4f rA1:%6.4f rB1:%6.4f rA2:%6.4f rB2:%6.4f rA3:%6.4f rB3:%6.4f c0:%6.4f c1:%6.4f c2:%6.4f c3:%6.4f\n", threadIdx.x,
// rA0, rB0, rA1, rB1, rA2, rB2, rA3, rB3, elt(0), elt(1), elt(2), elt(3));
// }
}
};
#else
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Fragment_accumulator : public Fragment<float, 8> { struct Fragment_accumulator : public Fragment<float, 8> {
...@@ -242,15 +214,6 @@ struct Fragment_accumulator : public Fragment<float, 8> { ...@@ -242,15 +214,6 @@ struct Fragment_accumulator : public Fragment<float, 8> {
template< typename Layout_a, typename Layout_b > template< typename Layout_a, typename Layout_b >
inline __device__ void mma(const Fragment_a<Layout_a> &a, inline __device__ void mma(const Fragment_a<Layout_a> &a,
const Fragment_b<Layout_b> &b) { const Fragment_b<Layout_b> &b) {
#if defined (__HIP_PLATFORM_HCC__)
const uint32_t * A = reinterpret_cast<const uint32_t*>(a.regs_);
const uint32_t * B = reinterpret_cast<const uint32_t*>(b.regs_);
float * D = reinterpret_cast<float*>(regs_);
float regs[8];
__builtin_memcpy(regs, D, sizeof(float)*8);
m16n8k16(A, B, D);
m16n8k16(A, B+2, D+4);
#else
asm volatile( \ asm volatile( \
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
" {%0, %1, %2, %3}, \n" \ " {%0, %1, %2, %3}, \n" \
...@@ -269,11 +232,10 @@ struct Fragment_accumulator : public Fragment<float, 8> { ...@@ -269,11 +232,10 @@ struct Fragment_accumulator : public Fragment<float, 8> {
: "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7)) : "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7))
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)) : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
, "r"(b.reg(2)), "r"(b.reg(3))); , "r"(b.reg(2)), "r"(b.reg(3)));
#endif
} }
}; };
#endif
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Fragment, int M, int N > template< typename Fragment, int M, int N >
...@@ -310,8 +272,24 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) ...@@ -310,8 +272,24 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N])
#pragma unroll #pragma unroll
for( int mi = 0; mi < M; ++mi ) { for( int mi = 0; mi < M; ++mi ) {
// wangaq debug
// if (blockIdx.x == 0) {
// printf("a tid:%d mi:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, mi,
// a[mi].template elt_as<float>(0),
// a[mi].template elt_as<float>(1),
// a[mi].template elt_as<float>(2),
// a[mi].template elt_as<float>(3));
// }
#pragma unroll #pragma unroll
for( int ni = 0; ni < N; ++ni ) { for( int ni = 0; ni < N; ++ni ) {
// wangaq debug
// if (blockIdx.x == 0) {
// printf("b tid:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, ni,
// b[ni].template elt_as<float>(0),
// b[ni].template elt_as<float>(1),
// b[ni].template elt_as<float>(2),
// b[ni].template elt_as<float>(3));
// }
acc[mi][ni].mma(a[mi], b[ni]); acc[mi][ni].mma(a[mi], b[ni]);
} }
} }
...@@ -340,7 +318,11 @@ struct Cta_tile_ { ...@@ -340,7 +318,11 @@ struct Cta_tile_ {
// The number of warps per CTA. // The number of warps per CTA.
enum { WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K }; enum { WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K };
// The number of threads per warp. // The number of threads per warp.
#if defined(__HIP_PLATFORM_HCC__)
enum { THREADS_PER_WARP = 64 };
#else
enum { THREADS_PER_WARP = 32 }; enum { THREADS_PER_WARP = 32 };
#endif
// The number of threads per CTA. // The number of threads per CTA.
enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP }; enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP };
}; };
...@@ -350,7 +332,11 @@ struct Cta_tile_ { ...@@ -350,7 +332,11 @@ struct Cta_tile_ {
template<typename Cta_tile> template<typename Cta_tile>
struct Hmma_tile { struct Hmma_tile {
// The number of elements computed with a single warp-MMA. // The number of elements computed with a single warp-MMA.
// #if defined(__HIP_PLATFORM_HCC__)
// enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 4 };
// #else
enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16 }; enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16 };
// #endif
// The number of elements computed with a single CTA-MMA. // The number of elements computed with a single CTA-MMA.
enum { enum {
......
...@@ -85,6 +85,20 @@ struct Gmem_tile_qkv { ...@@ -85,6 +85,20 @@ struct Gmem_tile_qkv {
// Store data to shared memory. // Store data to shared memory.
template< typename Smem_tile > template< typename Smem_tile >
inline __device__ void commit(Smem_tile &smem_tile) { inline __device__ void commit(Smem_tile &smem_tile) {
// wangaq debug
// for( int ii = 0; ii < LDGS; ++ii ) {
// if (blockIdx.x == 0) {
// printf("commit tid:%d LDGS:%d ii:%d %f %f %f %f %f %f %f %f\n", threadIdx.x, LDGS, ii,
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[0]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[1]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[2]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[3]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[4]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[5]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[6]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[7]));
// }
// }
smem_tile.store(fetch_); smem_tile.store(fetch_);
} }
...@@ -105,6 +119,18 @@ struct Gmem_tile_qkv { ...@@ -105,6 +119,18 @@ struct Gmem_tile_qkv {
#pragma unroll #pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) { for( int ii = 0; ii < LDGS; ++ii ) {
fct.load(ii, preds[ii]); fct.load(ii, preds[ii]);
// wangaq debug
// if (blockIdx.x == 0) {
// printf("load tid:%d LDGS:%d ii:%d %f %f %f %f %f %f %f %f\n", threadIdx.x, LDGS, ii,
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[0]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[1]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[2]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[3]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[4]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[5]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[6]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[7]));
// }
} }
} }
...@@ -254,8 +280,13 @@ struct Gmem_tile_mma_sd { ...@@ -254,8 +280,13 @@ struct Gmem_tile_mma_sd {
// The mma tile. // The mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>; using Mma_tile = fmha::Hmma_tile<Cta_tile>;
#if defined(__HIP_PLATFORM_HCC__)
// Each STG stores 16 elements.
enum { BYTES_PER_STG = BYTES_PER_ELEMENT * 4 };
#else
// Each STG stores 8 elements. // Each STG stores 8 elements.
enum { BYTES_PER_STG = BYTES_PER_ELEMENT * 8 }; enum { BYTES_PER_STG = BYTES_PER_ELEMENT * 8 };
#endif
// The number of MMAs in the M dimension. // The number of MMAs in the M dimension.
enum { MMAS_M = Mma_tile::MMAS_M }; enum { MMAS_M = Mma_tile::MMAS_M };
// The number of MMAs in the N dimension. // The number of MMAs in the N dimension.
...@@ -369,6 +400,14 @@ struct Gmem_tile_mma_s : public Base { ...@@ -369,6 +400,14 @@ struct Gmem_tile_mma_s : public Base {
for( int mi = 0; mi < M; mi++ ) { for( int mi = 0; mi < M; mi++ ) {
#pragma unroll #pragma unroll
for( int ni = 0; ni < N; ni++ ) { for( int ni = 0; ni < N; ni++ ) {
#if defined(__HIP_PLATFORM_HCC__)
uint2 dst;
dst.x = float2_to_half2(frag[ni][mi].reg(0), frag[ni][mi].reg(1));
dst.y = float2_to_half2(frag[ni][mi].reg(2), frag[ni][mi].reg(3));
if( mask.any_valid(mi, ni) ) {
Base::store(dst, mi, ni);
}
#else
uint4 dst; uint4 dst;
dst.x = frag[ni][mi].reg(0); dst.x = frag[ni][mi].reg(0);
dst.y = frag[ni][mi].reg(2); dst.y = frag[ni][mi].reg(2);
...@@ -377,6 +416,7 @@ struct Gmem_tile_mma_s : public Base { ...@@ -377,6 +416,7 @@ struct Gmem_tile_mma_s : public Base {
if( mask.any_valid(mi, ni) ) { if( mask.any_valid(mi, ni) ) {
Base::store(dst, mi, ni); Base::store(dst, mi, ni);
} }
#endif
} }
} }
} }
......
...@@ -47,11 +47,28 @@ struct Mask { ...@@ -47,11 +47,28 @@ struct Mask {
// find the warp in the Cta tile // find the warp in the Cta tile
const int warp_n = (warp / Cta_tile::WARPS_M); const int warp_n = (warp / Cta_tile::WARPS_M);
const int warp_m = (warp % Cta_tile::WARPS_M); const int warp_m = (warp % Cta_tile::WARPS_M);
#if defined(__HIP_PLATFORM_HCC__)
// decompose warp into 16x16 tile
const int quad = lane % 16;
const int tid = lane / 16;
row = warp_m * 16 + quad;
col = warp_n * 16 + tid;
#else
// decompose warp into 8x4 tile // decompose warp into 8x4 tile
const int quad = lane / 4; const int quad = lane / 4;
const int tid = (lane % 4) * 2; const int tid = (lane % 4) * 2;
row = warp_m * 16 + quad; row = warp_m * 16 + quad;
col = warp_n * 16 + tid; col = warp_n * 16 + tid;
#endif
}
inline __device__ bool is_valid(const int mi, const int ni, const int jj) const {
// jj iterate over the 1x4 fragment
const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + 4 * jj) < actual_seqlen;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen;
return col_valid;
// return row_valid && col_valid;
} }
inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const { inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const {
...@@ -65,7 +82,11 @@ struct Mask { ...@@ -65,7 +82,11 @@ struct Mask {
//BERT Mask: if upper left is invalid, none are valid //BERT Mask: if upper left is invalid, none are valid
inline __device__ bool any_valid(int mi, int ni) const { inline __device__ bool any_valid(int mi, int ni) const {
#if defined(__HIP_PLATFORM_HCC__)
return is_valid(mi, ni, 0);
#else
return is_valid(mi, ni, 0, 0); return is_valid(mi, ni, 0, 0);
#endif
} }
inline __device__ void load(int it) { inline __device__ void load(int it) {
......
...@@ -69,7 +69,7 @@ struct Smem_tile_without_skews { ...@@ -69,7 +69,7 @@ struct Smem_tile_without_skews {
// The number of bytes per row without packing of rows. // The number of bytes per row without packing of rows.
enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 }; enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 };
// The number of bytes per row -- we want at least 128B per row. // The number of bytes per row -- we want at least 128B per row.
enum { BYTES_PER_ROW = Max<BYTES_PER_ROW_BEFORE_PACKING, 128>::VALUE }; enum { BYTES_PER_ROW = Max<BYTES_PER_ROW_BEFORE_PACKING, 128>::VALUE + 4 };
// The number of rows in shared memory (two rows may be packed into a single one). // The number of rows in shared memory (two rows may be packed into a single one).
enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW }; enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW };
...@@ -117,6 +117,18 @@ struct Smem_tile_without_skews { ...@@ -117,6 +117,18 @@ struct Smem_tile_without_skews {
inline __device__ Smem_tile_without_skews(void *smem, int tidx) inline __device__ Smem_tile_without_skews(void *smem, int tidx)
: smem_(__nvvm_get_smem_pointer(smem)) { : smem_(__nvvm_get_smem_pointer(smem)) {
#if defined (__HIP_PLATFORM_HCC__)
int smem_write_row = tidx / THREADS_PER_ROW;
int smem_write_col = tidx % THREADS_PER_ROW;
// The offset.
this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS;
// TODO: Why not merge it with the read offset?
this->smem_read_buffer_ = __shfl(0, 0);
this->smem_write_buffer_ = __shfl(0, 0);
#else
// The row written by a thread. See doc/mma_smem_layout.xlsx. // The row written by a thread. See doc/mma_smem_layout.xlsx.
int smem_write_row = tidx / THREADS_PER_ROW; int smem_write_row = tidx / THREADS_PER_ROW;
...@@ -129,10 +141,6 @@ struct Smem_tile_without_skews { ...@@ -129,10 +141,6 @@ struct Smem_tile_without_skews {
this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS; this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS;
// TODO: Why not merge it with the read offset? // TODO: Why not merge it with the read offset?
#if defined (__HIP_PLATFORM_HCC__)
this->smem_read_buffer_ = __shfl(0, 0);
this->smem_write_buffer_ = __shfl(0, 0);
#else
this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0);
this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0);
#endif #endif
...@@ -259,6 +267,32 @@ struct Smem_tile_without_skews { ...@@ -259,6 +267,32 @@ struct Smem_tile_without_skews {
uint32_t smem_ptrs[N]; uint32_t smem_ptrs[N];
this->compute_store_pointers(smem_ptrs); this->compute_store_pointers(smem_ptrs);
sts(smem_ptrs, data); sts(smem_ptrs, data);
// wangaq debug
// if (blockIdx.x == 0) {
// extern __shared__ char smem[];
// uint32_t base = __nvvm_get_smem_pointer(smem);
// for (int ii = 0; ii < N; ++ii) {
// printf("data tid:%d N:%d ii:%d %f %f %f %f %f %f %f %f\n", threadIdx.x, N, ii,
// __half2float(reinterpret_cast<const __half*>(&data[ii])[0]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[1]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[2]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[3]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[4]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[5]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[6]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[7]));
// __half * smem_ptr = reinterpret_cast<__half*>(smem-base+smem_ptrs[ii]);
// printf("smem_ptrs tid:%d N:%d ii:%d %f %f %f %f %f %f %f %f\n", threadIdx.x, N, ii,
// __half2float(smem_ptr[0]),
// __half2float(smem_ptr[1]),
// __half2float(smem_ptr[2]),
// __half2float(smem_ptr[3]),
// __half2float(smem_ptr[4]),
// __half2float(smem_ptr[5]),
// __half2float(smem_ptr[6]),
// __half2float(smem_ptr[7]));
// }
// }
} }
// Store to the tile in shared memory. // Store to the tile in shared memory.
...@@ -408,17 +442,28 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile, ...@@ -408,17 +442,28 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
const int WARPS_K = Cta_tile::WARPS_K; const int WARPS_K = Cta_tile::WARPS_K;
static_assert(WARPS_M == 1); static_assert(WARPS_M == 1);
#if defined (__HIP_PLATFORM_HCC__)
static_assert(WARPS_N == 2 || WARPS_N == 4);
#else
static_assert(WARPS_N == 4 || WARPS_N == 8); static_assert(WARPS_N == 4 || WARPS_N == 8);
#endif
static_assert(WARPS_K == 1); static_assert(WARPS_K == 1);
static_assert(Base::ROWS_PER_XOR_PATTERN == 8); static_assert(Base::ROWS_PER_XOR_PATTERN == 8);
// The row and column read by the thread. // The row and column read by the thread.
#if defined(__HIP_PLATFORM_HCC__)
const int M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA;
int smem_read_row = (tidx % M_PER_MMA_PER_CTA);
int smem_read_col = ((tidx & 0x3f) / M_PER_MMA_PER_CTA);
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*(Base::BITS_PER_ELEMENT/8);
#else
int smem_read_row = (tidx & 0x0f); int smem_read_row = (tidx & 0x0f);
int smem_read_col = (tidx & 0x07); int smem_read_col = (tidx & 0x07);
smem_read_col ^= (tidx & 0x10) / 16; smem_read_col ^= (tidx & 0x10) / 16;
// The shared memory offset. // The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS; this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;
#endif
} }
// Rewind smem_read_offset for last LDS phase in main loop. // Rewind smem_read_offset for last LDS phase in main loop.
...@@ -437,6 +482,21 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile, ...@@ -437,6 +482,21 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
#if defined(__HIP_PLATFORM_HCC__)
int k_offset = 4 /* 指令的k维度为4 */ * (Base::BITS_PER_ELEMENT/8);
int ki_offset = ki * Mma_tile::K_PER_MMA * (Base::BITS_PER_ELEMENT/8);
ldsm(a[mi].reg(0), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 0 * k_offset + ki_offset);
ldsm(a[mi].reg(1), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 1 * k_offset + ki_offset);
ldsm(a[mi].reg(2), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 2 * k_offset + ki_offset);
ldsm(a[mi].reg(3), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 3 * k_offset + ki_offset);
// if (blockIdx.x == 0) {
// printf("smem a load tid:%d %f %f %f %f\n", threadIdx.x,
// a[mi].template elt_as<float>(0),
// a[mi].template elt_as<float>(1),
// a[mi].template elt_as<float>(2),
// a[mi].template elt_as<float>(3));
// }
#else
// Load using LDSM.M88.4. // Load using LDSM.M88.4.
uint4 tmp; uint4 tmp;
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
...@@ -446,8 +506,12 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile, ...@@ -446,8 +506,12 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
a[mi].reg(1) = tmp.y; a[mi].reg(1) = tmp.y;
a[mi].reg(2) = tmp.z; a[mi].reg(2) = tmp.z;
a[mi].reg(3) = tmp.w; a[mi].reg(3) = tmp.w;
#endif
} }
#if defined(__HIP_PLATFORM_HCC__)
// this->smem_read_offset_ = (ki+1) * Mma_tile::K_PER_MMA * (Base::BITS_PER_ELEMENT/8);
#else
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx. // Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) { if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
...@@ -461,6 +525,7 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile, ...@@ -461,6 +525,7 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
} else if( Mma_tile_with_padding::MMAS_K >= 2 ) { } else if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
} }
#endif
} }
// Reset the read offset. // Reset the read offset.
...@@ -593,9 +658,15 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile, ...@@ -593,9 +658,15 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
const int WARPS_K = Cta_tile::WARPS_K; const int WARPS_K = Cta_tile::WARPS_K;
static_assert(Base::ROWS_PER_XOR_PATTERN == 8); static_assert(Base::ROWS_PER_XOR_PATTERN == 8);
static_assert(WARPS_M == 1); static_assert(WARPS_M == 1);
static_assert(WARPS_N == 4 || WARPS_N == 8); static_assert(WARPS_N == 2 || WARPS_N == 4 || WARPS_N == 8);
static_assert(WARPS_K == 1); static_assert(WARPS_K == 1);
#if defined(__HIP_PLATFORM_HCC__)
const int N_PER_MMA = Mma_tile::N_PER_MMA;
int smem_read_row = (tidx % N_PER_MMA) + tidx / Cta_tile::THREADS_PER_WARP * N_PER_MMA;
int smem_read_col = (tidx / N_PER_MMA) % 4; // 指令的k维度为4
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*(Base::BITS_PER_ELEMENT/8);
#else
// The masks to select the warps. // The masks to select the warps.
const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N; const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
...@@ -610,6 +681,7 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile, ...@@ -610,6 +681,7 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
smem_read_col ^= (tidx & 0x08) / 8; smem_read_col ^= (tidx & 0x08) / 8;
// The shared memory offset. // The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS; this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;
#endif
} }
// Rewind smem_read_offset for last LDS phase in main loop. // Rewind smem_read_offset for last LDS phase in main loop.
...@@ -628,6 +700,21 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile, ...@@ -628,6 +700,21 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
#if defined(__HIP_PLATFORM_HCC__)
int k_offset = 4 /* 指令的k维度为4 */ * (Base::BITS_PER_ELEMENT/8);
int ki_offset = ki * Mma_tile::K_PER_MMA_PER_CTA * (Base::BITS_PER_ELEMENT/8);
ldsm(b[ni].reg(0), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 0 * k_offset + ki_offset);
ldsm(b[ni].reg(1), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 1 * k_offset + ki_offset);
ldsm(b[ni].reg(2), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 2 * k_offset + ki_offset);
ldsm(b[ni].reg(3), this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset + 3 * k_offset + ki_offset);
// if (blockIdx.x == 0) {
// printf("smem b load tid:%d %f %f %f %f\n", threadIdx.x,
// b[ni].template elt_as<float>(0),
// b[ni].template elt_as<float>(1),
// b[ni].template elt_as<float>(2),
// b[ni].template elt_as<float>(3));
// }
#else
// Load using LDSM.M88.4. // Load using LDSM.M88.4.
uint4 tmp; uint4 tmp;
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
...@@ -637,8 +724,12 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile, ...@@ -637,8 +724,12 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
b[ni].reg(1) = tmp.y; b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z; b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w; b[ni].reg(3) = tmp.w;
#endif
} }
#if defined(__HIP_PLATFORM_HCC__)
// this->smem_read_offset_ = (ki+1) * Mma_tile::K_PER_MMA_PER_CTA * (Base::BITS_PER_ELEMENT/8);
#else
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx. // Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) { if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
...@@ -652,6 +743,7 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile, ...@@ -652,6 +743,7 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
} else if( Mma_tile_with_padding::MMAS_K >= 2 ) { } else if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
} }
#endif
} }
// Reset the read offset. // Reset the read offset.
...@@ -919,20 +1011,41 @@ struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K, ...@@ -919,20 +1011,41 @@ struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K,
// The row/col read by the thread. // The row/col read by the thread.
int read_row, read_col; int read_row, read_col;
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 2 || Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
#if defined(__HIP_PLATFORM_HCC__)
const int K_PER_MMA = Mma_tile::K_PER_MMA;
read_row = (tidx / 16) % 4 + (tidx / Cta_tile::THREADS_PER_WARP) * K_PER_MMA;
read_col = tidx % 16;
// The shared memory offset.
this->smem_read_offset_ = read_row*Base::BYTES_PER_ROW + read_col*(Base::BITS_PER_ELEMENT/8);
#else
read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f); read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f);
read_col = (tidx & 0x07); read_col = (tidx & 0x07);
read_col ^= (tidx & 0x10) / 16; read_col ^= (tidx & 0x10) / 16;
// The shared memory offset. // The shared memory offset.
this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS;
#endif
} }
// Load from shared memory. // Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
#pragma unroll #pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
#if defined(__HIP_PLATFORM_HCC__)
// Jump by 16 * #warps row.
// int row = ki * Cta_tile::K_PER_MMA_PER_CTA;
int col_offset = ni * Mma_tile::N_PER_MMA * (Base::BITS_PER_ELEMENT/8) ;
int k_offset = 4 /* 指令的k维度为4 */ * Base::BYTES_PER_ROW;
int ki_offset = ki * Mma_tile::K_PER_MMA_PER_CTA * Base::BYTES_PER_ROW;
ldsm(b[ni].reg(0), this->smem_ + this->smem_read_offset_ + col_offset + 0 * k_offset + ki_offset);
ldsm(b[ni].reg(1), this->smem_ + this->smem_read_offset_ + col_offset + 1 * k_offset + ki_offset);
ldsm(b[ni].reg(2), this->smem_ + this->smem_read_offset_ + col_offset + 2 * k_offset + ki_offset);
ldsm(b[ni].reg(3), this->smem_ + this->smem_read_offset_ + col_offset + 3 * k_offset + ki_offset);
#else
// Jump by 16 * #warps row. // Jump by 16 * #warps row.
int row = ki * 16 * Cta_tile::WARPS_K; int row = ki * 16 * Cta_tile::WARPS_K;
...@@ -950,6 +1063,7 @@ struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K, ...@@ -950,6 +1063,7 @@ struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K,
} else { } else {
assert(false); // Not implemented! assert(false); // Not implemented!
} }
#endif
} }
} }
}; };
...@@ -1010,8 +1124,25 @@ struct Smem_tile_o { ...@@ -1010,8 +1124,25 @@ struct Smem_tile_o {
// Get a 32-bit value for the shared memory address. // Get a 32-bit value for the shared memory address.
uint32_t smem_ = __nvvm_get_smem_pointer(smem); uint32_t smem_ = __nvvm_get_smem_pointer(smem);
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 2 || Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
#if defined(__HIP_PLATFORM_HCC__)
int write_row = tidx % 16;
int write_col = (tidx / 16) % 4 + (tidx / 64) * Mma_tile::K_PER_MMA * Mma_tile::MMAS_N;
// Assemble the write pointer.
smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_ELEMENT;
// The element read by each thread.
int read_row = tidx / THREADS_PER_ROW;
int read_col = tidx % THREADS_PER_ROW;
// Take the XOR pattern into account for the column.
// read_col ^= 2 * (read_row & 0x7);
// Assemble the read pointer.
this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
#else
int write_row = (tidx & 0x1c) / 4; int write_row = (tidx & 0x1c) / 4;
int write_col = (tidx); int write_col = (tidx);
...@@ -1027,6 +1158,7 @@ struct Smem_tile_o { ...@@ -1027,6 +1158,7 @@ struct Smem_tile_o {
// Assemble the read pointer. // Assemble the read pointer.
this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
#endif
// Is that thread active on the last LDS? // Is that thread active on the last LDS?
if( HAS_INCOMPLETE_LDS ) { if( HAS_INCOMPLETE_LDS ) {
...@@ -1036,6 +1168,34 @@ struct Smem_tile_o { ...@@ -1036,6 +1168,34 @@ struct Smem_tile_o {
// Load the output fragments. // Load the output fragments.
inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const {
#if defined(__HIP_PLATFORM_HCC__)
for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) {
// Load the elements before the reduction (split-K).
uint4 tmp[Cta_tile::WARPS_K];
#pragma unroll
for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) {
int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;
if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) {
fmha::lds(tmp[jj], this->smem_read_ + imm);
// wangaq debug
// float * xxx = reinterpret_cast<float*>(&tmp[0]);
// xxx[0] = threadIdx.x * 10.0 + 0;
// xxx[1] = threadIdx.x * 10.0 + 1;
// xxx[2] = threadIdx.x * 10.0 + 2;
// xxx[3] = threadIdx.x * 10.0 + 3;
// fmha::sts(this->smem_read_ + imm, tmp[0]);
}
}
// Perform the reduction.
out[ii] = tmp[0];
#pragma unroll
for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) {
out[ii] = fmha::fadd4(out[ii], tmp[jj]);
}
}
#else
#pragma unroll #pragma unroll
for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) { for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) {
...@@ -1056,10 +1216,31 @@ struct Smem_tile_o { ...@@ -1056,10 +1216,31 @@ struct Smem_tile_o {
out[ii] = fmha::fadd4(out[ii], tmp[jj]); out[ii] = fmha::fadd4(out[ii], tmp[jj]);
} }
} }
#endif
} }
// Store the accumulators. // Store the accumulators.
template <int M, int N> template <int M, int N>
inline __device__ void store(const Accumulator (&acc)[M][N], int mi) { inline __device__ void store(const Accumulator (&acc)[M][N], int mi) {
#if defined(__HIP_PLATFORM_HCC__)
for (int mi = 0; mi < M; ++mi) {
for (int ni = 0; ni < N; ++ni) {
int ni_offset = Mma_tile::K_PER_MMA * ni * BYTES_PER_ELEMENT;
// uint32_t tmp[4];
// reinterpret_cast<float&>(tmp[0]) = threadIdx.x * 100.0 + ni * 10 + 0;
// reinterpret_cast<float&>(tmp[1]) = threadIdx.x * 100.0 + ni * 10 + 1;
// reinterpret_cast<float&>(tmp[2]) = threadIdx.x * 100.0 + ni * 10 + 2;
// reinterpret_cast<float&>(tmp[3]) = threadIdx.x * 100.0 + ni * 10 + 3;
// fmha::sts(this->smem_write_ + ni_offset + 0 * BYTES_PER_ELEMENT, tmp[0]);
// fmha::sts(this->smem_write_ + ni_offset + 4 * BYTES_PER_ELEMENT, tmp[1]);
// fmha::sts(this->smem_write_ + ni_offset + 8 * BYTES_PER_ELEMENT, tmp[2]);
// fmha::sts(this->smem_write_ + ni_offset + 12 * BYTES_PER_ELEMENT, tmp[3]);
fmha::sts(this->smem_write_ + ni_offset + 0 * BYTES_PER_ELEMENT, acc[mi][ni].reg(0));
fmha::sts(this->smem_write_ + ni_offset + 4 * BYTES_PER_ELEMENT, acc[mi][ni].reg(1));
fmha::sts(this->smem_write_ + ni_offset + 8 * BYTES_PER_ELEMENT, acc[mi][ni].reg(2));
fmha::sts(this->smem_write_ + ni_offset + 12 * BYTES_PER_ELEMENT, acc[mi][ni].reg(3));
}
}
#else
enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA };
#pragma unroll #pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
...@@ -1109,6 +1290,7 @@ struct Smem_tile_o { ...@@ -1109,6 +1290,7 @@ struct Smem_tile_o {
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B. // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32; this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
} }
#endif
} }
}; };
...@@ -1177,11 +1359,11 @@ struct Smem_tile_mma_transposed : public Base { ...@@ -1177,11 +1359,11 @@ struct Smem_tile_mma_transposed : public Base {
enum { BYTES_PER_ELT = Base::BYTES_PER_ELT }; enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };
enum { WARPS_M = Base::WARPS_M }; enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N }; enum { WARPS_N = Base::WARPS_N };
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); static_assert(WARPS_M == 1 && (WARPS_N == 2 || WARPS_N == 4 || WARPS_N == 8));
using Fragment = typename Base::Fragment; using Fragment = typename Base::Fragment;
inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) { inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) {
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); static_assert(WARPS_M == 1 && (WARPS_N == 2 || WARPS_N == 4 || WARPS_N == 8));
int read_row, read_col; int read_row, read_col;
read_row = (tidx & 0x0f); read_row = (tidx & 0x0f);
read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16; read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;
...@@ -1221,7 +1403,7 @@ struct Smem_tile_mma_epilogue : public Base { ...@@ -1221,7 +1403,7 @@ struct Smem_tile_mma_epilogue : public Base {
static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M); static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M);
enum { WARPS_M = Base::WARPS_M }; enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N }; enum { WARPS_N = Base::WARPS_N };
static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); static_assert((WARPS_N == 2 || WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
using Acc = fmha::Fragment_accumulator; using Acc = fmha::Fragment_accumulator;
......
...@@ -56,8 +56,13 @@ inline __device__ float apply_exp_(float x, float max) { ...@@ -56,8 +56,13 @@ inline __device__ float apply_exp_(float x, float max) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<int COLS> struct ReadType {}; template<int COLS> struct ReadType {};
#if defined(__HIP_PLATFORM_HCC__)
template<> struct ReadType<2> { using T = float;};
template<> struct ReadType<4> { using T = float2;};
#else
template<> struct ReadType<4> { using T = float;}; template<> struct ReadType<4> { using T = float;};
template<> struct ReadType<8> { using T = float2;}; template<> struct ReadType<8> { using T = float2;};
#endif
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -78,7 +83,11 @@ struct Smem_tile_reduce { ...@@ -78,7 +83,11 @@ struct Smem_tile_reduce {
static constexpr int ROWS = WARPS_M * MMAS_M * 16; static constexpr int ROWS = WARPS_M * MMAS_M * 16;
static constexpr int COLS = WARPS_N; static constexpr int COLS = WARPS_N;
#if defined(__HIP_PLATFORM_HCC__)
static_assert(COLS == 2 || COLS == 4);
#else
static_assert(COLS == 4 || COLS == 8); static_assert(COLS == 4 || COLS == 8);
#endif
static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8; static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float); static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
static constexpr int ELTS_PER_TILE = ROWS * COLS; static constexpr int ELTS_PER_TILE = ROWS * COLS;
...@@ -93,6 +102,20 @@ struct Smem_tile_reduce { ...@@ -93,6 +102,20 @@ struct Smem_tile_reduce {
__device__ inline Smem_tile_reduce(float *smem_, const int tidx) { __device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
#if defined(__HIP_PLATFORM_HCC__)
int lane = tidx % Cta_tile::THREADS_PER_WARP;
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int warp_m = warp % WARPS_M;
int warp_n = warp / WARPS_M;
qid_ = lane / 16; // 前16个线程才能写入
int qp = lane % 16;
const int col = warp_n;
smem_write_ = &smem_[warp_m * ELTS_PER_TILE + qp * WARPS_N + col];
smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * ELTS_PER_TILE + qp * 2 + qid_ / WARPS_N];
#else
int lane = tidx % 32; int lane = tidx % 32;
int warp = tidx / 32; int warp = tidx / 32;
...@@ -107,9 +130,19 @@ struct Smem_tile_reduce { ...@@ -107,9 +130,19 @@ struct Smem_tile_reduce {
const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN); const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN);
smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col]; smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col];
smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_]; smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_];
#endif
} }
__device__ inline void store(float (&frag)[MMAS_M]) {
if( qid_ == 0 ) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
*smem_write_ = frag[mi];
}
}
}
__device__ inline void store(float (&frag)[2 * MMAS_M]) { __device__ inline void store(float (&frag)[2 * MMAS_M]) {
if( qid_ == 0 ) { if( qid_ == 0 ) {
#pragma unroll #pragma unroll
...@@ -121,6 +154,13 @@ struct Smem_tile_reduce { ...@@ -121,6 +154,13 @@ struct Smem_tile_reduce {
} }
} }
__device__ inline void load(read_t (&frag)[MMAS_M]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
frag[mi] = *smem_read_;
}
}
__device__ inline void load(read_t (&frag)[2 * MMAS_M]) { __device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
#pragma unroll #pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) { for( int mi = 0; mi < MMAS_M; mi++ ) {
...@@ -188,6 +228,29 @@ struct Softmax_base { ...@@ -188,6 +228,29 @@ struct Softmax_base {
smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4]; smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];
} }
#if defined(__HIP_PLATFORM_HCC__)
template<typename Mask>
inline __device__ void apply_mask(const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
#pragma unroll
for( int jj = 0; jj < 4; ++jj ) {
if( !mask.is_valid(mi, ni, jj) ) {
elt_[mi][4 * ni + jj] = -INFINITY;
}
}
// wangaq debug
// if (blockIdx.x == 0) {
// printf("apply_mask tid:%d mi:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, mi, ni,
// this->elt_[mi][4 * ni + 0], this->elt_[mi][4 * ni + 1], this->elt_[mi][4 * ni + 2], this->elt_[mi][4 * ni + 3]);
// }
}
}
}
#else
template<typename Mask> template<typename Mask>
inline __device__ void apply_mask(const Mask &mask) { inline __device__ void apply_mask(const Mask &mask) {
#pragma unroll #pragma unroll
...@@ -206,6 +269,26 @@ struct Softmax_base { ...@@ -206,6 +269,26 @@ struct Softmax_base {
} }
} }
} }
#endif
// Apply the exp to all the elements.
inline __device__ void apply_exp(const float (&max)[MMAS_M]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]);
}
// wangaq debug
// if (blockIdx.x == 0) {
// for( int ni = 0; ni < MMAS_N; ++ni ) {
// printf("apply_exp tid:%d mi:%d ni:%d max:%6.4f %f %f %f %f\n", threadIdx.x, mi, ni, max[mi],
// this->elt_[mi][4 * ni + 0], this->elt_[mi][4 * ni + 1], this->elt_[mi][4 * ni + 2], this->elt_[mi][4 * ni + 3]);
// }
// }
}
}
// Apply the exp to all the elements. // Apply the exp to all the elements.
inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) { inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) {
...@@ -218,6 +301,33 @@ struct Softmax_base { ...@@ -218,6 +301,33 @@ struct Softmax_base {
} }
} }
// Scale all the elements.
inline __device__ void scale(const float (&sum)[MMAS_M]) {
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
float inv_sum[MMAS_M];
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi];
}
// Update the values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
elt_[mi][ni] *= inv_sum[mi];
}
// wangaq debug
// if (blockIdx.x == 0) {
// for( int ni = 0; ni < MMAS_N; ++ni ) {
// printf("scale tid:%d mi:%d ni:%d sum:%6.4f inv_sum:%6.4f %f %f %f %f\n", threadIdx.x, mi, ni, sum[mi], inv_sum[mi],
// this->elt_[mi][4 * ni + 0], this->elt_[mi][4 * ni + 1], this->elt_[mi][4 * ni + 2], this->elt_[mi][4 * ni + 3]);
// }
// }
}
}
// Scale all the elements. // Scale all the elements.
inline __device__ void scale(const float (&sum)[MMAS_M * 2]) { inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal. // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
...@@ -244,7 +354,11 @@ struct Softmax_base { ...@@ -244,7 +354,11 @@ struct Softmax_base {
// The current thread index. // The current thread index.
int tidx_; int tidx_;
// The elements. // The elements.
#if defined(__HIP_PLATFORM_HCC__)
float elt_[MMAS_M][MMAS_N * 4];
#else
float elt_[MMAS_M * 2][MMAS_N * 4]; float elt_[MMAS_M * 2][MMAS_N * 4];
#endif
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -290,7 +404,20 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> { ...@@ -290,7 +404,20 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
for( int mi = 0; mi < M; ++mi ) { for( int mi = 0; mi < M; ++mi ) {
#pragma unroll #pragma unroll
for( int ki = 0; ki < K; ++ki ) { for( int ki = 0; ki < K; ++ki ) {
#if defined(__HIP_PLATFORM_HCC__)
dst[ki][mi].template elt_as<float>(0) = this->elt_[mi][4 * ki + 0];
dst[ki][mi].template elt_as<float>(1) = this->elt_[mi][4 * ki + 1];
dst[ki][mi].template elt_as<float>(2) = this->elt_[mi][4 * ki + 2];
dst[ki][mi].template elt_as<float>(3) = this->elt_[mi][4 * ki + 3];
// wangaq debug
// if (blockIdx.x == 0) {
// printf("pack tid:%d mi:%d ki:%d %6.4f %6.4f %6.4f %6.4f -> %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, mi, ki,
// this->elt_[mi][4 * ki + 0], this->elt_[mi][4 * ki + 1], this->elt_[mi][4 * ki + 2], this->elt_[mi][4 * ki + 3],
// dst[ki][mi].template elt_as<float>(0), dst[ki][mi].template elt_as<float>(1), dst[ki][mi].template elt_as<float>(2), dst[ki][mi].template elt_as<float>(3));
// // printf("pack tid:%d mi:%d ki:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, mi, ki,
// // dst[ki][mi].template elt_as<float>(0), dst[ki][mi].template elt_as<float>(1), dst[ki][mi].template elt_as<float>(2), dst[ki][mi].template elt_as<float>(3));
// }
#else
// 1st row - 4 elements per row. // 1st row - 4 elements per row.
float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0];
float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1];
...@@ -308,6 +435,7 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> { ...@@ -308,6 +435,7 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11); dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);
dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03); dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);
dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13); dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);
#endif
} }
} }
} }
...@@ -340,6 +468,19 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> { ...@@ -340,6 +468,19 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
for( int mi = 0; mi < MMAS_M; ++mi ) { for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll #pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) { for( int ni = 0; ni < MMAS_N; ++ni ) {
#if defined(__HIP_PLATFORM_HCC__)
this->elt_[mi][4 * ni + 0] = acc[mi][ni].elt(0);
this->elt_[mi][4 * ni + 1] = acc[mi][ni].elt(1);
this->elt_[mi][4 * ni + 2] = acc[mi][ni].elt(2);
this->elt_[mi][4 * ni + 3] = acc[mi][ni].elt(3);
// wangaq debug
// if (blockIdx.x == 0) {
// printf("unpack_noscale tid:%d mi:%d ni:%d %6.4f %6.4f %6.4f %6.4f -> %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, mi, ni,
// acc[mi][ni].template elt_as<float>(0), acc[mi][ni].template elt_as<float>(1), acc[mi][ni].template elt_as<float>(2), acc[mi][ni].template elt_as<float>(3),
// this->elt_[mi][4 * ni + 0], this->elt_[mi][4 * ni + 1], this->elt_[mi][4 * ni + 2], this->elt_[mi][4 * ni + 3]);
// }
#else
// 1st row - 4 elements per row. // 1st row - 4 elements per row.
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0); this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0);
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1); this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1);
...@@ -350,11 +491,39 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> { ...@@ -350,11 +491,39 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3); this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3);
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6); this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6);
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7); this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7);
#endif
}
}
}
template<typename Operator>
__device__ inline void reduce_(float (&frag)[MMAS_M], Operator &op, Smem_tile_red & smem_red) {
for( int mi = 0; mi < MMAS_M; mi++ ) {
frag[mi] = this->elt_[mi][0];
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
frag[mi] = op(frag[mi], this->elt_[mi][ni]);
} }
} }
quad_reduce(frag, frag, op);
smem_red.store(frag);
__syncthreads();
typename Smem_tile_red::read_t tmp[MMAS_M];
smem_red.load(tmp);
binary_allreduce(frag, tmp, op);
} }
__device__ inline void reduce_max(float (&frag)[MMAS_M]){
MaxOp<float> max;
reduce_(frag, max, smem_max_);
}
__device__ inline void reduce_sum(float (&frag)[MMAS_M]){
SumOp<float> sum;
reduce_(frag, sum, smem_sum_);
}
template<typename Operator> template<typename Operator>
__device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) { __device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) {
......
...@@ -294,6 +294,13 @@ static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) { ...@@ -294,6 +294,13 @@ static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) {
return c; return c;
} }
static inline __device__ uint32_t fmul(uint32_t a, uint32_t b) {
uint32_t c;
float tmp = reinterpret_cast<float&>(a) * reinterpret_cast<float&>(b);
c = reinterpret_cast<uint32_t&>(tmp);
return c;
}
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hmul4(uint2 a, uint2 b) { static inline __device__ uint2 hmul4(uint2 a, uint2 b) {
...@@ -346,6 +353,15 @@ static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) { ...@@ -346,6 +353,15 @@ static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {
#endif #endif
return res; return res;
} }
static inline __device__ uint32_t frelu(uint32_t x, uint32_t lb = 0) {
uint32_t res;
float tmp_x = reinterpret_cast<float&>(x);
tmp_x = tmp_x > lb ? tmp_x : lb;
__builtin_memcpy(&res, &tmp_x, sizeof(uint32_t));
return res;
}
static inline __device__ uint32_t habs2(uint32_t x) { static inline __device__ uint32_t habs2(uint32_t x) {
uint32_t res; uint32_t res;
#if defined (__HIP_PLATFORM_HCC__) #if defined (__HIP_PLATFORM_HCC__)
...@@ -905,10 +921,19 @@ inline __device__ void lds(uint4 &dst, uint32_t ptr) { ...@@ -905,10 +921,19 @@ inline __device__ void lds(uint4 &dst, uint32_t ptr) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint32_t &dst, uint32_t ptr) { inline __device__ void ldsm(uint32_t &dst, uint32_t ptr) {
#if defined (__HIP_PLATFORM_HCC__)
extern __shared__ char smem[];
uint32_t base = __nvvm_get_smem_pointer(smem);
float tmp = __half2float(*(__half*)(smem-base+ptr));
__builtin_memcpy(&dst, &tmp, sizeof(uint32_t));
// if (blockIdx.x == 0)
// printf("ldsm tid:%d tmp:%f\n", threadIdx.x, tmp);
#else
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n"
: "=r"(dst) : "r"(ptr)); : "=r"(dst) : "r"(ptr));
#endif #endif
#endif
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -1114,10 +1139,32 @@ inline __device__ void sts(uint32_t ptr, uint2 val) { ...@@ -1114,10 +1139,32 @@ inline __device__ void sts(uint32_t ptr, uint2 val) {
inline __device__ void sts(uint32_t ptr, uint4 val) { inline __device__ void sts(uint32_t ptr, uint4 val) {
#if defined (__HIP_PLATFORM_HCC__) #if defined (__HIP_PLATFORM_HCC__)
asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr) , "v"(val.x)); // asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr) , "v"(val.x));
asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+4) , "v"(val.y)); // asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+4) , "v"(val.y));
asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+8) , "v"(val.z)); // asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+8) , "v"(val.z));
asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+12) , "v"(val.w)); // asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+12) , "v"(val.w));
extern __shared__ char smem[];
uint32_t base = __nvvm_get_smem_pointer(smem);
__builtin_memcpy(smem-base+ptr, &val, sizeof(uint4));
// if (blockIdx.x == 0) {
// printf("sts tid:%d %f %f %f %f %f %f %f %f -> %f %f %f %f %f %f %f %f\n", threadIdx.x,
// __half2float(reinterpret_cast<half*>(&val)[0]),
// __half2float(reinterpret_cast<half*>(&val)[1]),
// __half2float(reinterpret_cast<half*>(&val)[2]),
// __half2float(reinterpret_cast<half*>(&val)[3]),
// __half2float(reinterpret_cast<half*>(&val)[4]),
// __half2float(reinterpret_cast<half*>(&val)[5]),
// __half2float(reinterpret_cast<half*>(&val)[6]),
// __half2float(reinterpret_cast<half*>(&val)[7]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[0]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[1]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[2]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[3]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[4]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[5]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[6]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[7]));
// }
#else #else
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n"
: :
...@@ -1190,7 +1237,7 @@ struct Allreduce { ...@@ -1190,7 +1237,7 @@ struct Allreduce {
static __device__ inline T run(T x, Operator &op) { static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2; constexpr int OFFSET = THREADS / 2;
#if defined (__HIP_PLATFORM_HCC__) #if defined (__HIP_PLATFORM_HCC__)
x = op(x, __shfl_xor(uint32_t(-1), x, OFFSET)); x = op(x, __shfl_xor(x, OFFSET));
#else #else
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
#endif #endif
...@@ -1221,8 +1268,8 @@ __device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator & ...@@ -1221,8 +1268,8 @@ __device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &
for(int mi=0; mi < M; mi++){ for(int mi=0; mi < M; mi++){
dst[mi] = src[mi]; dst[mi] = src[mi];
#if defined (__HIP_PLATFORM_HCC__) #if defined (__HIP_PLATFORM_HCC__)
dst[mi] = op(dst[mi], __shfl_down(dst[mi], 2)); dst[mi] = op(dst[mi], __shfl_down(dst[mi], 32));
dst[mi] = op(dst[mi], __shfl_down(dst[mi], 1)); dst[mi] = op(dst[mi], __shfl_down(dst[mi], 16));
#else #else
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2)); dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1)); dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));
...@@ -1267,4 +1314,34 @@ __device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operato ...@@ -1267,4 +1314,34 @@ __device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operato
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined (__HIP_PLATFORM_HCC__)
template<int THREADS>
struct Allreduce32 {
static_assert(THREADS == 64);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
return op(x, __shfl_xor(x, 32));
}
};
template<typename Operator, int M>
__device__ inline void binary_allreduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = Allreduce32<64>::run(dst[mi], op);
}
}
template<typename Operator, int M>
__device__ inline void binary_allreduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
binary_allreduce(dst, tmp, op);
}
#endif
} // namespace fmha } // namespace fmha
...@@ -28,7 +28,11 @@ ...@@ -28,7 +28,11 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h" #include "fmha_dgrad_kernel_1xN_reload.h"
#if defined(__HIP_PLATFORM_HCC__)
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 2, 0x08u>;
#else
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
#endif
extern "C" __global__ void fmha_dgrad_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { extern "C" __global__ void fmha_dgrad_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params); fmha::compute_dv_1xN<Kernel_traits>(params);
......
...@@ -28,7 +28,11 @@ ...@@ -28,7 +28,11 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_fprop_kernel_1xN.h" #include "fmha_fprop_kernel_1xN.h"
#if defined(__HIP_PLATFORM_HCC__)
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 2, 0x08u>;
#else
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
#endif
template<bool Is_training> template<bool Is_training>
__global__ __global__
......
...@@ -28,7 +28,11 @@ ...@@ -28,7 +28,11 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_fprop_kernel_1xN.h" #include "fmha_fprop_kernel_1xN.h"
#if defined(__HIP_PLATFORM_HCC__)
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 2, 0x08u>;
#else
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
#endif
template<bool Is_training> template<bool Is_training>
__global__ __global__
......
...@@ -28,7 +28,11 @@ ...@@ -28,7 +28,11 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_fprop_kernel_1xN.h" #include "fmha_fprop_kernel_1xN.h"
#if defined(__HIP_PLATFORM_HCC__)
using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 2, 0x18u>;
#else
using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 4, 0x18u>; using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 4, 0x18u>;
#endif
template<bool Is_training> template<bool Is_training>
__global__ __global__
......
...@@ -28,7 +28,11 @@ ...@@ -28,7 +28,11 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_fprop_kernel_1xN.h" #include "fmha_fprop_kernel_1xN.h"
#if defined(__HIP_PLATFORM_HCC__)
using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 4, 0x00u>;
#else
using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x00u>; using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x00u>;
#else
template<bool Is_training> template<bool Is_training>
__global__ __global__
......
...@@ -111,11 +111,67 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> { ...@@ -111,11 +111,67 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
Base::smem_q.load(Base::frag_q[ki & 1], ki); Base::smem_q.load(Base::frag_q[ki & 1], ki);
// Do the math for the values already in registers. // Do the math for the values already in registers.
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// for(int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
// printf("frag_q[%d] tid:%d ki:%d mi:%d %6.4f %6.4f %6.4f %6.4f\n", (ki - 1) & 1, threadIdx.x, ki, mi,
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(0),
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(1),
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(2),
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(3));
// }
// for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
// printf("frag_k[%d] tid:%d ki:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", (ki - 1) & 1, threadIdx.x, ki, ni,
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(0),
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(1),
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(2),
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(3));
// }
// for(int m = 0; m < M; ++m) {
// for (int n = 0; n < N; ++n) {
// printf("acc_p tid:%d ki:%d mi:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, ki, m, n,
// acc_p[m][n].template elt_as<float>(0),
// acc_p[m][n].template elt_as<float>(1),
// acc_p[m][n].template elt_as<float>(2),
// acc_p[m][n].template elt_as<float>(3));
// }
// }
// }
} }
// Do the final stage of math. // Do the final stage of math.
{ {
int ki = Mma_tile_p::MMAS_K; int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// for(int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
// printf("frag_q[%d] tid:%d ki:%d mi:%d %6.4f %6.4f %6.4f %6.4f\n", (ki - 1) & 1, threadIdx.x, ki, mi,
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(0),
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(1),
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(2),
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(3));
// }
// for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
// printf("frag_k[%d] tid:%d ki:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", (ki - 1) & 1, threadIdx.x, ki, ni,
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(0),
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(1),
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(2),
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(3));
// }
// for(int m = 0; m < M; ++m) {
// for (int n = 0; n < N; ++n) {
// printf("acc_p tid:%d ki:%d mi:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, ki, m, n,
// acc_p[m][n].template elt_as<float>(0),
// acc_p[m][n].template elt_as<float>(1),
// acc_p[m][n].template elt_as<float>(2),
// acc_p[m][n].template elt_as<float>(3));
// }
// }
// }
} }
} }
...@@ -188,6 +244,7 @@ constexpr size_t get_dynamic_smem_size(){ ...@@ -188,6 +244,7 @@ constexpr size_t get_dynamic_smem_size(){
template<typename Kernel_traits, bool Is_training, typename Params, typename Prng> template<typename Kernel_traits, bool Is_training, typename Params, typename Prng>
inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, const int begin, const int steps, Prng & ph) { inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, const int begin, const int steps, Prng & ph) {
// if (blockIdx.x == 0 && threadIdx.x == 0) printf("steps:%d\n", steps);
// The description of the CTA tile for the 1st batched GEMM. // The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p; using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
...@@ -291,8 +348,42 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -291,8 +348,42 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
__syncthreads(); __syncthreads();
// wangaq debug
// if (blockIdx.x == 0 && tidx == 0) {
// __half * smem = reinterpret_cast<__half*>(smem_);
// printf("begin:%d q %d bytes smem:\n", begin, Gemm1::Smem_tile_q::BYTES_PER_TILE);
// for (int row = 0; row < Gemm1::Smem_tile_q::BYTES_PER_TILE / 2 / 8; ++row) {
// printf("row:%d ", row);
// for (int col = 0; col < 8; ++col) {
// printf("col:%d value:%6.4f\t", col, __half2float(smem[row*8+col]));
// }
// printf("\n");
// }
// printf("begin:%d v %d bytes smem:\n", begin, Smem_tile_v::BYTES_PER_TILE);
// smem = reinterpret_cast<__half*>(smem_v_);
// for (int row = 0; row < Smem_tile_v::BYTES_PER_TILE / 2 / 8; ++row) {
// printf("row:%d ", row);
// for (int col = 0; col < 8; ++col) {
// printf("col:%d value:%6.4f\t", col, __half2float(smem[row*8+col]));
// }
// printf("\n");
// }
// }
// Load the fragments for Q. // Load the fragments for Q.
gemm_q_k.load_q(); gemm_q_k.load_q();
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// for(int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
// printf("frag_q tid:%d mi:%d %6.4f %6.4f %6.4f %6.4f\n", tidx, mi,
// gemm_q_k.frag_q[0][mi].template elt_as<float>(0),
// gemm_q_k.frag_q[0][mi].template elt_as<float>(1),
// gemm_q_k.frag_q[0][mi].template elt_as<float>(2),
// gemm_q_k.frag_q[0][mi].template elt_as<float>(3));
// }
// }
// Load the fragments for V. We keep the data in registers during the entire kernel. // Load the fragments for V. We keep the data in registers during the entire kernel.
typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N]; typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
...@@ -311,11 +402,39 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -311,11 +402,39 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Make sure the data is in shared memory. // Make sure the data is in shared memory.
__syncthreads(); __syncthreads();
// wangaq debug
// if (blockIdx.x == 0 && tidx == 0) {
// printf("begin:%d k %d bytes smem:\n", begin, Gemm1::Smem_tile_k::BYTES_PER_TILE);
// __half * smem = reinterpret_cast<__half*>(smem_ + Gemm1::Smem_tile_q::BYTES_PER_TILE);
// for (int row = 0; row < Gemm1::Smem_tile_k::BYTES_PER_TILE / 2 / 8; ++row) {
// printf("row:%d ", row);
// for (int col = 0; col < 8; ++col) {
// printf("col:%d value:%6.4f\t", col, __half2float(smem[row*8+col]));
// }
// printf("\n");
// }
// }
} }
// Load the fragments for K. // Load the fragments for K.
gemm_q_k.load_k(); gemm_q_k.load_k();
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// // Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
// for(int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki) {
// for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
// printf("frag_k tid:%d ki:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", tidx, ki, ni,
// gemm_q_k.frag_k[ki][ni].template elt_as<float>(0),
// gemm_q_k.frag_k[ki][ni].template elt_as<float>(1),
// gemm_q_k.frag_k[ki][ni].template elt_as<float>(2),
// gemm_q_k.frag_k[ki][ni].template elt_as<float>(3));
// }
// }
// }
// Create the object to do the softmax. // Create the object to do the softmax.
Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE], bidb, tidx); Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);
...@@ -330,6 +449,19 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -330,6 +449,19 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Do this part of P^T = (Q * K^T)^T. // Do this part of P^T = (Q * K^T)^T.
gemm_q_k(acc_p); gemm_q_k(acc_p);
// wangaq debug
// if (blockIdx.x == 0) {
// for(int m = 0; m < Mma_tile_p::MMAS_M; ++m) {
// for (int n = 0; n < Mma_tile_p::MMAS_N; ++n) {
// printf("acc_p steps:%d step:%d tid:%d mi:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", steps, l, threadIdx.x, m, n,
// acc_p[m][n].template elt_as<float>(0),
// acc_p[m][n].template elt_as<float>(1),
// acc_p[m][n].template elt_as<float>(2),
// acc_p[m][n].template elt_as<float>(3));
// }
// }
// }
// Trigger the load for the next Q values. // Trigger the load for the next Q values.
if( l < steps - 1) { if( l < steps - 1) {
gemm_q_k.smem_q.move_to_next_write_buffer(); gemm_q_k.smem_q.move_to_next_write_buffer();
...@@ -351,18 +483,33 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -351,18 +483,33 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
__syncthreads(); __syncthreads();
} }
// Compute the max. // Compute the max.
#if defined(__HIP_PLATFORM_HCC__)
float p_max[Mma_tile_p::MMAS_M];
#else
float p_max[Mma_tile_p::MMAS_M * 2]; float p_max[Mma_tile_p::MMAS_M * 2];
#endif
//softmax.template reduce<fmha::Max_>(p_max); //softmax.template reduce<fmha::Max_>(p_max);
softmax.reduce_max(p_max); softmax.reduce_max(p_max);
// wangaq debug
// if (blockIdx.x == 0) {
// for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
// printf("tid:%d mi:%d p_max:%f\n", threadIdx.x, mi, p_max[mi]);
// }
// }
// Compute the exponential value. // Compute the exponential value.
softmax.apply_exp(p_max); softmax.apply_exp(p_max);
// Compute the sum. // Compute the sum.
#if defined(__HIP_PLATFORM_HCC__)
float p_sum[Mma_tile_p::MMAS_M];
#else
float p_sum[Mma_tile_p::MMAS_M * 2]; float p_sum[Mma_tile_p::MMAS_M * 2];
#endif
softmax.reduce_sum(p_sum); softmax.reduce_sum(p_sum);
// Finalize softmax on the accumulators of P^T. // Finalize softmax on the accumulators of P.
softmax.scale(p_sum); softmax.scale(p_sum);
using Frag_p = fmha::Fragment_a<fmha::Row>; using Frag_p = fmha::Fragment_a<fmha::Row>;
...@@ -372,24 +519,55 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -372,24 +519,55 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
#pragma unroll #pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll #pragma unroll
for( int ii = 0; ii < 2; ii++ ) { for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
#pragma unroll float4 tmp = uniform4(ph());
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { // We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros
float4 tmp = uniform4(ph()); #if defined(__HIP_PLATFORM_HCC__)
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros softmax.elt_[mi][4 * ni + 0] =
softmax.elt_[2 * mi + ii][4 * ni + 0] = encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[mi][4 * ni + 0]);
encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]); softmax.elt_[mi][4 * ni + 1] =
softmax.elt_[2 * mi + ii][4 * ni + 1] = encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[mi][4 * ni + 1]);
encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]); softmax.elt_[mi][4 * ni + 2] =
softmax.elt_[2 * mi + ii][4 * ni + 2] = encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[mi][4 * ni + 2]);
encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]); softmax.elt_[mi][4 * ni + 3] =
softmax.elt_[2 * mi + ii][4 * ni + 3] = encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[mi][4 * ni + 3]);
encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]); #else
} softmax.elt_[2 * mi + ii][4 * ni + 0] =
encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);
softmax.elt_[2 * mi + ii][4 * ni + 1] =
encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);
softmax.elt_[2 * mi + ii][4 * ni + 2] =
encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);
softmax.elt_[2 * mi + ii][4 * ni + 3] =
encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);
#endif
} }
} }
softmax.pack(frag_p); softmax.pack(frag_p);
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// for (int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki) {
// for (int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi) {
// printf("frag_p tid:%d ki:%d mi:%d %6.4f %6.4f %6.4f %6.4f\n", tidx, ki, mi,
// frag_p[ki][mi].template elt_as<float>(0),
// frag_p[ki][mi].template elt_as<float>(1),
// frag_p[ki][mi].template elt_as<float>(2),
// frag_p[ki][mi].template elt_as<float>(3));
// }
// }
// }
gmem_s.store(frag_p, mask); gmem_s.store(frag_p, mask);
// wangaq debug
// printf("begin:%d gmem s:\n", begin);
// __half * gmem = reinterpret_cast<__half*>(gmem_s.ptr_);
// for (int i = 0; i < Gmem_tile_s::LOOP_STRIDE_BYTES / 2 / 8; ++i) {
// printf("tid:%d row:%d ", threadIdx.x, i);
// for (int j = 0; j < 8; ++j) {
// printf("col:%d value:%d\t", j, gmem[i*8+j]);
// }
// }
gmem_s.move(); gmem_s.move();
} else { } else {
softmax.pack(frag_p); softmax.pack(frag_p);
...@@ -407,8 +585,8 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -407,8 +585,8 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
#pragma unroll #pragma unroll
for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) { for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {
//"Apply" the dropout. //"Apply" the dropout.
frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout); frag_p[ki][mi].reg(ii) = fmha::fmul(frag_p[ki][mi].reg(ii), params.scale_dropout);
frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii)); frag_p[ki][mi].reg(ii) = fmha::frelu(frag_p[ki][mi].reg(ii));
} }
} }
} }
...@@ -421,6 +599,34 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -421,6 +599,34 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
#pragma unroll #pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm(acc_o, frag_p[ki], frag_v[ki]); fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// for (int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi) {
// printf("frag_p tid:%d ki:%d mi:%d %6.4f %6.4f %6.4f %6.4f\n", tidx, ki, mi,
// frag_p[ki][mi].template elt_as<float>(0),
// frag_p[ki][mi].template elt_as<float>(1),
// frag_p[ki][mi].template elt_as<float>(2),
// frag_p[ki][mi].template elt_as<float>(3));
// }
// for (int ni = 0; ni < Mma_tile_o::MMAS_N; ++ni) {
// printf("frag_v tid:%d ki:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", tidx, ki, ni,
// frag_v[ki][ni].template elt_as<float>(0),
// frag_v[ki][ni].template elt_as<float>(1),
// frag_v[ki][ni].template elt_as<float>(2),
// frag_v[ki][ni].template elt_as<float>(3));
// }
// for (int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi) {
// for (int ni = 0; ni < Mma_tile_o::MMAS_N; ++ni) {
// printf("acc_o tid:%d ki:%d mi:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", tidx, ki, mi, ni,
// acc_o[mi][ni].template elt_as<float>(0),
// acc_o[mi][ni].template elt_as<float>(1),
// acc_o[mi][ni].template elt_as<float>(2),
// acc_o[mi][ni].template elt_as<float>(3));
// }
// }
// }
} }
// Loop over MMAS_M. // Loop over MMAS_M.
...@@ -429,6 +635,18 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -429,6 +635,18 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Swizzle the elements and do the final reduction. // Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii); smem_o.store(acc_o, ii);
// wangaq debug
// if (blockIdx.x == 0 && tidx == 0) {
// printf("ii:%d smem_o %d bytes smem:\n", ii, Smem_tile_o::BYTES_PER_TILE);
// float * smem_o = reinterpret_cast<float*>(&smem_[Gemm1::SMEM_OFFSET_O]);
// for (int row = 0; row < Smem_tile_o::BYTES_PER_TILE / 4 / 8; ++row) {
// printf("row:%d ", row);
// for (int col = 0; col < 8; ++col) {
// printf("col:%d value:%6.4f\t", col, smem_o[row*8+col]);
// }
// printf("\n");
// }
// }
// Make sure the data is in shared memory. // Make sure the data is in shared memory.
__syncthreads(); __syncthreads();
...@@ -436,6 +654,18 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -436,6 +654,18 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Load from shared memory. // Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP]; uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out); smem_o.load(out);
// wangaq debug
// if (blockIdx.x == 0 && tidx == 0) {
// printf("ii:%d smem_o %d bytes smem:\n", ii, Smem_tile_o::BYTES_PER_TILE);
// float * smem_o = reinterpret_cast<float*>(&smem_[Gemm1::SMEM_OFFSET_O]);
// for (int row = 0; row < Smem_tile_o::BYTES_PER_TILE / 4 / 8; ++row) {
// printf("row:%d ", row);
// for (int col = 0; col < 8; ++col) {
// printf("col:%d value:%6.4f\t", col, smem_o[row*8+col]);
// }
// printf("\n");
// }
// }
// Make sure the data was read from shared memory. // Make sure the data was read from shared memory.
if( ii < Gmem_tile_o::LOOPS - 1 ) { if( ii < Gmem_tile_o::LOOPS - 1 ) {
...@@ -472,10 +702,14 @@ inline __device__ void device_1xN(const Params &params, ...@@ -472,10 +702,14 @@ inline __device__ void device_1xN(const Params &params,
const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x; const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x;
auto seeds = at::cuda::philox::unpack(params.philox_args); auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
// if (blockIdx.x == 0 && threadIdx.x == 0)
// printf("num_full_heads:%d num_main_groups:%d main_group_size:%d main_steps:%d, rest_steps:%d", num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
for( int it = 0; it < num_full_heads; it++ ) { for( int it = 0; it < num_full_heads; it++ ) {
const int bidx = it * gridDim.x + blockIdx.x; const int bidx = it * gridDim.x + blockIdx.x;
const int bidh = bidx % params.h; const int bidh = bidx % params.h;
const int bidb = bidx / params.h; const int bidb = bidx / params.h;
// if (blockIdx.x == 0 && threadIdx.x == 0)
// printf("%s:%d N:%d M:%d steps:%d\n", __FILE__, __LINE__, Kernel_traits::Cta_tile_p::N, Kernel_traits::Cta_tile_p::M, STEPS);
fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, 0, STEPS, ph); fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, 0, STEPS, ph);
__syncthreads(); __syncthreads();
} }
...@@ -490,6 +724,8 @@ inline __device__ void device_1xN(const Params &params, ...@@ -490,6 +724,8 @@ inline __device__ void device_1xN(const Params &params,
const int bidh = (head_offset + bidx) % params.h; const int bidh = (head_offset + bidx) % params.h;
const int bidb = (head_offset + bidx) / params.h; const int bidb = (head_offset + bidx) / params.h;
const int offset = group * main_steps; const int offset = group * main_steps;
// if (blockIdx.x == 0 && threadIdx.x == 0)
// printf("%s:%d N:%d M:%d steps:%d\n", __FILE__, __LINE__, Kernel_traits::Cta_tile_p::N, Kernel_traits::Cta_tile_p::M, STEPS);
fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, offset, main_steps, ph); fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, offset, main_steps, ph);
} else { } else {
if(rest_steps == 0 ) return; if(rest_steps == 0 ) return;
...@@ -501,6 +737,8 @@ inline __device__ void device_1xN(const Params &params, ...@@ -501,6 +737,8 @@ inline __device__ void device_1xN(const Params &params,
for( int it = head_offset + bidx; it < total_heads; it += rest_ctas ) { for( int it = head_offset + bidx; it < total_heads; it += rest_ctas ) {
const int bidh = it % params.h; const int bidh = it % params.h;
const int bidb = it / params.h; const int bidb = it / params.h;
// if (blockIdx.x == 0 && threadIdx.x == 0)
// printf("%s:%d N:%d M:%d steps:%d\n", __FILE__, __LINE__, Kernel_traits::Cta_tile_p::N, Kernel_traits::Cta_tile_p::M, STEPS);
fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, offset, rest_steps, ph); fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, offset, rest_steps, ph);
__syncthreads(); __syncthreads();
} }
......
...@@ -131,6 +131,8 @@ std::tuple<int , int, int, int, int, int> work_dist(const int total_ctas, const ...@@ -131,6 +131,8 @@ std::tuple<int , int, int, int, int, int> work_dist(const int total_ctas, const
const int num_full_heads = heads_total / total_ctas; const int num_full_heads = heads_total / total_ctas;
const int heads_last_wave = heads_total % total_ctas; const int heads_last_wave = heads_total % total_ctas;
// printf("total_ctas:%d heads_total:%d num_full_heads:%d heads_last_wave:%d N:%d M:%d steps:%d\n",
// total_ctas, heads_total, num_full_heads, heads_last_wave, Kernel_traits::Cta_tile_p::N, Kernel_traits::Cta_tile_p::M, STEPS_PER_HEAD);
int num_main_groups = 0; int num_main_groups = 0;
int main_steps = 0; int main_steps = 0;
......
...@@ -512,15 +512,15 @@ if "--fmha" in sys.argv: ...@@ -512,15 +512,15 @@ if "--fmha" in sys.argv:
CUDAExtension(name='fmhalib', CUDAExtension(name='fmhalib',
sources=[ sources=[
'apex/contrib/csrc/fmha/fmha_api.cpp', 'apex/contrib/csrc/fmha/fmha_api.cpp',
'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu', # 'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu', 'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu', 'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu', 'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu', # 'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu', # 'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu', # 'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu', # 'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu', # 'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu',
], ],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}, 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment