Commit 4c676e3d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.1' into v0.9.1-dev

parents b4c4464d b6553be1
......@@ -25,8 +25,17 @@
#include "../attention/dtype_fp8.cuh"
#include "../quantization/fp8/amd/quant_utils.cuh"
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
#if defined(__HIPCC__) && \
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
#define __HIP__GFX9__
#endif
#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__))
#define __HIP__GFX11__
#endif
#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__))
#define __HIP__GFX12__
#endif
#if defined(NDEBUG)
......@@ -42,7 +51,7 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#if defined(__HIP__GFX9__)
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
......@@ -1286,7 +1295,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions) {
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x;
const auto head_idx = blockIdx.x;
const auto seq_idx = blockIdx.y;
......@@ -1464,8 +1473,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const float inv_global_exp_sum =
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
const float out_scale =
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
acc *= inv_global_exp_sum;
acc *= out_scale;
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
......@@ -1479,192 +1490,1697 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#elif defined(__HIP__GFX11__)
// clang-format off
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED,
int GQA_RATIO>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
UNREACHABLE_CODE
using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float;
using bit16_t = uint16_t;
using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t;
typedef bit16x4 _B16x4;
using bit16x8 = __attribute__((__vector_size__(8 * sizeof(uint16_t)))) uint16_t;
union b16x8_u {
bit16x8 u16x8;
_B16x4 xy[2];
};
typedef b16x8_u _B16x8;
using bit16x16 =
__attribute__((__vector_size__(16 * sizeof(uint16_t)))) uint16_t;
union b16x16_u {
bit16x16 u16x16;
_B16x8 xy[2];
};
typedef b16x16_u _B16x16;
using _B8x8 = uint2;
using bit8_t = uint8_t;
typedef struct _B8x16 {
_B8x8 xy[2];
} _B8x16;
template <typename T, int absz, int cbid, int blgp>
__device__ __forceinline__ floatx8 gcn_wmma16x16x16_instr(const bit16x16& inpA,
const bit16x16& inpB,
const floatx8& inpC) {
if constexpr (std::is_same<T, _Float16>::value) {
return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(inpA, inpB, inpC);
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(inpA, inpB, inpC);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ float to_float(const T& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
return (float)inp;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __bfloat162float(inp);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ T from_float(const float& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
return (_Float16)inp;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __float2bfloat16(inp);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
union h2cvt {
__half2 h2[4];
_B16x8 b16x8;
} u;
u.h2[0] = __float22half2_rn(make_float2(inp[0], inp[1]));
u.h2[1] = __float22half2_rn(make_float2(inp[2], inp[3]));
u.h2[2] = __float22half2_rn(make_float2(inp[4], inp[5]));
u.h2[3] = __float22half2_rn(make_float2(inp[6], inp[7]));
return u.b16x8;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
union b2cvt {
__hip_bfloat162 b2[4];
_B16x8 b16x8;
} u;
u.b2[0] = __float22bfloat162_rn(make_float2(inp[0], inp[1]));
u.b2[1] = __float22bfloat162_rn(make_float2(inp[2], inp[3]));
u.b2[2] = __float22bfloat162_rn(make_float2(inp[4], inp[5]));
u.b2[3] = __float22bfloat162_rn(make_float2(inp[6], inp[7]));
return u.b16x8;
} else {
static_assert(false, "unsupported 16b dtype");
}
}
// clang-format off
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED,
int GQA_RATIO>
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
UNREACHABLE_CODE
}
// clang-format on
constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11
const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE;
const int lane2id = laneid % 2;
const int lane4id = laneid % 4;
const int lane16id = laneid % 16;
const int rowid = laneid / 16;
// Grid: (num_heads, num_seqs).
template <typename scalar_t, typename OUTT, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE, int NPAR_LOOPS>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
OUTT* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions) {
UNREACHABLE_CODE
}
// clang-format on
const int seq_idx = blockIdx.x;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if (query_start_loc_ptr != nullptr &&
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx]) != 1) {
return;
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
const int partition_idx = blockIdx.y;
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
constexpr int T_PAR_SIZE = 256; // token partition size set to 256
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
const int max_num_partitions = gridDim.y;
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
context_lens_ptr, query_start_loc_ptr, max_num_partitions);
const int context_len = context_lens[seq_idx]; // length of a seq
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
bool ALIBI_ENABLED>
void paged_attention_custom_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
int num_seqs = block_tables.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
const int partition_start_token_idx = partition_idx * T_PAR_SIZE;
// exit if partition is out of context for seq
if (partition_start_token_idx >= context_len) {
return;
}
// NOTE: query start location is optional for V0 decode should not be used.
// If batch contains mix of prefills and decode, prefills should be skipped.
const int* query_start_loc_ptr =
query_start_loc
? reinterpret_cast<const int*>(query_start_loc.value().data_ptr())
: nullptr;
constexpr int GQA_RATIO2 = DIVIDE_ROUND_UP(GQA_RATIO, 2);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
__shared__ float shared_qk_max[NWARPS][16 + 1];
__shared__ float shared_exp_sum[NWARPS][16 + 1];
// shared_logits is used for multiple purposes
__shared__ _B16x16 shared_logits[NWARPS][2][16][2];
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
// for QK wmma16x16, layout is QHead/Tokenx16 across every 16 lanes,
// 32 Bytes HeadElements in each lane, 2x16B HeadElements across a row of warp
constexpr int ROWS_PER_WARP =
WARP_SIZE / 16 / 2; // rows refers to 16 lanes; refer dpp terminology
constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD =
16 / sizeof(cache_t); // 8 for 16 bit cache type, 16 for 8 bit types
constexpr int QKHE_PER_FETCH =
CONTIGUOUS_KV_ELEMS_16B_LOAD *
ROWS_PER_WARP; // each fetch across a warp fetches these many elements
constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 2xQKHE_16B across
// warp
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
_B16x16 Qlocal[QKHELOOP / 2]; // note that 16 contiguous elements of Q should
// be fetched per lane for 16 bit cache types
// partition size is fixed at 256 since both mfma4 and mfma16 kernels support
// it mfma4 kernel also supports partition size 512
constexpr int PARTITION_SIZE = 256;
const int max_num_partitions =
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
const int gqa_ratio = num_heads / num_kv_heads;
assert(num_heads % num_kv_heads == 0);
assert(head_size == HEAD_SIZE);
constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t);
constexpr int NTHR = 256;
dim3 grid(num_seqs, max_num_partitions, num_kv_heads);
dim3 block(NTHR);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
constexpr int TOKENS_PER_WARP =
T_PAR_SIZE /
NWARPS; // sub partition of tokens per warp for qk calculation
constexpr int TLOOP =
TOKENS_PER_WARP /
16; // each wmma16x16x16 instruction processes 16 tokens
// mfma4 kernel is faster than mfma16 for gqa_ratio <= 4
switch (gqa_ratio) {
case 1:
LAUNCH_CUSTOM_ATTENTION_MFMA4(1);
break;
case 2:
LAUNCH_CUSTOM_ATTENTION_MFMA4(2);
break;
case 3:
LAUNCH_CUSTOM_ATTENTION_MFMA4(3);
break;
case 4:
LAUNCH_CUSTOM_ATTENTION_MFMA4(4);
break;
case 5:
LAUNCH_CUSTOM_ATTENTION_MFMA16(5);
break;
case 6:
LAUNCH_CUSTOM_ATTENTION_MFMA16(6);
break;
case 7:
_B16x16 Klocal[TLOOP]
[QKHELOOP / 2]; // can be interpreted as B8x16 for 8 bit types
const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
const int wg_start_kv_head_idx = blockIdx.z;
const int total_num_heads = gridDim.z * GQA_RATIO;
// for QK wmma, tokens in multiples of TOKENS_PER_WARP are spread across warps
// each wmma takes QH16xT16x16HE across warp
// repeat wmma across QKHELOOP dimension
// output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens
// across 2 rows x 8 tokens per lane
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
if (GQA_RATIO == 1) {
const int local_qhead_idx = lane16id % GQA_RATIO;
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
const scalar_t* q_ptr =
q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;
if (lane16id < GQA_RATIO) {
#pragma unroll
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) {
const scalar_t* q_fetch_ptr = q_ptr + qkhe_depth * QKHE_PER_FETCH * 2;
const _B16x16* q_fetch_ptr_32B =
reinterpret_cast<const _B16x16*>(q_fetch_ptr);
Qlocal[qkhe_depth] = *q_fetch_ptr_32B;
}
}
} else {
// fetch Q in shared across warps and then write to registers
const int local_qhead_idx = 2 * warpid + rowid;
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
const scalar_t* q_ptr =
q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;
const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B;
if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) {
const scalar_t* q_fetch_ptr = q_ptr + qhead_element;
const _B16x8* q_fetch_ptr_16B =
reinterpret_cast<const _B16x8*>(q_fetch_ptr);
_B16x8 tmp = *q_fetch_ptr_16B;
const int offset1 =
lane16id /
2; // 16 contiguous chunks of head elems are spread across 8x2lanes
shared_logits[offset1][lane2id][local_qhead_idx][0].xy[0] = tmp;
}
__syncthreads();
#pragma unroll
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) {
Qlocal[qkhe_depth].xy[0] =
shared_logits[qkhe_depth][0][lane16id % GQA_RATIO][0].xy[0];
Qlocal[qkhe_depth].xy[1] =
shared_logits[qkhe_depth][1][lane16id % GQA_RATIO][0].xy[0];
}
}
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int last_ctx_block = num_context_blocks - 1;
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
int kphysical_block_number[TLOOP];
// fetch k physical block numbers
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
const int kblock_idx = (kglobal_token_idx < context_len)
? kglobal_token_idx / BLOCK_SIZE
: last_ctx_block;
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
}
constexpr int KX = 16 / sizeof(cache_t);
const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride;
const int row_head_elem = 0;
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int64_t kblock_number =
static_cast<int64_t>(kphysical_block_number[token_depth]);
const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride;
const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE;
const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX;
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH;
const int offset1 = head_elem / KX;
const int offset2 = head_elem % KX;
const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2;
const _B16x8* k_fetch_ptr_16B =
reinterpret_cast<const _B16x8*>(k_fetch_ptr);
Klocal[token_depth][qkhe_depth / 2].xy[qkhe_depth % 2] = *k_fetch_ptr_16B;
}
}
constexpr int VTOKENS_PER_LANE =
TOKENS_PER_WARP / ROWS_PER_WARP; // 32/1 = 32 vtokens per lane
constexpr int VBLOCKS_PER_LANE = 2; // assumes block size >=16
constexpr int VTLOOP = NWARPS; // corresponds to tokens across warps
constexpr int VTLANELOOP = DIVIDE_ROUND_UP(
VTOKENS_PER_LANE,
CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes
// minimum block size is 16
constexpr int VHELOOP = DIVIDE_ROUND_UP(
(HEAD_SIZE / 16), NWARPS); // head_size distributed across warps; each
// wmma instr works on 16 head elements
int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE];
// fetch v physical block numbers
for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) {
for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE;
vblock_depth++) {
const int vlocal_token_idx =
vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP +
vblock_depth * BLOCK_SIZE;
const int vglobal_token_idx =
partition_start_token_idx + vlocal_token_idx;
const int vblock_idx = (vglobal_token_idx < context_len)
? vglobal_token_idx / BLOCK_SIZE
: last_ctx_block;
vphysical_block_number[vtoken_depth][vblock_depth] =
block_table_seq[vblock_idx];
}
}
_B16x16 Vlocal[VTLOOP][VHELOOP]
[VTLANELOOP / 2]; // this can be interpreted as B8x16 too
const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
// v fetches are 16head elems across lanes x (16x2) tokens per lane
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id;
const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE;
for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) {
for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) {
const int64_t vblock_number = static_cast<int64_t>(
vphysical_block_number[vtoken_depth]
[vfetch_depth / VBLOCKS_PER_LANE]);
const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride);
const cache_t* v_fetch_ptr =
v_ptr3 +
(vfetch_depth % VBLOCKS_PER_LANE) * CONTIGUOUS_KV_ELEMS_16B_LOAD;
const _B16x8* v_fetch_ptr_16B =
reinterpret_cast<const _B16x8*>(v_fetch_ptr);
Vlocal[vtoken_depth][vhe_depth][vfetch_depth / 2].xy[vfetch_depth % 2] =
*v_fetch_ptr_16B;
}
}
}
floatx8 dout[TLOOP];
// qk wmma
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
dout[token_depth] = {0};
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) {
dout[token_depth] = gcn_wmma16x16x16_instr<scalar_t, 0, 0, 0>(
Klocal[token_depth][qkhe_depth].u16x16, Qlocal[qkhe_depth].u16x16,
dout[token_depth]);
}
dout[token_depth] *= scale;
}
// calculate qk_max and exp_sum per warp and write to shared memory
float qk_max = -FLT_MAX;
float exp_sum = 0.0f;
const int qkout_token_idx =
partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid;
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
const float tmp = (local_token_idx + 2 * i < context_len)
? dout[token_depth][i]
: -FLT_MAX;
qk_max = fmaxf(qk_max, tmp);
}
}
qk_max = fmaxf(qk_max, __shfl_xor(qk_max, 16));
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
const float tmp = (local_token_idx + 2 * i < context_len)
? __expf(dout[token_depth][i] - qk_max)
: 0.0f;
dout[token_depth][i] = tmp;
exp_sum += tmp;
}
}
exp_sum += __shfl_xor(exp_sum, 16);
__syncthreads();
if (laneid < 16) {
shared_qk_max[warpid][lane16id] = qk_max;
shared_exp_sum[warpid][lane16id] = exp_sum;
}
__syncthreads();
// calculate partition qk_max and exp_sum
float partition_qk_max = -FLT_MAX;
float warp_qk_max_exp[NWARPS];
float partition_exp_sum = 0.0f;
#pragma unroll
for (int w = 0; w < NWARPS; w++) {
warp_qk_max_exp[w] = shared_qk_max[w][lane16id];
partition_qk_max = fmaxf(partition_qk_max, warp_qk_max_exp[w]);
}
for (int w = 0; w < NWARPS; w++) {
warp_qk_max_exp[w] = __expf(warp_qk_max_exp[w] - partition_qk_max);
partition_exp_sum += shared_exp_sum[w][lane16id] * warp_qk_max_exp[w];
}
const float inv_sum_scale =
__fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid];
__syncthreads();
// write logits to shared mem
#pragma unroll
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
dout[token_depth] *= inv_sum_scale;
shared_logits[warpid][token_depth][lane16id][0].xy[rowid] =
from_floatx8<scalar_t>(dout[token_depth]);
}
__syncthreads();
_B16x8 swp_buf[TLOOP][2];
#pragma unroll
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
swp_buf[token_depth][0] =
shared_logits[warpid][token_depth][lane16id][0].xy[0];
swp_buf[token_depth][1] =
shared_logits[warpid][token_depth][lane16id][0].xy[1];
}
#pragma unroll
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
#pragma unroll
for (int i = 0; i < 8; i++) {
shared_logits[warpid][token_depth][lane16id][0].xy[rowid].u16x8[i] =
swp_buf[token_depth][i % 2].u16x8[4 * rowid + (i / 2)];
}
}
// write out partition max_logits and exp_sum
if (threadIdx.x < GQA_RATIO) {
const int qhead_idx = lane16id;
const int offset = seq_idx * total_num_heads * max_num_partitions +
(wg_start_head_idx + qhead_idx) * max_num_partitions +
partition_idx;
max_logits[offset] = partition_qk_max;
exp_sums[offset] = partition_exp_sum;
}
__syncthreads();
_B16x8 outelems[VHELOOP];
// Softmax V wmma
// v layout: 16he across lanes x (16x2) tokens per lane
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
floatx8 tmp_out = {0};
for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) {
for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP / 2;
vfetch_depth++) {
const int offset = vfetch_depth;
// if output format is 16 qheads across 16 lanes, 16 head elems spread
// across rows
tmp_out = gcn_wmma16x16x16_instr<scalar_t, 0, 0, 0>(
Vlocal[vtoken_depth][vhe_depth][vfetch_depth].u16x16,
shared_logits[vtoken_depth][offset][lane16id][0].u16x16, tmp_out);
}
}
outelems[vhe_depth] = from_floatx8<scalar_t>(tmp_out);
}
__syncthreads();
#pragma unroll
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
shared_logits[warpid][vhe_depth][lane16id][0].xy[rowid] =
outelems[vhe_depth]; // lane16 id head dimension; rowid head element
// dimension
}
__syncthreads();
#pragma unroll
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
swp_buf[vhe_depth][0] = shared_logits[warpid][vhe_depth][lane16id][0].xy[0];
swp_buf[vhe_depth][1] = shared_logits[warpid][vhe_depth][lane16id][0].xy[1];
}
#pragma unroll
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
#pragma unroll
for (int i = 0; i < 8; i++) {
shared_logits[warpid][vhe_depth][lane16id][0].xy[rowid].u16x8[i] =
swp_buf[vhe_depth][i % 2].u16x8[4 * rowid + (i / 2)];
}
}
__syncthreads();
// write to tmp_out with coalesced writes after reading from shared mem
if (warpid == 0) {
_B16x8 vout[GQA_RATIO2];
// each lane writes out 16Bytes of tmp_out along head elem dimension
const int head_elem_idx = lane16id * 8;
if (head_elem_idx < HEAD_SIZE) {
for (int h = 0; h < GQA_RATIO2; h++) {
const int local_head_idx = 2 * h + rowid;
const int offset1 = (head_elem_idx / 16) % NWARPS;
const int offset2 = head_elem_idx / 16 / NWARPS;
const int offset3 = (head_elem_idx / 8) % 2; // num_he % num_row
vout[h] =
shared_logits[offset1][offset2][local_head_idx][0].xy[offset3];
}
const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions;
scalar_t* out_ptr = out + seq_idx * total_num_heads * hsz_maxp_mult +
partition_idx * HEAD_SIZE;
for (int h = 0; h < GQA_RATIO2; h++) {
const int local_head_idx = 2 * h + rowid;
if (local_head_idx < GQA_RATIO) {
const int out_head_idx = wg_start_head_idx + local_head_idx;
scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult;
scalar_t* out_ptr3 = out_ptr2 + head_elem_idx;
_B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3);
*out_ptr_B16x8 = vout[h];
}
}
}
}
}
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED,
int GQA_RATIO>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
UNREACHABLE_CODE
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, typename OUTT, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE, int NPAR_LOOPS>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
OUTT* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x;
const auto head_idx = blockIdx.x;
const auto seq_idx = blockIdx.y;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if (query_start_loc_ptr != nullptr &&
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
return;
}
const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
[[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
[[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE;
__shared__ float shared_global_exp_sum;
// max num partitions supported is warp_size * NPAR_LOOPS
__shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE];
if (warpid == 0) {
const float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
// valid partition is the last valid partition in case threadid > num
// partitions
int valid_partition[NPAR_LOOPS];
float reg_max_logit[NPAR_LOOPS];
const int last_valid_partition = num_partitions - 1;
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
const int partition_no = i * WARP_SIZE + threadIdx.x;
valid_partition[i] =
(partition_no < num_partitions) ? partition_no : last_valid_partition;
}
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
reg_max_logit[i] = max_logits_ptr[valid_partition[i]];
}
float max_logit = reg_max_logit[0];
#pragma unroll
for (int i = 1; i < NPAR_LOOPS; i++) {
max_logit = fmaxf(max_logit, reg_max_logit[i]);
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask));
}
const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float rescaled_exp_sum[NPAR_LOOPS];
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]];
}
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
const int partition_no = i * WARP_SIZE + threadIdx.x;
rescaled_exp_sum[i] *= (partition_no < num_partitions)
? expf(reg_max_logit[i] - max_logit)
: 0.0f;
}
float global_exp_sum = rescaled_exp_sum[0];
#pragma unroll
for (int i = 1; i < NPAR_LOOPS; i++) {
global_exp_sum += rescaled_exp_sum[i];
}
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
const int partition_no = i * WARP_SIZE + threadIdx.x;
shared_exp_sums[partition_no] = rescaled_exp_sum[i];
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_exp_sum += __shfl_xor(global_exp_sum, mask);
}
if (threadIdx.x == 0) {
shared_global_exp_sum = global_exp_sum;
}
} // warpid == 0
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x;
constexpr int MAX_NPAR = 32;
scalar_t tmps[MAX_NPAR];
const float dzero = 0.0f;
#pragma unroll
for (int j = 0; j < MAX_NPAR; j++) {
tmps[j] = from_float<scalar_t>(dzero);
}
const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE;
const int num_partition_offset = (num_partitions)*HEAD_SIZE;
int idx = 0;
constexpr int JCHUNK = 16;
#pragma unroll
for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) {
// lastj is last valid partition
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
__syncthreads();
if (num_partitions > JCHUNK) {
#pragma unroll
for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE;
j += HEAD_SIZE) {
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
if (num_partitions > 2 * JCHUNK) {
#pragma unroll
for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE) {
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
}
} // num_partitions > JCHUNK
// Aggregate tmp_out to out.
float acc = 0.0f;
#pragma unroll
for (int j = 0; j < JCHUNK; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if (num_partitions > JCHUNK) {
#pragma unroll
for (int j = JCHUNK; j < 2 * JCHUNK; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if (num_partitions > 2 * JCHUNK) {
#pragma unroll
for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
}
}
for (int p = 1; p < NPAR_LOOPS; p++) {
if (num_partitions > p * MAX_NPAR) {
idx = 0;
#pragma unroll
for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE) {
// lastj is last valid partition
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
#pragma unroll
for (int j = 0; j < MAX_NPAR; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR];
}
}
}
const float inv_global_exp_sum =
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
acc *= inv_global_exp_sum;
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
static_cast<int64_t>(head_idx) * HEAD_SIZE;
out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
}
#elif defined(__HIP__GFX12__)
using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float;
using bit16_t = uint16_t;
using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t;
typedef bit16x4 _B16x4;
using bit16x8 = __attribute__((__vector_size__(8 * sizeof(uint16_t)))) uint16_t;
union b16x8_u {
bit16x8 u16x8;
_B16x4 xy[2];
};
typedef b16x8_u _B16x8;
using _B8x8 = uint2;
using bit8_t = uint8_t;
typedef struct _B8x16 {
_B8x8 xy[2];
} _B8x16;
template <typename T, int absz, int cbid, int blgp>
__device__ __forceinline__ floatx8 gcn_wmma16x16x16_instr(const bit16x8& inpA,
const bit16x8& inpB,
const floatx8& inpC) {
if constexpr (std::is_same<T, _Float16>::value) {
return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(inpA, inpB, inpC);
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(inpA, inpB, inpC);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ float to_float(const T& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
return (float)inp;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __bfloat162float(inp);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ float to_float_b16(const bit16_t& inp) {
union tmpcvt {
bit16_t u;
_Float16 f;
__hip_bfloat16 b;
} t16;
t16.u = inp;
if constexpr (std::is_same<T, _Float16>::value) {
return (float)t16.f;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __bfloat162float(t16.b);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ T from_float(const float& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
return (_Float16)inp;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __float2bfloat16(inp);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
union h2cvt {
__half2 h2[4];
_B16x8 b16x8;
} u;
u.h2[0] = __float22half2_rn(make_float2(inp[0], inp[1]));
u.h2[1] = __float22half2_rn(make_float2(inp[2], inp[3]));
u.h2[2] = __float22half2_rn(make_float2(inp[4], inp[5]));
u.h2[3] = __float22half2_rn(make_float2(inp[6], inp[7]));
return u.b16x8;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
union b2cvt {
__hip_bfloat162 b2[4];
_B16x8 b16x8;
} u;
u.b2[0] = __float22bfloat162_rn(make_float2(inp[0], inp[1]));
u.b2[1] = __float22bfloat162_rn(make_float2(inp[2], inp[3]));
u.b2[2] = __float22bfloat162_rn(make_float2(inp[4], inp[5]));
u.b2[3] = __float22bfloat162_rn(make_float2(inp[6], inp[7]));
return u.b16x8;
} else {
static_assert(false, "unsupported 16b dtype");
}
}
// clang-format off
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
__global__
__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
// clang-format on
constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11
const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE;
const int lane2id = laneid % 2;
const int lane4id = laneid % 4;
const int lane16id = laneid % 16;
const int rowid = laneid / 16;
const int seq_idx = blockIdx.x;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if (query_start_loc_ptr != nullptr &&
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
return;
}
const int partition_idx = blockIdx.y;
constexpr int T_PAR_SIZE = 256; // token partition size set to 256
const int max_num_partitions = gridDim.y;
const int context_len = context_lens[seq_idx]; // length of a seq
const int partition_start_token_idx = partition_idx * T_PAR_SIZE;
// exit if partition is out of context for seq
if (partition_start_token_idx >= context_len) {
return;
}
constexpr int GQA_RATIO2 = DIVIDE_ROUND_UP(GQA_RATIO, 2);
__shared__ float shared_qk_max[NWARPS][16 + 1];
__shared__ float shared_exp_sum[NWARPS][16 + 1];
// shared_logits is used for multiple purposes
__shared__ _B16x8 shared_logits[NWARPS][2][16][2];
// for QK wmma16x16_gfx12, layout is QHead/Tokenx16 across every 16 lanes,
// 16 Bytes HeadElements in each lane, 2x16B HeadElements across 2 rows of
// warp
constexpr int ROWS_PER_WARP =
WARP_SIZE / 16; // rows refers to 16 lanes; refer dpp terminology
constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD =
16 / sizeof(cache_t); // 8 for 16 bit cache type, 16 for 8 bit types
constexpr int QKHE_PER_FETCH =
CONTIGUOUS_KV_ELEMS_16B_LOAD *
ROWS_PER_WARP; // each fetch across a warp fetches these many elements
constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 2xQKHE_16B across
// warp
_B16x8 Qlocal[QKHELOOP]; // note that 16 contiguous elements of Q should
// be fetched per lane for 16 bit cache types
constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t);
constexpr int TOKENS_PER_WARP =
T_PAR_SIZE /
NWARPS; // sub partition of tokens per warp for qk calculation
constexpr int TLOOP =
TOKENS_PER_WARP /
16; // each wmma16x16x16 instruction processes 16 tokens
_B16x8 Klocal[TLOOP]
[QKHELOOP]; // can be interpreted as B8x16 for 8 bit types
const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
const int wg_start_kv_head_idx = blockIdx.z;
const int total_num_heads = gridDim.z * GQA_RATIO;
// for QK wmma, tokens in multiples of TOKENS_PER_WARP are spread across warps
// each wmma takes QH16xT16x16HE across warp
// repeat wmma across QKHELOOP dimension
// output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens
// across 2 rows x 8 tokens per lane
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
if (GQA_RATIO == 1) {
const int local_qhead_idx = lane16id % GQA_RATIO;
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
const scalar_t* q_ptr = q + query_start_off * q_stride +
global_qhead_idx * HEAD_SIZE +
rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD;
if (lane16id < GQA_RATIO) {
#pragma unroll
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
const scalar_t* q_fetch_ptr = q_ptr + qkhe_depth * QKHE_PER_FETCH;
const _B16x8* q_fetch_ptr_16B =
reinterpret_cast<const _B16x8*>(q_fetch_ptr);
Qlocal[qkhe_depth] = *q_fetch_ptr_16B;
}
}
} else {
// fetch Q in shared across warps and then write to registers
const int local_qhead_idx = 2 * warpid + rowid;
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
const scalar_t* q_ptr =
q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;
const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B;
if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) {
const scalar_t* q_fetch_ptr = q_ptr + qhead_element;
const _B16x8* q_fetch_ptr_16B =
reinterpret_cast<const _B16x8*>(q_fetch_ptr);
_B16x8 tmp = *q_fetch_ptr_16B;
const int offset1 =
lane16id /
2; // 16 contiguous chunks of head elems are spread across 8x2lanes
shared_logits[offset1][lane2id][local_qhead_idx][0] = tmp;
}
__syncthreads();
#pragma unroll
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
Qlocal[qkhe_depth] =
shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO][0];
}
}
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int last_ctx_block = num_context_blocks - 1;
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
int kphysical_block_number[TLOOP];
// fetch k physical block numbers
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
const int kblock_idx = (kglobal_token_idx < context_len)
? kglobal_token_idx / BLOCK_SIZE
: last_ctx_block;
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
}
constexpr int KX = 16 / sizeof(cache_t);
const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride;
const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD;
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int64_t kblock_number =
static_cast<int64_t>(kphysical_block_number[token_depth]);
const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride;
const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE;
const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX;
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH;
const int offset1 = head_elem / KX;
const int offset2 = head_elem % KX;
const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2;
const _B16x8* k_fetch_ptr_16B =
reinterpret_cast<const _B16x8*>(k_fetch_ptr);
Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B;
}
}
constexpr int VTOKENS_PER_LANE =
TOKENS_PER_WARP / ROWS_PER_WARP; // 32/2 = 16 vtokens per lane
constexpr int VBLOCKS_PER_LANE = 1; // assumes block size >=16
constexpr int VTLOOP = NWARPS; // corresponds to tokens across warps
constexpr int VTLANELOOP = DIVIDE_ROUND_UP(
VTOKENS_PER_LANE,
CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes
// minimum block size is 16
constexpr int VHELOOP = DIVIDE_ROUND_UP(
(HEAD_SIZE / 16), NWARPS); // head_size distributed across warps; each
// wmma instr works on 16 head elements
int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE];
// fetch v physical block numbers
for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) {
for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE;
vblock_depth++) {
const int vlocal_token_idx =
vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP +
rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE;
const int vglobal_token_idx =
partition_start_token_idx + vlocal_token_idx;
const int vblock_idx = (vglobal_token_idx < context_len)
? vglobal_token_idx / BLOCK_SIZE
: last_ctx_block;
vphysical_block_number[vtoken_depth][vblock_depth] =
block_table_seq[vblock_idx];
}
}
_B16x8 Vlocal[VTLOOP][VHELOOP]
[VTLANELOOP]; // this can be interpreted as B8x16 too
const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride +
((rowid * VTOKENS_PER_LANE) % BLOCK_SIZE);
// v fetches are 16head elems across lanes x 16 tokens per lane
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id;
const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE;
for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) {
for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) {
const int vblock_depth = 0;
const int64_t vblock_number = static_cast<int64_t>(
vphysical_block_number[vtoken_depth][vblock_depth]);
const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride);
const cache_t* v_fetch_ptr =
v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD;
const _B16x8* v_fetch_ptr_16B =
reinterpret_cast<const _B16x8*>(v_fetch_ptr);
Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B;
}
}
}
floatx8 dout[TLOOP];
// qk wmma
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
dout[token_depth] = {0};
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
dout[token_depth] = gcn_wmma16x16x16_instr<scalar_t, 0, 0, 0>(
Klocal[token_depth][qkhe_depth].u16x8, Qlocal[qkhe_depth].u16x8,
dout[token_depth]);
}
dout[token_depth] *= scale;
}
// calculate qk_max and exp_sum per warp and write to shared memory
float qk_max = -FLT_MAX;
float exp_sum = 0.0f;
const int qkout_token_idx =
partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 8;
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
const float tmp =
(local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX;
qk_max = fmaxf(qk_max, tmp);
}
}
qk_max = fmaxf(qk_max, __shfl_xor(qk_max, 16));
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
const float tmp = (local_token_idx + i < context_len)
? __expf(dout[token_depth][i] - qk_max)
: 0.0f;
dout[token_depth][i] = tmp;
exp_sum += tmp;
}
}
exp_sum += __shfl_xor(exp_sum, 16);
__syncthreads();
if (laneid < 16) {
shared_qk_max[warpid][lane16id] = qk_max;
shared_exp_sum[warpid][lane16id] = exp_sum;
}
__syncthreads();
// calculate partition qk_max and exp_sum
float partition_qk_max = -FLT_MAX;
float warp_qk_max_exp[NWARPS];
float partition_exp_sum = 0.0f;
#pragma unroll
for (int w = 0; w < NWARPS; w++) {
warp_qk_max_exp[w] = shared_qk_max[w][lane16id];
partition_qk_max = fmaxf(partition_qk_max, warp_qk_max_exp[w]);
}
for (int w = 0; w < NWARPS; w++) {
warp_qk_max_exp[w] = __expf(warp_qk_max_exp[w] - partition_qk_max);
partition_exp_sum += shared_exp_sum[w][lane16id] * warp_qk_max_exp[w];
}
const float inv_sum_scale =
__fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid];
__syncthreads();
// write logits to shared mem
#pragma unroll
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
dout[token_depth] *= inv_sum_scale;
shared_logits[warpid][token_depth][lane16id][rowid] =
from_floatx8<scalar_t>(dout[token_depth]);
}
// write out partition max_logits and exp_sum
if (threadIdx.x < GQA_RATIO) {
const int qhead_idx = lane16id;
const int offset = seq_idx * total_num_heads * max_num_partitions +
(wg_start_head_idx + qhead_idx) * max_num_partitions +
partition_idx;
max_logits[offset] = partition_qk_max;
exp_sums[offset] = partition_exp_sum;
}
__syncthreads();
_B16x8 outelems[VHELOOP];
// Softmax V wmma
// v layout: 16he across lanes x 16 tokens per lane
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
floatx8 tmp_out = {0};
for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) {
for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) {
const int offset = rowid * VTLANELOOP + vfetch_depth;
const int offset1 = offset % ROWS_PER_WARP;
const int offset2 = offset / ROWS_PER_WARP;
// if output format is 16 qheads across 16 lanes, 16 head elems spread
// across rows
tmp_out = gcn_wmma16x16x16_instr<scalar_t, 0, 0, 0>(
Vlocal[vtoken_depth][vhe_depth][vfetch_depth].u16x8,
shared_logits[vtoken_depth][offset2][lane16id][offset1].u16x8,
tmp_out);
}
}
outelems[vhe_depth] = from_floatx8<scalar_t>(tmp_out);
}
__syncthreads();
#pragma unroll
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
shared_logits[warpid][vhe_depth][lane16id][rowid] =
outelems[vhe_depth]; // lane16 id head dimension; rowid head element
// dimension
}
__syncthreads();
// write to tmp_out with coalesced writes after reading from shared mem
if (warpid == 0) {
_B16x8 vout[GQA_RATIO2];
// each lane writes out 16Bytes of tmp_out along head elem dimension
const int head_elem_idx = lane16id * 8;
if (head_elem_idx < HEAD_SIZE) {
for (int h = 0; h < GQA_RATIO2; h++) {
const int local_head_idx = 2 * h + rowid;
const int offset1 = (head_elem_idx / 16) % NWARPS;
const int offset2 = head_elem_idx / 16 / NWARPS;
const int offset3 = (head_elem_idx / 8) % 2; // num_he % num_row
vout[h] = shared_logits[offset1][offset2][local_head_idx][offset3];
}
const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions;
scalar_t* out_ptr = out + seq_idx * total_num_heads * hsz_maxp_mult +
partition_idx * HEAD_SIZE;
for (int h = 0; h < GQA_RATIO2; h++) {
const int local_head_idx = 2 * h + rowid;
if (local_head_idx < GQA_RATIO) {
const int out_head_idx = wg_start_head_idx + local_head_idx;
scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult;
scalar_t* out_ptr3 = out_ptr2 + head_elem_idx;
_B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3);
*out_ptr_B16x8 = vout[h];
}
}
}
}
}
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED,
int GQA_RATIO>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
UNREACHABLE_CODE
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, typename OUTT, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE, int NPAR_LOOPS>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
OUTT* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x;
const auto head_idx = blockIdx.x;
const auto seq_idx = blockIdx.y;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if (query_start_loc_ptr != nullptr &&
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
return;
}
const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
[[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
[[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE;
__shared__ float shared_global_exp_sum;
// max num partitions supported is warp_size * NPAR_LOOPS
__shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE];
if (warpid == 0) {
const float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
// valid partition is the last valid partition in case threadid > num
// partitions
int valid_partition[NPAR_LOOPS];
float reg_max_logit[NPAR_LOOPS];
const int last_valid_partition = num_partitions - 1;
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
const int partition_no = i * WARP_SIZE + threadIdx.x;
valid_partition[i] =
(partition_no < num_partitions) ? partition_no : last_valid_partition;
}
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
reg_max_logit[i] = max_logits_ptr[valid_partition[i]];
}
float max_logit = reg_max_logit[0];
#pragma unroll
for (int i = 1; i < NPAR_LOOPS; i++) {
max_logit = fmaxf(max_logit, reg_max_logit[i]);
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask));
}
const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float rescaled_exp_sum[NPAR_LOOPS];
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]];
}
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
const int partition_no = i * WARP_SIZE + threadIdx.x;
rescaled_exp_sum[i] *= (partition_no < num_partitions)
? expf(reg_max_logit[i] - max_logit)
: 0.0f;
}
float global_exp_sum = rescaled_exp_sum[0];
#pragma unroll
for (int i = 1; i < NPAR_LOOPS; i++) {
global_exp_sum += rescaled_exp_sum[i];
}
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
const int partition_no = i * WARP_SIZE + threadIdx.x;
shared_exp_sums[partition_no] = rescaled_exp_sum[i];
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_exp_sum += __shfl_xor(global_exp_sum, mask);
}
if (threadIdx.x == 0) {
shared_global_exp_sum = global_exp_sum;
}
} // warpid == 0
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x;
constexpr int MAX_NPAR = 32;
scalar_t tmps[MAX_NPAR];
const float dzero = 0.0f;
#pragma unroll
for (int j = 0; j < MAX_NPAR; j++) {
tmps[j] = from_float<scalar_t>(dzero);
}
const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE;
const int num_partition_offset = (num_partitions)*HEAD_SIZE;
int idx = 0;
constexpr int JCHUNK = 16;
#pragma unroll
for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) {
// lastj is last valid partition
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
__syncthreads();
if (num_partitions > JCHUNK) {
#pragma unroll
for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE;
j += HEAD_SIZE) {
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
if (num_partitions > 2 * JCHUNK) {
#pragma unroll
for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE) {
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
}
} // num_partitions > JCHUNK
// Aggregate tmp_out to out.
float acc = 0.0f;
#pragma unroll
for (int j = 0; j < JCHUNK; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if (num_partitions > JCHUNK) {
#pragma unroll
for (int j = JCHUNK; j < 2 * JCHUNK; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if (num_partitions > 2 * JCHUNK) {
#pragma unroll
for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
}
}
for (int p = 1; p < NPAR_LOOPS; p++) {
if (num_partitions > p * MAX_NPAR) {
idx = 0;
#pragma unroll
for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE) {
// lastj is last valid partition
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
#pragma unroll
for (int j = 0; j < MAX_NPAR; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR];
}
}
}
const float inv_global_exp_sum =
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
acc *= inv_global_exp_sum;
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
static_cast<int64_t>(head_idx) * HEAD_SIZE;
out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
}
#else
// clang-format off
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED,
int GQA_RATIO>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
UNREACHABLE_CODE
}
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED,
int GQA_RATIO>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
UNREACHABLE_CODE
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, typename OUTT, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE, int NPAR_LOOPS>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
OUTT* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
UNREACHABLE_CODE
}
// clang-format on
#endif
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
fp8_out_scale_ptr);
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
bool ALIBI_ENABLED>
void paged_attention_custom_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale) {
int num_seqs = block_tables.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
// NOTE: query start location is optional for V0 decode should not be used.
// If batch contains mix of prefills and decode, prefills should be skipped.
const int* query_start_loc_ptr =
query_start_loc
? reinterpret_cast<const int*>(query_start_loc.value().data_ptr())
: nullptr;
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
// NOTE: fp8_out_scale is optional.
const auto fp8_out_scale_ptr =
fp8_out_scale
? static_cast<const float*>(fp8_out_scale.value().data_ptr())
: nullptr;
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
// partition size is fixed at 256 since both mfma4 and mfma16 kernels support
// it mfma4 kernel also supports partition size 512
constexpr int PARTITION_SIZE = 256;
const int max_num_partitions =
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
const int gqa_ratio = num_heads / num_kv_heads;
assert(num_heads % num_kv_heads == 0);
assert(head_size == HEAD_SIZE);
constexpr int NTHR = 256;
dim3 grid(num_seqs, max_num_partitions, num_kv_heads);
dim3 block(NTHR);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// mfma4 kernel is faster than mfma16 for gqa_ratio <= 4
switch (gqa_ratio) {
case 1:
LAUNCH_CUSTOM_ATTENTION_MFMA4(1);
break;
case 2:
LAUNCH_CUSTOM_ATTENTION_MFMA4(2);
break;
case 3:
LAUNCH_CUSTOM_ATTENTION_MFMA4(3);
break;
case 4:
LAUNCH_CUSTOM_ATTENTION_MFMA4(4);
break;
case 5:
LAUNCH_CUSTOM_ATTENTION_MFMA16(5);
break;
case 6:
LAUNCH_CUSTOM_ATTENTION_MFMA16(6);
break;
case 7:
LAUNCH_CUSTOM_ATTENTION_MFMA16(7);
break;
case 8:
......@@ -1735,33 +3251,236 @@ void paged_attention_custom_launcher(
}
}
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \
ALIBI_ENABLED) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale);
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
PSIZE) \
if (alibi_slopes) { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, true); \
} else { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, false); \
}
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
switch (block_size) { \
case 16: \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 16, HEAD_SIZE, 256); \
break; \
case 32: \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 32, HEAD_SIZE, 256); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
bool ALIBI_ENABLED>
void paged_attention_custom_launcher_navi(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
int num_seqs = block_tables.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
// NOTE: query start location is optional for V0 decode should not be used.
// If batch contains mix of prefills and decode, prefills should be skipped.
const int* query_start_loc_ptr =
query_start_loc
? reinterpret_cast<const int*>(query_start_loc.value().data_ptr())
: nullptr;
// NOTE: Navi does not support alibi_slopes.
const float* alibi_slopes_ptr = nullptr;
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
// NOTE: Navi does not support fp8.
const auto fp8_out_scale_ptr = nullptr;
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
constexpr int PARTITION_SIZE = 256;
const int max_num_partitions =
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
const int gqa_ratio = num_heads / num_kv_heads;
assert(num_heads % num_kv_heads == 0);
assert(head_size == HEAD_SIZE);
constexpr int NTHR = 256;
dim3 grid(num_seqs, max_num_partitions, num_kv_heads);
dim3 block(NTHR);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (gqa_ratio) {
case 1:
LAUNCH_CUSTOM_ATTENTION_MFMA16(1);
break;
case 2:
LAUNCH_CUSTOM_ATTENTION_MFMA16(2);
break;
case 3:
LAUNCH_CUSTOM_ATTENTION_MFMA16(3);
break;
case 4:
LAUNCH_CUSTOM_ATTENTION_MFMA16(4);
break;
case 5:
LAUNCH_CUSTOM_ATTENTION_MFMA16(5);
break;
case 6:
LAUNCH_CUSTOM_ATTENTION_MFMA16(6);
break;
case 7:
LAUNCH_CUSTOM_ATTENTION_MFMA16(7);
break;
case 8:
LAUNCH_CUSTOM_ATTENTION_MFMA16(8);
break;
case 9:
LAUNCH_CUSTOM_ATTENTION_MFMA16(9);
break;
case 10:
LAUNCH_CUSTOM_ATTENTION_MFMA16(10);
break;
case 11:
LAUNCH_CUSTOM_ATTENTION_MFMA16(11);
break;
case 12:
LAUNCH_CUSTOM_ATTENTION_MFMA16(12);
break;
case 13:
LAUNCH_CUSTOM_ATTENTION_MFMA16(13);
break;
case 14:
LAUNCH_CUSTOM_ATTENTION_MFMA16(14);
break;
case 15:
LAUNCH_CUSTOM_ATTENTION_MFMA16(15);
break;
case 16:
LAUNCH_CUSTOM_ATTENTION_MFMA16(16);
break;
default:
TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio);
break;
}
dim3 reduce_grid(num_heads, num_seqs);
dim3 reduce_block(head_size);
const int warp_size = 32;
const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, warp_size);
// reduction kernel supports upto 16 NPAR_loops * 32 (warp_size) * 256
// (partition size) = 128K context length
switch (npar_loops) {
case 1:
LAUNCH_CUSTOM_REDUCTION(1);
break;
case 2:
LAUNCH_CUSTOM_REDUCTION(2);
break;
case 3:
LAUNCH_CUSTOM_REDUCTION(3);
break;
case 4:
LAUNCH_CUSTOM_REDUCTION(4);
break;
case 5:
LAUNCH_CUSTOM_REDUCTION(5);
break;
case 6:
LAUNCH_CUSTOM_REDUCTION(6);
break;
case 7:
LAUNCH_CUSTOM_REDUCTION(7);
break;
case 8:
LAUNCH_CUSTOM_REDUCTION(8);
break;
case 9:
LAUNCH_CUSTOM_REDUCTION(9);
break;
case 10:
LAUNCH_CUSTOM_REDUCTION(10);
break;
case 11:
LAUNCH_CUSTOM_REDUCTION(11);
break;
case 12:
LAUNCH_CUSTOM_REDUCTION(12);
break;
case 13:
LAUNCH_CUSTOM_REDUCTION(13);
break;
case 14:
LAUNCH_CUSTOM_REDUCTION(14);
break;
case 15:
LAUNCH_CUSTOM_REDUCTION(15);
break;
case 16:
LAUNCH_CUSTOM_REDUCTION(16);
break;
default:
TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops);
break;
}
}
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
PSIZE, ALIBI_ENABLED) \
if (!is_navi) { \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
OUTT, PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
} else { \
paged_attention_custom_launcher_navi< \
T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale); \
}
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
OUTT, PSIZE) \
if (alibi_slopes) { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
true); \
} else { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
false); \
}
#if defined(__HIPCC__) && defined(__gfx90a__)
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (fp8_out_scale) { \
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \
} else { \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
256); \
}
#else
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (fp8_out_scale) { \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
uint8_t, 256); \
} else { \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
256); \
}
#endif
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
switch (block_size) { \
case 16: \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
break; \
case 32: \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
......@@ -1777,6 +3496,24 @@ void paged_attention_custom_launcher(
break; \
}
bool is_navi_gpu() {
static bool is_cached = false;
static bool result;
if (!is_cached) {
int device_id;
hipDeviceProp_t deviceProp;
hipGetDevice(&device_id);
hipGetDeviceProperties(&deviceProp, device_id);
std::string arch = deviceProp.gcnArchName;
result = arch.find("gfx11") == 0 || arch.find("gfx12") == 0;
is_cached = true;
}
return result;
}
// clang-format off
void paged_attention(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
......@@ -1794,8 +3531,11 @@ void paged_attention(
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
torch::Tensor& v_scale,
const std::optional<torch::Tensor>& fp8_out_scale) {
// clang-format on
bool is_navi = is_navi_gpu();
const int head_size = query.size(2);
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Half) {
......
......@@ -11,14 +11,12 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads,
double scale, torch::Tensor& block_tables,
torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc,
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale);
void paged_attention(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
int64_t max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale);
......@@ -13,14 +13,34 @@
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
#if defined(__HIPCC__) && \
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
#define __HIP__GFX9__
#endif
#if defined(__HIPCC__) && defined(__gfx942__)
#define __HIP__MI300__
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
#define __HIP__MI3XX__
#endif
#if defined(__gfx950__)
#define LDS_SIZE 160 * 1024
#else
#define LDS_SIZE 64 * 1024
#endif
int get_lds_size() {
static bool is_cached = false;
static int result;
if (is_cached == false) {
auto dprops = at::cuda::getCurrentDeviceProperties();
std::string device_arch = dprops->gcnArchName;
size_t substring = device_arch.find("gfx95");
result = (substring == std::string::npos ? 64 * 1024 : 160 * 1024);
is_cached = true;
}
return result;
}
#if defined(NDEBUG)
#undef NDEBUG
#include <assert.h>
......@@ -126,8 +146,8 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b,
const int warp = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
const int num_warps = blockDim.x / WARP_SIZE;
const int qwarpid = threadid / num_warps;
const int qthreadid = threadid % num_warps;
const int qwarpid = threadid / 16;
const int qthreadid = threadid % 16;
float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK];
scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w;
float acc[NUM_A_ROWS_PER_BLOCK];
......@@ -142,15 +162,13 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b,
// rowA_elem4[i] holds 8 * half numbers seen as a single float4.
rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]);
}
colB_elem4x = bf4[threadid * 4 + 0];
colB_elem4y = bf4[threadid * 4 + 1];
colB_elem4z = bf4[threadid * 4 + 2];
colB_elem4w = bf4[threadid * 4 + 3];
}
colB_elem4x = bf4[threadid * 4 + 0];
colB_elem4y = bf4[threadid * 4 + 1];
colB_elem4z = bf4[threadid * 4 + 2];
colB_elem4w = bf4[threadid * 4 + 3];
scalar2_t Af2;
[[maybe_unused]] scalar2_t Bf2;
float2 S;
auto Ah2ptr = reinterpret_cast<scalar2_t*>(&rowA_elem4);
......@@ -193,12 +211,13 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b,
if (qwarpid < NUM_A_ROWS_PER_BLOCK) {
acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f;
for (int mask = num_warps / 2; mask >= 1; mask /= 2) {
#pragma unroll
for (int mask = 16 / 2; mask >= 1; mask /= 2) {
acc[qwarpid] += __shfl_xor(acc[qwarpid], mask);
}
float oval2 = __shfl_xor(acc[qwarpid], num_warps);
float oval2 = __shfl_xor(acc[qwarpid], 16);
if (lane % (num_warps * 2) == 0) {
if (lane % 32 == 0) {
oval = __float22s2_rn<scalar2_t>(make_float2(acc[qwarpid], oval2));
c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval;
}
......@@ -222,9 +241,10 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
// NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle
// operations.
const int NUM_THREADS =
K * 2 / 16 % WARP_SIZE == 0
? K * 2 / 16
: K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE);
max(rows_per_block * 16,
K * 2 / 16 % WARP_SIZE == 0
? K * 2 / 16
: K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE));
int NUM_BLOCKS = M / rows_per_block;
......@@ -267,7 +287,7 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
V0 += (s.x + s.y); \
}
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
// This version targets cases where A[] fits LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
......@@ -275,24 +295,34 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2;
#if defined(__HIP__MI3XX__)
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
#else
constexpr bool use_mfma = false;
#endif
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
using half4 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
union bigType {
scalar_t h[A_CHUNK];
float f[A_CHUNK / 2];
float2 f2[A_CHUNK / 4];
double d[A_CHUNK / 4];
half4 h4[A_CHUNK / 4];
scalar8 h8;
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Reserving 64/160 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
//----------------------------------------------------
__shared__ scalar_t s[1024 * 32];
__shared__ scalar_t s[max_lds_len];
//----------------------------------------------------
// Fetch the activation matrix to LDS
......@@ -303,11 +333,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
for (uint32_t k = 0; k < min(K * N, max_lds_len);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min(K * N, 32 * 1024)) break;
if (k_in >= min(K * N, max_lds_len)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
}
......@@ -318,6 +348,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
......@@ -343,7 +374,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = 0;
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
......@@ -374,24 +409,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if constexpr (YTILE >= 2)
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
if constexpr (YTILE >= 3)
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
if constexpr (YTILE >= 4)
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
if constexpr (YTILE >= 5)
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
if constexpr (YTILE >= 6)
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
if constexpr (YTILE >= 7)
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
if constexpr (YTILE >= 8)
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
for (int y = 0; y < YTILE; y++)
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * K])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
......@@ -419,32 +438,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll
for (uint32_t n = 0; n < N; n++) {
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b])
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if constexpr (YTILE >= 2) {
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
}
if constexpr (YTILE >= 3) {
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
}
if constexpr (YTILE >= 4) {
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
}
if constexpr (YTILE >= 5) {
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
}
if constexpr (YTILE >= 6) {
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
}
if constexpr (YTILE >= 7) {
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
}
if constexpr (YTILE >= 8) {
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
}
for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
}
else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
}
}
}
......@@ -453,41 +457,88 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
}
if (threadIdx.x == 63) {
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
} else {
#pragma unroll
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
#pragma unroll
for (int y = 0; y < YTILE; y++) {
// float accm1 = 0;
// for (int i=0; i<64; i++)
// accm1 += __shfl(sum4[n][y][i%4], i);
float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
}
}
}
}
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
......@@ -495,9 +546,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
// This version targets cases where A[] marginally exceeds LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
......@@ -505,13 +556,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2;
#if defined(__HIP__MI3XX__)
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
#else
constexpr bool use_mfma = false;
#endif
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
using half4 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
union bigType {
scalar_t h[A_CHUNK];
float f[A_CHUNK / 2];
float2 f2[A_CHUNK / 4];
double d[A_CHUNK / 4];
half4 h4[A_CHUNK / 4];
scalar8 h8;
};
......@@ -522,7 +583,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
//----------------------------------------------------
__shared__ scalar_t s[1024 * 32];
__shared__ scalar_t s[max_lds_len];
//----------------------------------------------------
// Computation of columns that need to be committed to memory!
......@@ -559,11 +620,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
for (uint32_t k = 0; k < min(K * N, max_lds_len);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min(K * N, 32 * 1024)) break;
if (k_in >= min(K * N, max_lds_len)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
}
......@@ -573,6 +634,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.y >= _WvPrGrp) return;
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
......@@ -598,7 +660,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = 0;
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
......@@ -628,24 +694,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if constexpr (YTILE >= 2)
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
if constexpr (YTILE >= 3)
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
if constexpr (YTILE >= 4)
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
if constexpr (YTILE >= 5)
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
if constexpr (YTILE >= 6)
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
if constexpr (YTILE >= 7)
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
if constexpr (YTILE >= 8)
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
for (int b = 0; b < YTILE; b++)
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
......@@ -658,7 +708,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) {
if (k_ + K * n < 32 * 1024)
if (k_ + K * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
......@@ -676,32 +726,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]);
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if constexpr (YTILE >= 2) {
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
}
if constexpr (YTILE >= 3) {
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
}
if constexpr (YTILE >= 4) {
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
}
if constexpr (YTILE >= 5) {
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
}
if constexpr (YTILE >= 6) {
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
}
if constexpr (YTILE >= 7) {
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
}
if constexpr (YTILE >= 8) {
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
}
for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
}
else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
}
}
}
......@@ -710,34 +745,82 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
}
if (threadIdx.x == 63) {
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
} else {
#pragma unroll
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
#pragma unroll
for (int y = 0; y < YTILE; y++) {
// float accm1 = 0;
// for (int i=0; i<64; i++)
// accm1 += __shfl(sum4[n][y][i%4], i);
float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
}
}
}
}
......@@ -756,7 +839,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
......@@ -764,9 +847,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
// This version targets big A[] cases, where it is much larger than LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
......@@ -774,25 +857,34 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2;
#if defined(__HIP__MI3XX__)
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
#else
constexpr bool use_mfma = false;
#endif
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
using half4 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
union bigType {
scalar_t h[A_CHUNK];
float f[A_CHUNK / 2];
float2 f2[A_CHUNK / 4];
double d[A_CHUNK / 4];
half4 h4[A_CHUNK / 4];
scalar8 h8;
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Reserving 64/160 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
//----------------------------------------------------
__shared__ scalar_t s[1024 * 32];
__shared__ scalar_t s[max_lds_len];
//----------------------------------------------------
// Computation of columns that need to be committed to memory!
......@@ -833,11 +925,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
#define PCML
#ifndef PCML
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
for (uint32_t k = 0; k < min(K * N, max_lds_len);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min(K * N, 32 * 1024)) break;
if (k_in >= min(K * N, max_lds_len)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
}
......@@ -847,7 +939,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#define TUC (THRDS * UNRL * A_CHUNK)
uint32_t kBase = 0;
// find biggest k size that fits in LDS
uint32_t kFit = (32 * 1024) / N;
uint32_t kFit = (max_lds_len) / N;
// kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple
// of TUC
kFit = (kFit % TUC == 0)
......@@ -857,6 +949,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
kFit = min(kFit, K);
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
......@@ -888,7 +981,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = 0;
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
......@@ -937,24 +1034,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if constexpr (YTILE >= 2)
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
if constexpr (YTILE >= 3)
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
if constexpr (YTILE >= 4)
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
if constexpr (YTILE >= 5)
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
if constexpr (YTILE >= 6)
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
if constexpr (YTILE >= 7)
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
if constexpr (YTILE >= 8)
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
for (int b = 0; b < YTILE; b++)
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
......@@ -989,32 +1070,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]);
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if constexpr (YTILE >= 2) {
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
}
if constexpr (YTILE >= 3) {
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
}
if constexpr (YTILE >= 4) {
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
}
if constexpr (YTILE >= 5) {
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
}
if constexpr (YTILE >= 6) {
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
}
if constexpr (YTILE >= 7) {
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
}
if constexpr (YTILE >= 8) {
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
}
for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
}
else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
}
}
}
......@@ -1031,34 +1097,78 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
}
if (threadIdx.x == 63) {
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
} else {
#pragma unroll
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
#pragma unroll
for (int y = 0; y < YTILE; y++) {
float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
}
}
}
}
......@@ -1077,7 +1187,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
......@@ -1085,7 +1195,7 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
int mindiv(int N, int div1, int div2) {
int nPrRnd = div1 * div2;
......@@ -1135,17 +1245,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size() / 2;
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
CuCount); \
} else if (K_in * N_in <= 32 * 1024 * 1.2) { \
} else if (K_in * N_in <= max_lds_len * 1.2) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
......@@ -1185,7 +1296,7 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
return out_c;
}
#if defined(__HIP__MI300__) // TODO: Add NAVI support
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
......@@ -1194,6 +1305,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp,
const int CuCount) {
constexpr int max_lds_len = LDS_SIZE;
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
......@@ -1209,10 +1321,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
scalar8 h8;
};
__shared__ fp8_t s[1024 * 64];
__shared__ fp8_t s[max_lds_len];
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
}
__syncthreads();
......@@ -1349,7 +1461,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
......@@ -1359,9 +1471,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
#if defined(__HIP__MI300__) // TODO: Add NAVI support
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
......@@ -1369,6 +1481,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
const fp8_t* __restrict__ A, scalar_t* C,
const float* __restrict__ s_A, const float* __restrict__ s_B,
const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE;
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
......@@ -1384,10 +1497,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
scalar8 h8;
};
__shared__ fp8_t s[1024 * 64];
__shared__ fp8_t s[max_lds_len];
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
}
__syncthreads();
......@@ -1430,7 +1543,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
for (int n = 0; n < N; n++) {
if (k_ + K * n < 64 * 1024)
if (k_ + K * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
......@@ -1521,7 +1634,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
......@@ -1531,7 +1644,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b,
......@@ -1551,12 +1664,13 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
dim3 grid(CuCount);
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size();
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= 64 * 1024) && (M_in % _YTILEs == 0)) { \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
......
......@@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
" Tensor k_scale, Tensor v_scale,"
" Tensor? fp8_out_scale) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
}
......
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace vllm {
template <typename scalar_t>
__global__ void apply_repetition_penalties_kernel(
scalar_t* __restrict__ logits, // [num_seqs, vocab_size]
const bool* __restrict__ prompt_mask, // [num_seqs, vocab_size]
const bool* __restrict__ output_mask, // [num_seqs, vocab_size]
const scalar_t* __restrict__ repetition_penalties, // [num_seqs]
const int num_seqs, const int vocab_size, const int tile_size) {
// Each block handles one sequence and a tile of vocab
const int seq_idx = blockIdx.x;
if (seq_idx >= num_seqs) return;
const int tile_start = blockIdx.y * tile_size;
const int tile_end = min(tile_start + tile_size, vocab_size);
// Load repetition penalty for this sequence
const scalar_t penalty = repetition_penalties[seq_idx];
// Each thread processes multiple vocab items within the tile
for (int vocab_idx = tile_start + threadIdx.x; vocab_idx < tile_end;
vocab_idx += blockDim.x) {
const int64_t idx = static_cast<int64_t>(seq_idx) * vocab_size + vocab_idx;
const bool is_repeated = prompt_mask[idx] || output_mask[idx];
if (is_repeated) {
scalar_t logit = logits[idx];
if (logit > 0) {
logits[idx] = logit / penalty;
} else {
logits[idx] = logit * penalty;
}
}
}
}
} // namespace vllm
void apply_repetition_penalties_(
torch::Tensor& logits, // [num_seqs, vocab_size], in-place
const torch::Tensor& prompt_mask, // [num_seqs, vocab_size]
const torch::Tensor& output_mask, // [num_seqs, vocab_size]
const torch::Tensor& repetition_penalties) { // [num_seqs]
TORCH_CHECK(logits.is_contiguous());
TORCH_CHECK(prompt_mask.is_contiguous());
TORCH_CHECK(output_mask.is_contiguous());
TORCH_CHECK(repetition_penalties.is_contiguous());
int vocab_size = logits.size(-1);
int num_seqs = logits.size(0);
// Get number of SMs on the current device
int sms = 0;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
logits.get_device());
// Compute tile_num and tile_size
int tile_num =
std::min(vocab_size, std::max(1, (sms + num_seqs - 1) / num_seqs));
int tile_size = (vocab_size + tile_num - 1) / tile_num;
// Each block handles one sequence and a tile of vocab
dim3 grid(num_seqs, tile_num);
dim3 block(std::min(tile_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(logits));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
logits.scalar_type(), "apply_repetition_penalties_kernel", [&] {
vllm::apply_repetition_penalties_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
logits.data_ptr<scalar_t>(), prompt_mask.data_ptr<bool>(),
output_mask.data_ptr<bool>(),
repetition_penalties.data_ptr<scalar_t>(), num_seqs, vocab_size,
tile_size);
});
}
\ No newline at end of file
......@@ -8,6 +8,8 @@
#include <ATen/cuda/CUDAContext.h>
#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
......@@ -95,9 +97,9 @@ struct cutlass_sparse_3x_gemm {
// clang-format off
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
ElementAcc, TileShape, ClusterShape,
Stages,
KernelSchedule>::CollectiveOp;
......
......@@ -229,13 +229,40 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_output,"
" Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
ops.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
ops.impl("convert_vertical_slash_indexes", torch::kCUDA,
&convert_vertical_slash_indexes);
ops.def(
"convert_vertical_slash_indexes_mergehead("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" Tensor vertical_indices_count, Tensor slash_indices_count, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA,
&convert_vertical_slash_indexes_mergehead);
#endif
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
ops.def(
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
......@@ -294,13 +321,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
") -> ()");
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
// Compute MLA decode using cutlass.
// ops.def(
// "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
// " Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
// " Tensor page_table, float scale) -> ()");
// ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
......@@ -314,7 +334,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
// Apply Root Mean Square (RMS) Normalization to the input tensor. (opt)
// Apply repetition penalties to logits in-place
ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
"Tensor output_mask, Tensor repetition_penalties) -> ()");
ops.impl("apply_repetition_penalties_", torch::kCUDA,
&apply_repetition_penalties_);
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()");
......@@ -355,7 +383,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
......@@ -372,7 +400,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// (supports multiple loras).
ops.def(
"batched_rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox,"
" int rot_dim,"
" Tensor cos_sin_cache_offsets) -> ()");
......@@ -483,12 +511,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def(
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"int b_q_type, "
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool has_zp, bool use_atomic_add, bool use_fp32_reduce, "
"bool is_zp_float) -> Tensor",
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
{stride_tag});
// conditionally compiled so impl registration is in source file
......@@ -530,17 +557,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);
ops.def(
"ggml_moe_a8_vec(Tensor X, Tensor W, "
"Tensor topk_ids, int top_k, "
"int type, SymInt row, SymInt tokens) -> Tensor");
ops.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec);
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
#ifndef USE_ROCM
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops.def(
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
"SymInt size_k) -> Tensor",
{stride_tag});
// conditionally compiled so impl registration is in source file
// marlin_qqq_gemm for QQQ.
ops.def(
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
......@@ -558,6 +583,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
{stride_tag});
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
// cutlass nvfp4 block scaled group GEMM
ops.def(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()",
{stride_tag});
ops.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
......@@ -592,7 +625,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor c_strides) -> ()",
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
" bool per_out_ch) -> ()",
{stride_tag});
ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm);
......@@ -607,10 +641,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k) -> ()",
" int n, int k, Tensor? blockscale_offsets) -> ()",
{stride_tag});
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
// as an input, and computes expert_offsets (token start indices of each
// expert). In addition to this, it computes problem sizes for each expert's
// multiplication used by the two mms called from fused MoE operation.
ops.def(
"get_cutlass_pplx_moe_mm_data(Tensor! expert_offsets, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" Tensor expert_num_tokens, "
" int num_local_experts, int padded_m, "
" int n, int k) -> ()",
{stride_tag});
ops.impl("get_cutlass_pplx_moe_mm_data", torch::kCUDA,
&get_cutlass_pplx_moe_mm_data);
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
ops.def(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
......@@ -639,40 +689,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]");
ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress);
// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states,"
"int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
// CUTLASS MLA decode
ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
" Tensor page_table, float scale) -> ()");
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
// Compute NVFP4 block quantized tensor.
ops.def(
......@@ -680,6 +702,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! output_scale, Tensor input_scale) -> ()");
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
// Compute NVFP4 experts quantization.
ops.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
......@@ -735,6 +764,41 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
&dynamic_scaled_int8_quant);
// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states,"
"int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#ifndef USE_ROCM
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def(
......
......@@ -2,14 +2,14 @@
# to run the OpenAI compatible server.
# Please update any changes made here to
# docs/source/contributing/dockerfile/dockerfile.md and
# docs/source/assets/contributing/dockerfile-stages-dependency.png
# docs/contributing/dockerfile/dockerfile.md and
# docs/assets/contributing/dockerfile-stages-dependency.png
ARG CUDA_VERSION=12.4.1
ARG CUDA_VERSION=12.8.1
#################### BASE BUILD IMAGE ####################
# prepare basic build environment
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
ARG CUDA_VERSION=12.4.1
ARG CUDA_VERSION=12.8.1
ARG PYTHON_VERSION=3.12
ARG TARGETPLATFORM
ENV DEBIAN_FRONTEND=noninteractive
......@@ -19,7 +19,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update -y \
&& apt-get install -y ccache software-properties-common git curl sudo \
&& add-apt-repository ppa:deadsnakes/ppa \
&& for i in 1 2 3; do \
add-apt-repository -y ppa:deadsnakes/ppa && break || \
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
done \
&& apt-get update -y \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
......@@ -34,6 +37,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
ENV UV_INDEX_STRATEGY="unsafe-best-match"
# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
# as it was causing spam when compiling the CUTLASS kernels
......@@ -66,13 +70,14 @@ RUN --mount=type=cache,target=/root/.cache/uv \
COPY requirements/common.txt requirements/common.txt
COPY requirements/cuda.txt requirements/cuda.txt
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/cuda.txt
uv pip install --system -r requirements/cuda.txt \
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
# cuda arch list used by torch
# can be useful for both `dev` and `test`
# explicitly set the list to avoid issues with torch 2.2
# see https://github.com/pytorch/pytorch/pull/123243
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0+PTX'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# Override the arch list for flash-attn to reduce the binary size
ARG vllm_fa_cmake_gpu_arches='80-real;90-real'
......@@ -89,9 +94,11 @@ COPY requirements/build.txt requirements/build.txt
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
ENV UV_INDEX_STRATEGY="unsafe-best-match"
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/build.txt
uv pip install --system -r requirements/build.txt \
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
COPY . .
ARG GIT_REPO_CHECK=0
......@@ -158,27 +165,32 @@ FROM base as dev
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
ENV UV_INDEX_STRATEGY="unsafe-best-match"
# Workaround for #17068
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4"
COPY requirements/lint.txt requirements/lint.txt
COPY requirements/test.txt requirements/test.txt
COPY requirements/dev.txt requirements/dev.txt
# Workaround for #17068
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system mamba-ssm==2.2.4 --no-build-isolation
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/dev.txt
uv pip install --system -r requirements/dev.txt \
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
#################### DEV IMAGE ####################
#################### vLLM installation IMAGE ####################
# image with vLLM installed
# TODO: Restore to base image after FlashInfer AOT wheel fixed
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
ARG CUDA_VERSION=12.4.1
ARG CUDA_VERSION=12.8.1
ARG PYTHON_VERSION=3.12
WORKDIR /vllm-workspace
ENV DEBIAN_FRONTEND=noninteractive
ARG TARGETPLATFORM
SHELL ["/bin/bash", "-c"]
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
......@@ -188,7 +200,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& apt-get update -y \
&& apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
&& add-apt-repository ppa:deadsnakes/ppa \
&& for i in 1 2 3; do \
add-apt-repository -y ppa:deadsnakes/ppa && break || \
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
done \
&& apt-get update -y \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
......@@ -203,6 +218,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
ENV UV_INDEX_STRATEGY="unsafe-best-match"
# Workaround for https://github.com/openai/triton/issues/2507 and
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
......@@ -223,7 +239,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Install vllm wheel first, so that torch etc will be installed.
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/uv \
uv pip install --system dist/*.whl --verbose
uv pip install --system dist/*.whl --verbose \
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
# If we need to build FlashInfer wheel before its release:
# $ export FLASHINFER_ENABLE_AOT=1
......@@ -240,19 +257,34 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
RUN --mount=type=cache,target=/root/.cache/uv \
. /etc/environment && \
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \
# FlashInfer alreary has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
if [[ "$CUDA_VERSION" == 12.8* ]]; then \
uv pip install --system https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.5%2Bcu128torch2.7-cp38-abi3-linux_x86_64.whl; \
else \
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX'; \
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
if [ "$CUDA_MAJOR" -lt 12 ]; then \
export FLASHINFER_ENABLE_SM90=0; \
fi; \
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@21ea1d2545f74782b91eb8c08fd503ac4c0743fc" ; \
fi \
fi
COPY examples examples
COPY benchmarks benchmarks
COPY ./vllm/collect_env.py .
# Although we build Flashinfer with AOT mode, there's still
RUN --mount=type=cache,target=/root/.cache/uv \
. /etc/environment && \
uv pip list
# Even when we build Flashinfer with AOT mode, there's still
# some issues w.r.t. JIT compilation. Therefore we need to
# install build dependencies for JIT compilation.
# TODO: Remove this once FlashInfer AOT wheel is fixed
COPY requirements/build.txt requirements/build.txt
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/build.txt
uv pip install --system -r requirements/build.txt \
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
#################### vLLM installation IMAGE ####################
......@@ -266,13 +298,18 @@ ADD . /vllm-workspace/
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
ENV UV_INDEX_STRATEGY="unsafe-best-match"
# install development dependencies (for testing)
# Workaround for #17068
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system mamba-ssm==2.2.4 --no-build-isolation
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/dev.txt
uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4"
# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
if [ "$CUDA_MAJOR" -ge 12 ]; then \
uv pip install --system -r requirements/dev.txt; \
fi
# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
......@@ -291,7 +328,9 @@ COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1
# will not be imported by other tests
RUN mkdir test_docs
RUN mv docs test_docs/
RUN cp -r examples test_docs/
RUN mv vllm test_docs/
RUN mv mkdocs.yaml test_docs/
#################### TEST IMAGE ####################
#################### OPENAI API SERVER ####################
......
......@@ -51,9 +51,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --upgrade pip && \
uv pip install -r requirements/cpu.txt
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install intel-openmp==2024.2.1 intel_extension_for_pytorch==2.6.0
ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/opt/venv/lib/libiomp5.so:$LD_PRELOAD"
RUN echo 'ulimit -c 0' >> ~/.bashrc
......@@ -78,6 +75,7 @@ RUN --mount=type=bind,source=.git,target=.git \
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/workspace/vllm/.deps,sharing=locked \
--mount=type=bind,source=.git,target=.git \
VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel
......@@ -88,7 +86,7 @@ WORKDIR /workspace/vllm
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get install -y --no-install-recommends vim numactl
apt-get install -y --no-install-recommends vim numactl xz-utils
# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
......@@ -111,8 +109,11 @@ FROM base AS vllm-test
WORKDIR /workspace/
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,src=requirements/test.txt,target=requirements/test.txt \
uv pip install -r requirements/test.txt
--mount=type=bind,src=requirements/test.in,target=requirements/test.in \
cp requirements/test.in requirements/test-cpu.in && \
sed -i '/mamba_ssm/d' requirements/test-cpu.in && \
uv pip compile requirements/test-cpu.in -o requirements/cpu-test.txt && \
uv pip install -r requirements/cpu-test.txt
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=vllm-build,src=/workspace/vllm/dist,target=dist \
......
# default base image
# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.5.1-neuronx-py310-sdk2.22.0-ubuntu22.04"
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.6.0-neuronx-py310-sdk2.23.0-ubuntu22.04"
FROM $BASE_IMAGE
......@@ -22,8 +22,7 @@ WORKDIR ${APP_MOUNT}/vllm
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas tenacity
RUN python3 -m pip install sentencepiece transformers==4.48.0 -U
RUN python3 -m pip install neuronx-cc==2.17.194.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
RUN python3 -m pip install neuronx-cc==2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
RUN python3 -m pip install pytest
# uninstall transformers-neuronx package explicitly to avoid version conflict
......@@ -35,7 +34,7 @@ RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
RUN python3 -m pip install -U \
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
'cmake>=3.26.1' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
-r requirements/neuron.txt
ENV VLLM_TARGET_DEVICE neuron
......@@ -49,6 +48,8 @@ RUN python3 -m pip install -e tests/vllm_test_utils
# FIXME: `--no-deps` argument is temporarily added to resolve transformers package version conflict
RUN python3 -m pip install transformers-neuronx==0.13.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U --no-deps
RUN python3 -m pip install sentencepiece transformers==4.48.0 -U
# overwrite entrypoint to run bash script
RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py
......
......@@ -16,7 +16,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update -y \
&& apt-get install -y ccache software-properties-common git curl sudo \
&& add-apt-repository ppa:deadsnakes/ppa \
&& for i in 1 2 3; do \
add-apt-repository -y ppa:deadsnakes/ppa && break || \
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
done \
&& apt-get update -y \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
......@@ -197,7 +200,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& apt-get update -y \
&& apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
&& add-apt-repository ppa:deadsnakes/ppa \
&& for i in 1 2 3; do \
add-apt-repository -y ppa:deadsnakes/ppa && break || \
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
done \
&& apt-get update -y \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
......@@ -303,5 +309,10 @@ ENV HF_HUB_ENABLE_HF_TRANSFER 1
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/nightly_torch_test.txt
#################### UNITTEST IMAGE #############################
# Logging to confirm the torch versions
RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer'
# Logging to confirm all the packages are installed
RUN pip freeze
#################### UNITTEST IMAGE #############################
ARG BASE_UBI_IMAGE_TAG=9.5-1741850109
###############################################################
# Stage to build openblas
###############################################################
FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS openblas-builder
ARG MAX_JOBS
ARG OPENBLAS_VERSION=0.3.29
RUN microdnf install -y dnf && dnf install -y gcc-toolset-13 make wget unzip \
&& source /opt/rh/gcc-toolset-13/enable \
&& wget https://github.com/OpenMathLib/OpenBLAS/releases/download/v$OPENBLAS_VERSION/OpenBLAS-$OPENBLAS_VERSION.zip \
&& unzip OpenBLAS-$OPENBLAS_VERSION.zip \
&& cd OpenBLAS-$OPENBLAS_VERSION \
&& make -j${MAX_JOBS} TARGET=POWER9 BINARY=64 USE_OPENMP=1 USE_THREAD=1 NUM_THREADS=120 DYNAMIC_ARCH=1 INTERFACE64=0 \
&& cd /tmp && touch control
###############################################################
# base stage with dependencies coming from centos mirrors
###############################################################
FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS centos-deps-builder
RUN microdnf install -y dnf && \
dnf install -y https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-gpg-keys-9.0-24.el9.noarch.rpm \
https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-stream-repos-9.0-24.el9.noarch.rpm \
https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm && \
dnf config-manager --set-enabled crb
RUN dnf install -y openjpeg2-devel lcms2-devel tcl-devel tk-devel fribidi-devel && \
dnf remove -y centos-gpg-keys-9.0-24.el9.noarch centos-stream-repos-9.0-24.el9.noarch
###############################################################
# base stage with basic dependencies
###############################################################
FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base-builder
FROM centos-deps-builder AS base-builder
ARG PYTHON_VERSION=3.12
ARG OPENBLAS_VERSION=0.3.29
......@@ -20,29 +51,27 @@ ENV UV_LINK_MODE=copy
# Note: A symlink for libatomic.so is created for gcc-13 (linker fails to find libatomic otherwise - reqd. for sentencepiece)
# Note: A dummy file 'control' is created in /tmp/ to artificially create dependencies between stages when building stages in parallel
# when `--jobs=<N>` is passed with podman build command
RUN microdnf install -y openssl-devel dnf \
&& dnf install -y https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-gpg-keys-9.0-24.el9.noarch.rpm \
https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-stream-repos-9.0-24.el9.noarch.rpm \
https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm \
&& dnf config-manager --add-repo https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os \
&& dnf config-manager --add-repo https://mirror.stream.centos.org/9-stream/AppStream/`arch`/os \
&& dnf config-manager --set-enabled crb \
COPY --from=openblas-builder /tmp/control /dev/null
RUN --mount=type=bind,from=openblas-builder,source=/OpenBLAS-$OPENBLAS_VERSION/,target=/openblas/,rw \
dnf install -y openssl-devel \
&& dnf install -y \
git tar gcc-toolset-13 automake libtool numactl-devel lapack-devel \
git tar gcc-toolset-13 automake libtool \
pkgconfig xsimd zeromq-devel kmod findutils protobuf* \
libtiff-devel libjpeg-devel openjpeg2-devel zlib-devel \
freetype-devel lcms2-devel libwebp-devel tcl-devel tk-devel \
harfbuzz-devel fribidi-devel libraqm-devel libimagequant-devel libxcb-devel \
libtiff-devel libjpeg-devel zlib-devel freetype-devel libwebp-devel \
harfbuzz-devel libraqm-devel libimagequant-devel libxcb-devel \
python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip \
&& dnf clean all \
&& PREFIX=/usr/local make -C /openblas install \
&& ln -sf /usr/lib64/libatomic.so.1 /usr/lib64/libatomic.so \
&& python${PYTHON_VERSION} -m venv ${VIRTUAL_ENV} \
&& python -m pip install -U pip uv \
&& uv pip install wheel build "setuptools<70" setuptools_scm setuptools_rust meson-python 'cmake<4' ninja cython scikit_build_core scikit_build \
&& curl -sL https://ftp2.osuosl.org/pub/ppc64el/openblas/latest/Openblas_${OPENBLAS_VERSION}_ppc64le.tar.gz | tar xvf - -C /usr/local \
&& curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \
&& cd /tmp && touch control
###############################################################
# Stage to build torch family
###############################################################
......@@ -52,6 +81,8 @@ FROM base-builder AS torch-builder
ARG MAX_JOBS
ARG TORCH_VERSION=2.6.0
ARG _GLIBCXX_USE_CXX11_ABI=1
ARG OPENBLAS_VERSION=0.3.29
RUN --mount=type=cache,target=/root/.cache/uv \
source /opt/rh/gcc-toolset-13/enable && \
git clone --recursive https://github.com/pytorch/pytorch.git -b v${TORCH_VERSION} && \
......@@ -113,7 +144,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
.. && \
make install -j ${MAX_JOBS:-$(nproc)} && \
cd ../../python/ && \
uv pip install -v -r requirements-wheel-build.txt && \
uv pip install -v -r requirements-build.txt && uv pip install numpy==2.1.3 && \
pip show numpy && ls -lrt /opt/vllm/lib/python3.12/site-packages/numpy && \
PYARROW_PARALLEL=${PYARROW_PARALLEL:-$(nproc)} \
python setup.py build_ext \
--build-type=release --bundle-arrow-cpp \
......@@ -136,8 +168,25 @@ RUN --mount=type=cache,target=/root/.cache/uv \
cd opencv-python && \
sed -i -E -e 's/"setuptools.+",/"setuptools",/g' pyproject.toml && \
cd opencv && git cherry-pick --no-commit $OPENCV_PATCH && cd .. && \
uv pip install scikit-build && \
python -m build --wheel --installer=uv --outdir /opencvwheels/
###############################################################
# Stage to build numactl
###############################################################
FROM base-builder AS numa-builder
# Note: Building numactl with gcc-11. Compiling with gcc-13 in this builder stage will
# trigger recompilation with gcc-11 (and require libtool) in the final stage where we do not have gcc-13
ARG MAX_JOBS
ARG NUMACTL_VERSION=2.0.19
RUN git clone --recursive https://github.com/numactl/numactl.git -b v${NUMACTL_VERSION} \
&& cd numactl \
&& autoreconf -i && ./configure \
&& make -j ${MAX_JOBS:-$(nproc)}
###############################################################
# Stage to build vllm - this stage builds and installs
# vllm, tensorizer and vllm-tgis-adapter and builds uv cache
......@@ -149,6 +198,7 @@ FROM base-builder AS vllmcache-builder
COPY --from=torch-builder /tmp/control /dev/null
COPY --from=arrow-builder /tmp/control /dev/null
COPY --from=cv-builder /tmp/control /dev/null
COPY --from=numa-builder /tmp/control /dev/null
ARG VLLM_TARGET_DEVICE=cpu
ARG GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1
......@@ -164,11 +214,13 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
--mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \
--mount=type=bind,from=numa-builder,source=/numactl/,target=/numactl/,rw \
--mount=type=bind,src=.,dst=/src/,rw \
source /opt/rh/gcc-toolset-13/enable && \
uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl && \
sed -i -e 's/.*torch.*//g' /src/pyproject.toml /src/requirements/*.txt && \
uv pip install pandas pythran pybind11 /hf_wheels/*.whl && \
make -C /numactl install && \
# sentencepiece.pc is in some pkgconfig inside uv cache
export PKG_CONFIG_PATH=$(find / -type d -name "pkgconfig" 2>/dev/null | tr '\n' ':') && \
uv pip install -r /src/requirements/common.txt -r /src/requirements/cpu.txt -r /src/requirements/build.txt --no-build-isolation && \
......@@ -177,21 +229,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install /vllmwheel/*.whl
###############################################################
# Stage to build numactl
###############################################################
FROM base-builder AS numa-builder
# Note: Building numactl with gcc-11. Compiling with gcc-13 in this builder stage will
# trigger recompilation with gcc-11 (and require libtool) in the final stage where we do not have gcc-13
ARG MAX_JOBS
ARG NUMACTL_VERSION=2.0.19
RUN git clone --recursive https://github.com/numactl/numactl.git -b v${NUMACTL_VERSION} \
&& cd numactl \
&& autoreconf -i && ./configure \
&& make -j ${MAX_JOBS:-$(nproc)}
###############################################################
# Stage to build lapack
###############################################################
......@@ -221,6 +258,7 @@ ENV PATH=${VIRTUAL_ENV}/bin:$PATH
ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig/
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64:/usr/local/lib:/usr/lib64:/usr/lib
ENV UV_LINK_MODE=copy
ENV OMP_NUM_THREADS=16
# create artificial dependencies between stages for independent stages to build in parallel
COPY --from=torch-builder /tmp/control /dev/null
......@@ -229,11 +267,13 @@ COPY --from=cv-builder /tmp/control /dev/null
COPY --from=vllmcache-builder /tmp/control /dev/null
COPY --from=numa-builder /tmp/control /dev/null
COPY --from=lapack-builder /tmp/control /dev/null
COPY --from=openblas-builder /tmp/control /dev/null
# install gcc-11, python, openblas, numactl, lapack
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=numa-builder,source=/numactl/,target=/numactl/,rw \
--mount=type=bind,from=lapack-builder,source=/lapack/,target=/lapack/,rw \
--mount=type=bind,from=openblas-builder,source=/OpenBLAS-$OPENBLAS_VERSION/,target=/openblas/,rw \
rpm -ivh https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm && \
microdnf install --nodocs -y \
tar findutils openssl \
......@@ -245,8 +285,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
&& microdnf clean all \
&& python${PYTHON_VERSION} -m venv ${VIRTUAL_ENV} \
&& python -m pip install -U pip uv --no-cache \
&& curl -sL https://ftp2.osuosl.org/pub/ppc64el/openblas/latest/Openblas_${OPENBLAS_VERSION}_ppc64le.tar.gz | tar xvf - -C /usr/local \
&& make -C /numactl install \
&& PREFIX=/usr/local make -C /openblas install \
&& uv pip install 'cmake<4' \
&& cmake --install /lapack/build \
&& uv pip uninstall cmake
......
# default base image
ARG REMOTE_VLLM="0"
ARG USE_CYTHON="0"
ARG BUILD_RPD="1"
ARG COMMON_WORKDIR=/app
ARG BASE_IMAGE=rocm/vllm-dev:base
......@@ -15,7 +13,7 @@ RUN apt-get update -q -y && apt-get install -q -y \
sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev \
apt-transport-https ca-certificates wget curl
# Remove sccache
RUN python3 -m pip install --upgrade pip && pip install setuptools_scm
RUN python3 -m pip install --upgrade pip
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
ARG COMMON_WORKDIR
WORKDIR ${COMMON_WORKDIR}
......@@ -30,18 +28,17 @@ ARG VLLM_REPO="https://github.com/vllm-project/vllm.git"
ARG VLLM_BRANCH="main"
ONBUILD RUN git clone ${VLLM_REPO} \
&& cd vllm \
&& git checkout ${VLLM_BRANCH}
&& git fetch -v --prune -- origin ${VLLM_BRANCH} \
&& git checkout FETCH_HEAD
FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm
# -----------------------
# vLLM build stages
FROM fetch_vllm AS build_vllm
ARG USE_CYTHON
# Build vLLM
RUN cd vllm \
&& python3 -m pip install -r requirements/rocm.txt \
&& python3 setup.py clean --all \
&& if [ ${USE_CYTHON} -eq "1" ]; then python3 tests/build_cython.py build_ext --inplace; fi \
&& python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch AS export_vllm
ARG COMMON_WORKDIR
......@@ -90,13 +87,6 @@ RUN case "$(which python3)" in \
*) ;; esac
RUN python3 -m pip install --upgrade huggingface-hub[cli]
ARG BUILD_RPD
RUN if [ ${BUILD_RPD} -eq "1" ]; then \
git clone -b nvtx_enabled https://github.com/ROCm/rocmProfileData.git \
&& cd rocmProfileData/rpd_tracer \
&& pip install -r requirements.txt && cd ../ \
&& make && make install \
&& cd hipMarker && python3 setup.py install ; fi
# Install vLLM
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
......@@ -114,8 +104,10 @@ COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
ENV TOKENIZERS_PARALLELISM=false
# ENV that can improve safe tensor loading, and end-to-end time
ENV SAFETENSORS_FAST_GPU=1
# Performance environment variable.
ENV HIP_FORCE_DEV_KERNARG=1
CMD ["/bin/bash"]
......@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="7e1ed08"
ARG AITER_BRANCH="c1debd8"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
FROM ${BASE_IMAGE} AS base
......@@ -32,7 +32,10 @@ ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies
RUN apt-get update -y \
&& apt-get install -y software-properties-common git curl sudo vim less libgfortran5 \
&& add-apt-repository ppa:deadsnakes/ppa \
&& for i in 1 2 3; do \
add-apt-repository -y ppa:deadsnakes/ppa && break || \
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
done \
&& apt-get update -y \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
python${PYTHON_VERSION}-lib2to3 python-is-python3 \
......
......@@ -16,7 +16,7 @@ ENV LANG=C.UTF-8 \
RUN microdnf install -y \
which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \
libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \
openssl-devel openblas openblas-devel autoconf automake libtool cmake && \
openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy && \
microdnf clean all
# Python Installation
......@@ -84,16 +84,40 @@ RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \
rustup default stable && \
rustup show
FROM python-install AS torch
ARG TORCH_VERSION=2.7.0
ENV export _GLIBCXX_USE_CXX11_ABI=1
ENV CARGO_HOME=/root/.cargo
ENV RUSTUP_HOME=/root/.rustup
ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH"
WORKDIR /tmp
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \
--mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \
git clone https://github.com/pytorch/pytorch.git && \
cd pytorch && \
git checkout v2.7.0 && \
git submodule sync && \
git submodule update --init --recursive && \
uv pip install cmake ninja && \
uv pip install -r requirements.txt && \
python setup.py bdist_wheel
FROM python-install AS torch-vision
# Install torchvision
ARG TORCH_VERSION=2.7.0.dev20250304
ARG TORCH_VERSION=2.7.0
ARG TORCH_VISION_VERSION=v0.20.1
WORKDIR /tmp
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=torch,source=/tmp/pytorch/dist,target=/tmp/torch-wheels/ \
git clone https://github.com/pytorch/vision.git && \
cd vision && \
git checkout $TORCH_VISION_VERSION && \
uv pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \
TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl | head -n 1) && \
uv pip install -v $TORCH_WHL_FILE && \
python setup.py bdist_wheel
FROM python-install AS hf-xet-builder
......@@ -123,6 +147,7 @@ ENV UV_LINK_MODE=copy
ENV CARGO_HOME=/root/.cargo
ENV RUSTUP_HOME=/root/.rustup
ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH"
ENV GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1
COPY . /workspace/vllm
WORKDIR /workspace/vllm
......@@ -137,15 +162,17 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \
--mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \
--mount=type=bind,from=hf-xet-builder,source=/tmp/hf-xet/dist,target=/tmp/hf-xet-wheels/ \
--mount=type=bind,from=torch,source=/tmp/pytorch/dist,target=/tmp/torch-wheels/ \
sed -i '/^torch/d' requirements/build.txt && \
ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \
VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \
HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl | head -n 1) && \
TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl | head -n 1) && \
uv pip install -v \
$ARROW_WHL_FILE \
$VISION_WHL_FILE \
$HF_XET_WHL_FILE \
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
$TORCH_WHL_FILE \
--index-strategy unsafe-best-match \
-r requirements/build.txt \
-r requirements/cpu.txt
......
......@@ -23,7 +23,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=.git,target=.git \
python3 -m pip install \
-r requirements/tpu.txt
RUN python3 setup.py develop
RUN python3 -m pip install -e .
# install development dependencies (for testing)
RUN python3 -m pip install -e tests/vllm_test_utils
......
......@@ -40,12 +40,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=.git,target=.git \
python3 setup.py install
# Please refer xpu doc, we need manually install intel-extension-for-pytorch 2.6.10+xpu due to there are some conflict dependencies with torch 2.6.0+xpu
# FIXME: This will be fix in ipex 2.7. just leave this here for awareness.
RUN --mount=type=cache,target=/root/.cache/pip \
pip install intel-extension-for-pytorch==2.6.10+xpu \
--extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
CMD ["/bin/bash"]
FROM vllm-base AS vllm-openai
......
nav:
- Home:
- vLLM: README.md
- Getting Started:
- getting_started/quickstart.md
- getting_started/installation
- Examples:
- Offline Inference: examples/offline_inference
- Online Serving: examples/online_serving
- Others: examples/others
- Quick Links:
- User Guide: usage/README.md
- Developer Guide: contributing/README.md
- API Reference: api/README.md
- CLI Reference: cli/README.md
- Timeline:
- Roadmap: https://roadmap.vllm.ai
- Releases: https://github.com/vllm-project/vllm/releases
- User Guide:
- Summary: usage/README.md
- usage/v1_guide.md
- General:
- usage/*
- Inference and Serving:
- serving/offline_inference.md
- serving/openai_compatible_server.md
- serving/*
- serving/integrations
- Deployment:
- deployment/*
- deployment/frameworks
- deployment/integrations
- Training: training
- Configuration:
- Summary: configuration/README.md
- configuration/*
- Models:
- models/supported_models.md
- models/generative_models.md
- models/pooling_models.md
- models/extensions
- Features:
- features/compatibility_matrix.md
- features/*
- features/quantization
- Developer Guide:
- Summary: contributing/README.md
- General:
- glob: contributing/*
flatten_single_child_sections: true
- Model Implementation: contributing/model
- Design Documents:
- V0: design
- V1: design/v1
- API Reference:
- Summary: api/README.md
- Contents:
- glob: api/vllm/*
preserve_directory_names: true
- CLI Reference:
- Summary: cli/README.md
- Community:
- community/*
- Blog: https://blog.vllm.ai
- Forum: https://discuss.vllm.ai
- Slack: https://slack.vllm.ai
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
clean:
@$(SPHINXBUILD) -M clean "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
rm -rf "$(SOURCEDIR)/getting_started/examples"
# vLLM documents
## Build the docs
- Make sure in `docs` directory
```bash
cd docs
```
- Install the dependencies:
```bash
pip install -r ../requirements/docs.txt
```
- Clean the previous build (optional but recommended):
```bash
make clean
```
- Generate the HTML documentation:
```bash
make html
```
## Open the docs with your browser
- Serve the documentation locally:
```bash
python -m http.server -d build/html/
```
This will start a local server at http://localhost:8000. You can now open your browser and view the documentation.
If port 8000 is already in use, you can specify a different port, for example:
```bash
python -m http.server 3000 -d build/html/
```
# Welcome to vLLM
<figure markdown="span">
![](./assets/logos/vllm-logo-text-light.png){ align="center" alt="vLLM" class="no-scaled-link" width="60%" }
</figure>
<p style="text-align:center">
<strong>Easy, fast, and cheap LLM serving for everyone
</strong>
</p>
<p style="text-align:center">
<script async defer src="https://buttons.github.io/buttons.js"></script>
<a class="github-button" href="https://github.com/vllm-project/vllm" data-show-count="true" data-size="large" aria-label="Star">Star</a>
<a class="github-button" href="https://github.com/vllm-project/vllm/subscription" data-show-count="true" data-icon="octicon-eye" data-size="large" aria-label="Watch">Watch</a>
<a class="github-button" href="https://github.com/vllm-project/vllm/fork" data-show-count="true" data-icon="octicon-repo-forked" data-size="large" aria-label="Fork">Fork</a>
</p>
vLLM is a fast and easy-to-use library for LLM inference and serving.
Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
vLLM is fast with:
- State-of-the-art serving throughput
- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
- Continuous batching of incoming requests
- Fast model execution with CUDA/HIP graph
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8
- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer.
- Speculative decoding
- Chunked prefill
vLLM is flexible and easy to use with:
- Seamless integration with popular HuggingFace models
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
- Tensor parallelism and pipeline parallelism support for distributed inference
- Streaming outputs
- OpenAI-compatible API server
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
- Prefix caching support
- Multi-lora support
For more information, check out the following:
- [vLLM announcing blog post](https://vllm.ai) (intro to PagedAttention)
- [vLLM paper](https://arxiv.org/abs/2309.06180) (SOSP 2023)
- [How continuous batching enables 23x throughput in LLM inference while reducing p50 latency](https://www.anyscale.com/blog/continuous-batching-llm-inference) by Cade Daniel et al.
- [vLLM Meetups][meetups]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment