Unverified Commit a0e50a42 authored by Hashem Hashemi's avatar Hashem Hashemi Committed by GitHub
Browse files

Convert wvSplitKQ to 16x16 MFMA in prep for mi4xx. (#34100)


Signed-off-by: default avatarHashem Hashemi <hashem.hashemi@amd.com>
parent 9fa5b25a
......@@ -1902,7 +1902,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
float sB = *s_B;
while (m < M) {
floatx16 sum[N][YTILE] = {};
scalar8 sum[N][YTILE] = {};
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
......@@ -1936,7 +1936,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0,
0);
}
......@@ -1949,31 +1949,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8];
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
1); // row_shl3
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
1); // row_shl8
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
1); // row_shl9
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
1); // row_shl10
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
1); // row_shl11
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16);
accm0 += __shfl_down(accm0, 20);
accm0 += __shfl_down(accm0, 40);
sum[n][y][0] = accm0;
}
}
......@@ -2064,7 +2048,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
float sB = *s_B;
while (m < M) {
floatx16 sum[N][YTILE] = {};
scalar8 sum[N][YTILE] = {};
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
......@@ -2100,7 +2084,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0,
0);
}
......@@ -2113,31 +2097,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8];
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
1); // row_shl3
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
1); // row_shl8
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
1); // row_shl9
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
1); // row_shl10
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
1); // row_shl11
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16);
accm0 += __shfl_down(accm0, 20);
accm0 += __shfl_down(accm0, 40);
sum[n][y][0] = accm0;
}
}
......@@ -2242,16 +2210,16 @@ void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
: nullptr;
switch (N_in) {
case 1:
WVSPLITKQ(12, 2, 2, 2, 2, 1)
WVSPLITKQ(16, 2, 2, 2, 2, 1)
break;
case 2:
WVSPLITKQ(12, 2, 2, 2, 2, 2)
WVSPLITKQ(16, 2, 2, 2, 2, 2)
break;
case 3:
WVSPLITKQ(8, 2, 2, 1, 1, 3)
WVSPLITKQ(16, 2, 2, 2, 2, 3)
break;
case 4:
WVSPLITKQ(4, 2, 2, 1, 1, 4)
WVSPLITKQ(16, 2, 2, 2, 2, 4)
break;
default:
throw std::runtime_error(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment