Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
749242a0
Commit
749242a0
authored
Aug 09, 2024
by
flyingdown
Browse files
Revert "pa add v prefetch for gemm1"
This reverts commit
f38bd872
.
parent
dcaabcf7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
365 additions
and
492 deletions
+365
-492
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+365
-492
No files found.
csrc/attention/attention_kernels.cu
View file @
749242a0
...
@@ -20,6 +20,7 @@ typedef __hip_bfloat16 __nv_bfloat16;
...
@@ -20,6 +20,7 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define WARP_SIZE warpSize
#define WARP_SIZE warpSize
#endif
#endif
#include "static_switch.h"
#include "static_switch.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
...
@@ -80,7 +81,7 @@ __device__ void paged_attention_kernel(
...
@@ -80,7 +81,7 @@ __device__ void paged_attention_kernel(
// head_size/x, block_size, x]
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
// head_size, block_size]
const
int
num_heads
,
// [num_heads]
const
int
num_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_kv_heads]
const
int
num_kv_heads
,
// [num_kv_heads]
const
float
scale
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
...
@@ -105,8 +106,7 @@ __device__ void paged_attention_kernel(
...
@@ -105,8 +106,7 @@ __device__ void paged_attention_kernel(
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
int
num_blocks_per_partition
=
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
partition_size
=
const
int
partition_size
=
USE_PARTITIONING
?
PARTITION_SIZE
:
num_seq_blocks
*
BLOCK_SIZE
;
USE_PARTITIONING
?
PARTITION_SIZE
:
num_seq_blocks
*
BLOCK_SIZE
;
// [start_block_idx, end_block_idx) is the range of blocks to process.
// [start_block_idx, end_block_idx) is the range of blocks to process.
const
int
start_block_idx
=
const
int
start_block_idx
=
USE_PARTITIONING
?
partition_idx
*
num_blocks_per_partition
:
0
;
USE_PARTITIONING
?
partition_idx
*
num_blocks_per_partition
:
0
;
...
@@ -139,15 +139,15 @@ __device__ void paged_attention_kernel(
...
@@ -139,15 +139,15 @@ __device__ void paged_attention_kernel(
// const int lane = thread_idx % WARP_SIZE;
// const int lane = thread_idx % WARP_SIZE;
//
const int warp_idx = thread_idx / WARP_SIZE;
//const int warp_idx = thread_idx / WARP_SIZE;
const
int
lane
=
thread_idx
%
WARP_SIZE
;
const
int
lane
=
thread_idx
%
WARP_SIZE
;
int
warp_id_vec
=
threadIdx
.
x
/
WARP_SIZE
;
//
warp id in a block
int
warp_id_vec
=
threadIdx
.
x
/
WARP_SIZE
;
//warp id in a block
int
warp_idx
=
0
;
int
warp_idx
=
0
;
asm
volatile
(
"v_readfirstlane_b32 %0,%1"
asm
volatile
(
"v_readfirstlane_b32 %0,%1"
:
"=s"
(
warp_idx
)
:
"=s"
(
warp_idx
)
:
"v"
(
warp_id_vec
)
:
"v"
(
warp_id_vec
)
:
);
:
);
// const int head_idx = blockIdx.x;
// const int head_idx = blockIdx.x;
// const int num_heads = gridDim.x;
// const int num_heads = gridDim.x;
...
@@ -180,18 +180,16 @@ __device__ void paged_attention_kernel(
...
@@ -180,18 +180,16 @@ __device__ void paged_attention_kernel(
// const scalar_t* q_ptr = q + seq_idx * q_stride;
// const scalar_t* q_ptr = q + seq_idx * q_stride;
const
scalar_t
*
q_ptr_offset
=
q
+
seq_idx
*
q_stride
;
const
scalar_t
*
q_ptr_offset
=
q
+
seq_idx
*
q_stride
;
__shared__
Q_vec
__shared__
Q_vec
q_vecs
[
REUSE_KV_TIMES
*
THREAD_GROUP_SIZE
][
NUM_VECS_PER_THREAD
];
q_vecs
[
REUSE_KV_TIMES
*
THREAD_GROUP_SIZE
][
NUM_VECS_PER_THREAD
];
// #pragma unroll
// #pragma unroll
// for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
// for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
// i += NUM_THREAD_GROUPS) {
// i += NUM_THREAD_GROUPS) {
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
// q_vecs[thread_group_offset][i] =
// q_vecs[thread_group_offset][i] =
// *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
// *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
// }
// }
// __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// __syncthreads(); // TODO(naed90): possible speedup if this is replaced
// // memory wall right before we use q_vecs
// with a
// // memory wall right before we use q_vecs
// Memory planning.
// Memory planning.
extern
__shared__
char
shared_mem
[];
extern
__shared__
char
shared_mem
[];
...
@@ -199,8 +197,7 @@ __device__ void paged_attention_kernel(
...
@@ -199,8 +197,7 @@ __device__ void paged_attention_kernel(
float
*
logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
float
*
logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
// Workspace for reduction.
// Workspace for reduction.
__shared__
float
red_smem
[
REUSE_KV_TIMES
][
2
*
NUM_WARPS
];
__shared__
float
red_smem
[
REUSE_KV_TIMES
][
2
*
NUM_WARPS
];
// float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 *
// float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 * NUM_WARPS]>(&shared_mem[10*1024]);
// NUM_WARPS]>(&shared_mem[10*1024]);
// __shared__ char shared_mem[12 * 1024];
// __shared__ char shared_mem[12 * 1024];
// float* logits = reinterpret_cast<float*>(shared_mem);
// float* logits = reinterpret_cast<float*>(shared_mem);
...
@@ -211,173 +208,146 @@ __device__ void paged_attention_kernel(
...
@@ -211,173 +208,146 @@ __device__ void paged_attention_kernel(
constexpr
int
x
=
16
/
sizeof
(
cache_t
);
constexpr
int
x
=
16
/
sizeof
(
cache_t
);
float
qk_max
[
REUSE_KV_TIMES
];
float
qk_max
[
REUSE_KV_TIMES
];
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
qk_max
[
reuse_kv_idx
]
=
-
FLT_MAX
;
qk_max
[
reuse_kv_idx
]
=
-
FLT_MAX
;
}
}
const
int
num_blocks_per_kv
=
const
int
num_blocks_per_kv
=
((
num_queries_per_kv
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
);
((
num_queries_per_kv
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
);
const
int
head_idx_soffset
=
(
blockIdx
.
x
/
num_blocks_per_kv
)
*
num_queries_per_kv
+
(
blockIdx
.
x
%
num_blocks_per_kv
)
*
REUSE_KV_TIMES
;
const
int
head_idx_soffset
=
(
blockIdx
.
x
/
num_blocks_per_kv
)
*
num_queries_per_kv
+
(
blockIdx
.
x
%
num_blocks_per_kv
)
*
REUSE_KV_TIMES
;
const
int
kv_head_idx
=
head_idx_soffset
/
num_queries_per_kv
;
const
int
kv_head_idx
=
head_idx_soffset
/
num_queries_per_kv
;
const
int
q_boundary
=
(
kv_head_idx
+
1
)
*
num_queries_per_kv
;
const
int
q_boundary
=
(
kv_head_idx
+
1
)
*
num_queries_per_kv
;
#pragma unroll
#pragma unroll
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
const
int
head_idx
=
const
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
head_idx_soffset
+
reuse_kv_idx
;
// blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
const
scalar_t
*
q_ptr
=
q_ptr_offset
+
head_idx
*
HEAD_SIZE
;
const
scalar_t
*
q_ptr
=
q_ptr_offset
+
head_idx
*
HEAD_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
thread_group_idx
;
i
<
NUM_VECS_PER_THREAD
;
for
(
int
i
=
thread_group_idx
;
i
<
NUM_VECS_PER_THREAD
;
i
+=
NUM_THREAD_GROUPS
)
{
i
+=
NUM_THREAD_GROUPS
)
{
const
int
vec_idx
=
thread_group_offset
+
i
*
THREAD_GROUP_SIZE
;
const
int
vec_idx
=
thread_group_offset
+
i
*
THREAD_GROUP_SIZE
;
q_vecs
[
reuse_kv_idx
*
THREAD_GROUP_SIZE
+
thread_group_offset
][
i
]
=
q_vecs
[
reuse_kv_idx
*
THREAD_GROUP_SIZE
+
thread_group_offset
][
i
]
=
*
reinterpret_cast
<
const
Q_vec
*>
(
q_ptr
+
vec_idx
*
VEC_SIZE
);
*
reinterpret_cast
<
const
Q_vec
*>
(
q_ptr
+
vec_idx
*
VEC_SIZE
);
}
}
}
}
__syncthreads
();
// TODO(naed90): possible speedup if this is replaced with a
__syncthreads
();
// TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
// memory wall right before we use q_vecs
// Iterate over the key blocks.
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
// dot product with the query.
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
block_idx
+=
NUM_WARPS
)
{
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// For blocksparse attention: skip computation on blocks that are not
// attended
// attended
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
const
int
head_idx
=
const
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
head_idx_soffset
+
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
reuse_kv_idx
;
// blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
// blocksparse specific vars
// blocksparse specific vars
int
bs_block_offset
;
int
bs_block_offset
;
int
q_bs_block_id
;
int
q_bs_block_id
;
if
constexpr
(
IS_BLOCK_SPARSE
)
{
if
constexpr
(
IS_BLOCK_SPARSE
)
{
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
// blocksparse_block_size);
q_bs_block_id
=
(
seq_len
-
1
)
/
blocksparse_block_size
;
q_bs_block_id
=
(
seq_len
-
1
)
/
blocksparse_block_size
;
if
(
blocksparse_head_sliding_step
>=
0
)
if
(
blocksparse_head_sliding_step
>=
0
)
// sliding on q heads
// sliding on q heads
bs_block_offset
=
(
tp_rank
*
num_heads
+
head_idx
)
*
bs_block_offset
=
blocksparse_head_sliding_step
+
(
tp_rank
*
num_heads
+
head_idx
)
*
blocksparse_head_sliding_step
+
1
;
1
;
else
else
// sliding on kv heads
// sliding on kv heads
bs_block_offset
=
(
tp_rank
*
num_kv_heads
+
kv_head_idx
)
*
bs_block_offset
=
(
tp_rank
*
num_kv_heads
+
kv_head_idx
)
*
(
-
blocksparse_head_sliding_step
)
+
(
-
blocksparse_head_sliding_step
)
+
1
;
1
;
}
}
if
constexpr
(
IS_BLOCK_SPARSE
)
{
if
constexpr
(
IS_BLOCK_SPARSE
)
{
const
int
k_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
const
int
k_bs_block_id
=
const
bool
is_remote
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
((
k_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
);
const
bool
is_remote
=
const
bool
is_local
=
((
k_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
(
k_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
);
0
);
if
(
!
is_remote
&&
!
is_local
)
{
const
bool
is_local
=
(
k_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
);
if
(
!
is_remote
&&
!
is_local
)
{
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
const
int
physical_block_offset
=
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
if
(
thread_group_offset
==
0
)
{
// NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This
// will not be used at computing sum(softmax*v) as the blocks
// will be skipped.
logits
[
token_idx
-
start_token_idx
]
=
-
FLT_MAX
;
}
}
continue
;
}
}
const
float
alibi_slope
=
alibi_slopes
==
nullptr
?
0.
f
:
alibi_slopes
[
head_idx
];
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread
// in the group has 0, 4, 8, ... th vectors of the key, and the second
// thread has 1, 5, 9, ... th vectors of the key, and so on.
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
const
int
physical_block_offset
=
const
int
physical_block_offset
=
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
K_vec
k_vecs
[
NUM_VECS_PER_THREAD
];
if
(
reuse_kv_idx
==
0
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM_VECS_PER_THREAD
;
j
++
)
{
const
cache_t
*
k_ptr
=
k_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
physical_block_offset
*
x
;
const
int
vec_idx
=
thread_group_offset
+
j
*
THREAD_GROUP_SIZE
;
const
int
offset1
=
(
vec_idx
*
VEC_SIZE
)
/
x
;
const
int
offset2
=
(
vec_idx
*
VEC_SIZE
)
%
x
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
k_vecs
[
j
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
}
else
{
// Vector conversion from Quant_vec to K_vec.
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vec_quant
,
kv_scale
);
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// Compute dot product.
// This includes a reduction across the threads in the same thread
// group.
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
[
reuse_kv_idx
*
THREAD_GROUP_SIZE
+
thread_group_offset
],
k_vecs
);
// Add the ALiBi bias if slopes are given.
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq_len
+
1
)
:
0
;
__builtin_amdgcn_sched_barrier
(
0
);
if
(
thread_group_offset
==
0
)
{
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// NOTE(linxihui): assign very large number to skipped tokens to
// NOTE(woosuk): It is required to zero out the masked logits.
// avoid contribution to the sumexp softmax normalizer. This will
const
bool
mask
=
token_idx
>=
seq_len
;
// not be used at computing sum(softmax*v) as the blocks will be
logits
[(
reuse_kv_idx
*
partition_size
)
+
// skipped.
(
token_idx
-
start_token_idx
)]
=
mask
?
0.
f
:
qk
;
logits
[
token_idx
-
start_token_idx
]
=
-
FLT_MAX
;
// Update the max value.
}
qk_max
[
reuse_kv_idx
]
=
}
mask
?
qk_max
[
reuse_kv_idx
]
:
fmaxf
(
qk_max
[
reuse_kv_idx
],
qk
);
continue
;
}
}
const
float
alibi_slope
=
alibi_slopes
==
nullptr
?
0.
f
:
alibi_slopes
[
head_idx
];
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
const
int
physical_block_offset
=
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
K_vec
k_vecs
[
NUM_VECS_PER_THREAD
];
if
(
reuse_kv_idx
==
0
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM_VECS_PER_THREAD
;
j
++
)
{
const
cache_t
*
k_ptr
=
k_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
physical_block_offset
*
x
;
const
int
vec_idx
=
thread_group_offset
+
j
*
THREAD_GROUP_SIZE
;
const
int
offset1
=
(
vec_idx
*
VEC_SIZE
)
/
x
;
const
int
offset2
=
(
vec_idx
*
VEC_SIZE
)
%
x
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
k_vecs
[
j
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
}
else
{
// Vector conversion from Quant_vec to K_vec.
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vec_quant
,
kv_scale
);
}
}
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
[
reuse_kv_idx
*
THREAD_GROUP_SIZE
+
thread_group_offset
],
k_vecs
);
// Add the ALiBi bias if slopes are given.
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq_len
+
1
)
:
0
;
__builtin_amdgcn_sched_barrier
(
0
);
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const
bool
mask
=
token_idx
>=
seq_len
;
logits
[(
reuse_kv_idx
*
partition_size
)
+
(
token_idx
-
start_token_idx
)]
=
mask
?
0.
f
:
qk
;
// Update the max value.
qk_max
[
reuse_kv_idx
]
=
mask
?
qk_max
[
reuse_kv_idx
]
:
fmaxf
(
qk_max
[
reuse_kv_idx
],
qk
);
}
}
}
}
}
}
}
// Get the sum of the exp values.
// Get the sum of the exp values.
float
exp_sum
[
REUSE_KV_TIMES
]
=
{
0.
f
};
float
exp_sum
[
REUSE_KV_TIMES
]
=
{
0.
f
};
// Perform reduction across the threads in the same warp to get the
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
// The 0-th thread of each thread group already has its max qk value.
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
const
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
const
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
#pragma unroll
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREAD_GROUP_SIZE
;
mask
/=
2
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREAD_GROUP_SIZE
;
mask
/=
2
)
{
qk_max
[
reuse_kv_idx
]
=
qk_max
[
reuse_kv_idx
]
=
fmaxf
(
qk_max
[
reuse_kv_idx
],
VLLM_SHFL_XOR_SYNC
(
qk_max
[
reuse_kv_idx
],
mask
));
fmaxf
(
qk_max
[
reuse_kv_idx
],
VLLM_SHFL_XOR_SYNC
(
qk_max
[
reuse_kv_idx
],
mask
));
}
}
if
(
lane
==
0
)
{
if
(
lane
==
0
)
{
red_smem
[
reuse_kv_idx
][
warp_idx
]
=
qk_max
[
reuse_kv_idx
];
red_smem
[
reuse_kv_idx
][
warp_idx
]
=
qk_max
[
reuse_kv_idx
];
...
@@ -386,25 +356,20 @@ __device__ void paged_attention_kernel(
...
@@ -386,25 +356,20 @@ __device__ void paged_attention_kernel(
// TODO(woosuk): Refactor this part.
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
// Get the max qk value for the sequence.
qk_max
[
reuse_kv_idx
]
=
qk_max
[
reuse_kv_idx
]
=
lane
<
NUM_WARPS
?
red_smem
[
reuse_kv_idx
][
lane
]
:
-
FLT_MAX
;
lane
<
NUM_WARPS
?
red_smem
[
reuse_kv_idx
][
lane
]
:
-
FLT_MAX
;
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max
[
reuse_kv_idx
]
=
qk_max
[
reuse_kv_idx
]
=
fmaxf
(
qk_max
[
reuse_kv_idx
],
VLLM_SHFL_XOR_SYNC
(
qk_max
[
reuse_kv_idx
],
mask
));
fmaxf
(
qk_max
[
reuse_kv_idx
],
VLLM_SHFL_XOR_SYNC
(
qk_max
[
reuse_kv_idx
],
mask
));
}
}
// Broadcast the max qk value to all threads.
// Broadcast the max qk value to all threads.
qk_max
[
reuse_kv_idx
]
=
VLLM_SHFL_SYNC
(
qk_max
[
reuse_kv_idx
],
0
);
qk_max
[
reuse_kv_idx
]
=
VLLM_SHFL_SYNC
(
qk_max
[
reuse_kv_idx
],
0
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
float
val
=
__expf
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
-
float
val
=
__expf
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
-
qk_max
[
reuse_kv_idx
]);
qk_max
[
reuse_kv_idx
]);
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
=
val
;
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
=
val
;
exp_sum
[
reuse_kv_idx
]
+=
val
;
exp_sum
[
reuse_kv_idx
]
+=
val
;
}
}
exp_sum
[
reuse_kv_idx
]
=
block_sum
<
NUM_WARPS
>
(
exp_sum
[
reuse_kv_idx
]
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
reuse_kv_idx
][
NUM_WARPS
],
exp_sum
[
reuse_kv_idx
]);
&
red_smem
[
reuse_kv_idx
][
NUM_WARPS
],
exp_sum
[
reuse_kv_idx
]);
// Compute softmax.
// Compute softmax.
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
[
reuse_kv_idx
]
+
1e-6
f
);
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
[
reuse_kv_idx
]
+
1e-6
f
);
...
@@ -419,8 +384,7 @@ __device__ void paged_attention_kernel(
...
@@ -419,8 +384,7 @@ __device__ void paged_attention_kernel(
seq_idx
*
num_heads
*
max_num_partitions
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
head_idx
*
max_num_partitions
+
partition_idx
;
*
max_logits_ptr
=
qk_max
[
reuse_kv_idx
];
*
max_logits_ptr
=
qk_max
[
reuse_kv_idx
];
float
*
exp_sums_ptr
=
exp_sums
+
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
head_idx
*
max_num_partitions
+
partition_idx
;
*
exp_sums_ptr
=
exp_sum
[
reuse_kv_idx
];
*
exp_sums_ptr
=
exp_sum
[
reuse_kv_idx
];
}
}
...
@@ -441,11 +405,11 @@ __device__ void paged_attention_kernel(
...
@@ -441,11 +405,11 @@ __device__ void paged_attention_kernel(
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float
accs
[
REUSE_KV_TIMES
][
NUM_ROWS_PER_THREAD
];
float
accs
[
REUSE_KV_TIMES
][
NUM_ROWS_PER_THREAD
];
#pragma unroll
#pragma unroll
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
accs
[
reuse_kv_idx
][
i
]
=
0.
f
;
accs
[
reuse_kv_idx
][
i
]
=
0.
f
;
}
}
}
}
scalar_t
zero_value
;
scalar_t
zero_value
;
...
@@ -457,230 +421,155 @@ __device__ void paged_attention_kernel(
...
@@ -457,230 +421,155 @@ __device__ void paged_attention_kernel(
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
L_vec
logits_vec
;
L_vec
logits_vec
;
V_vec
v_vec
[
2
];
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
;
const
int
start_row_idx
=
lane
/
NUM_V_VECS_PER_ROW
;
if
(
start_row_idx
<
HEAD_SIZE
)
{
const
int
offset
=
start_row_idx
*
BLOCK_SIZE
+
physical_block_offset
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
v_vec
[
0
]
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
}
else
{
V_quant_vec
v_quant_vec
=
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
// Vector conversion from V_quant_vec to V_vec.
v_vec
[
0
]
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
kv_scale
);
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
[
0
]);
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
V_VEC_SIZE
;
j
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
V_vec
v_vec
;
}
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
// blocksparse specific vars
const
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
int
bs_block_offset
;
int
q_bs_block_id
;
if
constexpr
(
IS_BLOCK_SPARSE
)
{
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id
=
(
seq_len
-
1
)
/
blocksparse_block_size
;
if
(
blocksparse_head_sliding_step
>=
0
)
// sliding on q heads
bs_block_offset
=
(
tp_rank
*
num_heads
+
head_idx
)
*
blocksparse_head_sliding_step
+
1
;
else
// sliding on kv heads
bs_block_offset
=
(
tp_rank
*
num_kv_heads
+
kv_head_idx
)
*
(
-
blocksparse_head_sliding_step
)
+
1
;
}
}
}
if
constexpr
(
IS_BLOCK_SPARSE
)
{
#pragma unroll
int
v_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
for
(
int
i
=
1
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
if
(
!
((
v_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
)
&&
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
!
((
v_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
)))
{
reuse_kv_idx
++
)
{
continue
;
// NOTE(woosuk): The block number is stored in int32. However, we cast
// it to int64 because int32 can lead to overflow when this variable is
// multiplied by large numbers (e.g., kv_block_stride). For blocksparse
// attention: skip computation on blocks that are not attended
// blocksparse specific vars
const
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
int
bs_block_offset
;
int
q_bs_block_id
;
if
constexpr
(
IS_BLOCK_SPARSE
)
{
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id
=
(
seq_len
-
1
)
/
blocksparse_block_size
;
if
(
blocksparse_head_sliding_step
>=
0
)
// sliding on q heads
bs_block_offset
=
(
tp_rank
*
num_heads
+
head_idx
)
*
blocksparse_head_sliding_step
+
1
;
else
// sliding on kv heads
bs_block_offset
=
(
tp_rank
*
num_kv_heads
+
kv_head_idx
)
*
(
-
blocksparse_head_sliding_step
)
+
1
;
}
}
if
constexpr
(
IS_BLOCK_SPARSE
)
{
}
int
v_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
if
(
!
((
v_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
)
&&
!
((
v_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
)))
{
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
continue
;
+
kv_head_idx
*
kv_head_stride
;
}
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
logits
+
(
reuse_kv_idx
*
partition_size
)
+
token_idx
-
start_token_idx
));
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// from_float(*(logits_vec_ptr+i), 1000);
// }
if
(
reuse_kv_idx
==
0
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
)
{
const
int
offset
=
row_idx
*
BLOCK_SIZE
+
physical_block_offset
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
}
else
{
V_quant_vec
v_quant_vec
=
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
// Vector conversion from V_quant_vec to V_vec.
v_vec
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
kv_scale
);
}
}
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
if
(
block_idx
==
num_seq_blocks
-
1
)
{
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
// NOTE(woosuk): When v_vec contains the tokens that are out of the
logits
+
(
reuse_kv_idx
*
partition_size
)
+
// context, we should explicitly zero out the values since they may
token_idx
-
start_token_idx
));
// contain NaNs. See
// scalar_t* logits_vec_ptr =
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
// reinterpret_cast<scalar_t*>(&logits_vec); for(int i=0;i<8;++i){
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
// from_float(*(logits_vec_ptr+i), 1000);
// }
if
(
reuse_kv_idx
==
0
)
{
const
int
row_idx
=
start_row_idx
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
)
{
const
int
offset
=
row_idx
*
BLOCK_SIZE
+
physical_block_offset
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
v_vec
[
i
%
2
]
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
}
else
{
V_quant_vec
v_quant_vec
=
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
// Vector conversion from V_quant_vec to V_vec.
v_vec
[
i
%
2
]
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
kv_scale
);
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
// NOTE(woosuk): When v_vec contains the tokens that are out of
// the context, we should explicitly zero out the values since
// they may contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
[
i
%
2
]);
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
V_VEC_SIZE
;
j
++
)
{
for
(
int
j
=
0
;
j
<
V_VEC_SIZE
;
j
++
)
{
v_vec_ptr
[
j
]
=
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
}
}
// if(threadIdx.x==0){
// scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
// scalar_t* logits_vec_ptr =
// reinterpret_cast<scalar_t*>(&logits_vec); for(int
// i=0;i<8;++i){
// printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i]));
// // from_float(*(v_vec_ptr + i), 1000);
// }
// for(int i=0;i<8;++i){
// printf("logits_vec[%d] =
// %f\n",i,half_to_float(logits_vec_ptr[i]));
// // from_float(*(logits_vec_ptr + i), 1000);
// }
// }
// accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
}
}
accs
[
reuse_kv_idx
][
i
-
1
]
+=
dot
(
logits_vec
,
v_vec
[(
i
-
1
)
%
2
]);
}
}
// if(threadIdx.x==0){
// scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i]));
// // from_float(*(v_vec_ptr + i), 1000);
// }
// for(int i=0;i<8;++i){
// printf("logits_vec[%d] = %f\n",i,half_to_float(logits_vec_ptr[i]));
// // from_float(*(logits_vec_ptr + i), 1000);
// }
// }
// accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
}
accs
[
reuse_kv_idx
][
i
]
+=
dot
(
logits_vec
,
v_vec
);
}
}
}
// tail
{
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
const
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
int
bs_block_offset
;
int
q_bs_block_id
;
if
constexpr
(
IS_BLOCK_SPARSE
)
{
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id
=
(
seq_len
-
1
)
/
blocksparse_block_size
;
if
(
blocksparse_head_sliding_step
>=
0
)
// sliding on q heads
bs_block_offset
=
(
tp_rank
*
num_heads
+
head_idx
)
*
blocksparse_head_sliding_step
+
1
;
else
// sliding on kv heads
bs_block_offset
=
(
tp_rank
*
num_kv_heads
+
kv_head_idx
)
*
(
-
blocksparse_head_sliding_step
)
+
1
;
}
if
constexpr
(
IS_BLOCK_SPARSE
)
{
int
v_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
if
(
!
((
v_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
)
&&
!
((
v_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
)))
{
continue
;
}
}
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
logits
+
(
reuse_kv_idx
*
partition_size
)
+
token_idx
-
start_token_idx
));
accs
[
reuse_kv_idx
][
NUM_ROWS_PER_THREAD
-
1
]
+=
dot
(
logits_vec
,
v_vec
[(
NUM_ROWS_PER_THREAD
-
1
)
%
2
]);
}
}
}
}
}
}
}
// Perform reduction within each warp.
// Perform reduction within each warp.
#pragma unroll
#pragma unroll
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
float
acc
=
accs
[
reuse_kv_idx
][
i
];
float
acc
=
accs
[
reuse_kv_idx
][
i
];
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_V_VECS_PER_ROW
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
NUM_V_VECS_PER_ROW
/
2
;
mask
>=
1
;
mask
/=
2
)
{
acc
+=
VLLM_SHFL_XOR_SYNC
(
acc
,
mask
);
acc
+=
VLLM_SHFL_XOR_SYNC
(
acc
,
mask
);
}
}
accs
[
reuse_kv_idx
][
i
]
=
acc
;
accs
[
reuse_kv_idx
][
i
]
=
acc
;
}
}
// NOTE(woosuk): A barrier is required because the shared memory space for
// NOTE(woosuk): A barrier is required because the shared memory space for
// logits is reused for the output.
// logits is reused for the output.
__syncthreads
();
__syncthreads
();
// Perform reduction across warps.
// Perform reduction across warps.
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
#pragma unroll
#pragma unroll
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
int
mid
=
i
/
2
;
int
mid
=
i
/
2
;
// Upper warps write to shared memory.
// Upper warps write to shared memory.
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
float
*
dst
=
&
out_smem
[(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
float
*
dst
=
&
out_smem
[(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
(
warp_idx
-
mid
)
*
HEAD_SIZE
];
(
warp_idx
-
mid
)
*
HEAD_SIZE
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
dst
[
row_idx
]
=
accs
[
reuse_kv_idx
][
i
];
dst
[
row_idx
]
=
accs
[
reuse_kv_idx
][
i
];
}
}
}
}
__syncthreads
();
}
}
__syncthreads
();
// Lower warps update the output.
// Lower warps update the output.
if
(
warp_idx
<
mid
)
{
if
(
warp_idx
<
mid
)
{
const
float
*
src
=
const
float
*
src
=
&
out_smem
[(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
warp_idx
*
HEAD_SIZE
];
&
out_smem
[(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
warp_idx
*
HEAD_SIZE
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
accs
[
reuse_kv_idx
][
i
]
+=
src
[
row_idx
];
accs
[
reuse_kv_idx
][
i
]
+=
src
[
row_idx
];
}
}
}
}
__syncthreads
();
}
}
}
__syncthreads
();
}
// Write the final output.
// Write the final output.
if
(
warp_idx
==
0
)
{
if
(
warp_idx
==
0
)
{
scalar_t
*
out_ptr
=
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
partition_idx
*
HEAD_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
...
@@ -694,18 +583,21 @@ __device__ void paged_attention_kernel(
...
@@ -694,18 +583,21 @@ __device__ void paged_attention_kernel(
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
// Grid: (num_heads, num_seqs, 1).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
REUSE_KV_TIMES
=
1
,
bool
IS_BLOCK_SPARSE
,
bool
odd_nheads
=
false
>
int
REUSE_KV_TIMES
=
1
,
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v1_kernel
(
bool
IS_BLOCK_SPARSE
,
bool
odd_nheads
=
false
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v1_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
// head_size, block_size]
const
int
num_heads
,
// [num_heads]
const
int
num_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
...
@@ -716,22 +608,24 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v1_kernel(
...
@@ -716,22 +608,24 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v1_kernel(
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
>
(
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
kv_head_stride
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_head_sliding_step
);
}
}
// Grid: (num_heads, num_seqs, max_num_partitions).
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
,
int
PARTITION_SIZE
,
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
,
int
PARTITION_SIZE
,
bool
odd_nheads
=
false
>
bool
odd_nheads
=
false
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_kernel
(
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
...
@@ -742,7 +636,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel(
...
@@ -742,7 +636,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel(
// head_size/x, block_size, x]
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
// head_size, block_size]
const
int
num_heads
,
// [num_heads]
const
int
num_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_kv_heads]
const
int
num_kv_heads
,
// [num_kv_heads]
const
float
scale
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
...
@@ -753,19 +647,19 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel(
...
@@ -753,19 +647,19 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel(
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
,
PARTITION_SIZE
>
(
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_heads
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
kv_block_stride
,
kv_head_stride
,
kv_scale
,
tp_rank
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
kv_scale
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_head_sliding_step
);
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
}
// Grid: (num_heads, num_seqs).
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_reduce_kernel
(
int
PARTITION_SIZE
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_reduce_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
...
@@ -871,22 +765,21 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel(
...
@@ -871,22 +765,21 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel(
}
// namespace vllm
}
// namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel< \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
BLOCK_SIZE, NUM_THREADS, \
REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \
KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \
shared_mem_size); \
shared_mem_size); \
hipLaunchKernelGGL( \
hipLaunchKernelGGL(( vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
(vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>) \
NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, \
, dim3(grid), dim3(block), shared_mem_size, stream, \
IS_BLOCK_SPARSE, odd_nheads>), \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
dim3(grid), dim3(block), shared_mem_size, stream, out_ptr, query_ptr, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, scale, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
kv_scale, tp_rank, blocksparse_local_blocks, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, kv_scale, \
blocksparse_vert_stride, blocksparse_block_size, \
tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_head_sliding_step);
blocksparse_block_size, blocksparse_head_sliding_step);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
...
@@ -918,8 +811,8 @@ void paged_attention_v1_launcher(
...
@@ -918,8 +811,8 @@ void paged_attention_v1_launcher(
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
int
num_threads
=
128
;
int
num_threads
=
128
;
if
(
num_heads
!=
num_kv_heads
)
{
if
(
num_heads
!=
num_kv_heads
){
num_threads
=
256
;
num_threads
=
256
;
}
}
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
assert
(
head_size
%
thread_group_size
==
0
);
assert
(
head_size
%
thread_group_size
==
0
);
...
@@ -937,42 +830,31 @@ void paged_attention_v1_launcher(
...
@@ -937,42 +830,31 @@ void paged_attention_v1_launcher(
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
int
padded_max_seq_len
=
int
padded_max_seq_len
=
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
REUSEKV_SWITCH_V1
(
num_heads
*
num_seqs
,
[
&
]
{
REUSEKV_SWITCH_V1
(
num_heads
*
num_seqs
,
[
&
]
{
BOOL_SWITCH
((
num_heads
/
num_kv_heads
%
REUSE_KV_TIMES
!=
0
),
odd_nheads
,
[
&
]
{
BOOL_SWITCH
(
HEADSIZE_SWITCH
(
head_size
,
[
&
]
{
(
num_heads
/
num_kv_heads
%
REUSE_KV_TIMES
!=
0
),
odd_nheads
,
[
&
]
{
NUM_THREADS_SWITCH
(
num_threads
,
[
&
]
{
HEADSIZE_SWITCH
(
head_size
,
[
&
]
{
OPT_SWITCH
(
num_heads
==
num_kv_heads
,
[
&
]
{
NUM_THREADS_SWITCH
(
num_threads
,
[
&
]
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
OPT_SWITCH
(
num_heads
==
num_kv_heads
,
[
&
]
{
int
logits_size
=
REUSE_KV_TIMES
*
padded_max_seq_len
*
sizeof
(
float
);
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
outputs_size
=
REUSE_KV_TIMES
*
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
int
logits_size
=
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
REUSE_KV_TIMES
*
padded_max_seq_len
*
sizeof
(
float
);
// Keep that in sync with the logic here!
int
outputs_size
=
REUSE_KV_TIMES
*
(
NUM_WARPS
/
2
)
*
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
head_size
*
sizeof
(
float
);
if
(
num_heads
==
num_kv_heads
)
shared_mem_size
=
::
max
(
12
*
1024
,
shared_mem_size
);
// Python-side check in
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// vllm.worker.worker._check_if_can_support_max_seq_len Keep
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
// that in sync with the logic here!
dim3
grid
((
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
,
1
,
num_seqs
);
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
dim3
block
(
NUM_THREADS
);
if
(
num_heads
==
num_kv_heads
)
const
at
::
hip
::
OptionalHIPGuardMasqueradingAsCUDA
device_guard
(
device_of
(
query
));
shared_mem_size
=
::
max
(
12
*
1024
,
shared_mem_size
);
const
hipStream_t
stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
// int shared_mem_size = ::max(31*1024, ::max(logits_size,
LAUNCH_PAGED_ATTENTION_V1
(
HEAD_SIZE
);
// outputs_size)); std::cout<<"shared_mem_size =
// "<<shared_mem_size<<std::endl;
dim3
grid
((
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
,
1
,
num_seqs
);
dim3
block
(
NUM_THREADS
);
const
at
::
hip
::
OptionalHIPGuardMasqueradingAsCUDA
device_guard
(
device_of
(
query
));
const
hipStream_t
stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
LAUNCH_PAGED_ATTENTION_V1
(
HEAD_SIZE
);
});
});
});
});
});
});
});
});
});
});
}
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
...
@@ -1040,23 +922,21 @@ void paged_attention_v1(
...
@@ -1040,23 +922,21 @@ void paged_attention_v1(
}
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
hipLaunchKernelGGL( \
hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
(vllm::paged_attention_v2_kernel< \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \
IS_BLOCK_SPARSE, REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>), \
, dim3(grid), dim3(block), shared_mem_size, stream, \
dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_block_size, blocksparse_head_sliding_step); \
blocksparse_head_sliding_step); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
hipLaunchKernelGGL( \
PARTITION_SIZE>) \
(vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
PARTITION_SIZE>), \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, out_ptr, \
max_num_partitions);
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
...
@@ -1099,39 +979,32 @@ void paged_attention_v2_launcher(
...
@@ -1099,39 +979,32 @@ void paged_attention_v2_launcher(
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
REUSEKV_SWITCH
(
num_heads
*
max_num_partitions
*
num_seqs
,
[
&
]
{
REUSEKV_SWITCH
(
num_heads
*
max_num_partitions
*
num_seqs
,
[
&
]
{
BOOL_SWITCH
(
BOOL_SWITCH
((
num_heads
/
num_kv_heads
%
REUSE_KV_TIMES
!=
0
),
odd_nheads
,
[
&
]
{
(
num_heads
/
num_kv_heads
%
REUSE_KV_TIMES
!=
0
),
odd_nheads
,
[
&
]
{
HEADSIZE_SWITCH
(
head_size
,
[
&
]
{
HEADSIZE_SWITCH
(
head_size
,
[
&
]
{
OPT_SWITCH
(
num_heads
==
num_kv_heads
,
[
&
]
{
OPT_SWITCH
(
num_heads
==
num_kv_heads
,
[
&
]
{
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
sizeof
(
float
);
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
sizeof
(
float
);
int
outputs_size
=
REUSE_KV_TIMES
*
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
int
outputs_size
=
REUSE_KV_TIMES
*
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
// For paged attention v2 kernel.
// dim3 grid(num_heads, max_num_partitions, num_seqs);
// For paged attention v2 kernel.
// dim3 grid(num_heads, max_num_partitions, num_seqs);
dim3
grid
;
grid
.
x
=
(
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
;
dim3
grid
;
grid
.
y
=
max_num_partitions
;
grid
.
x
=
(
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
grid
.
z
=
num_seqs
;
REUSE_KV_TIMES
*
num_kv_heads
;
// int shared_mem_size = ::max(1024*32, ::max(logits_size, outputs_size));
grid
.
y
=
max_num_partitions
;
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
grid
.
z
=
num_seqs
;
// For paged attention v2 reduce kernel.
// int shared_mem_size = ::max(1024*32, ::max(logits_size,
dim3
reduce_grid
(
num_heads
,
num_seqs
);
// outputs_size));
int
reduce_shared_mem_size
=
2
*
max_num_partitions
*
sizeof
(
float
);
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
dim3
block
(
NUM_THREADS
);
// For paged attention v2 reduce kernel.
const
at
::
hip
::
OptionalHIPGuardMasqueradingAsCUDA
device_guard
(
device_of
(
query
));
dim3
reduce_grid
(
num_heads
,
num_seqs
);
const
hipStream_t
stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
int
reduce_shared_mem_size
=
LAUNCH_PAGED_ATTENTION_V2
(
HEAD_SIZE
);
2
*
max_num_partitions
*
sizeof
(
float
);
dim3
block
(
NUM_THREADS
);
const
at
::
hip
::
OptionalHIPGuardMasqueradingAsCUDA
device_guard
(
device_of
(
query
));
const
hipStream_t
stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
LAUNCH_PAGED_ATTENTION_V2
(
HEAD_SIZE
);
});
});
});
});
});
});
});
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment