Unverified Commit ed17f54c authored by Hashem Hashemi's avatar Hashem Hashemi Committed by GitHub
Browse files

Perf tuning and expansion of cases covered for wvSplitKrc (#33493)


Signed-off-by: default avatarHashem Hashemi <hashem.hashemi@amd.com>
parent 860981d8
...@@ -1365,13 +1365,12 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -1365,13 +1365,12 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
return out_c; return out_c;
} }
#if defined(__gfx950__) // TODO: Add NAVI support // This version targets cases skinny where CUs are not filled
// This version targets big A[] cases, where it is much larger than LDS // Wave-SplitK is used with reduction done via atomics.
// capacity #if defined(__gfx950__)
#define WVSPLITKRC_1KPASS #define WVSPLITKRC_1KPASS
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N, int GrpsShrB> int UNRL, int N, int GrpsShrB, int CHUNKK>
__global__ void __launch_bounds__(WvPrGrp* THRDS) __global__ void __launch_bounds__(WvPrGrp* THRDS)
__attribute__((amdgpu_waves_per_eu(1, 1))) __attribute__((amdgpu_waves_per_eu(1, 1)))
wvSplitKrc_(const int actlN, const int K, const int M, const int Bx, wvSplitKrc_(const int actlN, const int K, const int M, const int Bx,
...@@ -1383,12 +1382,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1383,12 +1382,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
int* cntr = (int*)(&glbl[M * N]); int* cntr = (int*)(&glbl[M * N]);
constexpr int NTILE = 16; constexpr int NTILE = 16;
constexpr int WVLDS_ = (NTILE * THRDS * A_CHUNK);
constexpr int APAD = 1; constexpr int APAD = 1;
constexpr int ASTRD = 64; constexpr int ASTRD = 64;
constexpr int BPAD = 1; constexpr int BPAD = 1;
constexpr int BSTRD = 64; constexpr int WVLDS_ = THRDS * A_CHUNK / CHUNKK;
constexpr int WVLDS = ((WVLDS_ + (WVLDS_ / BSTRD) * 4 * BPAD)); constexpr int WVLDS = ((WVLDS_ + A_CHUNK * BPAD)) * YTILE;
constexpr int max_lds_len = LDS_SIZE / 2; constexpr int max_lds_len = LDS_SIZE / 2;
...@@ -1442,17 +1440,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1442,17 +1440,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
break; break;
} }
#else #else
int constexpr kFit = 512; int constexpr kFit = 512 / CHUNKK;
int constexpr kfitsPerRdc = 1; int constexpr kfitsPerRdc = 1;
#endif #endif
bool doRdc = (kfitsPerRdc * kFit < K); bool doRdc = true; // Assuming (kfitsPerRdc * kFit < K) is always true
uint32_t numCuWithFullK = uint32_t numCuWithFullK =
((M + (WvPrGrp * YTILE / GrpsShrB) - 1) / (WvPrGrp * YTILE / GrpsShrB)); ((M + (WvPrGrp * YTILE / GrpsShrB) - 1) / (WvPrGrp * YTILE / GrpsShrB));
uint32_t Mmod = numCuWithFullK * (WvPrGrp * YTILE / GrpsShrB); uint32_t Mmod = numCuWithFullK * (WvPrGrp * YTILE / GrpsShrB);
// given above k-split, find this wave's position // given above k-split, find this wave's position
uint32_t kFitPdd = kFit + (kFit / ASTRD) * APAD; uint32_t kFitPdd = kFit * CHUNKK + ((kFit * CHUNKK) / ASTRD) * APAD;
uint32_t m0 = (blockIdx.x * WvPrGrp / GrpsShrB) * YTILE; uint32_t m0 = (blockIdx.x * WvPrGrp / GrpsShrB) * YTILE;
uint32_t m1 = ((threadIdx.y % WvPrGrp) / GrpsShrB) * YTILE; uint32_t m1 = ((threadIdx.y % WvPrGrp) / GrpsShrB) * YTILE;
uint32_t m = (m0 + m1) % Mmod; uint32_t m = (m0 + m1) % Mmod;
...@@ -1460,8 +1458,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1460,8 +1458,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k_end = (m0 / Mmod + 1) * kFit * kfitsPerRdc; uint32_t k_end = (m0 / Mmod + 1) * kFit * kfitsPerRdc;
const uint32_t k_rnd = (K + kFit * kfitsPerRdc - 1) / (kFit * kfitsPerRdc); const uint32_t k_rnd = (K + kFit * kfitsPerRdc - 1) / (kFit * kfitsPerRdc);
scalar8 sum4[N / NTILE / GrpsShrB][1]; scalar8 sum4[N / NTILE / GrpsShrB][1] = {0};
bigType bigB_[YTILE / GrpsShrB][UNRL]; bigType bigB_[YTILE / GrpsShrB / CHUNKK][UNRL];
const uint32_t bLoader = (threadIdx.y % GrpsShrB); const uint32_t bLoader = (threadIdx.y % GrpsShrB);
uint32_t kBase = 0; uint32_t kBase = 0;
if (k_str >= K) return; if (k_str >= K) return;
...@@ -1498,12 +1496,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1498,12 +1496,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll #pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k_str + k2 * THRDS * A_CHUNK; uint32_t k = k_str + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK; uint32_t k_ = k + (threadIdx.x % (THRDS / CHUNKK)) * A_CHUNK;
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)]; const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
#pragma unroll #pragma unroll
for (uint32_t y = 0; y < YTILE / GrpsShrB; y++) for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK)
bigB_[y][k2].h8 = (loadnt( bigB_[y / CHUNKK][k2].h8 = (loadnt(
(scalar8*)(&B_[min__(y * GrpsShrB + bLoader + m, M - 1) * K]))); (scalar8*)(&B_[min__((y + threadIdx.x / (THRDS / CHUNKK)) * GrpsShrB +
bLoader + m,
M - 1) *
K])));
} }
{ {
#else #else
...@@ -1556,48 +1557,51 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1556,48 +1557,51 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (reloada) { if (reloada) {
#endif #endif
constexpr int sprdN = 4; constexpr int sprdN = 4;
const uint32_t thrd = ((threadIdx.y / sprdN) * THRDS + threadIdx.x); const uint32_t thrd = threadIdx.x % (THRDS / CHUNKK);
#ifndef WVSPLITKRC_1KPASS #ifndef WVSPLITKRC_1KPASS
#pragma unroll #pragma unroll
for (int k = 0; k < kFit; k += THRDS * (WvPrGrp / sprdN) * A_CHUNK) { for (int k = 0; k < kFit;
k += (THRDS * (WvPrGrp / sprdN) * A_CHUNK) / CHUNKK) {
#else #else
const unsigned int k = 0; const unsigned int k = 0;
{ {
#endif #endif
unsigned int kOff = k + (thrd * A_CHUNK); unsigned int kOff = k + (thrd * A_CHUNK);
unsigned int kOffcp = min__(K - A_CHUNK, k_str + kOff); unsigned int kOffcp =
const unsigned int k_in = kOffcp + ((threadIdx.y % sprdN)) * K; k_str + kOff; // min__(K - A_CHUNK, k_str + kOff);
const unsigned int k_ot = kOff + ((threadIdx.y % sprdN)) * kFitPdd; for (unsigned int n = 0; n < N; n += CHUNKK * sprdN) {
for (unsigned int n = 0; n < N / 2; n += sprdN) {
__builtin_amdgcn_global_load_lds((int*)(&A[k_in + n * K]),
(int*)(&s[(k_ot + n * kFitPdd)]),
16, 0, 0);
if (((threadIdx.y % sprdN)) + n + N / 2 >= actlN) continue;
__builtin_amdgcn_global_load_lds( __builtin_amdgcn_global_load_lds(
(int*)(&A[k_in + (n + N / 2) * K]), (int*)(&A[min__(
(int*)(&s[(k_ot + (n + N / 2) * kFitPdd)]), 16, 0, 0); K * actlN - A_CHUNK,
kOffcp + K * (n / CHUNKK +
(N / CHUNKK) * (threadIdx.x / (64 / CHUNKK)) +
(threadIdx.y % sprdN)))]),
(int*)(&s[(k +
kFitPdd * ((n / CHUNKK) + (threadIdx.y % sprdN)))]),
16, 0, 0);
} }
// Stage loaded B[] to LDS for MFMA swizzling... // Stage loaded B[] to LDS for MFMA swizzling...
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK; uint32_t k_ = k + (threadIdx.x % (THRDS / CHUNKK)) * A_CHUNK;
const bool oob_k = (k_ >= K); const bool oob_k = (k_ >= K);
for (uint32_t y = 0; y < YTILE / GrpsShrB; y++) { for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK) {
uint32_t idx = threadIdx.x * 4 + uint32_t idx =
(y * GrpsShrB + bLoader) * ((THRDS + BPAD) * 4); (threadIdx.x % (THRDS / CHUNKK)) * 4 +
((y + threadIdx.x / (THRDS / CHUNKK)) * GrpsShrB + bLoader) *
((THRDS / CHUNKK + BPAD) * 4);
// zero out if oob // zero out if oob
*((scalar8*)&myStg[idx]) = *((scalar8*)&myStg[idx]) =
(oob_k || (y * GrpsShrB + bLoader + m >= M)) (oob_k) // TODO: ever necessary (y*GrpsShrB+bLoader+m>=M) ?
? 0 ? 0
: bigB_[y][k2].h8; : bigB_[y / CHUNKK][k2].h8;
} }
} }
} }
} }
} }
#ifndef WVSPLITKRC_1KPASS #ifndef WVSPLITKRC_1KPASS
// Fire load of next B[] chunk... // Fire load of next B[] chunk...
if ((k1 + THRDS * A_CHUNK * UNRL < k_end) && if ((k1 + THRDS * A_CHUNK * UNRL < k_end) &&
...@@ -1608,40 +1612,50 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1608,40 +1612,50 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k_ = k + threadIdx.x * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK;
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)]; const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
#pragma unroll #pragma unroll
for (uint32_t y = 0; y < YTILE / GrpsShrB; y++) for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK)
bigB_[y][k2].h8 = (loadnt( bigB_[y / CHUNKK][k2].h8 = (loadnt(
(scalar8*)(&B_[min__(y * GrpsShrB + bLoader + m, M - 1) * K]))); (scalar8*)(&B_[min__((y + threadIdx.x / (THRDS / CHUNKK)) *
GrpsShrB +
bLoader + m,
M - 1) *
K])));
} }
#endif #endif
// B[] staging is cooperative across GrpsShrB, so sync here before reading // B[] staging is cooperative across GrpsShrB, so sync here before reading
// back // back. This wait is currently inserted by compiler, but not gauranteed.
asm volatile("s_waitcnt 0");
__syncthreads(); __syncthreads();
// read back B[] swizzled for MFMA... // read back B[] swizzled for MFMA...
bigType bigB[YTILE][UNRL]; bigType bigB[YTILE / CHUNKK][UNRL];
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
for (uint32_t y = 0; y < YTILE; y++) { for (uint32_t y = 0; y < YTILE / CHUNKK; y++) {
unsigned int idx = (threadIdx.x % YTILE) * ((THRDS + BPAD) * 4) + unsigned int idx =
(threadIdx.x / YTILE) * 4 + y * 16; (threadIdx.x % YTILE) * ((THRDS / CHUNKK + BPAD) * 4) +
(threadIdx.x / YTILE) * 4 + y * 16;
bigB[y][k2].h8 = *((scalar8*)&myStg[idx]); bigB[y][k2].h8 = *((scalar8*)&myStg[idx]);
} }
} }
// rReadback A[] swizzled for MFMA... // rReadback A[] swizzled for MFMA...
bigType bigA[N / GrpsShrB][UNRL]; bigType bigA[N / GrpsShrB / CHUNKK][UNRL];
#pragma unroll #pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK - kBase - k_str; uint32_t k = k1 + k2 * THRDS * A_CHUNK - kBase - k_str;
#pragma unroll #pragma unroll
for (uint32_t nt = 0; nt < N / GrpsShrB; nt += NTILE) for (uint32_t nt = 0; nt < N / GrpsShrB; nt += NTILE)
#pragma unroll #pragma unroll
for (uint32_t n = 0; n < NTILE; n++) { for (uint32_t n = 0; n < NTILE / CHUNKK; n++) {
uint32_t idxa = (nt + (threadIdx.x % NTILE) + uint32_t idxa =
(N / GrpsShrB) * (threadIdx.y % GrpsShrB)) * ((nt + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) % (N / CHUNKK) +
kFitPdd + (threadIdx.x % NTILE)) *
A_CHUNK * ((threadIdx.x / NTILE) + n * 4) + k; kFitPdd +
bigA[nt + n][k2] = *((const bigType*)(&(s[idxa]))); ((nt + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) /
(N / CHUNKK)) *
A_CHUNK * (64 / CHUNKK) +
A_CHUNK * ((threadIdx.x / NTILE) + n * 4) + k;
bigA[nt / CHUNKK + n][k2] = *((const bigType*)(&(s[idxa])));
} }
} }
...@@ -1650,152 +1664,75 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1650,152 +1664,75 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
#pragma unroll #pragma unroll
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
if constexpr (std::is_same_v<scalar_t, half>) {
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16(
bigA[nt * NTILE + 0][k2].h4[0], bigB[0][k2].h4[0],
(k1 == k_str) ? ((scalar8){0}) : sum4[nt][0], 0, 0, 0);
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16(
bigA[nt * NTILE + 0][k2].h4[1], bigB[0][k2].h4[1], sum4[nt][0], 0,
0, 0);
} else { // bf16
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
bigA[nt * NTILE + 0][k2].h4[0], bigB[0][k2].h4[0],
(k1 == k_str) ? ((scalar8){0}) : sum4[nt][0], 0, 0, 0);
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
bigA[nt * NTILE + 0][k2].h4[1], bigB[0][k2].h4[1], sum4[nt][0], 0,
0, 0);
}
#pragma unroll #pragma unroll
for (uint32_t j = 1; j < YTILE; j++) { for (uint32_t j = 0; j < YTILE / CHUNKK; j++) {
if constexpr (std::is_same_v<scalar_t, half>) { if constexpr (std::is_same_v<scalar_t, half>) {
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16( sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x32_f16(
bigA[nt * NTILE + j][k2].h4[0], bigB[j][k2].h4[0], sum4[nt][0], bigA[nt * (YTILE / CHUNKK) + j][k2].h8, bigB[j][k2].h8,
0, 0, 0); sum4[nt][0], 0, 0, 0);
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16(
bigA[nt * NTILE + j][k2].h4[1], bigB[j][k2].h4[1], sum4[nt][0],
0, 0, 0);
} else { // bf16 } else { // bf16
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
bigA[nt * NTILE + j][k2].h4[0], bigB[j][k2].h4[0], sum4[nt][0], bigA[nt * (YTILE / CHUNKK) + j][k2].h8, bigB[j][k2].h8,
0, 0, 0); sum4[nt][0], 0, 0, 0);
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
bigA[nt * NTILE + j][k2].h4[1], bigB[j][k2].h4[1], sum4[nt][0],
0, 0, 0);
} }
} }
} }
} }
} }
if (!doRdc) { if (m + (threadIdx.x % 16) < M) {
if (m + (threadIdx.x % 16) < M) { int my_cntr;
scalar_t biases[N / NTILE / GrpsShrB][4] = {0}; int mindx = m + (threadIdx.x % 16);
int g_mindx = m * 4 + (threadIdx.x % 64); // coalesced atomic reduction
scalar_t biases[N / NTILE / GrpsShrB][4] = {};
// Atomic add the output, read biases
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++)
for (uint32_t j = 0; j < 4; j++) {
// int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
// (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
// int adr = mindx + M * nindx;
int g_nindx =
j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
int g_adr = g_mindx + M * g_nindx * 4;
atomicAdd(&glbl[g_adr], sum4[nt][0][j]);
}
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
int adr_ = mindx + M * nindx_ / 4;
// Update the complete counter
my_cntr = atomicAdd(&cntr[adr_], 1);
float vals[N / NTILE / GrpsShrB][4] = {};
// If we're the last k-shard, read back the value and convert...
if (my_cntr + 1 == k_rnd) {
if (BIAS) if (BIAS)
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) { for (uint32_t j = 0; j < 4; j++) {
int mindx = m + (threadIdx.x % 16);
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB); (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * M]; biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx];
} }
} }
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) { for (uint32_t j = 0; j < 4; j++) {
int mindx = m + (threadIdx.x % 16); int g_nindx =
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
(N / GrpsShrB) * (threadIdx.y % GrpsShrB); int g_adr = g_mindx + M * g_nindx * 4;
int adr = mindx + M * nindx; vals[nt][j] = glbl[g_adr];
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS) sum4[nt][0][j] += __bfloat162float(biases[nt][j]);
C[adr] = __float2bfloat16(sum4[nt][0][j]);
} else {
if (BIAS) sum4[nt][0][j] += __half2float(biases[nt][j]);
C[adr] = __float2half(sum4[nt][0][j]);
}
} }
} }
} __builtin_amdgcn_sched_barrier(0);
} else { for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
if (m + (threadIdx.x % 16) < M) { for (uint32_t j = 0; j < 4; j++) {
int my_cntr; int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
if (!BIAS) { (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
int mindx = m + (threadIdx.x % 16); if (nindx < actlN) {
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++)
for (uint32_t j = 0; j < 4; j++) {
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
int adr = mindx + M * nindx;
atomicAdd(&glbl[adr], sum4[nt][0][j]);
}
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
int adr_ = mindx + M * nindx_ / 4;
my_cntr = atomicAdd(&cntr[adr_], 1);
float vals[N / NTILE / GrpsShrB][4] = {};
if (my_cntr + 1 == k_rnd) {
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) {
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
int adr = mindx + M * nindx;
vals[nt][j] = glbl[adr];
}
}
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) {
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
if (nindx >= actlN) break;
int adr = mindx + M * nindx;
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
C[adr] = __float2bfloat16(vals[nt][j]);
} else {
C[adr] = __float2half(vals[nt][j]);
}
}
}
}
} else {
int mindx = m + (threadIdx.x % 16);
scalar_t biases[N / NTILE / GrpsShrB][4] = {};
// Atomic add the output, read biases
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++)
for (uint32_t j = 0; j < 4; j++) {
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
int adr = mindx + M * nindx; int adr = mindx + M * nindx;
atomicAdd(&glbl[adr], sum4[nt][0][j]); if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * M]; vals[nt][j] += __bfloat162float(biases[nt][j]);
} C[adr] = __float2bfloat16(vals[nt][j]);
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE + } else {
(N / GrpsShrB) * (threadIdx.y % GrpsShrB); vals[nt][j] += __half2float(biases[nt][j]);
int adr_ = mindx + M * nindx_ / 4; C[adr] = __float2half(vals[nt][j]);
// Update the complete counter
my_cntr = atomicAdd(&cntr[adr_], 1);
float vals[N / NTILE / GrpsShrB][4] = {};
// If we're the last k-shard, read back the value and convert...
if (my_cntr + 1 == k_rnd) {
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) {
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
int adr = mindx + M * nindx;
vals[nt][j] = glbl[adr];
}
}
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) {
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
if (nindx >= actlN) break;
int adr = mindx + M * nindx;
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
vals[nt][j] += __bfloat162float(biases[nt][j]);
C[adr] = __float2bfloat16(vals[nt][j]);
} else {
vals[nt][j] += __half2float(biases[nt][j]);
C[adr] = __float2half(vals[nt][j]);
}
} }
} }
} }
...@@ -1814,7 +1751,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1814,7 +1751,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
} }
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support #else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N, int GrpsShrB> int UNRL, int N, int GrpsShrB, int CHUNKK>
__global__ void wvSplitKrc_(const int actlN, const int K, const int M, __global__ void wvSplitKrc_(const int actlN, const int K, const int M,
const int Bx, const int By, const scalar_t* B, const int Bx, const int By, const scalar_t* B,
const scalar_t* __restrict__ A, const scalar_t* __restrict__ A,
...@@ -1859,10 +1796,10 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -1859,10 +1796,10 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// const int max_lds_len = get_lds_size() / 2; // const int max_lds_len = get_lds_size() / 2;
#define WVSPLITKrc(_WvPrGrp, _YTILE, _UNRL, _N, _GrpsShrB) \ #define WVSPLITKrc(_N, _GrpsShrB, _CHUNKK) \
{ \ { \
dim3 block(64, _WvPrGrp); \ dim3 block(64, 4); \
wvSplitKrc_<fptype, 64, _YTILE, _WvPrGrp, 8, _UNRL, _N, _GrpsShrB> \ wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK> \
<<<grid, block, 0, stream>>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \ <<<grid, block, 0, stream>>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, glbl, c, CuCount); \ biasf4, glbl, c, CuCount); \
} }
...@@ -1877,15 +1814,37 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -1877,15 +1814,37 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
: nullptr; : nullptr;
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr()); fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
auto glbl = axl_glbl.data_ptr<float>(); auto glbl = axl_glbl.data_ptr<float>();
// With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
// and each working on a 512-shard of K, how many CUs would we need?
int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512);
// How many of 4 waves in a group can work on same 16 Ms at same time? First
// try to maximize this. This reduces the Ms each group works on, i.e.
// increasing the number of CUs needed.
int GrpsShrB = min(N_p2 / 16, 4);
// Given the above, how many CUs would we need?
int CuNeeded = rndup_cus * GrpsShrB;
if (CuNeeded > CuCount) std::runtime_error("Invalid wvSplitKrc size");
// Can we increase SplitK by shrinking the K-shared to 256?
int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1;
switch (N_p2) { switch (N_p2) {
case 16: case 16:
WVSPLITKrc(4, 16, 1, 16, 1) break; WVSPLITKrc(16, 1, 1) break;
case 32: case 32:
WVSPLITKrc(4, 16, 1, 32, 2) break; if (chunkk == 2)
WVSPLITKrc(32, 2, 2) else if (chunkk == 1) WVSPLITKrc(32, 2, 1) break;
case 64: case 64:
WVSPLITKrc(4, 16, 1, 64, 2) break; if (chunkk == 2)
WVSPLITKrc(64, 4, 2) else if (chunkk == 1) WVSPLITKrc(64, 4, 1) break;
case 128: case 128:
WVSPLITKrc(4, 16, 1, 128, 4) break; if (chunkk == 2)
WVSPLITKrc(128, 4, 2) else if (chunkk == 1)
WVSPLITKrc(128, 4, 1) break;
default: default:
throw std::runtime_error( throw std::runtime_error(
"Unsupported N value: " + std::to_string(M_in) + "," + "Unsupported N value: " + std::to_string(M_in) + "," +
......
...@@ -45,31 +45,28 @@ NKM_FACTORS_WVSPLITK = [ ...@@ -45,31 +45,28 @@ NKM_FACTORS_WVSPLITK = [
(4, 256, 8), (4, 256, 8),
] ]
NKM_FACTORS_WVSPLITKRC = [ N_FACTORS_WVSPLITKRC = [
(16, 2880, 128), 13,
(16, 2880, 640), 16,
(17, 2880, 128), 17,
(17, 2880, 640), 25,
(25, 2880, 128), 29,
(25, 2880, 640), 31,
(31, 2880, 128), 32,
(31, 2880, 640), 41,
(32, 2880, 128), 51,
(32, 2880, 640), 64,
(40, 2880, 128), 71,
(40, 2880, 640), 81,
(60, 2880, 128), 91,
(60, 2880, 640), 103,
(64, 2880, 128), 117,
(64, 2880, 640), 128,
(81, 2880, 128),
(81, 2880, 640),
(98, 2880, 128),
(98, 2880, 640),
(128, 2880, 128),
(128, 2880, 640),
] ]
K_FACTORS_WVSPLITKRC = [2880, 2880 + 8, 3072, 3072 + 8]
M_FACTORS_WVSPLITKRC = [128, 128 + 16, 256, 256 + 16, 640, 640 + 16]
NKM_FACTORS_WVSPLITK_FP8 = [ NKM_FACTORS_WVSPLITK_FP8 = [
# FP8-specific cases with K % 16 == 0 # FP8-specific cases with K % 16 == 0
(1, 16, 16), (1, 16, 16),
...@@ -113,30 +110,54 @@ def pad_fp8(weight): ...@@ -113,30 +110,54 @@ def pad_fp8(weight):
return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITKRC) @pytest.mark.parametrize("xnorm", [False, True])
@pytest.mark.parametrize("n", N_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("k", K_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("m", M_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("bias_mode", BIAS_MODES) @pytest.mark.parametrize("bias_mode", BIAS_MODES)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") @pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950") @pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
def test_rocm_wvsplitkrc_kernel(n, k, m, dtype, seed, bias_mode): def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
torch.manual_seed(seed) torch.manual_seed(seed)
cu_count = get_cu_count() cu_count = get_cu_count()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas # Next ^2 of n
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier N_p2 = 1 << (n - 1).bit_length()
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier # With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
# and each working on a 512-shard of K, how many CUs would we need?
rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512)
# How many of 4 waves in a group can work on same 16 Ms at same time?
# This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
GrpsShrB = min(N_p2 // 16, 4)
# Given the above, how many CUs would we need?
CuNeeded = rndup_cus * GrpsShrB
# candidate for atomic reduce count splitk?
fits_wvsplitkrc = CuNeeded <= cu_count
if not fits_wvsplitkrc:
pytest.skip("Too large for wvSplitKrc")
xavier = (
math.sqrt(2 / k) if xnorm else 1
) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
BIAS = None BIAS = None
if bias_mode == 1: if bias_mode == 1:
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5 BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
elif bias_mode == 2: elif bias_mode == 2:
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5 BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
ref_out = torch.nn.functional.linear(A, B, BIAS) ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS) out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS)
assert torch.allclose(out, ref_out, rtol=0.01) if xnorm:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
else:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
......
...@@ -145,32 +145,43 @@ def rocm_unquantized_gemm_impl( ...@@ -145,32 +145,43 @@ def rocm_unquantized_gemm_impl(
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.platforms.rocm import on_gfx9, on_gfx950 from vllm.platforms.rocm import on_gfx9, on_gfx950
n = x.numel() / x.size(-1) n = x.numel() // x.size(-1)
m = weight.shape[0] m = weight.shape[0]
k = weight.shape[1] k = weight.shape[1]
import math cu_count = get_cu_count()
if use_aiter_triton_gemm(n, m, k, x.dtype): if use_aiter_triton_gemm(n, m, k, x.dtype):
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
return gemm_a16w16(x, weight, bias) return gemm_a16w16(x, weight, bias)
# Next ^2 of n
N_p2 = 1 << (n - 1).bit_length()
# With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
# and each working on a 512-shard of K, how many CUs would we need?
rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512)
# How many of 4 waves in a group can work on same 16 Ms at same time?
# This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
GrpsShrB = min(N_p2 // 16, 4)
# Given the above, how many CUs would we need?
CuNeeded = rndup_cus * GrpsShrB
# candidate for atomic reduce count splitk?
fits_wvsplitkrc = CuNeeded <= cu_count
use_skinny_reduce_counting = ( use_skinny_reduce_counting = (
envs.VLLM_ROCM_USE_SKINNY_GEMM envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_gfx950() and on_gfx950()
and x.dtype in [torch.float16, torch.bfloat16] and x.dtype in [torch.float16, torch.bfloat16]
and ( and (
n >= 16 10 <= n <= 128
and n <= 128 and k % 8 == 0
and k > 512 and k > 512
and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count() and m % 16 == 0
and fits_wvsplitkrc
and x.is_contiguous() and x.is_contiguous()
) )
# k == 2880 and (m == 640 or m == 128))
) )
if use_skinny_reduce_counting: if use_skinny_reduce_counting:
cu_count = get_cu_count()
x_view = x.reshape(-1, x.size(-1)) x_view = x.reshape(-1, x.size(-1))
out = ops.wvSplitKrc(weight, x_view, cu_count, bias) out = ops.wvSplitKrc(weight, x_view, cu_count, bias)
return out.reshape(*x.shape[:-1], weight.shape[0]) return out.reshape(*x.shape[:-1], weight.shape[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment