Unverified Commit 7600642e authored by Hashem Hashemi's avatar Hashem Hashemi Committed by GitHub
Browse files

Add padding support to wvSplitK solution for skinny GEMMs (#33762)


Signed-off-by: default avatarHashem Hashemi <hashem.hashemi@amd.com>
parent 1e69c048
...@@ -304,8 +304,9 @@ __device__ inline unsigned int min__(uint32_t a, uint32_t b) { ...@@ -304,8 +304,9 @@ __device__ inline unsigned int min__(uint32_t a, uint32_t b) {
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N> int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS) __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_sml_(const int K, const int M, const int Bx, const int By, wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap, const int M,
const scalar_t* B, const scalar_t* __restrict__ A, const int Bx, const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C, const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2; constexpr int max_lds_len = LDS_SIZE / 2;
...@@ -314,7 +315,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -314,7 +315,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else #else
constexpr bool use_mfma = false; constexpr bool use_mfma = false;
#endif #endif
using scalar8 = using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
using half4 = using half4 =
...@@ -346,13 +346,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -346,13 +346,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// - Then the WG will move to another 8 K elements // - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8 // TODO: Logic below will only work when K is multiple of 8
//---------------------------------------------------- //----------------------------------------------------
for (uint32_t k = 0; k < min__(K * N, max_lds_len); for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k += THRDS * WvPrGrp * A_CHUNK) { k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); #if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
if (k_in >= min__(K * N, max_lds_len)) break; #else
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); #endif
} }
__syncthreads(); __syncthreads();
...@@ -360,9 +360,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -360,9 +360,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//---------------------------------------------------- //----------------------------------------------------
// Each wave works on a single column of weight matrix. // Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is // There are 16 waves per WG, and hence, each WG is
...@@ -386,44 +383,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -386,44 +383,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// YTILE represents how many column of weight matrix // YTILE represents how many column of weight matrix
// are being worked on by each wave. // are being worked on by each wave.
//---------------------------------------------------- //----------------------------------------------------
for (int i = 0; i < YTILE; i++) float sum[N][YTILE] = {};
for (int n = 0; n < N; n++) scalar8 sum4[N][YTILE] = {};
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
// for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
// Fetch the weight matrix from memory! // Fetch the weight matrix from memory!
#pragma unroll #pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break; const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
const scalar_t* B_ = &B[(m + 0) * K + k_];
for (int y = 0; y < YTILE; y++) for (int y = 0; y < YTILE; y++)
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * K]))); bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
} }
// Fetch activation matrix from either just LDS or from both LDS / memory // Fetch activation matrix from either just LDS or from both LDS / memory
...@@ -432,33 +405,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -432,33 +405,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break; if (k_ >= K) break;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
} }
} }
// Do the matrix multiplication in interleaved manner // Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (uint32_t n = 0; n < N; n++) { for (uint32_t n = 0; n < N; n++) {
#pragma unroll
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma) if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) { for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
} }
else else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++) for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
...@@ -466,46 +426,44 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -466,46 +426,44 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
} }
} }
} }
__builtin_amdgcn_sched_barrier(0);
//---------------------------------------------------- //----------------------------------------------------
// Final reduction step using shuffle // Final reduction step using shuffle
//---------------------------------------------------- //----------------------------------------------------
if constexpr (!use_mfma) { if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x118, 0xf, 0xf,
: "=v"(sum[n][y]) 1); // row_shr8
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x114, 0xf, 0xf,
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " 1); // row_shr4
: "=v"(sum[n][y]) sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x112, 0xf, 0xf,
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); 1); // row_shr2
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
: "=v"(sum[n][y]) 1); // row_shr1
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" 1); // ROW_BCAST15
: "=v"(sum[n][y]) sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); 1); // ROW_BCAST31
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
} }
} }
if (threadIdx.x == 63) { if (threadIdx.x == 63) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) { for (int y = 0; y < YTILE; y++) {
if constexpr (std::is_same_v<scalar_t, half>) { if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS) sum[n][y] += __half2float(biases[n][y]);
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { } else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS) sum[n][y] += __bfloat162float(biases[n][y]);
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
} }
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]); C[m + y + n * M] = __float2s<scalar_t>(sum[n][y]);
} }
} }
} }
...@@ -514,45 +472,43 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -514,45 +472,43 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
#pragma unroll #pragma unroll
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
// float accm1 = 0; /*float accm1 = 0;
// for (int i=0; i<64; i++) for (int i=0; i<64; i++)
// accm1 += __shfl(sum4[n][y][i%4], i); accm1 += __shfl(sum4[n][y][i%4], i);
sum4[n][y][0] = accm1;*/
float accm = sum4[n][y][0]; float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf,
: "=v"(accm) 1); // row_shl1
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); accm += __builtin_amdgcn_mov_dpp(sum4[n][y][2], 0x102, 0xf, 0xf,
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " 1); // row_shl2
: "=v"(accm) accm += __builtin_amdgcn_mov_dpp(sum4[n][y][3], 0x103, 0xf, 0xf,
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); 1); // row_shl3
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " accm += __builtin_amdgcn_mov_dpp(accm, 0x104, 0xf, 0xf,
: "=v"(accm) 1); // row_shl4
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); accm += __builtin_amdgcn_mov_dpp(accm, 0x108, 0xf, 0xf,
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " 1); // row_shl8
: "=v"(accm) accm = __builtin_amdgcn_mov_dpp(accm, 0x11f, 0xf, 0xf,
: "0"(accm), "v"(accm), "v"(accm)); 1); // row_shr15
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " accm += __builtin_amdgcn_mov_dpp(accm, 0x142, 0xf, 0xf,
: "=v"(accm) 1); // ROW_BCAST15
: "0"(accm), "v"(accm), "v"(accm)); accm += __builtin_amdgcn_mov_dpp(accm, 0x143, 0xf, 0xf,
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " 1); // ROW_BCAST31
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
sum4[n][y][0] = accm; sum4[n][y][0] = accm;
} }
} }
if (threadIdx.x == 63) { if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) { scalar_t biases[N][YTILE] = {};
for (int i = 0; i < YTILE; i++) {
if (BIAS) if (BIAS)
sum4[n][i][0] += for (int n = 0; n < N; n++) {
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); for (int y = 0; y < YTILE; y++) {
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
sum4[n][y][0] += __bfloat162float(biases[n][y]);
C[m + y + n * M] = __float2bfloat16(sum4[n][y][0]);
} }
} }
} }
...@@ -563,8 +519,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -563,8 +519,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support #else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N> int UNRL, int N>
__global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx, __global__ void wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap,
const int By, const scalar_t* B, const int M, const int Bx, const int By,
const scalar_t* B,
const scalar_t* __restrict__ A, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C, const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
...@@ -577,8 +534,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx, ...@@ -577,8 +534,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx,
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N> int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS) __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_(const int K, const int M, const int Bx, const int By, wvSplitK_hf_(const int K, const int Kbp, const int Kap, const int M,
const scalar_t* B, const scalar_t* __restrict__ A, const int Bx, const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C, const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2; constexpr int max_lds_len = LDS_SIZE / 2;
...@@ -601,13 +559,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -601,13 +559,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
scalar8 h8; scalar8 h8;
}; };
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not going to work!
//----------------------------------------------------
__shared__ scalar_t s[max_lds_len]; __shared__ scalar_t s[max_lds_len];
//---------------------------------------------------- //----------------------------------------------------
...@@ -618,12 +569,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -618,12 +569,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
commitColumn[i] = 1; commitColumn[i] = 1;
} }
//----------------------------------------------------
// Indexing function into the column of weight matrix B
// Algorithm does 64 lane k-splitting / wave and uses
// WG ID and Thread ID to find the index.
//----------------------------------------------------
// int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp);
uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE;
// Check whether there will be fragmentation! // Check whether there will be fragmentation!
...@@ -636,91 +581,34 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -636,91 +581,34 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m = startColumn; m = startColumn;
} }
//---------------------------------------------------- for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
// Fetch the activation matrix to LDS k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
// Loop iteration: #if defined(__gfx950__)
// - Each thread (lane) is fetching 8 elements (A_Chunk) __builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
// - Each wave will fetch 64*8=> 512 elements #else
// - Each WG will fetch 512 * 16 => 8K elements *((bigType*)(&s[k])) = *((bigType*)(&A[k]));
// - Then the WG will move to another 8 K elements #endif
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k = 0; k < min__(K * N, max_lds_len);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min__(K * N, max_lds_len)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
} }
__syncthreads(); __syncthreads();
if (threadIdx.y >= _WvPrGrp) return; if (threadIdx.y >= _WvPrGrp) return;
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
// working on 16 columns of weight matrix. Moreover,
// we tile in column direction by YTILE, so when YTILE=1
// the above math is right, however, when YTILE=2 then
// each wave will be working on 2 columns and WG will
// be working on 32 columns.
//
// Top level loop that makes WGs persistent!
// - WGs iterates across columns of weight matrix
// - Each wave within WG works on a given column(s)
// - After completing first set of columns, WGs start
// working on the next set of available columns
//----------------------------------------------------
while (m < M) { while (m < M) {
//---------------------------------------------------- float sum[N][YTILE] = {};
// 'sum' accumulates the matrix A x B computation scalar8 sum4[N][YTILE] = {};
// split across 64 lanes.
//
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
// Fetch the weight matrix from memory! // Fetch the weight matrix from memory!
#pragma unroll #pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break; const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
for (int y = 0; y < YTILE; y++)
const scalar_t* B_ = &B[(m + 0) * K + k_]; bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
for (int b = 0; b < YTILE; b++)
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
} }
// Fetch activation matrix from either just LDS or from both LDS / memory // Fetch activation matrix from either just LDS or from both LDS / memory
...@@ -729,36 +617,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -729,36 +617,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break; if (k_ >= K) break;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
if (k_ + K * n < max_lds_len) if (k_ + Kap * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
else else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n])));
} }
} }
// Do the matrix multiplication in interleaved manner // Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t n = 0; n < N; n++) { for (uint32_t n = 0; n < N; n++) {
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma) if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) { for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
} }
else else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++) for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
...@@ -773,40 +648,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -773,40 +648,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if constexpr (!use_mfma) { if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x118, 0xf, 0xf,
: "=v"(sum[n][y]) 1); // row_shr8
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x114, 0xf, 0xf,
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " 1); // row_shr4
: "=v"(sum[n][y]) sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x112, 0xf, 0xf,
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); 1); // row_shr2
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
: "=v"(sum[n][y]) 1); // row_shr1
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" 1); // ROW_BCAST15
: "=v"(sum[n][y]) sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); 1); // ROW_BCAST31
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
} }
} }
if (threadIdx.x == 63) { if (threadIdx.x == 63) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) { for (int y = 0; y < YTILE; y++) {
if (commitColumn[i]) { biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if (commitColumn[y]) {
if constexpr (std::is_same_v<scalar_t, half>) { if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS) sum[n][y] += __half2float(biases[n][y]);
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { } else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS) sum[n][y] += __bfloat162float(biases[n][y]);
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
} }
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]); C[m + y + n * M] = __float2s<scalar_t>(sum[n][y]);
} }
} }
} }
...@@ -819,44 +692,39 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -819,44 +692,39 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// float accm1 = 0; // float accm1 = 0;
// for (int i=0; i<64; i++) // for (int i=0; i<64; i++)
// accm1 += __shfl(sum4[n][y][i%4], i); // accm1 += __shfl(sum4[n][y][i%4], i);
float accm = sum4[n][y][0]; float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf,
: "=v"(accm) 1); // row_shl1
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); accm += __builtin_amdgcn_mov_dpp(sum4[n][y][2], 0x102, 0xf, 0xf,
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " 1); // row_shl2
: "=v"(accm) accm += __builtin_amdgcn_mov_dpp(sum4[n][y][3], 0x103, 0xf, 0xf,
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); 1); // row_shl3
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " accm += __builtin_amdgcn_mov_dpp(accm, 0x104, 0xf, 0xf,
: "=v"(accm) 1); // row_shl4
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); accm += __builtin_amdgcn_mov_dpp(accm, 0x108, 0xf, 0xf,
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " 1); // row_shl8
: "=v"(accm) accm = __builtin_amdgcn_mov_dpp(accm, 0x11f, 0xf, 0xf,
: "0"(accm), "v"(accm), "v"(accm)); 1); // row_shr15
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " accm += __builtin_amdgcn_mov_dpp(accm, 0x142, 0xf, 0xf,
: "=v"(accm) 1); // ROW_BCAST15
: "0"(accm), "v"(accm), "v"(accm)); accm += __builtin_amdgcn_mov_dpp(accm, 0x143, 0xf, 0xf,
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " 1); // ROW_BCAST31
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
sum4[n][y][0] = accm; sum4[n][y][0] = accm;
} }
} }
if (threadIdx.x == 63) { if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) { scalar_t biases[N][YTILE] = {};
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i]) {
if (BIAS) if (BIAS)
sum4[n][i][0] += for (int n = 0; n < N; n++) {
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); for (int y = 0; y < YTILE; y++) {
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if (commitColumn[y]) {
sum4[n][y][0] += __bfloat162float(biases[n][y]);
C[m + y + n * M] = __float2bfloat16(sum4[n][y][0]);
} }
} }
} }
...@@ -880,9 +748,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -880,9 +748,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support #else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N> int UNRL, int N>
__global__ void wvSplitK_hf_(const int K, const int M, const int Bx, __global__ void wvSplitK_hf_(const int K, const int Kbp, const int Kap,
const int By, const scalar_t* B, const int M, const int Bx, const int By,
const scalar_t* __restrict__ A, const scalar_t* B, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C, const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE UNREACHABLE_CODE
...@@ -894,8 +762,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const int Bx, ...@@ -894,8 +762,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const int Bx,
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N> int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS) __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_big_(const int K, const int M, const int Bx, const int By, wvSplitK_hf_big_(const int K, const int Kbp, const int Kap, const int M,
const scalar_t* B, const scalar_t* __restrict__ A, const int Bx, const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C, const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2; constexpr int max_lds_len = LDS_SIZE / 2;
...@@ -966,13 +835,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -966,13 +835,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//---------------------------------------------------- //----------------------------------------------------
#define PCML #define PCML
#ifndef PCML #ifndef PCML
for (uint32_t k = 0; k < min__(K * N, max_lds_len); for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k += THRDS * WvPrGrp * A_CHUNK) { k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); #if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
if (k_in >= min__(K * N, max_lds_len)) break; #else
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); #endif
} }
__syncthreads(); __syncthreads();
#endif #endif
...@@ -987,10 +856,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -987,10 +856,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
? kFit ? kFit
: (kFit - kFit % TUC); // round up to multiple of TUC : (kFit - kFit % TUC); // round up to multiple of TUC
// if (kFit == 0) kFit = TUC; // if (kFit == 0) kFit = TUC;
kFit = min__(kFit, K); kFit = min__(kFit, Kap);
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//---------------------------------------------------- //----------------------------------------------------
// Each wave works on a single column of weight matrix. // Each wave works on a single column of weight matrix.
...@@ -1021,15 +887,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1021,15 +887,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// YTILE represents how many column of weight matrix // YTILE represents how many column of weight matrix
// are being worked on by each wave. // are being worked on by each wave.
//---------------------------------------------------- //----------------------------------------------------
for (int i = 0; i < YTILE; i++) float sum[N][YTILE] = {};
for (int n = 0; n < N; n++) scalar8 sum4[N][YTILE] = {};
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
//---------------------------------------------------- //----------------------------------------------------
// Fetch weight matrix B in interleaved K-split! // Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk) // - Each thread (lane) is fetching 8 elements (A_Chunk)
...@@ -1048,18 +908,26 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1048,18 +908,26 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// TODO: Logic below will only work when K is multiple of 8 // TODO: Logic below will only work when K is multiple of 8
//---------------------------------------------------- //----------------------------------------------------
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
#ifdef PCML #ifdef PCML
if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS
if (k1 != 0) kBase += kFit; if (k1 != 0) kBase += kFit;
__syncthreads(); __syncthreads();
for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) { for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) {
uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (kBase + kOff >= K) break; if (kBase + kOff >= Kap) break;
if (kOff >= kFit) break; if (kOff >= kFit) break;
for (uint32_t n = 0; n < N; n++) { for (uint32_t n = 0; n < N; n++) {
uint32_t k_in = kBase + n * K + kOff; uint32_t k_in = kBase + n * Kap + kOff;
uint32_t k_ot = n * kFit + kOff; uint32_t k_ot = n * kFit + kOff;
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k_in]), (int*)(&s[k_ot]),
16, 0, 0);
#else
*((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in]));
#endif
} }
} }
__syncthreads(); __syncthreads();
...@@ -1072,11 +940,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1072,11 +940,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break; const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
for (int y = 0; y < YTILE; y++)
const scalar_t* B_ = &B[(m + 0) * K + k_]; bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
for (int b = 0; b < YTILE; b++)
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
} }
// Fetch activation matrix from either just LDS or from both LDS / memory // Fetch activation matrix from either just LDS or from both LDS / memory
...@@ -1085,17 +951,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1085,17 +951,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break; if (k_ >= K) break;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
#ifdef PCML #ifdef PCML
bigA[n][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * n]))); bigA[n][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * n])));
#else #else
if (k_ + K * n < 32 * 1024) if (k_ + Kap * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
else else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n])));
#endif #endif
} }
} }
...@@ -1103,22 +966,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1103,22 +966,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Do the matrix multiplication in interleaved manner // Do the matrix multiplication in interleaved manner
#pragma unroll #pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
#pragma unroll
for (uint32_t n = 0; n < N; n++) { for (uint32_t n = 0; n < N; n++) {
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma) if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) { for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
} }
else else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++) for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
...@@ -1141,40 +995,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1141,40 +995,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if constexpr (!use_mfma) { if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x118, 0xf, 0xf,
: "=v"(sum[n][y]) 1); // row_shr8
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x114, 0xf, 0xf,
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " 1); // row_shr4
: "=v"(sum[n][y]) sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x112, 0xf, 0xf,
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); 1); // row_shr2
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
: "=v"(sum[n][y]) 1); // row_shr1
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" 1); // ROW_BCAST15
: "=v"(sum[n][y]) sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); 1); // ROW_BCAST31
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
} }
} }
if (threadIdx.x == 63) { if (threadIdx.x == 63) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) { for (int y = 0; y < YTILE; y++) {
if (commitColumn[i]) { biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if (commitColumn[y]) {
if constexpr (std::is_same_v<scalar_t, half>) { if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS) sum[n][y] += __half2float(biases[n][y]);
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { } else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS) sum[n][y] += __bfloat162float(biases[n][y]);
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
} }
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]); C[m + y + n * M] = __float2s<scalar_t>(sum[n][y]);
} }
} }
} }
...@@ -1185,42 +1037,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1185,42 +1037,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll #pragma unroll
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
float accm = sum4[n][y][0]; float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf,
: "=v"(accm) 1); // row_shl1
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); accm += __builtin_amdgcn_mov_dpp(sum4[n][y][2], 0x102, 0xf, 0xf,
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " 1); // row_shl2
: "=v"(accm) accm += __builtin_amdgcn_mov_dpp(sum4[n][y][3], 0x103, 0xf, 0xf,
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); 1); // row_shl3
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " accm += __builtin_amdgcn_mov_dpp(accm, 0x104, 0xf, 0xf,
: "=v"(accm) 1); // row_shl4
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); accm += __builtin_amdgcn_mov_dpp(accm, 0x108, 0xf, 0xf,
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " 1); // row_shl8
: "=v"(accm) accm = __builtin_amdgcn_mov_dpp(accm, 0x11f, 0xf, 0xf,
: "0"(accm), "v"(accm), "v"(accm)); 1); // row_shr15
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " accm += __builtin_amdgcn_mov_dpp(accm, 0x142, 0xf, 0xf,
: "=v"(accm) 1); // ROW_BCAST15
: "0"(accm), "v"(accm), "v"(accm)); accm += __builtin_amdgcn_mov_dpp(accm, 0x143, 0xf, 0xf,
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " 1); // ROW_BCAST31
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
sum4[n][y][0] = accm; sum4[n][y][0] = accm;
} }
} }
if (threadIdx.x == 63) { if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) { scalar_t biases[N][YTILE] = {};
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i]) {
if (BIAS) if (BIAS)
sum4[n][i][0] += for (int n = 0; n < N; n++) {
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); for (int y = 0; y < YTILE; y++) {
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if (commitColumn[y]) {
sum4[n][y][0] += __bfloat162float(biases[n][y]);
C[m + y + n * M] = __float2bfloat16(sum4[n][y][0]);
} }
} }
} }
...@@ -1244,8 +1092,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1244,8 +1092,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support #else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N> int UNRL, int N>
__global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx, __global__ void wvSplitK_hf_big_(const int K, const int Kbp, const int Kap,
const int By, const scalar_t* B, const int M, const int Bx, const int By,
const scalar_t* B,
const scalar_t* __restrict__ A, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C, const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
...@@ -1272,6 +1121,8 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -1272,6 +1121,8 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
auto M_in = in_a.size(0); auto M_in = in_a.size(0);
auto K_in = in_a.size(1); auto K_in = in_a.size(1);
auto N_in = in_b.size(0); auto N_in = in_b.size(0);
auto Kap_in = in_a.stride(0);
auto Kbp_in = in_b.stride(0);
auto Bx_in = auto Bx_in =
(in_bias.has_value() && in_bias->numel() > 0) (in_bias.has_value() && in_bias->numel() > 0)
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0) ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
...@@ -1300,23 +1151,26 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -1300,23 +1151,26 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
{ \ { \
dim3 block(64, 16); \ dim3 block(64, 16); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \ if ((Kbp_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \ wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ <<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
biasf4, c, __wvPrGrp, CuCount); \ By_in, af4, bf4, biasf4, c, __wvPrGrp, \
else if (K_in * N_in <= max_lds_len * 1.2) \ CuCount); \
else if (Kbp_in * N_in <= max_lds_len * 1.2) \
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \ wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ <<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
biasf4, c, __wvPrGrp, CuCount); \ By_in, af4, bf4, biasf4, c, __wvPrGrp, \
CuCount); \
else \ else \
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \ wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ <<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
biasf4, c, __wvPrGrp, CuCount); \ By_in, af4, bf4, biasf4, c, __wvPrGrp, \
CuCount); \
} }
#define WVSPLIT_TILE(_sYT, __N) \ #define WVSPLIT_TILE(_sYT, __N) \
{ \ { \
bool fit_lds = (K_in * N_in <= max_lds_len); \ bool fit_lds = (Kbp_in * N_in <= max_lds_len); \
if (_sYT <= 1) \ if (_sYT <= 1) \
WVSPLITK(1, 4, __N) \ WVSPLITK(1, 4, __N) \
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \ else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
......
...@@ -30,15 +30,22 @@ NKM_FACTORS_LLMM1 = [ ...@@ -30,15 +30,22 @@ NKM_FACTORS_LLMM1 = [
NKM_FACTORS_WVSPLITK = [ NKM_FACTORS_WVSPLITK = [
# Different batch sizes with key dimensions # Different batch sizes with key dimensions
(1, 16, 16), (1, 32, 16),
(1, 64, 64), (1, 64, 64),
(2, 256, 256), (2, 256, 256),
(3, 1024, 1024), (3, 1024, 1024),
(4, 4096, 4096), (4, 4096, 4096),
(4, 4096, 4096 + 1),
(4, 4096 + 16, 4096),
(4, 4096 + 16, 4096 + 1),
# Extended K values # Extended K values
(1, 9216, 512), (1, 9216, 512),
(2, 10240, 1024), (2, 10240, 1024),
(4, 16384, 8192), (4, 16384, 8192),
(4, 16384 * 2, 8192),
(4, 16384 * 2, 8192 + 1),
(4, 16384 * 2 + 16, 8192),
(4, 16384 * 2 + 16, 8192 + 1),
# Minimum M constraint validation (m >= 8) # Minimum M constraint validation (m >= 8)
(1, 64, 8), (1, 64, 8),
(2, 128, 8), (2, 128, 8),
...@@ -180,59 +187,44 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): ...@@ -180,59 +187,44 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2) torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
@pytest.mark.parametrize("xnorm", [False, True])
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") @pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("bias_mode", BIAS_MODES)
torch.manual_seed(seed) @pytest.mark.parametrize("padded_a", [False, True])
cu_count = num_compute_units() @pytest.mark.parametrize("padded_b", [False, True])
def test_rocm_wvsplitk_kernel(
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5 xnorm, n, k, m, dtype, seed, bias_mode, padded_a, padded_b
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5 ):
ref_out = torch.nn.functional.linear(A, B)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)
torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed) torch.manual_seed(seed)
cu_count = num_compute_units() cu_count = num_compute_units()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas xavier = (
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier math.sqrt(2 / k) if xnorm else 1
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier ) # normalize to avoid large output-bias deltas
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5 A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) BIAS = None
@pytest.mark.parametrize("dtype", DTYPES) if bias_mode == 1:
@pytest.mark.parametrize("seed", SEEDS) BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") elif bias_mode == 2:
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed): BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
torch.manual_seed(seed)
cu_count = num_compute_units()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas if padded_a:
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier A = pad_fp8(A)
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier if padded_b:
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5 B = pad_fp8(B)
ref_out = torch.nn.functional.linear(A, B, BIAS) ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2) if xnorm:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
else:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2)
@pytest.mark.parametrize("xnorm", [False, True]) @pytest.mark.parametrize("xnorm", [False, True])
......
...@@ -191,7 +191,6 @@ def rocm_unquantized_gemm_impl( ...@@ -191,7 +191,6 @@ def rocm_unquantized_gemm_impl(
and on_gfx9() and on_gfx9()
and x.dtype in [torch.float16, torch.bfloat16] and x.dtype in [torch.float16, torch.bfloat16]
and k % 8 == 0 and k % 8 == 0
and x.is_contiguous()
) )
if use_skinny is not True: if use_skinny is not True:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment