Commit 0816a70e authored by flyingdown's avatar flyingdown Committed by flyingdown
Browse files

添加fmha支持

parent 2a4864d5
......@@ -81,7 +81,11 @@ struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {
inline __device__ void clear() {
#pragma unroll
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("s_mov_b32 %0, 0; \n" : "=s"(this->reg(ii)) : );
#else
asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : );
#endif
}
}
......@@ -142,6 +146,85 @@ struct Fragment_b : public Fragment<uint16_t, 8> {
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined (__HIP_PLATFORM_HCC__)
__device__ inline void f16mulf16addf32(uint32_t & a, uint32_t & b, const float * c, float * d){
// uint32_t res = 0;
// asm volatile("v_pk_fma_f16 %0, %1,%2,%3" : "=v"(res) : "v"(a), "v"(b), "v"(res));
// __half * h = reinterpret_cast<__half*>(&res);
__half * ha = reinterpret_cast<__half*>(&a);
__half * hb = reinterpret_cast<__half*>(&b);
float C = *c, D = *d;
*d = *c + __half2float(ha[0])*__half2float(hb[0]) + __half2float(ha[1])*__half2float(hb[1]);
// if (threadIdx.x == 15) {
// printf("f16mulf16addf32 %i: A %f, %f, B %f, %f, RES %f, %f, %f, C %f, %f, D %f, %f \n", threadIdx.x,
// __half2float(ha[0]), __half2float(ha[1]),
// __half2float(hb[0]), __half2float(hb[1]),
// __half2float(ha[0])*__half2float(hb[0]),
// __half2float(ha[1])*__half2float(hb[1]),
// __half2float(ha[0])*__half2float(hb[0]) + __half2float(ha[1])*__half2float(hb[1]),
// C, *c, D, *d
// );
// }
}
// row 8 col 4
__device__ inline void m16n8k16(const uint32_t * A, const uint32_t * B, /*const float * C,*/ float * D) {
int tid = threadIdx.x;
int baseId = tid / 32 * 32;
__shared__ uint32_t smem[256*6];
int base = tid*6;
__builtin_memcpy(smem+base, A, sizeof(uint32_t));
__builtin_memcpy(smem+(base+1), A+1, sizeof(uint32_t));
__builtin_memcpy(smem+(base+2), A+2, sizeof(uint32_t));
__builtin_memcpy(smem+(base+3), A+3, sizeof(uint32_t));
__builtin_memcpy(smem+(base+4), B, sizeof(uint32_t));
__builtin_memcpy(smem+(base+5), B+1, sizeof(uint32_t));
__syncthreads();
/* 站在D的视角,每个进程负责D数据的计算,从0线程开始循环,获取一行A和两列B
s为B矩阵的线程号
baseA为A的线程号
baseB0为当前线程获取B的第一列,baseB1为当前线程获取B的第二列
*/
int s = baseId+(tid%4)*8, e = s+4;
for (int i = s; i < e; ++i) {
// A[0]->i A[1]->i+1 A[2]->i+2 A[3]->i+3 B[0]->i+4 B[1]->i+5
int baseA = (tid-tid%4+i-s)*6; // 当前tid所处行的第一列的进程号+stride 再*6
int baseB0 = i*6, baseB1 = (i+4)*6;
f16mulf16addf32(smem[baseA], smem[baseB0+4], D, D);
f16mulf16addf32(smem[baseA+2], smem[baseB0+5], D, D);
f16mulf16addf32(smem[baseA], smem[baseB1+4], D+1, D+1);
f16mulf16addf32(smem[baseA+2], smem[baseB1+5], D+1, D+1);
f16mulf16addf32(smem[baseA+1], smem[baseB0+4], D+2, D+2);
f16mulf16addf32(smem[baseA+3], smem[baseB0+5], D+2, D+2);
f16mulf16addf32(smem[baseA+1], smem[baseB1+4], D+3, D+3);
f16mulf16addf32(smem[baseA+3], smem[baseB1+5], D+3, D+3);
}
// __half * a0 = reinterpret_cast<__half*>(smem+base);
// __half * a1 = reinterpret_cast<__half*>(smem+base+1);
// __half * a2 = reinterpret_cast<__half*>(smem+base+2);
// __half * a3 = reinterpret_cast<__half*>(smem+base+3);
// __half * b0 = reinterpret_cast<__half*>(smem+base+4);
// __half * b1 = reinterpret_cast<__half*>(smem+base+5);
// printf("m16n8k16 %i: \n A %f, %f, %f, %f, %f, %f, %f, %f \n B %f, %f, %f, %f \n D %f, %f, %f, %f \n", threadIdx.x,
// __half2float(a0[0]), __half2float(a0[1]),
// __half2float(a1[0]), __half2float(a1[1]),
// __half2float(a2[0]), __half2float(a2[1]),
// __half2float(a3[0]), __half2float(a3[1]),
// __half2float(b0[0]), __half2float(b0[1]),
// __half2float(b1[0]), __half2float(b1[1]),
// D[0], D[1], D[2], D[3]
// );
}
#endif
struct Fragment_accumulator : public Fragment<float, 8> {
// The base class.
......@@ -159,6 +242,15 @@ struct Fragment_accumulator : public Fragment<float, 8> {
template< typename Layout_a, typename Layout_b >
inline __device__ void mma(const Fragment_a<Layout_a> &a,
const Fragment_b<Layout_b> &b) {
#if defined (__HIP_PLATFORM_HCC__)
const uint32_t * A = reinterpret_cast<const uint32_t*>(a.regs_);
const uint32_t * B = reinterpret_cast<const uint32_t*>(b.regs_);
float * D = reinterpret_cast<float*>(regs_);
float regs[8];
__builtin_memcpy(regs, D, sizeof(float)*8);
m16n8k16(A, B, D);
m16n8k16(A, B+2, D+4);
#else
asm volatile( \
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
" {%0, %1, %2, %3}, \n" \
......@@ -177,6 +269,7 @@ struct Fragment_accumulator : public Fragment<float, 8> {
: "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7))
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
, "r"(b.reg(2)), "r"(b.reg(3)));
#endif
}
};
......
......@@ -129,8 +129,13 @@ struct Smem_tile_without_skews {
this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS;
// TODO: Why not merge it with the read offset?
#if defined (__HIP_PLATFORM_HCC__)
this->smem_read_buffer_ = __shfl(0, 0);
this->smem_write_buffer_ = __shfl(0, 0);
#else
this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0);
this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0);
#endif
}
// Compute the store pointers.
......
......@@ -31,7 +31,19 @@
#include <stdint.h>
#include <stdlib.h>
#if defined(__HIP_PLATFORM_HCC__)
extern "C" {
__device__ inline size_t __nv_cvta_generic_to_shared_impl(const void *__ptr) {
return (size_t)(void __attribute__((address_space(3))) *)__ptr;
}
__device__ inline uint32_t __nvvm_get_smem_pointer(void *__ptr) {
return __nv_cvta_generic_to_shared_impl(__ptr);
}
} // extern "C"
#else
extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr);
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -246,7 +258,15 @@ inline int find_log_2(int x, bool round_up = false) {
static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) {
uint32_t c;
#if defined (__HIP_PLATFORM_HCC__)
// __half * ha = reinterpret_cast<__half*>(&a);
// __half * hb = reinterpret_cast<__half*>(&b);
// __half2 h2c = make_half2(ha[0] + hb[0], ha[1] + hb[1]);
// __builtin_memcpy(&c, &h2c, sizeof(h2c));
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#else
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
#endif
return c;
}
......@@ -254,7 +274,11 @@ static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) {
static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) {
uint32_t c;
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("v_pk_min_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#else
asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b));
#endif
return c;
}
......@@ -262,7 +286,11 @@ static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) {
static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) {
uint32_t c;
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#else
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
#endif
return c;
}
......@@ -301,6 +329,9 @@ static inline __device__ uint4 hmul8(uint32_t a, uint4 b) {
static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {
uint32_t res;
#if defined (__HIP_PLATFORM_HCC__)
asm volatile( "v_pk_max_f16 %0, %1, %2;\n" : "=v"(res) : "v"(x), "v"(lb));
#else
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(lb));
#else
......@@ -312,11 +343,21 @@ static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {
"\t and.b32 %0, sela, %1;\n"
"}\n" : "=r"(res) : "r"(x), "r"(zero));
#endif
#endif
return res;
}
static inline __device__ uint32_t habs2(uint32_t x) {
uint32_t res;
#if defined (__HIP_PLATFORM_HCC__)
__half * hx = reinterpret_cast<__half*>(&x);
__half zero = __float2half(0.0);
__half neg_one = __float2half(-1.0);
hx[0] = hx[0] > zero ? hx[0] : hx[0]*neg_one;
hx[1] = hx[1] > zero ? hx[1] : hx[1]*neg_one;
__builtin_memcpy(&res, hx, sizeof(uint32_t));
#else
asm volatile( "abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x));
#endif
return res;
}
......@@ -331,7 +372,12 @@ static inline __device__ T clamp(T x, T lb, T ub) {
static inline __device__ uint16_t clamp_to_zero(uint16_t x) {
uint16_t mask;
#if defined (__HIP_PLATFORM_HCC__)
if (isnan(x) || x > 0) mask = 0xffffffff;
else mask = 0x00000000;
#else
asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x));
#endif
return mask & x;
}
......@@ -339,7 +385,11 @@ static inline __device__ uint16_t clamp_to_zero(uint16_t x) {
static inline __device__ uint16_t float_to_half(float f) {
uint16_t h;
#if defined (__HIP_PLATFORM_HCC__)
h = __float2half(f);
#else
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f));
#endif
return h;
}
......@@ -347,6 +397,13 @@ static inline __device__ uint16_t float_to_half(float f) {
static inline __device__ uint32_t float2_to_half2(float a, float b) {
uint32_t c;
#if defined (__HIP_PLATFORM_HCC__)
__half h1 = __float2half(a);
__half h2 = __float2half(b);
__half * h = reinterpret_cast<__half*>(&c);
__builtin_memcpy(h, &h1, sizeof(uint16_t));
__builtin_memcpy(h+1, &h2, sizeof(uint16_t));
#else
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a));
#else
......@@ -354,6 +411,7 @@ static inline __device__ uint32_t float2_to_half2(float a, float b) {
uint16_t hi = float_to_half(b);
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi));
#endif
#endif
return c;
}
......@@ -382,7 +440,11 @@ static inline __device__ uint2 float4_to_half4(float x, float y, float z, float
static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
#else
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
#endif
return d;
}
......@@ -402,8 +464,13 @@ static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c)
static inline __device__ uint32_t h0_h0(uint32_t x) {
uint32_t y;
#if defined (__HIP_PLATFORM_HCC__)
uint32_t lo = x & 0x0000ffff;
y = lo << 16 | lo;
#else
asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n"
: "=r"(y) : "r"(x));
#endif
return y;
}
......@@ -411,11 +478,16 @@ static inline __device__ uint32_t h0_h0(uint32_t x) {
static inline __device__ float h0_to_float(uint32_t h2) {
float f;
#if defined (__HIP_PLATFORM_HCC__)
uint32_t lo = h2 & 0x0000ffff;
f = __half2float(reinterpret_cast<__half&>(lo));
#else
asm volatile("{\n" \
".reg .f16 lo, hi;\n" \
"mov.b32 {lo, hi}, %1;\n" \
"cvt.f32.f16 %0, lo;\n" \
"}\n" : "=f"(f) : "r"(h2));
#endif
return f;
}
......@@ -423,8 +495,13 @@ static inline __device__ float h0_to_float(uint32_t h2) {
static inline __device__ uint32_t h1_h1(uint32_t x) {
uint32_t y;
#if defined (__HIP_PLATFORM_HCC__)
uint32_t hi = x & 0xffff0000;
y = hi << 16 | hi;
#else
asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n"
: "=r"(y) : "r"(x));
#endif
return y;
}
......@@ -432,7 +509,11 @@ static inline __device__ uint32_t h1_h1(uint32_t x) {
static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) {
uint16_t d;
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("v_add_f16 %0, %1, %2;" : "=v"(d) : "v"(a), "v"(b));
#else
asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b));
#endif
return d;
}
......@@ -489,7 +570,11 @@ static inline __device__ uint4 hadd(uint4 a, uint4 b) {
static inline __device__ float half_to_float(uint16_t h) {
float f;
#if defined (__HIP_PLATFORM_HCC__)
f = __half2float(reinterpret_cast<const __half&>(h));
#else
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
#endif
return f;
}
......@@ -497,7 +582,12 @@ static inline __device__ float half_to_float(uint16_t h) {
static inline __device__ float2 half2_to_float2(uint32_t x) {
uint16_t lo, hi;
#if defined (__HIP_PLATFORM_HCC__)
lo = x & 0xffff;
hi = (x >> 16) & 0xffff;
#else
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x));
#endif
return make_float2(half_to_float(lo), half_to_float(hi));
}
......@@ -513,7 +603,11 @@ static inline __device__ void half2_to_float2(float &x, float &y, uint32_t h) {
static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) {
uint16_t d;
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("v_fma_f16 %0, %1, %2, %3;" : "=v"(d) : "v"(a), "v"(b), "v"(c));
#else
asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c));
#endif
return d;
}
......@@ -521,7 +615,11 @@ static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) {
static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) {
uint16_t d;
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("v_mul_f16 %0, %1, %2;" : "=v"(d) : "v"(a), "v"(b));
#else
asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b));
#endif
return d;
}
......@@ -754,30 +852,50 @@ inline __device__ void ldg(uint4 (&fetch)[N], const void* (&ptrs)[N], uint32_t (
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint16_t &dst, uint32_t ptr) {
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("ds_read_u16 %0, %1;" : "=v"(dst) : "v"(ptr));
#else
asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint32_t &dst, uint32_t ptr) {
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("ds_read_b32 %0, %1;" : "=v"(dst) : "v"(ptr));
#else
asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint2 &dst, uint32_t ptr) {
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("ds_read_b32 %0, %1;" : "=v"(dst.x) : "v"(ptr));
asm volatile("ds_read_b32 %0, %1;" : "=v"(dst.y) : "v"(ptr+4));
#else
asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint4 &dst, uint32_t ptr) {
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("ds_read_b32 %0, %1;" : "=v"(dst.x) : "v"(ptr));
asm volatile("ds_read_b32 %0, %1;" : "=v"(dst.y) : "v"(ptr+4));
asm volatile("ds_read_b32 %0, %1;" : "=v"(dst.z) : "v"(ptr+8));
asm volatile("ds_read_b32 %0, %1;" : "=v"(dst.w) : "v"(ptr+12));
#else
asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x)
, "=r"(dst.y)
, "=r"(dst.z)
, "=r"(dst.w)
: "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -823,19 +941,100 @@ inline __device__ void ldsmt(uint2 &dst, uint32_t ptr) {
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint4 &dst, uint32_t ptr) {
#if defined (__HIP_PLATFORM_HCC__)
extern __shared__ char smem[];
int laneId = threadIdx.x % 32;
int row = laneId / 4;
int col = laneId % 4;
unsigned ptr0 = __shfl(ptr, row, 32) + col * 4;
unsigned ptr1 = __shfl(ptr, row + 8, 32) + col * 4;
unsigned ptr2 = __shfl(ptr, row + 16, 32) + col * 4;
unsigned ptr3 = __shfl(ptr, row + 24, 32) + col * 4;
// asm volatile("ds_read_b32 %0, %1;" : "=v"(dst.x) : "v"(ptr0));
// asm volatile("ds_read_b32 %0, %1;" : "=v"(dst.y) : "v"(ptr1));
// asm volatile("ds_read_b32 %0, %1;" : "=v"(dst.z) : "v"(ptr2));
// asm volatile("ds_read_b32 %0, %1;" : "=v"(dst.w) : "v"(ptr3));
uint32_t base = __nvvm_get_smem_pointer(smem);
__builtin_memcpy(&dst.x, smem-base+ptr0, sizeof(uint32_t));
__builtin_memcpy(&dst.y, smem-base+ptr1, sizeof(uint32_t));
__builtin_memcpy(&dst.z, smem-base+ptr2, sizeof(uint32_t));
__builtin_memcpy(&dst.w, smem-base+ptr3, sizeof(uint32_t));
// printf("ldsm %03i: %03d %03d %03d %03d %03d %03d %f %f %f %f %f %f %f %f \n",
// threadIdx.x, __nvvm_get_smem_pointer(smem), ptr, ptr0-base, ptr1-base, ptr2-base, ptr3-base,
// __half2float(reinterpret_cast<__half*>(&dst.x)[0]),
// __half2float(reinterpret_cast<__half*>(&dst.x)[1]),
// __half2float(reinterpret_cast<__half*>(&dst.y)[0]),
// __half2float(reinterpret_cast<__half*>(&dst.y)[1]),
// __half2float(reinterpret_cast<__half*>(&dst.z)[0]),
// __half2float(reinterpret_cast<__half*>(&dst.z)[1]),
// __half2float(reinterpret_cast<__half*>(&dst.w)[0]),
// __half2float(reinterpret_cast<__half*>(&dst.w)[1])
// );
#else
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr));
#endif
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint4 &dst, uint32_t ptr) {
#if defined (__HIP_PLATFORM_HCC__)
extern __shared__ char smem[];
int laneId = threadIdx.x % 32;
int row = laneId % 4;
int col = laneId / 4;
unsigned ptr0 = __shfl(ptr, 2*row, 32) + col * 2;
unsigned ptr1 = __shfl(ptr, 2*row+1, 32) + col * 2;
unsigned ptr2 = __shfl(ptr, 2*row+8, 32) + col * 2;
unsigned ptr3 = __shfl(ptr, 2*row+8+1, 32) + col * 2;
unsigned ptr4 = __shfl(ptr, 2*row+16, 32) + col * 2;
unsigned ptr5 = __shfl(ptr, 2*row+16+1, 32) + col * 2;
unsigned ptr6 = __shfl(ptr, 2*row+24, 32) + col * 2;
unsigned ptr7 = __shfl(ptr, 2*row+24+1, 32) + col * 2;
uint32_t base = __nvvm_get_smem_pointer(smem);
uint32_t h0, h1;
// asm volatile("ds_read_u16 %0, %1;" : "=v"(h0) : "v"(ptr0));
// asm volatile("ds_read_u16 %0, %1;" : "=v"(h1) : "v"(ptr1));
__builtin_memcpy(&h0, smem-base+ptr0, sizeof(uint16_t));
__builtin_memcpy(&h1, smem-base+ptr1, sizeof(uint16_t));
dst.x = h1 << 16 | (h0 & 0xffff);
// asm volatile("ds_read_u16 %0, %1;" : "=v"(h0) : "v"(ptr2));
// asm volatile("ds_read_u16 %0, %1;" : "=v"(h1) : "v"(ptr3));
__builtin_memcpy(&h0, smem-base+ptr2, sizeof(uint16_t));
__builtin_memcpy(&h1, smem-base+ptr3, sizeof(uint16_t));
dst.y = h1 << 16 | (h0 & 0xffff);
// asm volatile("ds_read_u16 %0, %1;" : "=v"(h0) : "v"(ptr4));
// asm volatile("ds_read_u16 %0, %1;" : "=v"(h1) : "v"(ptr5));
__builtin_memcpy(&h0, smem-base+ptr4, sizeof(uint16_t));
__builtin_memcpy(&h1, smem-base+ptr5, sizeof(uint16_t));
dst.z = h1 << 16 | (h0 & 0xffff);
// asm volatile("ds_read_u16 %0, %1;" : "=v"(h0) : "v"(ptr6));
// asm volatile("ds_read_u16 %0, %1;" : "=v"(h1) : "v"(ptr7));
__builtin_memcpy(&h0, smem-base+ptr6, sizeof(uint16_t));
__builtin_memcpy(&h1, smem-base+ptr7, sizeof(uint16_t));
dst.w = h1 << 16 | (h0 & 0xffff);
// printf("ldsmt %03i: %03d %03d %03d %03d %03d %03d %03d %03d %03d \n %f %f %f %f %f %f %f %f \n", threadIdx.x, ptr, ptr0, ptr1, ptr2, ptr3, ptr4, ptr5, ptr6, ptr7,
// __half2float(reinterpret_cast<__half*>(&dst.x)[0]),
// __half2float(reinterpret_cast<__half*>(&dst.x)[1]),
// __half2float(reinterpret_cast<__half*>(&dst.y)[0]),
// __half2float(reinterpret_cast<__half*>(&dst.y)[1]),
// __half2float(reinterpret_cast<__half*>(&dst.z)[0]),
// __half2float(reinterpret_cast<__half*>(&dst.z)[1]),
// __half2float(reinterpret_cast<__half*>(&dst.w)[0]),
// __half2float(reinterpret_cast<__half*>(&dst.w)[1])
// );
#else
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr));
#endif
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -879,28 +1078,47 @@ inline __device__ void stg(void *ptr, uint4 val) {
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint16_t val) {
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("ds_write_b16 %0, %1;" : : "v"(ptr) , "v"(val));
#else
asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint32_t val) {
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr) , "v"(val));
#else
asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint2 val) {
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr) , "v"(val.x));
asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+4) , "v"(val.y));
#else
asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n"
:
: "r"(ptr)
, "r"(val.x)
, "r"(val.y));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint4 val) {
#if defined (__HIP_PLATFORM_HCC__)
asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr) , "v"(val.x));
asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+4) , "v"(val.y));
asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+8) , "v"(val.z));
asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+12) , "v"(val.w));
#else
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n"
:
: "r"(ptr)
......@@ -908,6 +1126,7 @@ inline __device__ void sts(uint32_t ptr, uint4 val) {
, "r"(val.y)
, "r"(val.z)
, "r"(val.w));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -970,7 +1189,11 @@ struct Allreduce {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
#if defined (__HIP_PLATFORM_HCC__)
x = op(x, __shfl_xor(uint32_t(-1), x, OFFSET));
#else
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
#endif
return Allreduce<OFFSET>::run(x, op);
}
};
......@@ -981,7 +1204,11 @@ template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
#if defined (__HIP_PLATFORM_HCC__)
x = op(x, __shfl_xor(x, 1));
#else
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
#endif
return x;
}
};
......@@ -993,8 +1220,13 @@ __device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
#if defined (__HIP_PLATFORM_HCC__)
dst[mi] = op(dst[mi], __shfl_down(dst[mi], 2));
dst[mi] = op(dst[mi], __shfl_down(dst[mi], 1));
#else
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));
#endif
}
}
......
......@@ -49,11 +49,20 @@ void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_param
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
#if defined (__HIP_PLATFORM_HCC__)
int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
#else
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
#endif
if( smem_size >= 48 * 1024 ) {
#if defined (__HIP_PLATFORM_HCC__)
FMHA_CHECK_CUDA(hipFuncSetAttribute(
(const void *)fmha_dgrad_fp16_128_64_sm80_kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#else
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_128_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#endif
}
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_128_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
......
......@@ -49,11 +49,20 @@ void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_param
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
#if defined (__HIP_PLATFORM_HCC__)
int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
#else
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
#endif
if( smem_size >= 48 * 1024 ) {
#if defined (__HIP_PLATFORM_HCC__)
FMHA_CHECK_CUDA(hipFuncSetAttribute(
(const void *)fmha_dgrad_fp16_256_64_sm80_kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#else
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_256_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#endif
}
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_256_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
......
......@@ -49,11 +49,20 @@ void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_param
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
#if defined (__HIP_PLATFORM_HCC__)
int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
#else
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
#endif
if( smem_size >= 48 * 1024 ) {
#if defined (__HIP_PLATFORM_HCC__)
FMHA_CHECK_CUDA(hipFuncSetAttribute(
(const void *)fmha_dgrad_fp16_384_64_sm80_kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#else
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_384_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#endif
}
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_384_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
......
......@@ -57,11 +57,20 @@ void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_param
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
#if defined (__HIP_PLATFORM_HCC__)
int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
#else
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
#endif
if( smem_size >= 48 * 1024 ) {
#if defined (__HIP_PLATFORM_HCC__)
FMHA_CHECK_CUDA(hipFuncSetAttribute(
(const void *)fmha_dgrad_fp16_512_64_sm80_kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#else
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_512_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#endif
}
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_512_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
......@@ -81,7 +90,11 @@ void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_pa
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
#if defined (__HIP_PLATFORM_HCC__)
int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
#else
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
#endif
auto kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;
......@@ -94,7 +107,12 @@ void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_pa
}
if( smem_size >= 48 * 1024 ) {
#if defined (__HIP_PLATFORM_HCC__)
FMHA_CHECK_CUDA(hipFuncSetAttribute(
(const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#else
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#endif
}
dim3 grid(params.h, params.b, num_chunks);
......
......@@ -50,7 +50,12 @@ void run_fmha_fp16_128_64_sm80(Launch_params<Fused_multihead_attention_fprop_par
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) {
#if defined (__HIP_PLATFORM_HCC__)
FMHA_CHECK_CUDA(hipFuncSetAttribute(
(const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#else
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#endif
}
const int sm_count = launch_params.props->multiProcessorCount;
......
......@@ -50,7 +50,12 @@ void run_fmha_fp16_256_64_sm80(Launch_params<Fused_multihead_attention_fprop_par
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) {
#if defined (__HIP_PLATFORM_HCC__)
FMHA_CHECK_CUDA(hipFuncSetAttribute(
(const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#else
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#endif
}
const int sm_count = launch_params.props->multiProcessorCount;
......
......@@ -50,7 +50,12 @@ void run_fmha_fp16_384_64_sm80(Launch_params<Fused_multihead_attention_fprop_par
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) {
#if defined (__HIP_PLATFORM_HCC__)
FMHA_CHECK_CUDA(hipFuncSetAttribute(
(const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#else
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#endif
}
const int sm_count = launch_params.props->multiProcessorCount;
......
......@@ -58,7 +58,12 @@ void run_fmha_fp16_512_64_sm80_(Launch_params<Fused_multihead_attention_fprop_pa
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) {
#if defined (__HIP_PLATFORM_HCC__)
FMHA_CHECK_CUDA(hipFuncSetAttribute(
(const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#else
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#endif
}
const int sm_count = launch_params.props->multiProcessorCount;
......@@ -96,7 +101,12 @@ void run_fmha_fp16_512_64_sm80_nl_(Launch_params<Fused_multihead_attention_fprop
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) {
#if defined (__HIP_PLATFORM_HCC__)
FMHA_CHECK_CUDA(hipFuncSetAttribute(
(const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#else
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#endif
}
const int sm_count = launch_params.props->multiProcessorCount;
......
......@@ -470,6 +470,8 @@ if "--fast_layer_norm" in sys.argv:
"cxx": ["-O3"] + version_dependent_macros + generator_flag,
"nvcc": [
"-O3",
"-g",
'-ggdb',
'-U__HIP_NO_HALF_OPERATORS__',
'-U__HIP_NO_HALF_CONVERSIONS__',
"-I./apex/contrib/csrc/layer_norm/",
......@@ -481,23 +483,30 @@ if "--fast_layer_norm" in sys.argv:
if "--fmha" in sys.argv:
sys.argv.remove("--fmha")
raise_if_cuda_home_none("--fmha")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) < 11:
raise RuntimeError("--fmha only supported on SM80")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--fmha was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
raise_if_cuda_home_none("--fmha")
else:
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) < 11:
raise RuntimeError("--fmha only supported on SM80")
if not IS_ROCM_PYTORCH:
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) < 11:
raise RuntimeError("--fmha only supported on SM80")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
nvcc_args_mha = ['-O3',
'-gencode', 'arch=compute_80,code=sm_80',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag
hipcc_args_mha = [
#'-O3',
'-U__HIP_NO_HALF_OPERATORS__',
'-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag
ext_modules.append(
CUDAExtension(name='fmhalib',
......@@ -513,15 +522,8 @@ if "--fmha" in sys.argv:
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu',
],
extra_compile_args={'cxx': ['-O3',
] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_80,code=sm_80',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc"), os.path.join(this_dir, "apex/contrib/csrc/fmha/src")]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment