"vllm/vscode:/vscode.git/clone" did not exist on "3a0e1fc070dc7482ab1c8fcdc961e5729a4cb0b3"
Unverified Commit 721ae79f authored by Hashem Hashemi's avatar Hashem Hashemi Committed by GitHub
Browse files

Improvements to wvSplitKrc skinny GEMM solution (#34304)


Signed-off-by: default avatarHashem Hashemi <hashem.hashemi@amd.com>
parent aefc59f0
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "../cuda_compat.h" #include "../cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "quantization/w8a8/fp8/common.cuh" #include "quantization/w8a8/fp8/common.cuh"
#include "core/batch_invariant.hpp"
// TODO(rasmith): The kernels in this file are susceptible to integer overflow // TODO(rasmith): The kernels in this file are susceptible to integer overflow
// issues, do not take strides, and are unable to handle PyTorch tensors that // issues, do not take strides, and are unable to handle PyTorch tensors that
...@@ -1224,17 +1225,14 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -1224,17 +1225,14 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
#if defined(__gfx950__) #if defined(__gfx950__)
#define WVSPLITKRC_1KPASS #define WVSPLITKRC_1KPASS
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N, int GrpsShrB, int CHUNKK> int UNRL, int N, int GrpsShrB, int CHUNKK, int DTRMNSTC>
__global__ void __launch_bounds__(WvPrGrp* THRDS) __global__ void __launch_bounds__(WvPrGrp* THRDS)
__attribute__((amdgpu_waves_per_eu(1, 1))) __attribute__((amdgpu_waves_per_eu(1, 1)))
wvSplitKrc_(const int actlN, const int K, const int M, const int Bx, wvSplitKrc_(const int actlN, const int K, const int Kap, const int M,
const int By, const scalar_t* __restrict__ B, const int Bx, const int By, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ A, const scalar_t* __restrict__ B,
const scalar_t* __restrict__ BIAS, float* glbl, scalar_t* C, const scalar_t* __restrict__ BIAS, float* glbl, int* cntr,
const int CuCount) { scalar_t* C, const int CuCount) {
// Use upper half of glbl buffer for atomic reduce counting
int* cntr = (int*)(&glbl[M * N]);
constexpr int NTILE = 16; constexpr int NTILE = 16;
constexpr int APAD = 1; constexpr int APAD = 1;
constexpr int ASTRD = 64; constexpr int ASTRD = 64;
...@@ -1425,10 +1423,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1425,10 +1423,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
unsigned int kOffcp = min__(K - A_CHUNK, k_str + kOff); unsigned int kOffcp = min__(K - A_CHUNK, k_str + kOff);
for (unsigned int n = 0; n < N; n += CHUNKK * sprdN) { for (unsigned int n = 0; n < N; n += CHUNKK * sprdN) {
__builtin_amdgcn_global_load_lds( __builtin_amdgcn_global_load_lds(
(int*)(&A[min__( (int*)(&A[min__(Kap * actlN - A_CHUNK,
K * actlN - A_CHUNK, kOffcp + Kap * (n / CHUNKK +
kOffcp + K * (n / CHUNKK + (N / CHUNKK) * (threadIdx.x /
(N / CHUNKK) * (threadIdx.x / (64 / CHUNKK)) + (64 / CHUNKK)) +
(threadIdx.y % sprdN)))]), (threadIdx.y % sprdN)))]),
(int*)(&s[(k + (int*)(&s[(k +
kFitPdd * ((n / CHUNKK) + (threadIdx.y % sprdN)))]), kFitPdd * ((n / CHUNKK) + (threadIdx.y % sprdN)))]),
...@@ -1533,30 +1531,66 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1533,30 +1531,66 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
} }
} }
union flt4 {
scalar8 s8;
float2 f2[2];
float4 f4;
};
if (m + (threadIdx.x % 16) < M) { if (m + (threadIdx.x % 16) < M) {
int my_cntr; int my_cntr;
int mindx = m + (threadIdx.x % 16); int mindx = m + (threadIdx.x % 16);
int g_mindx = m * 4 + (threadIdx.x % 64); // coalesced atomic reduction int g_mindx = m * 4 + (threadIdx.x % 64); // coalesced atomic reduction
scalar_t biases[N / NTILE / GrpsShrB][4] = {}; scalar_t biases[N / NTILE / GrpsShrB][4] = {};
// Atomic add the output, read biases // Atomic add the output, read biases
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) {
// int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
// (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
// int adr = mindx + M * nindx;
int g_nindx = int g_nindx =
j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4; (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
int g_adr = g_mindx + M * g_nindx * 4; int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4;
atomicAdd(&glbl[g_adr], sum4[nt][0][j]); if (DTRMNSTC) {
flt4 flt4_ = {.s8 = sum4[nt][0]};
__hip_atomic_store((float2*)&glbl[g_adr + M * N * (m0 / Mmod)],
flt4_.f2[0], __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_AGENT);
__hip_atomic_store((float2*)&glbl[g_adr + 2 + M * N * (m0 / Mmod)],
flt4_.f2[1], __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_AGENT);
} else {
for (uint32_t j = 0; j < 4; j++)
atomicAdd((&glbl[g_adr + j]), sum4[nt][0][j]);
} }
}
__atomic_signal_fence(__ATOMIC_SEQ_CST);
asm volatile("s_waitcnt vmcnt(0)" ::: "memory");
__atomic_signal_fence(__ATOMIC_SEQ_CST);
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE + int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB); (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
int adr_ = mindx + M * nindx_ / 4; int adr_ = mindx + M * nindx_ / 4;
// Update the complete counter
my_cntr = atomicAdd(&cntr[adr_], 1); my_cntr = atomicAdd(&cntr[adr_], 1);
float vals[N / NTILE / GrpsShrB][4] = {};
// make sure LDS is free for write out staging
if (DTRMNSTC) __syncthreads();
// Update the complete counter
flt4 vals[N / NTILE / GrpsShrB] = {};
// If we're the last k-shard, read back the value and convert... // If we're the last k-shard, read back the value and convert...
if (my_cntr + 1 == k_rnd) { if (my_cntr + 1 == k_rnd) {
cntr[adr_] = 0; // clear for next round
if constexpr (DTRMNSTC) {
#pragma unroll
for (int ks = 0; ks < k_rnd; ks++) {
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
int g_nindx =
(nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4;
__builtin_amdgcn_global_load_lds(
(float4*)(&glbl[g_adr + M * N * ks]),
&(((float4*)s)[(threadIdx.y * THRDS) + ks * THRDS * 4 +
nt * THRDS * 4 * k_rnd]),
16, 0, 0);
}
}
if (BIAS) if (BIAS)
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) { for (uint32_t j = 0; j < 4; j++) {
...@@ -1565,12 +1599,29 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1565,12 +1599,29 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx]; biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx];
} }
} }
asm volatile("s_waitcnt 0");
for (int ks = 0; ks < k_rnd; ks++) {
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
float4 eval = ((float4*)s)[(threadIdx.x + threadIdx.y * THRDS) +
ks * THRDS * 4 + nt * THRDS * 4 * k_rnd];
vals[nt].f4 += eval;
}
}
} else {
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) {
int g_nindx = int g_nindx =
j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4; (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
int g_adr = g_mindx + M * g_nindx * 4; int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4;
vals[nt][j] = glbl[g_adr]; vals[nt].f4 = *(float4*)(&glbl[g_adr]);
*(float4*)(&glbl[g_adr]) = {}; // clear out for next round
}
if (BIAS)
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) {
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx];
}
} }
} }
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -1581,11 +1632,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1581,11 +1632,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (nindx < actlN) { if (nindx < actlN) {
int adr = mindx + M * nindx; int adr = mindx + M * nindx;
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
vals[nt][j] += __bfloat162float(biases[nt][j]); vals[nt].s8[j] += __bfloat162float(biases[nt][j]);
C[adr] = __float2bfloat16(vals[nt][j]); C[adr] = __float2bfloat16(vals[nt].s8[j]);
} else { } else {
vals[nt][j] += __half2float(biases[nt][j]); vals[nt].s8[j] += __half2float(biases[nt][j]);
C[adr] = __float2half(vals[nt][j]); C[adr] = __float2half(vals[nt].s8[j]);
} }
} }
} }
...@@ -1604,21 +1655,25 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1604,21 +1655,25 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
} }
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support #else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N, int GrpsShrB, int CHUNKK> int UNRL, int N, int GrpsShrB, int CHUNKK, int DTRMNSTC>
__global__ void wvSplitKrc_(const int actlN, const int K, const int M, __global__ void wvSplitKrc_(const int actlN, const int K, const int Kap,
const int Bx, const int By, const scalar_t* B, const int M, const int Bx, const int By,
const scalar_t* __restrict__ A, const scalar_t* B, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, float* glbl, const scalar_t* __restrict__ BIAS, float* glbl,
// int* cntr, int* cntr, scalar_t* C,
scalar_t* C, const int CuCount){UNREACHABLE_CODE} const int CuCount){UNREACHABLE_CODE}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support #endif // defined(__HIP__GFX9__) TODO: Add NAVI support
torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
const std::optional<at::Tensor>& in_bias, const std::optional<at::Tensor>& in_bias,
const int64_t CuCount) { const int64_t CuCount) {
auto M_in = in_a.size(0); int _DTRMNSTC = 1; // vllm::vllm_is_batch_invariant();
auto N_in = in_b.size(0);
auto K_in = in_a.size(1); auto M_in = in_b.size(0);
auto N_in = in_a.size(0);
auto K_in = in_b.size(1);
auto Kap_in = in_a.stride(0);
auto Bx_in = auto Bx_in =
(in_bias.has_value() && in_bias->numel() > 0) (in_bias.has_value() && in_bias->numel() > 0)
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0) ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
...@@ -1635,13 +1690,9 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -1635,13 +1690,9 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
auto out_c = torch::empty( auto out_c = torch::empty(
{N_in, M_in}, {N_in, M_in},
torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device())); torch::TensorOptions().dtype(in_a.dtype()).device(in_a.device()));
auto N_p2 = 1U << (32 - __builtin_clz(N_in - 1)); auto N_p2 = 1U << (32 - __builtin_clz(N_in - 1));
auto axl_glbl = torch::empty(
{N_p2 + N_p2 / 4, M_in + M_in / 4},
torch::TensorOptions().dtype(torch::kFloat32).device(in_b.device()));
axl_glbl.zero_(); // disable for FAST_UNSAFE_RDC_INIT
dim3 grid(CuCount); dim3 grid(CuCount);
...@@ -1649,25 +1700,6 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -1649,25 +1700,6 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// const int max_lds_len = get_lds_size() / 2; // const int max_lds_len = get_lds_size() / 2;
#define WVSPLITKrc(_N, _GrpsShrB, _CHUNKK) \
{ \
dim3 block(64, 4); \
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK> \
<<<grid, block, 0, stream>>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, glbl, c, CuCount); \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitKrc", [&] {
using fptype = typename scalar<scalar_t>::type;
fptype* af4 = reinterpret_cast<fptype*>(in_a.data_ptr());
const fptype* bf4 = reinterpret_cast<const fptype*>(in_b.data_ptr());
const fptype* biasf4 =
(in_bias.has_value() && in_bias->numel() > 0)
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
: nullptr;
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
auto glbl = axl_glbl.data_ptr<float>();
// With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile), // With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
// and each working on a 512-shard of K, how many CUs would we need? // and each working on a 512-shard of K, how many CUs would we need?
int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512); int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512);
...@@ -1680,24 +1712,58 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -1680,24 +1712,58 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
// Given the above, how many CUs would we need? // Given the above, how many CUs would we need?
int CuNeeded = rndup_cus * GrpsShrB; int CuNeeded = rndup_cus * GrpsShrB;
if (CuNeeded > CuCount) std::runtime_error("Invalid wvSplitKrc size"); if (CuNeeded > CuCount) throw std::runtime_error("Invalid wvSplitKrc size");
// Can we increase SplitK by shrinking the K-shared to 256? // Can we increase SplitK by shrinking the K-shared to 256?
int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1; int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1;
static torch::Tensor axl_glbl =
torch::zeros(
128 * 1024 * (_DTRMNSTC ? 12 : 1),
torch::TensorOptions().dtype(torch::kFloat32).device(in_a.device()))
.detach();
static torch::Tensor axl_cntr =
torch::zeros(
128 * 1024 * (_DTRMNSTC ? 12 : 1) / 4,
torch::TensorOptions().dtype(torch::kInt).device(in_a.device()))
.detach();
auto glbl = axl_glbl.data_ptr<float>();
auto cntr = axl_cntr.data_ptr<int>();
#define WVSPLITKrc(_N, _GrpsShrB, _CHUNKK) \
{ \
dim3 block(64, 4); \
if (_DTRMNSTC) \
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK, 1> \
<<<grid, block, 0, stream>>>(N_in, K_in, Kap_in, M_in, Bx_in, By_in, \
af4, bf4, biasf4, glbl, cntr, c, \
CuCount); \
else \
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK, 0> \
<<<grid, block, 0, stream>>>(N_in, K_in, Kap_in, M_in, Bx_in, By_in, \
af4, bf4, biasf4, glbl, cntr, c, \
CuCount); \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_a.scalar_type(), "wvSplitKrc", [&] {
using fptype = typename scalar<scalar_t>::type;
const fptype* af4 = reinterpret_cast<const fptype*>(in_a.data_ptr());
const fptype* bf4 = reinterpret_cast<const fptype*>(in_b.data_ptr());
const fptype* biasf4 =
(in_bias.has_value() && in_bias->numel() > 0)
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
: nullptr;
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
switch (N_p2) { switch (N_p2) {
case 16: case 16:
WVSPLITKrc(16, 1, 1) break; WVSPLITKrc(16, 1, 1) break;
case 32: case 32:
if (chunkk == 2) if (chunkk == 2) WVSPLITKrc(32, 2, 2) else WVSPLITKrc(32, 2, 1) break;
WVSPLITKrc(32, 2, 2) else if (chunkk == 1) WVSPLITKrc(32, 2, 1) break;
case 64: case 64:
if (chunkk == 2) if (chunkk == 2) WVSPLITKrc(64, 4, 2) else WVSPLITKrc(64, 4, 1) break;
WVSPLITKrc(64, 4, 2) else if (chunkk == 1) WVSPLITKrc(64, 4, 1) break;
case 128: case 128:
if (chunkk == 2) if (chunkk == 2) WVSPLITKrc(128, 4, 2) else WVSPLITKrc(128, 4, 1) break;
WVSPLITKrc(128, 4, 2) else if (chunkk == 1)
WVSPLITKrc(128, 4, 1) break;
default: default:
throw std::runtime_error( throw std::runtime_error(
"Unsupported N value: " + std::to_string(M_in) + "," + "Unsupported N value: " + std::to_string(M_in) + "," +
......
...@@ -70,7 +70,6 @@ N_FACTORS_WVSPLITKRC = [ ...@@ -70,7 +70,6 @@ N_FACTORS_WVSPLITKRC = [
117, 117,
128, 128,
] ]
K_FACTORS_WVSPLITKRC = [2880, 2880 + 8, 3072, 3072 + 8] K_FACTORS_WVSPLITKRC = [2880, 2880 + 8, 3072, 3072 + 8]
M_FACTORS_WVSPLITKRC = [128, 128 + 16, 256, 256 + 16, 640, 640 + 16] M_FACTORS_WVSPLITKRC = [128, 128 + 16, 256, 256 + 16, 640, 640 + 16]
...@@ -123,10 +122,11 @@ def pad_fp8(weight): ...@@ -123,10 +122,11 @@ def pad_fp8(weight):
@pytest.mark.parametrize("m", M_FACTORS_WVSPLITKRC) @pytest.mark.parametrize("m", M_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("padded_a", [False, True])
@pytest.mark.parametrize("bias_mode", BIAS_MODES) @pytest.mark.parametrize("bias_mode", BIAS_MODES)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") @pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950") @pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode): def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, padded_a, bias_mode):
torch.manual_seed(seed) torch.manual_seed(seed)
cu_count = num_compute_units() cu_count = num_compute_units()
...@@ -141,7 +141,8 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode): ...@@ -141,7 +141,8 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
# Given the above, how many CUs would we need? # Given the above, how many CUs would we need?
CuNeeded = rndup_cus * GrpsShrB CuNeeded = rndup_cus * GrpsShrB
# candidate for atomic reduce count splitk? # candidate for atomic reduce count splitk?
fits_wvsplitkrc = CuNeeded <= cu_count fits_wvsplitkrc = (N_p2 * m * ((k + 512 - 1) // 512)) <= 128 * 1024 * 12
fits_wvsplitkrc &= CuNeeded <= cu_count
if not fits_wvsplitkrc: if not fits_wvsplitkrc:
pytest.skip("Too large for wvSplitKrc") pytest.skip("Too large for wvSplitKrc")
...@@ -151,6 +152,8 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode): ...@@ -151,6 +152,8 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
) # normalize to avoid large output-bias deltas ) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
if padded_a:
A = pad_fp8(A)
BIAS = None BIAS = None
if bias_mode == 1: if bias_mode == 1:
...@@ -159,7 +162,7 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode): ...@@ -159,7 +162,7 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1 BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
ref_out = torch.nn.functional.linear(A, B, BIAS) ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS) out = ops.wvSplitKrc(A, B, cu_count, BIAS)
if xnorm: if xnorm:
torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8) torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8)
......
...@@ -129,10 +129,6 @@ def rocm_unquantized_gemm_impl( ...@@ -129,10 +129,6 @@ def rocm_unquantized_gemm_impl(
k = weight.shape[1] k = weight.shape[1]
cu_count = num_compute_units() cu_count = num_compute_units()
if use_aiter_triton_gemm(n, m, k, x.dtype):
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
return gemm_a16w16(x, weight, bias)
# Next ^2 of n # Next ^2 of n
N_p2 = 1 << (n - 1).bit_length() N_p2 = 1 << (n - 1).bit_length()
...@@ -145,7 +141,10 @@ def rocm_unquantized_gemm_impl( ...@@ -145,7 +141,10 @@ def rocm_unquantized_gemm_impl(
# Given the above, how many CUs would we need? # Given the above, how many CUs would we need?
CuNeeded = rndup_cus * GrpsShrB CuNeeded = rndup_cus * GrpsShrB
# candidate for atomic reduce count splitk? # candidate for atomic reduce count splitk?
fits_wvsplitkrc = CuNeeded <= cu_count fits_wvsplitkrc = (
N_p2 * m * ((k + 512 - 1) // 512)
) <= 128 * 1024 * 12 # deterministic
fits_wvsplitkrc &= CuNeeded <= cu_count
use_skinny_reduce_counting = ( use_skinny_reduce_counting = (
envs.VLLM_ROCM_USE_SKINNY_GEMM envs.VLLM_ROCM_USE_SKINNY_GEMM
...@@ -157,13 +156,16 @@ def rocm_unquantized_gemm_impl( ...@@ -157,13 +156,16 @@ def rocm_unquantized_gemm_impl(
and k > 512 and k > 512
and m % 16 == 0 and m % 16 == 0
and fits_wvsplitkrc and fits_wvsplitkrc
and x.is_contiguous() and weight.is_contiguous()
) )
) )
if use_skinny_reduce_counting: if use_skinny_reduce_counting:
x_view = x.reshape(-1, x.size(-1)) return ops.wvSplitKrc(x, weight, cu_count, bias)
out = ops.wvSplitKrc(weight, x_view, cu_count, bias)
return out.reshape(*x.shape[:-1], weight.shape[0]) if use_aiter_triton_gemm(n, m, k, x.dtype):
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
return gemm_a16w16(x, weight, bias)
use_skinny = ( use_skinny = (
envs.VLLM_ROCM_USE_SKINNY_GEMM envs.VLLM_ROCM_USE_SKINNY_GEMM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment