Unverified Commit 188b7f9b authored by Charlie Fu's avatar Charlie Fu Committed by GitHub
Browse files

[Performance][ROCm] Add skinny gemms for unquantized linear on ROCm (#15830)


Signed-off-by: default avatarcharlifu <charlifu@amd.com>
Co-authored-by: default avatarTyler Michael Smith <tysmith@redhat.com>
parent b9b47469
......@@ -678,6 +678,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
#
set(VLLM_ROCM_EXT_SRC
"csrc/rocm/torch_bindings.cpp"
"csrc/rocm/skinny_gemms.cu"
"csrc/rocm/attention.cu")
define_gpu_extension_target(
......
......@@ -2,6 +2,15 @@
#include <torch/all.h>
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
const int64_t rows_per_block);
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
const int64_t CuCount);
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache,
......
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <stdexcept>
#include <algorithm>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
#endif
#if defined(__HIPCC__) && defined(__gfx942__)
#define __HIP__MI300__
#endif
#if defined(NDEBUG)
#undef NDEBUG
#include <assert.h>
#define UNREACHABLE_CODE assert(false);
#define NDEBUG
#else
#define UNREACHABLE_CODE assert(false);
#endif
template <typename T>
struct scalar {};
template <typename T>
struct scalar2 {};
template <typename T>
__device__ __forceinline__ float2 __s22float2(T v);
template <typename T>
__device__ __forceinline__ T __float2s(float v);
template <typename T>
__device__ __forceinline__ T __float22s2_rn(float2 v);
// Definitions and cvt functions for fp16
template <>
struct scalar<c10::Half> {
using type = half;
};
template <>
struct scalar2<c10::Half> {
using type = __half2;
};
template <>
__device__ __forceinline__ half __float2s(float v) {
return __float2half(v);
}
template <>
__device__ __forceinline__ float2 __s22float2(__half2 v) {
return __half22float2(v);
}
template <>
__device__ __forceinline__ __half2 __float22s2_rn(float2 v) {
return __float22half2_rn(v);
}
// Definitions and cvt functions for bf16
template <>
struct scalar<c10::BFloat16> {
using type = __hip_bfloat16;
};
template <>
struct scalar2<c10::BFloat16> {
using type = __hip_bfloat162;
};
template <>
__device__ __forceinline__ __hip_bfloat16 __float2s(float v) {
return __float2bfloat16(v);
}
template <>
__device__ __forceinline__ float2 __s22float2(__hip_bfloat162 v) {
return __bfloat1622float2(v);
}
template <>
__device__ __forceinline__ __hip_bfloat162 __float22s2_rn(float2 v) {
return __float22bfloat162_rn(v);
}
template <typename T>
__device__ __forceinline__ T loadnt(T* addr) {
return __builtin_nontemporal_load(addr);
}
__device__ __forceinline__ float4 load_ntmprl(const float4* addr) {
auto addr_alias = reinterpret_cast<const float*>(addr);
auto dat0 = loadnt(addr_alias);
auto dat1 = loadnt(addr_alias + 1);
auto dat2 = loadnt(addr_alias + 2);
auto dat3 = loadnt(addr_alias + 3);
return make_float4(dat0, dat1, dat2, dat3);
}
// TBlock fetches entire rows of A, and entire col of B (K dimension); assume
// N=1 for time being grid is M/A_NUM_ROWS blocks
template <typename scalar_t, int NUM_A_ROWS_PER_BLOCK>
__global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b,
scalar_t* out_c, const int K) {
using scalar2_t = typename scalar2<scalar_t>::type;
auto af4 = reinterpret_cast<const float4*>(in_a);
auto bf4 = reinterpret_cast<const scalar2_t*>(in_b);
auto c = reinterpret_cast<scalar2_t*>(out_c);
__shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE];
const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8;
const int threadid = threadIdx.x;
const int warp = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
const int num_warps = blockDim.x / WARP_SIZE;
const int qwarpid = threadid / num_warps;
const int qthreadid = threadid % num_warps;
float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK];
scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w;
float acc[NUM_A_ROWS_PER_BLOCK];
scalar2_t acch2;
scalar2_t oval;
// As we later use warp shuffle operations, we may have more threads in the
// block than the actual available data, hence the if guard here.
if (threadid * 8 < K) {
#pragma unroll
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
// rowA_elem4[i] holds 8 * half numbers seen as a single float4.
rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]);
}
}
colB_elem4x = bf4[threadid * 4 + 0];
colB_elem4y = bf4[threadid * 4 + 1];
colB_elem4z = bf4[threadid * 4 + 2];
colB_elem4w = bf4[threadid * 4 + 3];
scalar2_t Af2;
scalar2_t Bf2;
float2 S;
auto Ah2ptr = reinterpret_cast<scalar2_t*>(&rowA_elem4);
scalar2_t* ah2lptr;
#pragma unroll
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
// Multiply-add on 8 scalar_t.
ah2lptr = Ah2ptr + i * 4;
Af2 = *(ah2lptr);
acch2 = __hmul2(Af2, colB_elem4x);
Af2 = *(ah2lptr + 1);
acch2 = __hfma2(Af2, colB_elem4y, acch2);
Af2 = *(ah2lptr + 2);
acch2 = __hfma2(Af2, colB_elem4z, acch2);
Af2 = *(ah2lptr + 3);
acch2 = __hfma2(Af2, colB_elem4w, acch2);
S = __s22float2(acch2);
// See comment above concerning the if guard.
acc[i] = (threadid * 8 < K ? S.x + S.y : 0.f);
}
// all reduce across warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
#pragma unroll
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
acc[i] += __shfl_xor(acc[i], mask);
}
}
// Warp leaders store the data to shared memory.
if (lane < NUM_A_ROWS_PER_BLOCK) {
red_smem[lane][warp] = acc[lane];
}
// Make sure the data is in shared memory.
__syncthreads();
if (qwarpid < NUM_A_ROWS_PER_BLOCK) {
acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f;
for (int mask = num_warps / 2; mask >= 1; mask /= 2) {
acc[qwarpid] += __shfl_xor(acc[qwarpid], mask);
}
float oval2 = __shfl_xor(acc[qwarpid], num_warps);
if (lane % (num_warps * 2) == 0) {
oval = __float22s2_rn<scalar2_t>(make_float2(acc[qwarpid], oval2));
c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval;
}
}
}
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
const int64_t rows_per_block) {
auto M = in_a.size(0);
auto K = in_a.size(1);
auto N = in_b.size(0);
TORCH_CHECK(N == 1, "Row number of activation tensor must be 1.");
TORCH_CHECK(in_a.dtype() == in_b.dtype());
TORCH_CHECK(in_b.dtype() == torch::kFloat16 ||
in_b.dtype() == torch::kBFloat16);
auto out_c = torch::empty(
{N, M}, torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device()));
// NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle
// operations.
const int NUM_THREADS =
K * 2 / 16 % WARP_SIZE == 0
? K * 2 / 16
: K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE);
int NUM_BLOCKS = M / rows_per_block;
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_b));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// call the kernel function...
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "LLGemm1", [&] {
auto a_ptr = in_a.data_ptr<scalar_t>();
auto b_ptr = in_b.data_ptr<scalar_t>();
auto c_ptr = out_c.data_ptr<scalar_t>();
if (rows_per_block == 2) {
LLGemm1_kernel<scalar_t, 2>
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(a_ptr, b_ptr, c_ptr, K);
} else if (rows_per_block == 4) {
LLGemm1_kernel<scalar_t, 4>
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(a_ptr, b_ptr, c_ptr, K);
} else if (rows_per_block == 8) {
LLGemm1_kernel<scalar_t, 8>
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(a_ptr, b_ptr, c_ptr, K);
} else if (rows_per_block == 16) {
LLGemm1_kernel<scalar_t, 16>
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(a_ptr, b_ptr, c_ptr, K);
} else {
NUM_BLOCKS = M / 4;
LLGemm1_kernel<scalar_t, 4>
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(a_ptr, b_ptr, c_ptr, K);
}
});
return out_c;
}
#define DOT2C(V0, V2, V3) \
if constexpr (std::is_same_v<scalar_t, half>) { \
asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { \
float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \
__bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \
V0 += (s.x + s.y); \
}
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
// This version targets cases where A[] fits LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
union bigType {
scalar_t h[A_CHUNK];
float f[A_CHUNK / 2];
float2 f2[A_CHUNK / 4];
double d[A_CHUNK / 4];
scalar8 h8;
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
//----------------------------------------------------
__shared__ scalar_t s[1024 * 32];
//----------------------------------------------------
// Fetch the activation matrix to LDS
// Loop iteration:
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements
// - Each WG will fetch 512 * 16 => 8K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min(K * N, 32 * 1024)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
}
__syncthreads();
if (threadIdx.y >= _WvPrGrp) return;
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
float sum[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
// working on 16 columns of weight matrix. Moreover,
// we tile in column direction by YTILE, so when YTILE=1
// the above math is right, however, when YTILE=2 then
// each wave will be working on 2 columns and WG will
// be working on 32 columns.
//
// Top level loop that makes WGs persistent!
// - WGs iterates across columns of weight matrix
// - Each wave within WG works on a given column(s)
// - After completing first set of columns, WGs start
// working on the next set of available columns
//----------------------------------------------------
while (m < M) {
//----------------------------------------------------
// 'sum' accumulates the matrix A x B computation
// split across 64 lanes.
//
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = 0;
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
// for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if constexpr (YTILE >= 2)
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
if constexpr (YTILE >= 3)
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
if constexpr (YTILE >= 4)
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
if constexpr (YTILE >= 5)
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
if constexpr (YTILE >= 6)
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
if constexpr (YTILE >= 7)
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
if constexpr (YTILE >= 8)
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) {
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (uint32_t n = 0; n < N; n++) {
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b])
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if constexpr (YTILE >= 2) {
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
}
if constexpr (YTILE >= 3) {
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
}
if constexpr (YTILE >= 4) {
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
}
if constexpr (YTILE >= 5) {
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
}
if constexpr (YTILE >= 6) {
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
}
if constexpr (YTILE >= 7) {
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
}
if constexpr (YTILE >= 8) {
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
}
}
}
}
}
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
// This version targets cases where A[] marginally exceeds LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
union bigType {
scalar_t h[A_CHUNK];
float f[A_CHUNK / 2];
float2 f2[A_CHUNK / 4];
double d[A_CHUNK / 4];
scalar8 h8;
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
//----------------------------------------------------
__shared__ scalar_t s[1024 * 32];
//----------------------------------------------------
// Computation of columns that need to be committed to memory!
//----------------------------------------------------
uint32_t commitColumn[YTILE];
for (uint32_t i = 0; i < YTILE; i++) {
commitColumn[i] = 1;
}
//----------------------------------------------------
// Indexing function into the column of weight matrix B
// Algorithm does 64 lane k-splitting / wave and uses
// WG ID and Thread ID to find the index.
//----------------------------------------------------
// int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp);
uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE;
// Check whether there will be fragmenation!
// This will happen only for the last wave!
if (m < M && (m + YTILE) >= M) {
uint32_t startColumn = M - YTILE;
for (uint32_t i = 0; i < (m - startColumn); i++) {
commitColumn[i] = 0;
}
m = startColumn;
}
//----------------------------------------------------
// Fetch the activation matrix to LDS
// Loop iteration:
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements
// - Each WG will fetch 512 * 16 => 8K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min(K * N, 32 * 1024)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
}
__syncthreads();
if (threadIdx.y >= _WvPrGrp) return;
float sum[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
// working on 16 columns of weight matrix. Moreover,
// we tile in column direction by YTILE, so when YTILE=1
// the above math is right, however, when YTILE=2 then
// each wave will be working on 2 columns and WG will
// be working on 32 columns.
//
// Top level loop that makes WGs persistent!
// - WGs iterates across columns of weight matrix
// - Each wave within WG works on a given column(s)
// - After completing first set of columns, WGs start
// working on the next set of available columns
//----------------------------------------------------
while (m < M) {
//----------------------------------------------------
// 'sum' accumulates the matrix A x B computation
// split across 64 lanes.
//
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = 0;
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if constexpr (YTILE >= 2)
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
if constexpr (YTILE >= 3)
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
if constexpr (YTILE >= 4)
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
if constexpr (YTILE >= 5)
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
if constexpr (YTILE >= 6)
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
if constexpr (YTILE >= 7)
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
if constexpr (YTILE >= 8)
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) {
if (k_ + K * n < 32 * 1024)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t n = 0; n < N; n++) {
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]);
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if constexpr (YTILE >= 2) {
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
}
if constexpr (YTILE >= 3) {
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
}
if constexpr (YTILE >= 4) {
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
}
if constexpr (YTILE >= 5) {
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
}
if constexpr (YTILE >= 6) {
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
}
if constexpr (YTILE >= 7) {
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
}
if constexpr (YTILE >= 8) {
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
}
}
}
}
}
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
m += CuCount * _WvPrGrp * YTILE;
// Check whether there will be fragmenation!
// This will happen only for the last wave!
if (m < M && (m + YTILE) >= M) {
uint32_t startColumn = M - YTILE;
for (uint32_t i = 0; i < (m - startColumn); i++) {
commitColumn[i] = 0;
}
m = startColumn;
}
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
// This version targets big A[] cases, where it is much larger than LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
union bigType {
scalar_t h[A_CHUNK];
float f[A_CHUNK / 2];
float2 f2[A_CHUNK / 4];
double d[A_CHUNK / 4];
scalar8 h8;
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
//----------------------------------------------------
__shared__ scalar_t s[1024 * 32];
//----------------------------------------------------
// Computation of columns that need to be committed to memory!
//----------------------------------------------------
uint32_t commitColumn[YTILE];
for (uint32_t i = 0; i < YTILE; i++) {
commitColumn[i] = 1;
}
// int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp);
if (threadIdx.y >= _WvPrGrp) return;
//----------------------------------------------------
// Indexing function into the column of weight matrix B
// Algorithm does 64 lane k-splitting / wave and uses
// WG ID and Thread ID to find the index.
//----------------------------------------------------
uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE;
// Check whether there will be fragmenation!
// This will happen only for the last wave!
if (m < M && (m + YTILE) >= M) {
uint32_t startColumn = M - YTILE;
for (uint32_t i = 0; i < (m - startColumn); i++) {
commitColumn[i] = 0;
}
m = startColumn;
}
//----------------------------------------------------
// Fetch the activation matrix to LDS
// Loop iteration:
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements
// - Each WG will fetch 512 * 16 => 8K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
#define PCML
#ifndef PCML
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min(K * N, 32 * 1024)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
}
__syncthreads();
#endif
#define TUC (THRDS * UNRL * A_CHUNK)
uint32_t kBase = 0;
// find biggest k size that fits in LDS
uint32_t kFit = (32 * 1024) / N;
// kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple
// of TUC
kFit = (kFit % TUC == 0)
? kFit
: (kFit - kFit % TUC); // round up to multiple of TUC
// if (kFit == 0) kFit = TUC;
kFit = min(kFit, K);
float sum[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
// working on 16 columns of weight matrix. Moreover,
// we tile in column direction by YTILE, so when YTILE=1
// the above math is right, however, when YTILE=2 then
// each wave will be working on 2 columns and WG will
// be working on 32 columns.
//
// Top level loop that makes WGs persistent!
// - WGs iterates across columns of weight matrix
// - Each wave within WG works on a given column(s)
// - After completing first set of columns, WGs start
// working on the next set of available columns
//----------------------------------------------------
#ifdef PCML
int YW = (YTILE * _WvPrGrp);
uint32_t Mrndp = (M % YW == 0) ? M : (M - M % YW + YW);
while (m < Mrndp) {
#else
while (m < M) {
#endif
//----------------------------------------------------
// 'sum' accumulates the matrix A x B computation
// split across 64 lanes.
//
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = 0;
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
#ifdef PCML
if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS
if (k1 != 0) kBase += kFit;
__syncthreads();
for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) {
uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (kBase + kOff >= K) break;
if (kOff >= kFit) break;
for (uint32_t n = 0; n < N; n++) {
uint32_t k_in = kBase + n * K + kOff;
uint32_t k_ot = n * kFit + kOff;
*((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in]));
}
}
__syncthreads();
}
if (m >= M) continue;
#endif
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if constexpr (YTILE >= 2)
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
if constexpr (YTILE >= 3)
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
if constexpr (YTILE >= 4)
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
if constexpr (YTILE >= 5)
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
if constexpr (YTILE >= 6)
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
if constexpr (YTILE >= 7)
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
if constexpr (YTILE >= 8)
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) {
#ifdef PCML
bigA[n][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * n])));
#else
if (k_ + K * n < 32 * 1024)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
#endif
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
#pragma unroll
for (uint32_t n = 0; n < N; n++) {
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]);
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if constexpr (YTILE >= 2) {
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
}
if constexpr (YTILE >= 3) {
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
}
if constexpr (YTILE >= 4) {
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
}
if constexpr (YTILE >= 5) {
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
}
if constexpr (YTILE >= 6) {
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
}
if constexpr (YTILE >= 7) {
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
}
if constexpr (YTILE >= 8) {
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
}
}
}
}
}
#ifdef PCML
if (m >= M) {
m += CuCount * _WvPrGrp * YTILE;
kBase = 0;
continue;
}
#endif
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
m += CuCount * _WvPrGrp * YTILE;
kBase = 0;
// Check whether there will be fragmenation!
// This will happen only for the last wave!
if (m < M && (m + YTILE) >= M) {
uint32_t startColumn = M - YTILE;
for (uint32_t i = 0; i < (m - startColumn); i++) {
commitColumn[i] = 0;
}
m = startColumn;
}
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
int mindiv(int N, int div1, int div2) {
int nPrRnd = div1 * div2;
int rnds0 = N / nPrRnd;
nPrRnd -= div1 * 3;
int rnds3 = N / nPrRnd;
nPrRnd -= div1;
int rnds4 = N / nPrRnd;
nPrRnd -= div1;
int rnds5 = N / nPrRnd;
nPrRnd -= div1;
int rnds6 = N / nPrRnd;
nPrRnd -= div1;
int rnds7 = N / nPrRnd;
nPrRnd -= div1;
int rnds8 = N / nPrRnd;
nPrRnd -= div1;
int rnds9 = N / nPrRnd;
nPrRnd -= div1;
int rtn = div2;
if (rnds0 == rnds3) rtn = div2 - 3;
if (rnds0 == rnds4) rtn = div2 - 4;
if (rnds0 == rnds5) rtn = div2 - 5;
if (rnds0 == rnds6) rtn = div2 - 6;
if (rnds0 == rnds7) rtn = div2 - 7;
if (rnds0 == rnds8) rtn = div2 - 8;
if (rnds0 == rnds9) rtn = div2 - 9;
return rtn;
}
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
const int64_t CuCount) {
auto M_in = in_a.size(0);
auto K_in = in_a.size(1);
auto N_in = in_b.size(0);
TORCH_CHECK(in_a.dtype() == in_b.dtype());
TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0");
TORCH_CHECK(in_a.dtype() == torch::kFloat16 ||
in_a.dtype() == torch::kBFloat16);
auto out_c = torch::empty(
{N_in, M_in},
torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device()));
dim3 grid(CuCount);
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
CuCount); \
} else if (K_in * N_in <= 32 * 1024 * 1.2) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
CuCount); \
} else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
CuCount); \
} \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
using fptype = typename scalar<scalar_t>::type;
fptype* af4 = reinterpret_cast<fptype*>(in_a.data_ptr());
const fptype* bf4 = reinterpret_cast<const fptype*>(in_b.data_ptr());
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
switch (N_in) {
case 1:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1)
break;
case 2:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2)
break;
case 3:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3)
break;
case 4:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4)
break;
default:
throw std::runtime_error(
"Unsupported N value: " + std::to_string(M_in) + "," +
std::to_string(K_in) + "," + std::to_string(N_in));
}
});
return out_c;
}
#if defined(__HIP__MI300__) // TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const fp8_t* B,
const fp8_t* __restrict__ A, scalar_t* C,
const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp,
const int CuCount) {
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int;
union bigType {
char f8[A_CHUNK];
char2 c2[A_CHUNK / 2];
scalar_t h[A_CHUNK / 2];
float f[A_CHUNK / 4];
int i[A_CHUNK / 4];
long l[A_CHUNK / 8];
intx4 l2[A_CHUNK / 16];
scalar8 h8;
};
__shared__ fp8_t s[1024 * 64];
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
}
__syncthreads();
if (threadIdx.y >= _WvPrGrp) return;
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
floatx16 sum[N][YTILE];
float sA = *s_A;
float sB = *s_B;
while (m < M) {
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = {0.f};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
#pragma unroll
for (uint32_t n = 0; n < N; ++n) bigA[n][k2].h8 = {0.f};
#pragma unroll
for (uint32_t y = 0; y < YTILE; ++y) bigB[y][k2].h8 = {0.f};
}
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const fp8_t* B_ = &B[(m + 0) * Kp + k_];
#pragma unroll
for (uint32_t y = 0; y < YTILE; ++y) {
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp])));
}
}
// Fetch activation matrix from either just LDS or from both LDS / memory
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
for (int n = 0; n < N; n++) {
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
if (k >= K) break;
for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0,
0);
}
}
}
}
}
// Final reduction
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8];
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][1]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][9]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][2]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][10]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][3]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][11]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][4]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][12]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][5]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][13]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][6]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][14]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][7]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][15]), "v"(accm16));
accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16);
}
}
if (threadIdx.x == 0) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0] * sA * sB);
}
}
}
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
const fp8_t* B, const fp8_t* __restrict__ A,
scalar_t* C, const float* __restrict__ s_A,
const float* __restrict__ s_B,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
#if defined(__HIP__MI300__) // TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitKQ_hf_(const int K, const int Kp, const int M, const fp8_t* B,
const fp8_t* __restrict__ A, scalar_t* C,
const float* __restrict__ s_A, const float* __restrict__ s_B,
const int _WvPrGrp, const int CuCount) {
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int;
union bigType {
char f8[A_CHUNK];
char2 c2[A_CHUNK / 2];
scalar_t h[A_CHUNK / 2];
float f[A_CHUNK / 4];
int i[A_CHUNK / 4];
long l[A_CHUNK / 8];
intx4 l2[A_CHUNK / 16];
scalar8 h8;
};
__shared__ fp8_t s[1024 * 64];
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
}
__syncthreads();
if (threadIdx.y >= _WvPrGrp) return;
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
floatx16 sum[N][YTILE];
float sA = *s_A;
float sB = *s_B;
while (m < M) {
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = {0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const fp8_t* B_ = &B[(m + 0) * Kp + k_];
for (int y = 0; y < YTILE; ++y) {
if (y + m >= M) break; // To avoid mem access fault.
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp])));
}
}
// Fetch activation matrix from either just LDS or from both LDS / memory
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
for (int n = 0; n < N; n++) {
if (k_ + K * n < 64 * 1024)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0,
0);
}
}
}
}
}
// Final reduction
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8];
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][1]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][9]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][2]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][10]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][3]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][11]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][4]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][12]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][5]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][13]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][6]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][14]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][7]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][15]), "v"(accm16));
accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16);
}
}
if (threadIdx.x == 0) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault.
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0] * sA * sB);
}
}
}
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
const fp8_t* B, const fp8_t* __restrict__ A,
scalar_t* C, const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp,
const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b,
const int64_t CuCount) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
auto M_in = in_a.size(0);
auto K_in = in_a.size(1);
auto N_in = in_b.size(0);
auto Kp_in = in_a.stride(0);
TORCH_CHECK(K_in % 16 == 0, "k % 16 == 0");
TORCH_CHECK(in_a.dtype() == in_b.dtype() && in_a.dtype() == kFp8Type);
TORCH_CHECK(out_c.dtype() == torch::kFloat16 ||
out_c.dtype() == torch::kBFloat16);
dim3 grid(CuCount);
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= 64 * 1024) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \
} else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \
} \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_c.scalar_type(), "wvSplitKQ", [&] {
using fptype = typename scalar<scalar_t>::type;
auto c_ptr = reinterpret_cast<fptype*>(out_c.data_ptr());
auto s_a = scale_a.data_ptr<float>();
auto s_b = scale_b.data_ptr<float>();
VLLM_DISPATCH_FP8_TYPES(in_a.scalar_type(), "wvSplitKQ", [&] {
auto a_ptr = in_a.data_ptr<fp8_t>();
auto b_ptr = in_b.data_ptr<fp8_t>();
switch (N_in) {
case 1:
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1)
break;
case 2:
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 2)
break;
case 3:
WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 3)
break;
case 4:
WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 4)
break;
default:
throw std::runtime_error(
"Unsupported N value: " + std::to_string(M_in) + "," +
std::to_string(K_in) + "," + std::to_string(N_in));
}
});
});
}
\ No newline at end of file
......@@ -14,6 +14,24 @@
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
// vLLM custom ops for rocm
// Custom gemm op for matrix-vector multiplication
rocm_ops.def(
"LLMM1(Tensor in_a, Tensor in_b, int rows_per_block) -> "
"Tensor");
rocm_ops.impl("LLMM1", torch::kCUDA, &LLMM1);
// Custom gemm op for skinny matrix-matrix multiplication
rocm_ops.def(
"wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> "
"Tensor");
rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK);
// wvSplitK for fp8
rocm_ops.def(
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, "
" Tensor scale_b, int CuCount) -> ()");
rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ);
// Custom attention op
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
......
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float16]
M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192]
K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # k % 8 == 0
N = [1, 2, 3, 4]
SEEDS = [0]
@pytest.mark.parametrize("n", [1]) # only test for batch size 1
@pytest.mark.parametrize("k", K)
@pytest.mark.parametrize("m", M)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
@torch.inference_mode()
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
torch.manual_seed(seed)
A = torch.rand(n, k, dtype=dtype, device="cuda")
B = torch.rand(m, k, dtype=dtype, device="cuda")
ref_out = torch.matmul(A, B.t())
out = ops.LLMM1(B, A, rows_per_block)
assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("n", N) # only test for batch size <= 4
@pytest.mark.parametrize("k", K + [9216, 10240, 16384])
@pytest.mark.parametrize("m", [8] + M) # m >= 8
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
A = torch.rand(n, k, dtype=dtype, device="cuda")
B = torch.rand(m, k, dtype=dtype, device="cuda")
ref_out = torch.matmul(A, B.t())
out = ops.wvSplitK(B, A, cu_count)
assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("n", N) # only test for batch size <= 4
@pytest.mark.parametrize("k", K[1:] + [14336, 24576, 32768]) # k % 16 == 0
@pytest.mark.parametrize("m", M + [28672]) # m >= 16
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
A = torch.rand(n, k, device="cuda")
B = torch.rand(m, k, device="cuda")
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
ref_out = torch._scaled_mm(A,
B.t(),
out_dtype=dtype,
scale_a=scale_a,
scale_b=scale_b)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
current_platform.get_cu_count())
assert torch.allclose(out, ref_out, rtol=0.01)
......@@ -1196,6 +1196,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
ssm_states, pad_slot_id)
# ROCm skinny gemms
def LLMM1(a: torch.Tensor, b: torch.Tensor,
rows_per_block: int) -> torch.Tensor:
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)
def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor:
return torch.ops._rocm_C.wvSplitK(a, b, cu_count)
def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor,
cu_count: int) -> torch.Tensor:
out = torch.empty((b.shape[0], a.shape[0]),
dtype=out_dtype,
device=b.device)
torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count)
return out
# moe
def moe_sum(input: torch.Tensor, output: torch.Tensor):
torch.ops._moe_C.moe_sum(input, output)
......
......@@ -78,6 +78,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
......@@ -550,6 +551,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
("true", "1")),
# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
("true", "1")),
# Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
......
......@@ -6,7 +6,6 @@ from typing import Any, Literal, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
......@@ -17,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter,
BlockQuantScaleParameter,
......@@ -188,7 +188,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return F.linear(x, layer.weight, bias)
return dispatch_unquantized_gemm()(x, layer.weight, bias)
class LinearBase(torch.nn.Module):
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform
......@@ -17,6 +18,7 @@ TORCH_DEVICE_IDENTITY = None
# The condition is determined once as the operations
# are time consuming.
USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm()
and torch.__version__[0:3] >= "2.7"
and current_platform.has_device_capability(94))
......@@ -131,6 +133,159 @@ def maybe_create_device_identity():
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
out_dtype: torch.dtype, scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
output_shape: List, **kwargs) -> torch.Tensor:
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
return output.view(*output_shape)
def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
if envs.VLLM_ROCM_USE_SKINNY_GEMM and qinput.shape[
0] == 1 and qinput.shape[1] % 16 == 0:
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count())
else:
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
# For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
#
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b.t(),
bias=bias)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
output = output.view(*output_shape)
return output
def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List,
**kwargs) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * scale_b.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)
def dispatch_w8a8_scaled_mm(
cutlass_fp8_supported: bool, per_tensor_weights: bool,
per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool]
) -> Callable[..., torch.Tensor]:
if cutlass_fp8_supported:
return cutlass_w8a8_scaled_mm
if per_tensor_weights and per_tensor_activations:
if current_platform.is_rocm():
return rocm_per_tensor_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
if (use_per_token_if_dynamic and not per_tensor_weights
and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM):
return torch_per_token_w8a8_scaled_mm
return torch_channelwise_w8a8_scaled_mm
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397
class Fp8LinearOp:
......@@ -156,7 +311,8 @@ class Fp8LinearOp:
if pad_output is None:
config = get_current_vllm_config().compilation_config
pad_output = config.level < CompilationLevel.PIECEWISE
self.output_padding = 17 if pad_output else None
self.output_padding = 17 if (
pad_output and not current_platform.is_rocm()) else None
def apply(
self,
......@@ -195,18 +351,6 @@ class Fp8LinearOp:
input_scale,
scale_ub=input_scale_ub,
use_per_token_if_dynamic=use_per_token_if_dynamic)
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
return output.view(*output_shape)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else:
if input.dtype != current_platform.fp8_dtype():
# Maybe apply padding to output, see comment in __init__
......@@ -218,84 +362,21 @@ class Fp8LinearOp:
else:
qinput, x_scale = input_2d, input_scale
per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1)
if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0,
input_2d.shape[0]).view(*output_shape)
elif (use_per_token_if_dynamic and not per_tensor_weights
and not per_tensor_activations
and USE_ROWWISE_TORCH_SCALED_MM):
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale.t(),
bias=bias)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
output = output.view(*output_shape)
return output
else:
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * weight_scale.t()
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1)
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
self.cutlass_fp8_supported, per_tensor_weights,
per_tensor_activations, use_per_token_if_dynamic)
return w8a8_scaled_mm_func(qinput=qinput,
weight=weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
input_2d=input_2d,
output_shape=output_shape)
def normalize_e4m3fn_to_e4m3fnuz(
......
# SPDX-License-Identifier: Apache-2.0
"""Utility methods for model layers."""
from typing import Tuple
from typing import Callable, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.platforms import current_platform
def get_token_bin_counts_and_mask(
tokens: torch.Tensor,
......@@ -61,3 +65,34 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits
def rocm_unquantized_gemm(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None):
k = weight.shape[1]
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and \
x.dtype in [torch.float16, torch.bfloat16] \
and k % 8 == 0 and bias is None)
if use_skinny is not True:
return torch.nn.functional.linear(x, weight, bias)
x_view = x.view(-1, x.size(-1))
n = x_view.shape[0]
m = weight.shape[0]
cu_count = current_platform.get_cu_count()
if m > 8 and n < 4:
out = ops.wvSplitK(weight, x_view, cu_count)
return out.view(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192:
out = ops.LLMM1(weight, x_view, out, 4)
return out.view(*x.shape[:-1], weight.shape[0])
return torch.nn.functional.linear(x, weight, bias)
def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
if current_platform.is_rocm():
return rocm_unquantized_gemm
return torch.nn.functional.linear
......@@ -413,6 +413,13 @@ class Platform:
self.device_name, key)
return None
@classmethod
def get_cu_count(cls, device_id: int = 0) -> int:
"""
Returns the total number of compute units (CU) on single GPU.
"""
raise NotImplementedError
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
......
......@@ -312,3 +312,8 @@ class RocmPlatform(Platform):
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
supported_archs = ['gfx94']
return any(gfx in gcn_arch for gfx in supported_archs)
@classmethod
def get_cu_count(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(
device_id).multi_processor_count
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment