Unverified Commit d5c48001 authored by Hashem Hashemi's avatar Hashem Hashemi Committed by GitHub
Browse files

Adds padding and perf improvements to wvSplitK_fp8 (#33527)


Signed-off-by: default avatarHashem Hashemi <hashem.hashemi@amd.com>
parent 42d5d705
...@@ -1899,8 +1899,9 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -1899,8 +1899,9 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp, template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N> int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS) __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const int Bx, wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp, const int M,
const int By, const fp8_t* B, const fp8_t* __restrict__ A, const int Bx, const int By, const fp8_t* B,
const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C, const scalar_t* __restrict__ BIAS, scalar_t* C,
const float* __restrict__ s_A, const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp, const float* __restrict__ s_B, const int _WvPrGrp,
...@@ -1924,9 +1925,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1924,9 +1925,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
__shared__ fp8_t s[max_lds_len]; __shared__ fp8_t s[max_lds_len];
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
#else
*((bigType*)(&s[k])) = *((bigType*)(&A[k])); *((bigType*)(&s[k])) = *((bigType*)(&A[k]));
#endif
} }
asm volatile("s_waitcnt vmcnt(0)");
__syncthreads(); __syncthreads();
if (threadIdx.y >= _WvPrGrp) return; if (threadIdx.y >= _WvPrGrp) return;
...@@ -1934,37 +1940,24 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1934,37 +1940,24 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
floatx16 sum[N][YTILE];
float sA = *s_A; float sA = *s_A;
float sB = *s_B; float sB = *s_B;
while (m < M) { while (m < M) {
for (int i = 0; i < YTILE; i++) floatx16 sum[N][YTILE] = {};
for (int n = 0; n < N; n++) sum[n][i] = {0.f};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
#pragma unroll bigType bigA[N][UNRL] = {};
for (uint32_t k2 = 0; k2 < UNRL; k2++) { bigType bigB[YTILE][UNRL];
#pragma unroll
for (uint32_t n = 0; n < N; ++n) bigA[n][k2].h8 = {0.f};
#pragma unroll
for (uint32_t y = 0; y < YTILE; ++y) bigB[y][k2].h8 = {0.f};
}
// Fetch the weight matrix from memory! // Fetch the weight matrix from memory!
#pragma unroll #pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break; const fp8_t* B_ = &B[min__(k_, K - A_CHUNK)];
const fp8_t* B_ = &B[(m + 0) * Kp + k_];
#pragma unroll #pragma unroll
for (uint32_t y = 0; y < YTILE; ++y) { for (uint32_t y = 0; y < YTILE; ++y) {
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp]))); bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
} }
} }
...@@ -1975,16 +1968,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1975,16 +1968,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k_ = k + threadIdx.x * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break; if (k_ >= K) break;
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
} }
} }
// Do the matrix multiplication in interleaved manner // Do the matrix multiplication in interleaved manner
#pragma unroll #pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
if (k >= K) break;
for (uint32_t n = 0; n < N; n++) { for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) { for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) { for (int y = 0; y < YTILE; ++y) {
...@@ -2002,48 +1992,27 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2002,48 +1992,27 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0]; float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8]; float accm16 = sum[n][y][8];
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
: "=v"(accm0) 1); // row_shl1
: "0"(accm0), "v"(sum[n][y][1]), "v"(accm0)); accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
: "=v"(accm16) 1); // row_shl2
: "0"(accm16), "v"(sum[n][y][9]), "v"(accm16)); accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
: "=v"(accm0) 1); // row_shl3
: "0"(accm0), "v"(sum[n][y][2]), "v"(accm0)); accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
: "=v"(accm16) 1); // row_shl8
: "0"(accm16), "v"(sum[n][y][10]), "v"(accm16)); accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
: "=v"(accm0) 1); // row_shl9
: "0"(accm0), "v"(sum[n][y][3]), "v"(accm0)); accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
: "=v"(accm16) 1); // row_shl10
: "0"(accm16), "v"(sum[n][y][11]), "v"(accm16)); accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
: "=v"(accm0) 1); // row_shl11
: "0"(accm0), "v"(sum[n][y][4]), "v"(accm0)); accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][12]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][5]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][13]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][6]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][14]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][7]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][15]), "v"(accm16));
accm0 += __shfl(accm0, 36); accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52); accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16); sum[n][y][0] = accm0 + __shfl(accm16, 16);
...@@ -2051,19 +2020,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2051,19 +2020,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
} }
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault. if (y + m >= M) break; // To avoid mem access fault.
sum[n][y][0] *= sA * sB; sum[n][y][0] *= sA * sB;
if constexpr (std::is_same_v<scalar_t, half>) { if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS) sum[n][y][0] += __half2float(biases[n][y]);
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { } else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS) sum[n][y][0] += __bfloat162float(biases[n][y]);
sum[n][y][0] +=
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
} }
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]); // * sA * sB); C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
} }
} }
} }
...@@ -2074,9 +2047,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2074,9 +2047,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support #else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp, template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N> int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, __global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
const int Bx, const int By, const fp8_t* B, const int M, const int Bx, const int By,
const fp8_t* __restrict__ A, const fp8_t* B, const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, const scalar_t* __restrict__ BIAS,
scalar_t* C, const float* __restrict__ s_A, scalar_t* C, const float* __restrict__ s_A,
const float* __restrict__ s_B, const float* __restrict__ s_B,
...@@ -2089,8 +2062,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, ...@@ -2089,8 +2062,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp, template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N> int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS) __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitKQ_hf_(const int K, const int Kp, const int M, const int Bx, wvSplitKQ_hf_(const int K, const int Kap, const int Kbp, const int M,
const int By, const fp8_t* B, const fp8_t* __restrict__ A, const int Bx, const int By, const fp8_t* B,
const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C, const scalar_t* __restrict__ BIAS, scalar_t* C,
const float* __restrict__ s_A, const float* __restrict__ s_B, const float* __restrict__ s_A, const float* __restrict__ s_B,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
...@@ -2113,9 +2087,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2113,9 +2087,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
__shared__ fp8_t s[max_lds_len]; __shared__ fp8_t s[max_lds_len];
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
#else
*((bigType*)(&s[k])) = *((bigType*)(&A[k])); *((bigType*)(&s[k])) = *((bigType*)(&A[k]));
#endif
} }
asm volatile("s_waitcnt vmcnt(0)");
__syncthreads(); __syncthreads();
if (threadIdx.y >= _WvPrGrp) return; if (threadIdx.y >= _WvPrGrp) return;
...@@ -2123,29 +2102,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2123,29 +2102,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
floatx16 sum[N][YTILE];
float sA = *s_A; float sA = *s_A;
float sB = *s_B; float sB = *s_B;
while (m < M) { while (m < M) {
for (int i = 0; i < YTILE; i++) floatx16 sum[N][YTILE] = {};
for (int n = 0; n < N; n++) sum[n][i] = {0}; for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL]; bigType bigB[YTILE][UNRL];
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
// Fetch the weight matrix from memory! // Fetch the weight matrix from memory!
#pragma unroll #pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break; const fp8_t* B_ = &B[min__(k_, K - A_CHUNK)];
const fp8_t* B_ = &B[(m + 0) * Kp + k_];
for (int y = 0; y < YTILE; ++y) { for (int y = 0; y < YTILE; ++y) {
if (y + m >= M) break; // To avoid mem access fault. bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp])));
} }
} }
...@@ -2156,20 +2129,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2156,20 +2129,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k_ = k + threadIdx.x * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break; if (k_ >= K) break;
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
if (k_ + K * n < max_lds_len) if (k_ + Kap * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
else else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n])));
} }
} }
// Do the matrix multiplication in interleaved manner // Do the matrix multiplication in interleaved manner
#pragma unroll #pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
for (uint32_t n = 0; n < N; n++) { for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) { for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) { for (int y = 0; y < YTILE; ++y) {
...@@ -2187,48 +2156,27 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2187,48 +2156,27 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0]; float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8]; float accm16 = sum[n][y][8];
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
: "=v"(accm0) 1); // row_shl1
: "0"(accm0), "v"(sum[n][y][1]), "v"(accm0)); accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
: "=v"(accm16) 1); // row_shl2
: "0"(accm16), "v"(sum[n][y][9]), "v"(accm16)); accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
: "=v"(accm0) 1); // row_shl3
: "0"(accm0), "v"(sum[n][y][2]), "v"(accm0)); accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
: "=v"(accm16) 1); // row_shl8
: "0"(accm16), "v"(sum[n][y][10]), "v"(accm16)); accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
: "=v"(accm0) 1); // row_shl9
: "0"(accm0), "v"(sum[n][y][3]), "v"(accm0)); accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
: "=v"(accm16) 1); // row_shl10
: "0"(accm16), "v"(sum[n][y][11]), "v"(accm16)); accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
: "=v"(accm0) 1); // row_shl11
: "0"(accm0), "v"(sum[n][y][4]), "v"(accm0)); accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][12]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][5]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][13]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][6]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][14]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][7]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][15]), "v"(accm16));
accm0 += __shfl(accm0, 36); accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52); accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16); sum[n][y][0] = accm0 + __shfl(accm16, 16);
...@@ -2236,17 +2184,21 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2236,17 +2184,21 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
} }
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault. if (y + m >= M) break; // To avoid mem access fault.
sum[n][y][0] *= sA * sB; sum[n][y][0] *= sA * sB;
if constexpr (std::is_same_v<scalar_t, half>) { if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS) sum[n][y][0] += __half2float(biases[n][y]);
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { } else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS) sum[n][y][0] += __bfloat162float(biases[n][y]);
sum[n][y][0] +=
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
} }
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]); C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
} }
...@@ -2259,9 +2211,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2259,9 +2211,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support #else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp, template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N> int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, __global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
const int Bx, const int By, const fp8_t* B, const int M, const int Bx, const int By,
const fp8_t* __restrict__ A, const fp8_t* B, const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C, const scalar_t* __restrict__ BIAS, scalar_t* C,
const float* __restrict__ s_A, const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp, const float* __restrict__ s_B, const int _WvPrGrp,
...@@ -2270,17 +2222,18 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, ...@@ -2270,17 +2222,18 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
} }
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support #endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
const std::optional<at::Tensor>& in_bias, at::Tensor& out_c, const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,
const at::Tensor& scale_a, const at::Tensor& scale_b, const at::Tensor& scale_a, const at::Tensor& scale_b,
const int64_t CuCount) { const int64_t CuCount) {
static c10::ScalarType kFp8Type = is_fp8_ocp() static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn ? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz; : c10::ScalarType::Float8_e4m3fnuz;
auto M_in = in_a.size(0); auto M_in = in_b.size(0);
auto K_in = in_a.size(1); auto K_in = in_b.size(1);
auto N_in = in_b.size(0); auto N_in = in_a.size(0);
auto Kp_in = in_a.stride(0); auto Kap_in = in_a.stride(0);
auto Kbp_in = in_b.stride(0);
auto Bx_in = auto Bx_in =
(in_bias.has_value() && in_bias->numel() > 0) (in_bias.has_value() && in_bias->numel() > 0)
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0) ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
...@@ -2300,22 +2253,21 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -2300,22 +2253,21 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size(); const int max_lds_len = get_lds_size();
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ #define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \
_N) \
{ \ { \
dim3 block(64, _WvPrGrp); \ dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ if ((Kap_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEs, 16)); \
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \ wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \ <<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
b_ptr, bias_ptr, c_ptr, s_a, s_b, \ By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \
__wvPrGrp, CuCount); \ s_a, s_b, __wvPrGrp, CuCount); \
} else { \ } else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEm, 16)); \
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \ wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \ <<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
b_ptr, bias_ptr, c_ptr, s_a, s_b, \ By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \
__wvPrGrp, CuCount); \ s_a, s_b, __wvPrGrp, CuCount); \
} \ } \
} }
...@@ -2332,16 +2284,16 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -2332,16 +2284,16 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
: nullptr; : nullptr;
switch (N_in) { switch (N_in) {
case 1: case 1:
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1) WVSPLITKQ(12, 2, 2, 2, 2, 1)
break; break;
case 2: case 2:
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 2) WVSPLITKQ(12, 2, 2, 2, 2, 2)
break; break;
case 3: case 3:
WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 3) WVSPLITKQ(8, 2, 2, 1, 1, 3)
break; break;
case 4: case 4:
WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 4) WVSPLITKQ(4, 2, 2, 1, 1, 4)
break; break;
default: default:
throw std::runtime_error( throw std::runtime_error(
......
...@@ -73,21 +73,40 @@ NKM_FACTORS_WVSPLITKRC = [ ...@@ -73,21 +73,40 @@ NKM_FACTORS_WVSPLITKRC = [
NKM_FACTORS_WVSPLITK_FP8 = [ NKM_FACTORS_WVSPLITK_FP8 = [
# FP8-specific cases with K % 16 == 0 # FP8-specific cases with K % 16 == 0
(1, 16, 16), (1, 16, 16),
(1, 32, 16 + 16),
(1, 64, 64), (1, 64, 64),
(1, 64, 64 + 16),
(1, 64 + 16, 64),
(1, 64 + 16, 64 + 16),
(4, 64, 64),
(4, 64, 64 + 16),
(4, 64 + 16, 64),
(4, 64 + 16, 64 + 16),
(2, 512, 512), (2, 512, 512),
(3, 512, 512),
(3, 512, 512 + 16),
(4, 512, 512),
(3, 2048, 2048), (3, 2048, 2048),
(3, 2048, 2048 + 16),
(4, 2048 + 16, 2048),
(4, 2048 + 16, 2048 + 16),
(4, 4096, 4096), (4, 4096, 4096),
(4, 16400, 2048), (4, 16400, 2048),
(4, 16400, 2048 + 16),
# Extended FP8 dimensions not covered by WVSPLITK # Extended FP8 dimensions not covered by WVSPLITK
(1, 14336, 1024), (1, 14336, 1024),
(2, 24576, 2048), (2, 24576, 2048),
(4, 32768, 28672), (4, 32768, 28672),
(4, 32768 * 2, 28672),
(4, 32768 * 2, 28672 + 16),
(4, 32768 * 2 + 16, 28672),
(4, 32768 * 2 + 16, 28672 + 16),
] ]
SEEDS = [0] SEEDS = [0]
def pad_weights_fp8(weight): def pad_fp8(weight):
num_pad = 256 // weight.element_size() num_pad = 256 // weight.element_size()
import torch.nn.functional as F import torch.nn.functional as F
...@@ -195,72 +214,41 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed): ...@@ -195,72 +214,41 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
assert torch.allclose(out, ref_out, rtol=0.01) assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("xnorm", [False, True])
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("padded", [False, True]) @pytest.mark.parametrize("padded_a", [False, True])
@pytest.mark.parametrize("padded_b", [False, True])
@pytest.mark.parametrize("biased", [False, True])
@pytest.mark.skipif( @pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()), not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8", reason="only test for rocm fp8",
) )
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed, padded): def test_rocm_wvsplitk_fp8_kernel(
xnorm, n, k, m, dtype, seed, padded_a, padded_b, biased
):
torch.manual_seed(seed) torch.manual_seed(seed)
A = torch.rand(n, k, device="cuda") - 0.5 xavier = math.sqrt(2 / k) if xnorm else 1 # normalize to avoid large deltas
B = torch.rand(m, k, device="cuda") - 0.5 A = (torch.rand(n, k, device="cuda") * 2 - 1) * xavier
B = (torch.rand(m, k, device="cuda") * 2 - 1) * xavier
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
if padded: if padded_b:
B = pad_weights_fp8(B) B = pad_fp8(B)
if padded_a:
ref_out = torch._scaled_mm( A = pad_fp8(A)
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
)
out = ops.wvSplitKQ(
B,
A,
dtype,
scale_a,
scale_b,
get_cu_count(),
)
assert torch.allclose(out, ref_out, rtol=0.01)
BIAS = None if (not biased) else (torch.rand(m, dtype=dtype, device="cuda") * 2 - 1)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("padded", [False, True])
@pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8",
)
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed, padded):
torch.manual_seed(seed)
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, device="cuda") - 0.5) * xavier
B = (torch.rand(m, k, device="cuda") - 0.5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
if padded:
B = pad_weights_fp8(B)
ref_out = torch._scaled_mm( ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
) )
out = ops.wvSplitKQ( out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, get_cu_count(), BIAS)
B,
A,
dtype,
scale_a,
scale_b,
get_cu_count(),
BIAS,
)
assert torch.allclose(out, ref_out, rtol=0.01) if xnorm:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
else:
assert torch.allclose(out, ref_out, 0.01)
...@@ -25,10 +25,10 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl( ...@@ -25,10 +25,10 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
bias: torch.Tensor, bias: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
if ( if (
A.shape[0] == 1 A.shape[0] <= 4
and B.shape[1] % 16 == 0 and B.shape[0] % 16 == 0 # M TODO: needed?
and B.shape[1] % 16 == 0 # K
and ((bias is None) or (bias.dtype == out_dtype)) and ((bias is None) or (bias.dtype == out_dtype))
and A.is_contiguous()
): ):
output = ops.wvSplitKQ( output = ops.wvSplitKQ(
B.t(), B.t(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment