Commit bf4adfeb authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

Support Turing (sm_75) architecture

parent 3ef186fd
......@@ -335,6 +335,7 @@ public:
template<typename Epilogue, bool USE_ALPHA>
struct gemm_w4a4_fp4_kernel {
static constexpr int MIN_ARCH = 1200;
__device__
void operator()(
const packed_act_t *act,
......@@ -389,67 +390,16 @@ public:
static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt) {
packed_psum_t psum;
if constexpr (!ACT_UNSIGNED) {
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
}
if constexpr (ACT_UNSIGNED) {
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
}
uint4 out1 = mma_m16n8kx_s32common<mma_helper::s4u4<ACT_UNSIGNED>, mma_helper::s4>(act, uint2(wgt.x, wgt.y), uint4(0, 0, 0, 0));
uint4 out2 = mma_m16n8kx_s32common<mma_helper::s4u4<ACT_UNSIGNED>, mma_helper::s4>(act, uint2(wgt.z, wgt.w), uint4(0, 0, 0, 0));
psum.data[0] = out1.x;
psum.data[1] = out1.y;
psum.data[2] = out1.z;
psum.data[3] = out1.w;
psum.data[4] = out2.x;
psum.data[5] = out2.y;
psum.data[6] = out2.z;
psum.data[7] = out2.w;
return psum;
}
......@@ -554,7 +504,7 @@ public:
// [WARP_M, WARP_N * 2] when fuse_glu
template<bool fuse_glu>
struct load_act_to_fpsum {
using matrix_t = half_t[WARP_M][WARP_N + 8];
using matrix_t = half_t[INSN_M][WARP_N + 8];
static constexpr size_t SHMEM_SIZE = sizeof(matrix_t);
__device__ __forceinline__
......@@ -568,41 +518,42 @@ public:
using packed_raw_input = std::array<half2_t, PACK_SIZE>;
#pragma unroll
for (int row = 0; row < WARP_M; row++) {
packed_input pack;
// TODO: numCols not multiples of PACK_SIZE
if constexpr (fuse_glu) {
packed_raw_input raw;
raw.fill(half2_t(0, 0));
bool pred = row < maxRows && laneId * PACK_SIZE * 2 < maxCols;
if (pred) {
raw = load(reinterpret_cast<const packed_raw_input *>(input + row * stride + laneId * PACK_SIZE * 2));
}
#pragma unroll
for (int j = 0; j < PACK_SIZE; j++) {
pack[j] = raw[j].x * silu(raw[j].y);
}
} else {
pack.fill(half_t(0));
bool pred = row < maxRows && laneId * PACK_SIZE < maxCols;
if (pred) {
pack = load(reinterpret_cast<const packed_input *>(input + row * stride + laneId * PACK_SIZE));
for (int m = 0; m < WARP_M_TILES; m++) {
#pragma unroll
for (int row = 0; row < INSN_M; row++) {
packed_input pack;
// TODO: numCols not multiples of PACK_SIZE
if constexpr (fuse_glu) {
packed_raw_input raw;
raw.fill(half2_t(0, 0));
bool pred = (m * INSN_M + row) < maxRows && laneId * PACK_SIZE * 2 < maxCols;
if (pred) {
raw = load(reinterpret_cast<const packed_raw_input *>(input + (m * INSN_M + row) * stride + laneId * PACK_SIZE * 2));
}
#pragma unroll
for (int j = 0; j < PACK_SIZE; j++) {
pack[j] = raw[j].x * silu(raw[j].y);
}
} else {
pack.fill(half_t(0));
bool pred = (m * INSN_M + row) < maxRows && laneId * PACK_SIZE < maxCols;
if (pred) {
pack = load(reinterpret_cast<const packed_input *>(input + (m * INSN_M + row) * stride + laneId * PACK_SIZE));
}
}
store<true>(reinterpret_cast<packed_input *>(&mat[row][laneId * PACK_SIZE]), pack);
}
store<true>(reinterpret_cast<packed_input *>(&mat[row][laneId * PACK_SIZE]), pack);
}
__syncwarp();
__syncwarp();
for (int m = 0; m < WARP_M_TILES; m++) {
for (int n = 0; n < WARP_N_TILES; n++) {
const int row = m * INSN_M + laneId % 16;
const int row = laneId % 16;
const int col = n * INSN_N + laneId / 16 * 8;
uint4 tmp;
ldmatrix(&mat[row][col], tmp);
*reinterpret_cast<uint4 *>(&out[m * WARP_N_TILES + n]) = tmp;
}
__syncwarp();
}
__syncwarp();
}
};
......@@ -707,6 +658,7 @@ public:
// each thread block (1 warp) quantize WARP_M * WARP_K tile (32 * 64)
struct quantize_w4a4_act_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
__device__
void operator()(const half_t *input, packed_act_t *output, packed_ascale_t *oscales, int K) {
const int laneId = threadIdx.x % WARP_SIZE;
......@@ -744,6 +696,7 @@ public:
// each thread block (1 warp) quantize WARP_N * WARP_K tile (128 * 64)
struct quantize_w4a4_wgt_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
__device__
void operator()(const half_t *input, packed_wgt_t *output, packed_wscale_t *oscales, int K) {
const int laneId = threadIdx.x % WARP_SIZE;
......@@ -777,10 +730,54 @@ public:
}
};
struct i2f_sm80 {
__device__ __forceinline__
static float2 int2float2(int x, int y) {
return make_float2(int2float_fast(x), int2float_fast(y));
}
__device__ __forceinline__
static half2_t int2half2(int x, int y) {
return float22half2<half2_t>(int2float2(x, y));
}
};
struct i2f_sm75 {
__device__ __forceinline__
static float2 int2float2(int x, int y) {
return make_float2(int2float_fast(x), int2float_fast(y));
}
__device__ __forceinline__
static half2_t int2half2(int x, int y) {
return half2(__int2half_rn(x), __int2half_rn(y));
}
};
struct i2f_sm75_fast {
__device__ __forceinline__
static float2 int2float2(int x, int y) {
return make_float2(int2float_fast(x), int2float_fast(y));
}
__device__ __forceinline__
static half2_t int2half2(int x, int y) {
return int2half2_fast_512(x, y);
}
};
template<bool ACT_UNSIGNED, typename T>
__device__ __forceinline__
static void compute(act_warp A, wgt_warp W, ascale_warp ascale, wscale_warp wscale, T &fpsum) {
apply_scales<true>([&](int i, int j) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
using int2half2 = i2f_sm80;
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
using int2half2 = std::conditional_t<Config::FASTER_I2F, i2f_sm75_fast, i2f_sm75>;;
#else
using int2half2 = Base::i2f_normal;
#endif
Base::template apply_scales<int2half2>([&](int i, int j) {
return mma<ACT_UNSIGNED>(A[i], W[j]);
}, ascale, wscale, fpsum);
}
......@@ -875,7 +872,7 @@ public:
}
// out: [M / BLOCK_M, N / BLOCK_N, NUM_WARPS, 1, NUM_M_TILES, NUM_N_TILES, WARP_SIZE] of fpsum_warp
template<typename Epilogue, bool ACT_UNSIGNED>
template<typename Epilogue, bool ACT_UNSIGNED, bool USE_FP32_ACCUM>
__device__ __forceinline__
static void gemm_w4a4_block(
const BlockInfo binfo,
......@@ -902,7 +899,7 @@ public:
wgt_warp W[NUM_STAGES]; // 32
ascale_warp ascale[NUM_STAGES]; // 1
wscale_warp wscale[NUM_STAGES]; // 2
fpsum_warp fpsum; // 64
std::conditional_t<USE_FP32_ACCUM, f32psum_warp, fpsum_warp> fpsum; // 64
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true);
......@@ -916,16 +913,16 @@ public:
}
for (auto &pack : fpsum) {
#if 1
for (int i = 0; i < 4; i++) {
pack.data[i].x = 0;
pack.data[i].y = 0;
}
#else
for (int i = 0; i < 8; i++) {
pack.data[i] = 0;
if constexpr (USE_FP32_ACCUM) {
for (int i = 0; i < 8; i++) {
pack.data[i] = 0;
}
} else {
for (int i = 0; i < 4; i++) {
pack.data[i].x = 0;
pack.data[i].y = 0;
}
}
#endif
}
int dummy = 0;
......@@ -949,9 +946,11 @@ public:
compute<ACT_UNSIGNED>(A[k2], W[k2], ascale[k2], wscale[k2], fpsum);
//#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
if (alwaysfalse) {
dummy = clock();
}
//#endif
// asm volatile ("membar.cta;");
}
......@@ -961,11 +960,12 @@ public:
#endif
#if 0
auto f16psum = packed_fp32_to_fp16(fpsum);
#else
auto f16psum = fpsum;
#endif
fpsum_warp f16psum;
if constexpr (USE_FP32_ACCUM) {
f16psum = packed_fp32_to_fp16(fpsum);
} else {
f16psum = fpsum;
}
CHECK_NAN(f16psum, "f16psum");
......@@ -1324,6 +1324,7 @@ public:
struct quantize_w4a4_fuse_lora_kernel {
using oscales_t = typename std::conditional_t<use_fp4, packed_amscale_t, packed_ascale_t>;
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(load_act_to_fpsum<fuse_glu>::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS;
......@@ -2059,6 +2060,7 @@ public:
// q: [batch_size, #blocks, block_size, #heads, HEAD_DIM]
// vk: [batch_size, #heads, HEAD_DIM+1, HEAD_DIM]
struct vk_mul_q_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
// FIXME FIXME FIXME
__device__
void operator()(half_t *q, const float *vk, float eps, int num_tokens) {
......@@ -2116,6 +2118,9 @@ public:
template<typename Epilogue, bool ACT_UNSIGNED>
struct gemm_w4a4_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
static constexpr int MAX_ARCH = Config::FASTER_I2F ? 750 : INT_MAX; // FASTER_I2F is only needed on sm_75
__device__
void operator()(
const packed_act_t *act,
......@@ -2146,7 +2151,7 @@ public:
// bool fusequant = !out;
gemm_w4a4_block<Epilogue, ACT_UNSIGNED>(
gemm_w4a4_block<Epilogue, ACT_UNSIGNED, false>(
binfo,
act + bm * (K / WARP_K) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
wgt + bn * (K / WARP_K) * WARP_N_TILES * WARP_SIZE,
......@@ -2164,6 +2169,7 @@ public:
template<typename Epilogue>
struct test_epilogue_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(load_act_to_fpsum<false>::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS;
......
......@@ -448,6 +448,7 @@ public:
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template<typename Epilogue>
struct gemm_w8a8_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
__device__
void operator()(
const packed_act_t *act,
......
......@@ -69,6 +69,8 @@ void attention_fp16(
float scale
);
// EXPERIMENTAL, for sm_75
void set_faster_i2f_mode(std::string mode);
// FOR TEST ONLY
void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb);
......
Subproject commit 0d23f715690c5171fd93679de8afd149376db167
Subproject commit 99511c34554a13ffaa81321834faf66389ffcb30
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment