Commit bf4adfeb authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

Support Turing (sm_75) architecture

parent 3ef186fd
...@@ -335,6 +335,7 @@ public: ...@@ -335,6 +335,7 @@ public:
template<typename Epilogue, bool USE_ALPHA> template<typename Epilogue, bool USE_ALPHA>
struct gemm_w4a4_fp4_kernel { struct gemm_w4a4_fp4_kernel {
static constexpr int MIN_ARCH = 1200;
__device__ __device__
void operator()( void operator()(
const packed_act_t *act, const packed_act_t *act,
...@@ -389,67 +390,16 @@ public: ...@@ -389,67 +390,16 @@ public:
static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt) { static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt) {
packed_psum_t psum; packed_psum_t psum;
if constexpr (!ACT_UNSIGNED) { uint4 out1 = mma_m16n8kx_s32common<mma_helper::s4u4<ACT_UNSIGNED>, mma_helper::s4>(act, uint2(wgt.x, wgt.y), uint4(0, 0, 0, 0));
asm volatile( uint4 out2 = mma_m16n8kx_s32common<mma_helper::s4u4<ACT_UNSIGNED>, mma_helper::s4>(act, uint2(wgt.z, wgt.w), uint4(0, 0, 0, 0));
"mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " psum.data[0] = out1.x;
"{%0, %1, %2, %3}," psum.data[1] = out1.y;
"{%4, %5, %6, %7}," psum.data[2] = out1.z;
"{%8, %9}," psum.data[3] = out1.w;
"{%10, %11, %12, %13};\n" psum.data[4] = out2.x;
: psum.data[5] = out2.y;
"=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3]) psum.data[6] = out2.z;
: psum.data[7] = out2.w;
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
}
if constexpr (ACT_UNSIGNED) {
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
}
return psum; return psum;
} }
...@@ -554,7 +504,7 @@ public: ...@@ -554,7 +504,7 @@ public:
// [WARP_M, WARP_N * 2] when fuse_glu // [WARP_M, WARP_N * 2] when fuse_glu
template<bool fuse_glu> template<bool fuse_glu>
struct load_act_to_fpsum { struct load_act_to_fpsum {
using matrix_t = half_t[WARP_M][WARP_N + 8]; using matrix_t = half_t[INSN_M][WARP_N + 8];
static constexpr size_t SHMEM_SIZE = sizeof(matrix_t); static constexpr size_t SHMEM_SIZE = sizeof(matrix_t);
__device__ __forceinline__ __device__ __forceinline__
...@@ -568,41 +518,42 @@ public: ...@@ -568,41 +518,42 @@ public:
using packed_raw_input = std::array<half2_t, PACK_SIZE>; using packed_raw_input = std::array<half2_t, PACK_SIZE>;
#pragma unroll #pragma unroll
for (int row = 0; row < WARP_M; row++) { for (int m = 0; m < WARP_M_TILES; m++) {
packed_input pack; #pragma unroll
// TODO: numCols not multiples of PACK_SIZE for (int row = 0; row < INSN_M; row++) {
if constexpr (fuse_glu) { packed_input pack;
packed_raw_input raw; // TODO: numCols not multiples of PACK_SIZE
raw.fill(half2_t(0, 0)); if constexpr (fuse_glu) {
bool pred = row < maxRows && laneId * PACK_SIZE * 2 < maxCols; packed_raw_input raw;
if (pred) { raw.fill(half2_t(0, 0));
raw = load(reinterpret_cast<const packed_raw_input *>(input + row * stride + laneId * PACK_SIZE * 2)); bool pred = (m * INSN_M + row) < maxRows && laneId * PACK_SIZE * 2 < maxCols;
} if (pred) {
#pragma unroll raw = load(reinterpret_cast<const packed_raw_input *>(input + (m * INSN_M + row) * stride + laneId * PACK_SIZE * 2));
for (int j = 0; j < PACK_SIZE; j++) { }
pack[j] = raw[j].x * silu(raw[j].y); #pragma unroll
} for (int j = 0; j < PACK_SIZE; j++) {
} else { pack[j] = raw[j].x * silu(raw[j].y);
pack.fill(half_t(0)); }
bool pred = row < maxRows && laneId * PACK_SIZE < maxCols; } else {
if (pred) { pack.fill(half_t(0));
pack = load(reinterpret_cast<const packed_input *>(input + row * stride + laneId * PACK_SIZE)); bool pred = (m * INSN_M + row) < maxRows && laneId * PACK_SIZE < maxCols;
if (pred) {
pack = load(reinterpret_cast<const packed_input *>(input + (m * INSN_M + row) * stride + laneId * PACK_SIZE));
}
} }
store<true>(reinterpret_cast<packed_input *>(&mat[row][laneId * PACK_SIZE]), pack);
} }
store<true>(reinterpret_cast<packed_input *>(&mat[row][laneId * PACK_SIZE]), pack); __syncwarp();
}
__syncwarp();
for (int m = 0; m < WARP_M_TILES; m++) {
for (int n = 0; n < WARP_N_TILES; n++) { for (int n = 0; n < WARP_N_TILES; n++) {
const int row = m * INSN_M + laneId % 16; const int row = laneId % 16;
const int col = n * INSN_N + laneId / 16 * 8; const int col = n * INSN_N + laneId / 16 * 8;
uint4 tmp; uint4 tmp;
ldmatrix(&mat[row][col], tmp); ldmatrix(&mat[row][col], tmp);
*reinterpret_cast<uint4 *>(&out[m * WARP_N_TILES + n]) = tmp; *reinterpret_cast<uint4 *>(&out[m * WARP_N_TILES + n]) = tmp;
} }
__syncwarp();
} }
__syncwarp();
} }
}; };
...@@ -707,6 +658,7 @@ public: ...@@ -707,6 +658,7 @@ public:
// each thread block (1 warp) quantize WARP_M * WARP_K tile (32 * 64) // each thread block (1 warp) quantize WARP_M * WARP_K tile (32 * 64)
struct quantize_w4a4_act_kernel { struct quantize_w4a4_act_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
__device__ __device__
void operator()(const half_t *input, packed_act_t *output, packed_ascale_t *oscales, int K) { void operator()(const half_t *input, packed_act_t *output, packed_ascale_t *oscales, int K) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
...@@ -744,6 +696,7 @@ public: ...@@ -744,6 +696,7 @@ public:
// each thread block (1 warp) quantize WARP_N * WARP_K tile (128 * 64) // each thread block (1 warp) quantize WARP_N * WARP_K tile (128 * 64)
struct quantize_w4a4_wgt_kernel { struct quantize_w4a4_wgt_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
__device__ __device__
void operator()(const half_t *input, packed_wgt_t *output, packed_wscale_t *oscales, int K) { void operator()(const half_t *input, packed_wgt_t *output, packed_wscale_t *oscales, int K) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
...@@ -777,10 +730,54 @@ public: ...@@ -777,10 +730,54 @@ public:
} }
}; };
struct i2f_sm80 {
__device__ __forceinline__
static float2 int2float2(int x, int y) {
return make_float2(int2float_fast(x), int2float_fast(y));
}
__device__ __forceinline__
static half2_t int2half2(int x, int y) {
return float22half2<half2_t>(int2float2(x, y));
}
};
struct i2f_sm75 {
__device__ __forceinline__
static float2 int2float2(int x, int y) {
return make_float2(int2float_fast(x), int2float_fast(y));
}
__device__ __forceinline__
static half2_t int2half2(int x, int y) {
return half2(__int2half_rn(x), __int2half_rn(y));
}
};
struct i2f_sm75_fast {
__device__ __forceinline__
static float2 int2float2(int x, int y) {
return make_float2(int2float_fast(x), int2float_fast(y));
}
__device__ __forceinline__
static half2_t int2half2(int x, int y) {
return int2half2_fast_512(x, y);
}
};
template<bool ACT_UNSIGNED, typename T> template<bool ACT_UNSIGNED, typename T>
__device__ __forceinline__ __device__ __forceinline__
static void compute(act_warp A, wgt_warp W, ascale_warp ascale, wscale_warp wscale, T &fpsum) { static void compute(act_warp A, wgt_warp W, ascale_warp ascale, wscale_warp wscale, T &fpsum) {
apply_scales<true>([&](int i, int j) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
using int2half2 = i2f_sm80;
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
using int2half2 = std::conditional_t<Config::FASTER_I2F, i2f_sm75_fast, i2f_sm75>;;
#else
using int2half2 = Base::i2f_normal;
#endif
Base::template apply_scales<int2half2>([&](int i, int j) {
return mma<ACT_UNSIGNED>(A[i], W[j]); return mma<ACT_UNSIGNED>(A[i], W[j]);
}, ascale, wscale, fpsum); }, ascale, wscale, fpsum);
} }
...@@ -875,7 +872,7 @@ public: ...@@ -875,7 +872,7 @@ public:
} }
// out: [M / BLOCK_M, N / BLOCK_N, NUM_WARPS, 1, NUM_M_TILES, NUM_N_TILES, WARP_SIZE] of fpsum_warp // out: [M / BLOCK_M, N / BLOCK_N, NUM_WARPS, 1, NUM_M_TILES, NUM_N_TILES, WARP_SIZE] of fpsum_warp
template<typename Epilogue, bool ACT_UNSIGNED> template<typename Epilogue, bool ACT_UNSIGNED, bool USE_FP32_ACCUM>
__device__ __forceinline__ __device__ __forceinline__
static void gemm_w4a4_block( static void gemm_w4a4_block(
const BlockInfo binfo, const BlockInfo binfo,
...@@ -902,7 +899,7 @@ public: ...@@ -902,7 +899,7 @@ public:
wgt_warp W[NUM_STAGES]; // 32 wgt_warp W[NUM_STAGES]; // 32
ascale_warp ascale[NUM_STAGES]; // 1 ascale_warp ascale[NUM_STAGES]; // 1
wscale_warp wscale[NUM_STAGES]; // 2 wscale_warp wscale[NUM_STAGES]; // 2
fpsum_warp fpsum; // 64 std::conditional_t<USE_FP32_ACCUM, f32psum_warp, fpsum_warp> fpsum; // 64
// load_wscale<true>(wscales, wscale[0], true); // load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true); // load_wscale<false>(wscales, wscale[1], true);
...@@ -916,16 +913,16 @@ public: ...@@ -916,16 +913,16 @@ public:
} }
for (auto &pack : fpsum) { for (auto &pack : fpsum) {
#if 1 if constexpr (USE_FP32_ACCUM) {
for (int i = 0; i < 4; i++) { for (int i = 0; i < 8; i++) {
pack.data[i].x = 0; pack.data[i] = 0;
pack.data[i].y = 0; }
} } else {
#else for (int i = 0; i < 4; i++) {
for (int i = 0; i < 8; i++) { pack.data[i].x = 0;
pack.data[i] = 0; pack.data[i].y = 0;
}
} }
#endif
} }
int dummy = 0; int dummy = 0;
...@@ -949,9 +946,11 @@ public: ...@@ -949,9 +946,11 @@ public:
compute<ACT_UNSIGNED>(A[k2], W[k2], ascale[k2], wscale[k2], fpsum); compute<ACT_UNSIGNED>(A[k2], W[k2], ascale[k2], wscale[k2], fpsum);
//#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
if (alwaysfalse) { if (alwaysfalse) {
dummy = clock(); dummy = clock();
} }
//#endif
// asm volatile ("membar.cta;"); // asm volatile ("membar.cta;");
} }
...@@ -961,11 +960,12 @@ public: ...@@ -961,11 +960,12 @@ public:
#endif #endif
#if 0 fpsum_warp f16psum;
auto f16psum = packed_fp32_to_fp16(fpsum); if constexpr (USE_FP32_ACCUM) {
#else f16psum = packed_fp32_to_fp16(fpsum);
auto f16psum = fpsum; } else {
#endif f16psum = fpsum;
}
CHECK_NAN(f16psum, "f16psum"); CHECK_NAN(f16psum, "f16psum");
...@@ -1324,6 +1324,7 @@ public: ...@@ -1324,6 +1324,7 @@ public:
struct quantize_w4a4_fuse_lora_kernel { struct quantize_w4a4_fuse_lora_kernel {
using oscales_t = typename std::conditional_t<use_fp4, packed_amscale_t, packed_ascale_t>; using oscales_t = typename std::conditional_t<use_fp4, packed_amscale_t, packed_ascale_t>;
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(load_act_to_fpsum<fuse_glu>::SHMEM_SIZE, 128) * 128; static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(load_act_to_fpsum<fuse_glu>::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS; static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS;
...@@ -2059,6 +2060,7 @@ public: ...@@ -2059,6 +2060,7 @@ public:
// q: [batch_size, #blocks, block_size, #heads, HEAD_DIM] // q: [batch_size, #blocks, block_size, #heads, HEAD_DIM]
// vk: [batch_size, #heads, HEAD_DIM+1, HEAD_DIM] // vk: [batch_size, #heads, HEAD_DIM+1, HEAD_DIM]
struct vk_mul_q_kernel { struct vk_mul_q_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
// FIXME FIXME FIXME // FIXME FIXME FIXME
__device__ __device__
void operator()(half_t *q, const float *vk, float eps, int num_tokens) { void operator()(half_t *q, const float *vk, float eps, int num_tokens) {
...@@ -2116,6 +2118,9 @@ public: ...@@ -2116,6 +2118,9 @@ public:
template<typename Epilogue, bool ACT_UNSIGNED> template<typename Epilogue, bool ACT_UNSIGNED>
struct gemm_w4a4_kernel { struct gemm_w4a4_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
static constexpr int MAX_ARCH = Config::FASTER_I2F ? 750 : INT_MAX; // FASTER_I2F is only needed on sm_75
__device__ __device__
void operator()( void operator()(
const packed_act_t *act, const packed_act_t *act,
...@@ -2146,7 +2151,7 @@ public: ...@@ -2146,7 +2151,7 @@ public:
// bool fusequant = !out; // bool fusequant = !out;
gemm_w4a4_block<Epilogue, ACT_UNSIGNED>( gemm_w4a4_block<Epilogue, ACT_UNSIGNED, false>(
binfo, binfo,
act + bm * (K / WARP_K) * NUM_WARPS * WARP_M_TILES * WARP_SIZE, act + bm * (K / WARP_K) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
wgt + bn * (K / WARP_K) * WARP_N_TILES * WARP_SIZE, wgt + bn * (K / WARP_K) * WARP_N_TILES * WARP_SIZE,
...@@ -2164,6 +2169,7 @@ public: ...@@ -2164,6 +2169,7 @@ public:
template<typename Epilogue> template<typename Epilogue>
struct test_epilogue_kernel { struct test_epilogue_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(load_act_to_fpsum<false>::SHMEM_SIZE, 128) * 128; static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(load_act_to_fpsum<false>::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS; static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS;
......
...@@ -448,6 +448,7 @@ public: ...@@ -448,6 +448,7 @@ public:
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N] // out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template<typename Epilogue> template<typename Epilogue>
struct gemm_w8a8_kernel { struct gemm_w8a8_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
__device__ __device__
void operator()( void operator()(
const packed_act_t *act, const packed_act_t *act,
......
...@@ -69,6 +69,8 @@ void attention_fp16( ...@@ -69,6 +69,8 @@ void attention_fp16(
float scale float scale
); );
// EXPERIMENTAL, for sm_75
void set_faster_i2f_mode(std::string mode);
// FOR TEST ONLY // FOR TEST ONLY
void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb); void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb);
......
Subproject commit 0d23f715690c5171fd93679de8afd149376db167 Subproject commit 99511c34554a13ffaa81321834faf66389ffcb30
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment