Unverified Commit b7c0942b authored by Charlie Fu's avatar Charlie Fu Committed by GitHub
Browse files

[ROCm][Misc] Rename the context_len to seq_len in ROCm custom paged attention kernel (#22097)


Signed-off-by: default avatarcharlifu <charlifu@amd.com>
parent 9a0c5ded
...@@ -270,7 +270,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -270,7 +270,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int num_kv_heads, const int num_kv_heads,
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
...@@ -304,12 +304,12 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -304,12 +304,12 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const auto max_num_partitions = gridDim.y; const auto max_num_partitions = gridDim.y;
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int partition_start_token_idx = const int partition_start_token_idx =
partition_idx * T_PAR_SIZE; // partition_size; partition_idx * T_PAR_SIZE; // partition_size;
// exit if partition is out of context for seq // exit if partition is out of context for seq
if (partition_start_token_idx >= context_len) { if (partition_start_token_idx >= seq_len) {
return; return;
} }
...@@ -361,8 +361,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -361,8 +361,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
// output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens // output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens
// across 4 rows x 4 tokens per lane // across 4 rows x 4 tokens per lane
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int last_ctx_block = num_context_blocks - 1; const int last_seq_block = num_seq_blocks - 1;
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
...@@ -373,9 +373,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -373,9 +373,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int klocal_token_idx = const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
const int kblock_idx = (kglobal_token_idx < context_len) const int kblock_idx = (kglobal_token_idx < seq_len)
? kglobal_token_idx / BLOCK_SIZE ? kglobal_token_idx / BLOCK_SIZE
: last_ctx_block; : last_seq_block;
kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
} }
...@@ -476,9 +476,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -476,9 +476,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
// tokens // tokens
const int vglobal_token_idx = const int vglobal_token_idx =
partition_start_token_idx + vlocal_token_idx; partition_start_token_idx + vlocal_token_idx;
const int vblock_idx = (vglobal_token_idx < context_len) const int vblock_idx = (vglobal_token_idx < seq_len)
? vglobal_token_idx / BLOCK_SIZE ? vglobal_token_idx / BLOCK_SIZE
: last_ctx_block; : last_seq_block;
vphysical_block_number[vtoken_depth][vblock_depth] = vphysical_block_number[vtoken_depth][vblock_depth] =
block_table_seq[vblock_idx]; block_table_seq[vblock_idx];
} }
...@@ -554,7 +554,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -554,7 +554,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
if constexpr (ALIBI_ENABLED) { if constexpr (ALIBI_ENABLED) {
for (int token_depth = 0; token_depth < TLOOP; token_depth++) { for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16; const int local_token_idx = qkout_token_idx + token_depth * 16;
const int alibi_offset = local_token_idx - context_len + 1; const int alibi_offset = local_token_idx - seq_len + 1;
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
d_out[token_depth][i] += alibi_slope * (alibi_offset + i); d_out[token_depth][i] += alibi_slope * (alibi_offset + i);
} }
...@@ -568,9 +568,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -568,9 +568,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
for (int token_depth = 0; token_depth < TLOOP; token_depth++) { for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16; const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
const float tmp = (local_token_idx + i < context_len) const float tmp =
? d_out[token_depth][i] (local_token_idx + i < seq_len) ? d_out[token_depth][i] : -FLT_MAX;
: -FLT_MAX;
qk_max = fmaxf(qk_max, tmp); qk_max = fmaxf(qk_max, tmp);
} }
} }
...@@ -582,7 +581,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -582,7 +581,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
for (int token_depth = 0; token_depth < TLOOP; token_depth++) { for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16; const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
const float tmp = (local_token_idx + i < context_len) const float tmp = (local_token_idx + i < seq_len)
? __expf(d_out[token_depth][i] - qk_max) ? __expf(d_out[token_depth][i] - qk_max)
: 0.0f; : 0.0f;
d_out[token_depth][i] = tmp; d_out[token_depth][i] = tmp;
...@@ -780,7 +779,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( ...@@ -780,7 +779,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int num_kv_heads, const int num_kv_heads,
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
...@@ -809,10 +808,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( ...@@ -809,10 +808,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const auto partition_size = blockDim.x; const auto partition_size = blockDim.x;
const auto max_num_partitions = gridDim.y; const auto max_num_partitions = gridDim.y;
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int partition_start_token_idx = partition_idx * partition_size; const int partition_start_token_idx = partition_idx * partition_size;
// exit if partition is out of context for seq // exit if partition is out of context for seq
if (partition_start_token_idx >= context_len) { if (partition_start_token_idx >= seq_len) {
return; return;
} }
// every 4 lanes fetch 4 different qheads // every 4 lanes fetch 4 different qheads
...@@ -855,7 +854,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( ...@@ -855,7 +854,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int warp_start_token_idx = const int warp_start_token_idx =
partition_start_token_idx + warpid * WARP_SIZE; partition_start_token_idx + warpid * WARP_SIZE;
if (warp_start_token_idx >= context_len) { // warp out of context if (warp_start_token_idx >= seq_len) { // warp out of context
#pragma unroll #pragma unroll
for (int h = 0; h < GQA_RATIO4; h++) { for (int h = 0; h < GQA_RATIO4; h++) {
shared_qk_max[warpid][h] = -FLT_MAX; shared_qk_max[warpid][h] = -FLT_MAX;
...@@ -863,8 +862,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( ...@@ -863,8 +862,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
} }
} else { // warp within context } else { // warp within context
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int last_ctx_block = num_context_blocks - 1; const int last_seq_block = num_seq_blocks - 1;
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
// token id within partition // token id within partition
...@@ -873,9 +872,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( ...@@ -873,9 +872,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int global_token_idx = partition_start_token_idx + local_token_idx; const int global_token_idx = partition_start_token_idx + local_token_idx;
// fetch block number for k // fetch block number for k
const int block_idx = (global_token_idx < context_len) const int block_idx = (global_token_idx < seq_len)
? global_token_idx / BLOCK_SIZE ? global_token_idx / BLOCK_SIZE
: last_ctx_block; : last_seq_block;
// fetch k physical block number // fetch k physical block number
// int32 physical_block_number leads to overflow when multiplied with // int32 physical_block_number leads to overflow when multiplied with
...@@ -888,7 +887,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( ...@@ -888,7 +887,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
for (int b = 0; b < VBLOCKS; b++) { for (int b = 0; b < VBLOCKS; b++) {
const int vblock_idx = warp_start_block_idx + b; const int vblock_idx = warp_start_block_idx + b;
const int vblock_idx_ctx = const int vblock_idx_ctx =
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; (vblock_idx <= last_seq_block) ? vblock_idx : last_seq_block;
vphysical_blocks[b] = block_table[vblock_idx_ctx]; vphysical_blocks[b] = block_table[vblock_idx_ctx];
} }
...@@ -1057,7 +1056,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( ...@@ -1057,7 +1056,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int lane4_token_idx = 4 * (global_token_idx >> 2); const int lane4_token_idx = 4 * (global_token_idx >> 2);
if constexpr (ALIBI_ENABLED) { if constexpr (ALIBI_ENABLED) {
const int alibi_offset = lane4_token_idx - context_len + 1; const int alibi_offset = lane4_token_idx - seq_len + 1;
for (int h = 0; h < QHLOOP; h++) { for (int h = 0; h < QHLOOP; h++) {
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
d_out[h][i] += alibi_slope[h] * (alibi_offset + i); d_out[h][i] += alibi_slope[h] * (alibi_offset + i);
...@@ -1070,7 +1069,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( ...@@ -1070,7 +1069,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
for (int h = 0; h < QHLOOP; h++) { for (int h = 0; h < QHLOOP; h++) {
qk_max[h] = -FLT_MAX; qk_max[h] = -FLT_MAX;
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
qk_max[h] = (lane4_token_idx + i < context_len) qk_max[h] = (lane4_token_idx + i < seq_len)
? fmaxf(qk_max[h], d_out[h][i]) ? fmaxf(qk_max[h], d_out[h][i])
: qk_max[h]; : qk_max[h];
} }
...@@ -1101,7 +1100,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( ...@@ -1101,7 +1100,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
for (int h = 0; h < QHLOOP; h++) { for (int h = 0; h < QHLOOP; h++) {
exp_sum[h] = 0.0f; exp_sum[h] = 0.0f;
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
d_out[h][i] = (lane4_token_idx + i < context_len) d_out[h][i] = (lane4_token_idx + i < seq_len)
? __expf(d_out[h][i] - qk_max[h]) ? __expf(d_out[h][i] - qk_max[h])
: 0.0f; : 0.0f;
exp_sum[h] += d_out[h][i]; exp_sum[h] += d_out[h][i];
...@@ -1181,7 +1180,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( ...@@ -1181,7 +1180,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
} }
} }
if (warp_start_token_idx >= context_len) { // warp out of context if (warp_start_token_idx >= seq_len) { // warp out of context
for (int qh = 0; qh < QHLOOP; qh++) { for (int qh = 0; qh < QHLOOP; qh++) {
for (int vh = 0; vh < VHELOOP; vh++) { for (int vh = 0; vh < VHELOOP; vh++) {
vout_shared[qh][vh][laneid][warpid] = {0}; vout_shared[qh][vh][laneid][warpid] = {0};
...@@ -1279,7 +1278,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -1279,7 +1278,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions] // max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size] // max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x; const auto num_heads = gridDim.x;
...@@ -1293,8 +1292,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -1293,8 +1292,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
return; return;
} }
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
const auto warpid = threadIdx.x / WARP_SIZE; const auto warpid = threadIdx.x / WARP_SIZE;
__shared__ float shared_global_exp_sum; __shared__ float shared_global_exp_sum;
...@@ -1581,7 +1580,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -1581,7 +1580,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
...@@ -1615,11 +1614,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -1615,11 +1614,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int max_num_partitions = gridDim.y; const int max_num_partitions = gridDim.y;
const int context_len = context_lens[seq_idx]; // length of a seq const int seq_len = seq_lens[seq_idx]; // length of a seq
const int partition_start_token_idx = partition_idx * T_PAR_SIZE; const int partition_start_token_idx = partition_idx * T_PAR_SIZE;
// exit if partition is out of context for seq // exit if partition is out of context for seq
if (partition_start_token_idx >= context_len) { if (partition_start_token_idx >= seq_len) {
return; return;
} }
...@@ -1715,8 +1714,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -1715,8 +1714,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
} }
} }
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int last_ctx_block = num_context_blocks - 1; const int last_seq_block = num_seq_blocks - 1;
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
...@@ -1727,9 +1726,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -1727,9 +1726,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int klocal_token_idx = const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
const int kblock_idx = (kglobal_token_idx < context_len) const int kblock_idx = (kglobal_token_idx < seq_len)
? kglobal_token_idx / BLOCK_SIZE ? kglobal_token_idx / BLOCK_SIZE
: last_ctx_block; : last_seq_block;
kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
} }
...@@ -1781,9 +1780,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -1781,9 +1780,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
vblock_depth * BLOCK_SIZE; vblock_depth * BLOCK_SIZE;
const int vglobal_token_idx = const int vglobal_token_idx =
partition_start_token_idx + vlocal_token_idx; partition_start_token_idx + vlocal_token_idx;
const int vblock_idx = (vglobal_token_idx < context_len) const int vblock_idx = (vglobal_token_idx < seq_len)
? vglobal_token_idx / BLOCK_SIZE ? vglobal_token_idx / BLOCK_SIZE
: last_ctx_block; : last_seq_block;
vphysical_block_number[vtoken_depth][vblock_depth] = vphysical_block_number[vtoken_depth][vblock_depth] =
block_table_seq[vblock_idx]; block_table_seq[vblock_idx];
} }
...@@ -1836,9 +1835,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -1836,9 +1835,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
for (int token_depth = 0; token_depth < TLOOP; token_depth++) { for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16; const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
const float tmp = (local_token_idx + 2 * i < context_len) const float tmp =
? dout[token_depth][i] (local_token_idx + 2 * i < seq_len) ? dout[token_depth][i] : -FLT_MAX;
: -FLT_MAX;
qk_max = fmaxf(qk_max, tmp); qk_max = fmaxf(qk_max, tmp);
} }
} }
...@@ -1848,7 +1846,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -1848,7 +1846,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
for (int token_depth = 0; token_depth < TLOOP; token_depth++) { for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16; const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
const float tmp = (local_token_idx + 2 * i < context_len) const float tmp = (local_token_idx + 2 * i < seq_len)
? __expf(dout[token_depth][i] - qk_max) ? __expf(dout[token_depth][i] - qk_max)
: 0.0f; : 0.0f;
dout[token_depth][i] = tmp; dout[token_depth][i] = tmp;
...@@ -2019,7 +2017,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( ...@@ -2019,7 +2017,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
...@@ -2046,7 +2044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -2046,7 +2044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions] // max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size] // max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x; const auto num_heads = gridDim.x;
...@@ -2060,8 +2058,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -2060,8 +2058,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
return; return;
} }
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
const int warpid = threadIdx.x / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE;
__shared__ float shared_global_exp_sum; __shared__ float shared_global_exp_sum;
...@@ -2349,7 +2347,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -2349,7 +2347,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
...@@ -2382,11 +2380,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -2382,11 +2380,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int max_num_partitions = gridDim.y; const int max_num_partitions = gridDim.y;
const int context_len = context_lens[seq_idx]; // length of a seq const int seq_len = seq_lens[seq_idx]; // length of a seq
const int partition_start_token_idx = partition_idx * T_PAR_SIZE; const int partition_start_token_idx = partition_idx * T_PAR_SIZE;
// exit if partition is out of context for seq // exit if partition is out of context for seq
if (partition_start_token_idx >= context_len) { if (partition_start_token_idx >= seq_len) {
return; return;
} }
...@@ -2482,8 +2480,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -2482,8 +2480,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
} }
} }
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int last_ctx_block = num_context_blocks - 1; const int last_seq_block = num_seq_blocks - 1;
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
...@@ -2494,9 +2492,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -2494,9 +2492,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int klocal_token_idx = const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
const int kblock_idx = (kglobal_token_idx < context_len) const int kblock_idx = (kglobal_token_idx < seq_len)
? kglobal_token_idx / BLOCK_SIZE ? kglobal_token_idx / BLOCK_SIZE
: last_ctx_block; : last_seq_block;
kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
} }
...@@ -2548,9 +2546,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -2548,9 +2546,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE; rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE;
const int vglobal_token_idx = const int vglobal_token_idx =
partition_start_token_idx + vlocal_token_idx; partition_start_token_idx + vlocal_token_idx;
const int vblock_idx = (vglobal_token_idx < context_len) const int vblock_idx = (vglobal_token_idx < seq_len)
? vglobal_token_idx / BLOCK_SIZE ? vglobal_token_idx / BLOCK_SIZE
: last_ctx_block; : last_seq_block;
vphysical_block_number[vtoken_depth][vblock_depth] = vphysical_block_number[vtoken_depth][vblock_depth] =
block_table_seq[vblock_idx]; block_table_seq[vblock_idx];
} }
...@@ -2604,7 +2602,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -2604,7 +2602,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int local_token_idx = qkout_token_idx + token_depth * 16; const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
const float tmp = const float tmp =
(local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX; (local_token_idx + i < seq_len) ? dout[token_depth][i] : -FLT_MAX;
qk_max = fmaxf(qk_max, tmp); qk_max = fmaxf(qk_max, tmp);
} }
} }
...@@ -2614,7 +2612,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -2614,7 +2612,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
for (int token_depth = 0; token_depth < TLOOP; token_depth++) { for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16; const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
const float tmp = (local_token_idx + i < context_len) const float tmp = (local_token_idx + i < seq_len)
? __expf(dout[token_depth][i] - qk_max) ? __expf(dout[token_depth][i] - qk_max)
: 0.0f; : 0.0f;
dout[token_depth][i] = tmp; dout[token_depth][i] = tmp;
...@@ -2751,7 +2749,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( ...@@ -2751,7 +2749,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
...@@ -2778,7 +2776,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -2778,7 +2776,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions] // max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size] // max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x; const auto num_heads = gridDim.x;
...@@ -2792,8 +2790,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -2792,8 +2790,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
return; return;
} }
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
const int warpid = threadIdx.x / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE;
__shared__ float shared_global_exp_sum; __shared__ float shared_global_exp_sum;
...@@ -2980,7 +2978,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( ...@@ -2980,7 +2978,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int num_kv_heads, const int num_kv_heads,
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
...@@ -3007,7 +3005,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( ...@@ -3007,7 +3005,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int num_kv_heads, const int num_kv_heads,
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
...@@ -3031,7 +3029,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -3031,7 +3029,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
UNREACHABLE_CODE UNREACHABLE_CODE
...@@ -3046,7 +3044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -3046,7 +3044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
GQA_RATIO> \ GQA_RATIO> \
<<<grid, block, 0, stream>>>( \ <<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr); max_ctx_blocks, k_scale_ptr, v_scale_ptr);
...@@ -3057,7 +3055,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -3057,7 +3055,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
GQA_RATIO> \ GQA_RATIO> \
<<<grid, block, 0, stream>>>( \ <<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr); max_ctx_blocks, k_scale_ptr, v_scale_ptr);
...@@ -3066,9 +3064,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -3066,9 +3064,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \ paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
PARTITION_SIZE, NPAR_LOOPS> \ PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \ <<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \ query_start_loc_ptr, max_num_partitions, fp8_out_scale_ptr);
fp8_out_scale_ptr);
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE, template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD, int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
...@@ -3077,8 +3074,8 @@ void paged_attention_custom_launcher( ...@@ -3077,8 +3074,8 @@ void paged_attention_custom_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale, torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens, torch::Tensor& block_tables, torch::Tensor& seq_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_context_len, const std::optional<torch::Tensor>& query_start_loc, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale, const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale) { torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale) {
int num_seqs = block_tables.size(0); int num_seqs = block_tables.size(0);
...@@ -3109,7 +3106,7 @@ void paged_attention_custom_launcher( ...@@ -3109,7 +3106,7 @@ void paged_attention_custom_launcher(
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr()); KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr()); KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr()); const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
// NOTE: fp8_out_scale is optional. // NOTE: fp8_out_scale is optional.
...@@ -3119,13 +3116,12 @@ void paged_attention_custom_launcher( ...@@ -3119,13 +3116,12 @@ void paged_attention_custom_launcher(
: nullptr; : nullptr;
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr()); OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE);
// partition size is fixed at 256 since both mfma4 and mfma16 kernels support // partition size is fixed at 256 since both mfma4 and mfma16 kernels support
// it mfma4 kernel also supports partition size 512 // it mfma4 kernel also supports partition size 512
constexpr int PARTITION_SIZE = 256; constexpr int PARTITION_SIZE = 256;
const int max_num_partitions = const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
const int gqa_ratio = num_heads / num_kv_heads; const int gqa_ratio = num_heads / num_kv_heads;
assert(num_heads % num_kv_heads == 0); assert(num_heads % num_kv_heads == 0);
assert(head_size == HEAD_SIZE); assert(head_size == HEAD_SIZE);
...@@ -3234,8 +3230,8 @@ void paged_attention_custom_launcher_navi( ...@@ -3234,8 +3230,8 @@ void paged_attention_custom_launcher_navi(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale, torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens, torch::Tensor& block_tables, torch::Tensor& seq_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_context_len, const std::optional<torch::Tensor>& query_start_loc, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale, const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale) { torch::Tensor& v_scale) {
int num_seqs = block_tables.size(0); int num_seqs = block_tables.size(0);
...@@ -3263,7 +3259,7 @@ void paged_attention_custom_launcher_navi( ...@@ -3263,7 +3259,7 @@ void paged_attention_custom_launcher_navi(
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr()); KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr()); KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr()); const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
...@@ -3271,11 +3267,10 @@ void paged_attention_custom_launcher_navi( ...@@ -3271,11 +3267,10 @@ void paged_attention_custom_launcher_navi(
const auto fp8_out_scale_ptr = nullptr; const auto fp8_out_scale_ptr = nullptr;
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr()); OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE);
constexpr int PARTITION_SIZE = 256; constexpr int PARTITION_SIZE = 256;
const int max_num_partitions = const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
const int gqa_ratio = num_heads / num_kv_heads; const int gqa_ratio = num_heads / num_kv_heads;
assert(num_heads % num_kv_heads == 0); assert(num_heads % num_kv_heads == 0);
assert(head_size == HEAD_SIZE); assert(head_size == HEAD_SIZE);
...@@ -3407,14 +3402,14 @@ void paged_attention_custom_launcher_navi( ...@@ -3407,14 +3402,14 @@ void paged_attention_custom_launcher_navi(
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
OUTT, PSIZE, ALIBI_ENABLED>( \ OUTT, PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \ max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
} else { \ } else { \
paged_attention_custom_launcher_navi< \ paged_attention_custom_launcher_navi< \
T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \ T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale); \ max_seq_len, alibi_slopes, k_scale, v_scale); \
} }
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ #define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
...@@ -3502,9 +3497,9 @@ void paged_attention( ...@@ -3502,9 +3497,9 @@ void paged_attention(
int64_t num_kv_heads, int64_t num_kv_heads,
double scale, double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
const std::optional<torch::Tensor>& query_start_loc, // [num_seqs] const std::optional<torch::Tensor>& query_start_loc, // [num_seqs]
int64_t block_size, int64_t max_context_len, int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, torch::Tensor& v_scale,
......
...@@ -15,8 +15,8 @@ void paged_attention( ...@@ -15,8 +15,8 @@ void paged_attention(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& context_lens, torch::Tensor& block_tables, torch::Tensor& seq_lens,
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size, const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
int64_t max_context_len, const std::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale); torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale);
...@@ -41,10 +41,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { ...@@ -41,10 +41,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" Tensor query, Tensor key_cache," " Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads," " Tensor value_cache, int num_kv_heads,"
" float scale, Tensor block_tables," " float scale, Tensor block_tables,"
" Tensor context_lens," " Tensor seq_lens,"
" Tensor? query_start_loc," " Tensor? query_start_loc,"
" int block_size," " int block_size,"
" int max_context_len," " int max_seq_len,"
" Tensor? alibi_slopes," " Tensor? alibi_slopes,"
" str kv_cache_dtype," " str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale," " Tensor k_scale, Tensor v_scale,"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment