Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fbde1e5a
Commit
fbde1e5a
authored
Mar 13, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'origin/0.7.2-zhangshao' into v0.7.2-pa
parents
146eb9d3
228a714a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
562 additions
and
891 deletions
+562
-891
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+401
-589
csrc/attention/attention_with_mask_kernels_opt_tc.cu
csrc/attention/attention_with_mask_kernels_opt_tc.cu
+7
-97
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+154
-205
No files found.
csrc/attention/attention_kernels_opt_tc.cu
View file @
fbde1e5a
...
@@ -45,8 +45,9 @@ static inline int get_env_(const char *env_var) {
...
@@ -45,8 +45,9 @@ static inline int get_env_(const char *env_var) {
}
}
return
0
;
return
0
;
}
}
static
const
int
PA_USE_V1
=
get_env_
(
"PA_USE_V1"
);
static
const
int
PA_REUSE_KV_TIMES
=
get_env_
(
"PA_REUSE_KV_TIMES"
);
static
const
int
PA_REUSE_KV_TIMES
=
get_env_
(
"PA_REUSE_KV_TIMES"
);
static
const
int
PA_PARTITION_SIZE
=
get_env_
(
"PA_PARTITION_SIZE"
);
static
const
int
PA_BLOCK_SIZE
=
get_env_
(
"PA_BLOCK_SIZE"
);
static
const
int
PA_BLOCK_SIZE
=
get_env_
(
"PA_BLOCK_SIZE"
);
static
const
int
PA_PRINT_PARAM
=
get_env_
(
"PA_PRINT_PARAM"
);
static
const
int
PA_PRINT_PARAM
=
get_env_
(
"PA_PRINT_PARAM"
);
namespace
vllm
{
namespace
vllm
{
...
@@ -90,10 +91,16 @@ inline __device__ float block_sum(float* red_smem, float sum) {
...
@@ -90,10 +91,16 @@ inline __device__ float block_sum(float* red_smem, float sum) {
using
half4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
_Float16
))
))
_Float16
;
using
half4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
_Float16
))
))
_Float16
;
using
v4bh
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
short
))
))
short
;
using
v4bh
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
short
))
))
short
;
using
float4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
float
))
))
float
;
using
float4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
float
))
))
float
;
using
float2_t
=
__attribute__
(
(
__vector_size__
(
2
*
sizeof
(
float
))
))
float
;
struct
half4x2
{
struct
half4x2
{
half4_t
data
[
2
];
half4_t
data
[
2
];
};
};
template
<
typename
scalar_t
>
struct
vec2data
{
scalar_t
data
[
2
];
};
template
<
bool
is_half
>
template
<
bool
is_half
>
inline
__device__
void
float4_2_half4
(
half4_t
&
dst
,
const
float4_t
&
src
)
inline
__device__
void
float4_2_half4
(
half4_t
&
dst
,
const
float4_t
&
src
)
{
{
...
@@ -126,15 +133,12 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4
...
@@ -126,15 +133,12 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4
}
}
}
}
template
<
bool
is_half
,
bool
use_vmac
>
template
<
bool
is_half
>
inline
__device__
void
builtin_amdgcn_mmac
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
&
reg_c
)
inline
__device__
void
builtin_amdgcn_mmac
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
&
reg_c
)
{
{
if
constexpr
(
use_vmac
){
v
_mmac_f32_16x16x16
_
f16
<
is_half
>
(
reg_a
,
reg_b
,
reg_c
);}
if
constexpr
(
is_half
){
reg_c
=
__builtin_amdgcn
_mmac_f32_16x16x16f16
(
reg_a
,
reg_b
,
reg_c
);}
else
{
else
{
if
constexpr
(
is_half
){
reg_c
=
__builtin_amdgcn_mmac_f32_16x16x16f16
(
reg_a
,
reg_b
,
reg_c
);}
reg_c
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
*
(
v4bh
*
)
&
reg_a
,
*
(
v4bh
*
)
&
reg_b
,
reg_c
);
else
{
reg_c
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
*
(
v4bh
*
)
&
reg_a
,
*
(
v4bh
*
)
&
reg_b
,
reg_c
);
}
}
}
}
}
...
@@ -142,13 +146,12 @@ inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t&
...
@@ -142,13 +146,12 @@ inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t&
// Grid: (num_heads, num_seqs, max_num_partitions).
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
,
bool
use_vmac
,
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
>
// Zero means no partitioning.
__
device
__
void
paged_attention_kernel_TC
(
__
global
__
void
paged_attention_kernel_TC
_with_mask
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
// max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads,head_size]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
scalar_t
*
__restrict__
out_tmp
,
// [num_seqs, num_heads, max_num_partitions,head_size]
// head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
// head_size/x, block_size, x]
...
@@ -164,33 +167,27 @@ __device__ void paged_attention_kernel_TC(
...
@@ -164,33 +167,27 @@ __device__ void paged_attention_kernel_TC(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
int
PARTITION_SIZE
=
0
)
{
#if defined(__gfx936__) || defined(__gfx928__)
const
int
seq_idx
=
blockIdx
.
z
;
const
int
seq_idx
=
blockIdx
.
z
;
const
int
partition_idx
=
blockIdx
.
x
;
const
int
partition_idx
=
blockIdx
.
x
;
const
int
max_num_partitions
=
gridDim
.
x
;
const
int
max_num_partitions
=
gridDim
.
x
;
constexpr
bool
USE_PARTITIONING
=
PARTITION_SIZE
>
0
;
const
int
seq_len
=
__builtin_amdgcn_readfirstlane
(
seq_lens
[
seq_idx
]);
const
int
seq_len
=
__builtin_amdgcn_readfirstlane
(
seq_lens
[
seq_idx
]);
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
{
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
// No work to do. Terminate the thread block.
const
bool
USE_PARTITIONING
=
PARTITION_SIZE
<
num_seq_blocks
*
BLOCK_SIZE
&&
PARTITION_SIZE
>
0
;
return
;
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
return
;
}
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
partition_size
=
USE_PARTITIONING
?
PARTITION_SIZE
:
num_seq_blocks
*
BLOCK_SIZE
;
const
int
partition_size
=
USE_PARTITIONING
?
PARTITION_SIZE
:
num_seq_blocks
*
BLOCK_SIZE
;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const
int
start_block_idx
=
partition_idx
*
num_blocks_per_partition
;
const
int
start_block_idx
=
partition_idx
*
num_blocks_per_partition
;
//0,64,128…
const
int
end_block_idx
=
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_seq_blocks
);
const
int
end_block_idx
=
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_seq_blocks
);
//64,128,192…
const
int
num_blocks
=
end_block_idx
-
start_block_idx
;
const
int
num_blocks
=
end_block_idx
-
start_block_idx
;
//64 or 1-63
const
int
start_token_idx
=
start_block_idx
*
BLOCK_SIZE
;
const
int
end_token_idx
=
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
seq_len
);
// [start_token_idx, end_token_idx) is the range of tokens to process.
const
int
num_tokens
=
end_token_idx
-
start_token_idx
;
const
int
start_token_idx
=
start_block_idx
*
BLOCK_SIZE
;
//0,1024,2048…
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
end_token_idx
=
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
seq_len
);
//1024,2048,3072…
constexpr
int
x
=
16
/
sizeof
(
cache_t
);
const
int
num_tokens
=
end_token_idx
-
start_token_idx
;
//1024 or 1-1023
// divides NUM_THREADS
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
//4
constexpr
int
x
=
16
/
sizeof
(
cache_t
);
//8
const
int
thread_idx
=
threadIdx
.
x
;
const
int
thread_idx
=
threadIdx
.
x
;
const
int
warp_idx
=
__builtin_amdgcn_readfirstlane
(
thread_idx
/
WARP_SIZE
);
const
int
warp_idx
=
__builtin_amdgcn_readfirstlane
(
thread_idx
/
WARP_SIZE
);
const
int
lane
=
thread_idx
%
WARP_SIZE
;
const
int
lane
=
thread_idx
%
WARP_SIZE
;
...
@@ -225,51 +222,42 @@ __device__ void paged_attention_kernel_TC(
...
@@ -225,51 +222,42 @@ __device__ void paged_attention_kernel_TC(
q_vec
.
data
[
1
]
=
{
0
,
0
,
0
,
0
};
q_vec
.
data
[
1
]
=
{
0
,
0
,
0
,
0
};
__shared__
half4x2
q_vecs
[
REUSE_KV_TIMES
][
16
];
__shared__
half4x2
q_vecs
[
REUSE_KV_TIMES
][
16
];
//if(thread_idx==0)printf("blockIdx.y==%d,q_boundary=%d,head_idx=%d,kv_head_idx=%d\n",blockIdx.y,q_boundary,head_idx,kv_head_idx);
for
(
int
i
=
0
;
i
<
q_boundary
;
i
++
){
for
(
int
i
=
0
;
i
<
q_boundary
;
i
++
){
if
(
thread_idx
<
16
){
if
(
thread_idx
<
16
){
q_vecs
[
i
][
thread_idx
]
=*
reinterpret_cast
<
const
half4x2
*>
(
q_ptr
+
i
*
HEAD_SIZE
+
thread_idx
*
8
);
half4x2
temp
=
*
reinterpret_cast
<
const
half4x2
*>
(
q_ptr
+
i
*
HEAD_SIZE
+
thread_idx
*
8
);
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
){
temp
.
data
[
0
][
k
]
=
((
float
)
temp
.
data
[
0
][
k
])
*
scale
;
temp
.
data
[
1
][
k
]
=
((
float
)
temp
.
data
[
1
][
k
])
*
scale
;
}
q_vecs
[
i
][
thread_idx
]
=
temp
;
}
}
}
}
__syncthreads
();
__syncthreads
();
// Memory planning.
extern
__shared__
char
shared_mem
[];
extern
__shared__
char
shared_mem
[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
scalar_t
*
logits
=
reinterpret_cast
<
scalar_t
*>
(
shared_mem
);
scalar_t
*
logits
=
reinterpret_cast
<
scalar_t
*>
(
shared_mem
);
// Workspace for reduction.
// __shared__ float red_smem[2 * NUM_WARPS];
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
__shared__
float
s_max
[
REUSE_KV_TIMES
][
NUM_WARPS
];
__shared__
float
s_logit
[
NUM_WARPS
];
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
// blocksparse specific vars
int
bs_block_offset
;
int
q_bs_block_id
;
const
cache_t
*
k_ptr_base
=
k_cache
+
kv_head_idx
*
kv_head_stride
+
lane
*
8
;
const
cache_t
*
k_ptr_base
=
k_cache
+
kv_head_idx
*
kv_head_stride
+
lane
*
8
;
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
cache_t
*
k_ptr
=
k_ptr_base
+
physical_block_number
*
kv_block_stride
;
const
cache_t
*
k_ptr
=
k_ptr_base
+
physical_block_number
*
kv_block_stride
;
float4_t
qk_vec
=
{
0
,
0
,
0
,
0
};
float4_t
qk_vec
=
{
0
,
0
,
0
,
0
};
half4x2
k_vec
[
2
];
half4x2
k_vec
[
2
];
k_vec
[
0
]
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
);
k_vec
[
0
]
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
3
;
i
++
){
for
(
int
i
=
0
;
i
<
3
;
i
++
){
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
i
*
4
+
rows
];
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
i
*
4
+
rows
];
k_vec
[
1
-
i
%
2
]
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
+
(
i
+
1
)
*
512
);
k_vec
[
1
-
i
%
2
]
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
+
(
i
+
1
)
*
512
);
builtin_amdgcn_mmac
<
is_half
,
use_vmac
>
(
k_vec
[
i
%
2
].
data
[
0
],
q_vec
.
data
[
0
],
qk_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
i
%
2
].
data
[
0
],
q_vec
.
data
[
0
],
qk_vec
);
builtin_amdgcn_mmac
<
is_half
,
use_vmac
>
(
k_vec
[
i
%
2
].
data
[
1
],
q_vec
.
data
[
1
],
qk_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
i
%
2
].
data
[
1
],
q_vec
.
data
[
1
],
qk_vec
);
}
}
//tail
//tail
{
{
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
3
*
4
+
rows
];
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
3
*
4
+
rows
];
builtin_amdgcn_mmac
<
is_half
,
use_vmac
>
(
k_vec
[
1
].
data
[
0
],
q_vec
.
data
[
0
],
qk_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
1
].
data
[
0
],
q_vec
.
data
[
0
],
qk_vec
);
v_mmac_f32_16x16x16_f16
<
is_half
>
(
k_vec
[
1
].
data
[
1
],
q_vec
.
data
[
1
],
qk_vec
);
v_mmac_f32_16x16x16_f16
<
is_half
>
(
k_vec
[
1
].
data
[
1
],
q_vec
.
data
[
1
],
qk_vec
);
}
}
#pragma unroll
#pragma unroll
...
@@ -277,7 +265,6 @@ __device__ void paged_attention_kernel_TC(
...
@@ -277,7 +265,6 @@ __device__ void paged_attention_kernel_TC(
int
reuse_kv_idx
=
rows
+
i
*
4
;
int
reuse_kv_idx
=
rows
+
i
*
4
;
if
(
reuse_kv_idx
<
REUSE_KV_TIMES
){
if
(
reuse_kv_idx
<
REUSE_KV_TIMES
){
if
(
reuse_kv_idx
>=
q_boundary
)
qk_vec
[
i
]
=
0
;
if
(
reuse_kv_idx
>=
q_boundary
)
qk_vec
[
i
]
=
0
;
else
qk_vec
[
i
]
*=
scale
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rowid
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rowid
;
if
(
alibi_slope
[
i
]
!=
0
){
if
(
alibi_slope
[
i
]
!=
0
){
float
alibi
=
alibi_slope
[
i
]
*
(
token_idx
-
seq_len
+
1
);
float
alibi
=
alibi_slope
[
i
]
*
(
token_idx
-
seq_len
+
1
);
...
@@ -295,56 +282,49 @@ __device__ void paged_attention_kernel_TC(
...
@@ -295,56 +282,49 @@ __device__ void paged_attention_kernel_TC(
}
}
}
}
}
}
// if(blockIdx.y==0)printf("%d,qkmax=%f\n",threadIdx.x,qk_max[0]);
// compute max
// Perform reduction across the threads in the same warp to get the
#pragma unroll
// max qk value for each "warp" (not across the thread block yet).
for
(
int
mask
=
8
;
mask
>=
1
;
mask
/=
2
)
{
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for
(
int
r
=
0
;
r
<
reuse_group
;
r
++
){
qk_max
[
r
]
=
fmaxf
(
qk_max
[
r
],
__shfl_xor
(
qk_max
[
r
],
mask
));
}
}
#pragma unroll
for
(
int
r
=
0
;
r
<
reuse_group
;
r
++
){
if
(
rowid
==
0
&&
r
*
4
+
rows
<
q_boundary
){
s_max
[
r
*
4
+
rows
][
warp_idx
]
=
qk_max
[
r
];
}
}
__syncthreads
();
__shared__
float
max_out
[
REUSE_KV_TIMES
];
__shared__
float
expsum_out
[
REUSE_KV_TIMES
];
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
q_boundary
;
reuse_kv_idx
++
)
{
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
q_boundary
;
reuse_kv_idx
++
)
{
const
int
head_idx_
=
head_idx
+
reuse_kv_idx
;
const
int
head_idx_
=
head_idx
+
reuse_kv_idx
;
float
qk_max_tmp
=
qk_max
[
reuse_kv_idx
/
4
]
;
float
qk_max_tmp
=
lane
<
NUM_WARPS
?
s_max
[
reuse_kv_idx
][
lane
]
:
-
FLT_MAX
;
float
exp_sum
=
0.
f
;
float
exp_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
mask
=
8
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max_tmp
=
fmaxf
(
qk_max_tmp
,
VLLM_SHFL_XOR_SYNC
(
qk_max_tmp
,
mask
));
qk_max_tmp
=
fmaxf
(
qk_max_tmp
,
__shfl_xor
(
qk_max_tmp
,
mask
));
}
}
if
(
rowid
==
0
&&
reuse_kv_idx
%
4
==
rows
)
{
qk_max_tmp
=
__shfl
(
qk_max_tmp
,
0
);
red_smem
[
warp_idx
]
=
qk_max_tmp
;
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
float
val
=
__expf
(
to_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
])
-
qk_max_tmp
);
from_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
,
val
);
exp_sum
+=
val
;
}
exp_sum
=
block_sum
<
NUM_WARPS
>
(
s_logit
,
exp_sum
);
// Compute softmax.
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
from_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
,
to_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
])
*
inv_sum
);
}
if
(
USE_PARTITIONING
&&
thread_idx
==
0
){
max_out
[
reuse_kv_idx
]
=
qk_max_tmp
;
expsum_out
[
reuse_kv_idx
]
=
exp_sum
;
}
}
__syncthreads
();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max_tmp
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max_tmp
=
fmaxf
(
qk_max_tmp
,
VLLM_SHFL_XOR_SYNC
(
qk_max_tmp
,
mask
));
}
// Broadcast the max qk value to all threads.
qk_max_tmp
=
VLLM_SHFL_SYNC
(
qk_max_tmp
,
0
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
float
val
=
__expf
(
to_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
])
-
qk_max_tmp
);
from_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
,
val
);
exp_sum
+=
val
;
}
exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
exp_sum
);
// Compute softmax.
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
from_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
,
to_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
])
*
inv_sum
);
}
__syncthreads
();
// If partitioning is enabled, store the max logit and exp_sum.
if
(
USE_PARTITIONING
&&
thread_idx
==
0
)
{
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx_
*
max_num_partitions
+
partition_idx
;
*
max_logits_ptr
=
qk_max_tmp
;
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx_
*
max_num_partitions
+
partition_idx
;
*
exp_sums_ptr
=
exp_sum
;
}
}
}
__syncthreads
();
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
WARP_SIZE
);
//2
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
WARP_SIZE
);
//2
if
constexpr
(
REUSE_KV_TIMES
<=
2
){
if
constexpr
(
REUSE_KV_TIMES
<=
2
){
float
accs
[
REUSE_KV_TIMES
][
NUM_ROWS_PER_THREAD
];
float
accs
[
REUSE_KV_TIMES
][
NUM_ROWS_PER_THREAD
];
...
@@ -355,7 +335,6 @@ __device__ void paged_attention_kernel_TC(
...
@@ -355,7 +335,6 @@ __device__ void paged_attention_kernel_TC(
{
{
accs
[
k
][
i
]
=
0.
f
;
accs
[
k
][
i
]
=
0.
f
;
}
}
}
}
scalar_t
zero_value
;
scalar_t
zero_value
;
zero
(
zero_value
);
zero
(
zero_value
);
...
@@ -384,7 +363,7 @@ __device__ void paged_attention_kernel_TC(
...
@@ -384,7 +363,7 @@ __device__ void paged_attention_kernel_TC(
}
}
}
}
float4_t
out_vec
=
{
0
,
0
,
0
,
0
};
float4_t
out_vec
=
{
0
,
0
,
0
,
0
};
builtin_amdgcn_mmac
<
is_half
,
use_vmac
>
(
v_vec
,
logits_vec
,
out_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
v_vec
,
logits_vec
,
out_vec
);
if
(
rows
==
k
){
if
(
rows
==
k
){
for
(
int
resuseid
=
0
;
resuseid
<
REUSE_KV_TIMES
;
resuseid
++
){
for
(
int
resuseid
=
0
;
resuseid
<
REUSE_KV_TIMES
;
resuseid
++
){
accs
[
resuseid
][
i
]
+=
out_vec
[
resuseid
];
accs
[
resuseid
][
i
]
+=
out_vec
[
resuseid
];
...
@@ -396,7 +375,6 @@ __device__ void paged_attention_kernel_TC(
...
@@ -396,7 +375,6 @@ __device__ void paged_attention_kernel_TC(
__syncthreads
();
__syncthreads
();
using
floatV_t
=
__attribute__
(
(
__vector_size__
(
NUM_ROWS_PER_THREAD
*
sizeof
(
float
))
))
float
;
using
floatV_t
=
__attribute__
(
(
__vector_size__
(
NUM_ROWS_PER_THREAD
*
sizeof
(
float
))
))
float
;
// Perform reduction across warps.
// Perform reduction across warps.
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
q_boundary
;
reuse_kv_idx
++
)
{
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
q_boundary
;
reuse_kv_idx
++
)
{
if
constexpr
(
NUM_THREADS
>
64
){
if
constexpr
(
NUM_THREADS
>
64
){
floatV_t
*
out_smem
=
reinterpret_cast
<
floatV_t
*>
(
shared_mem
);
floatV_t
*
out_smem
=
reinterpret_cast
<
floatV_t
*>
(
shared_mem
);
...
@@ -419,12 +397,17 @@ __device__ void paged_attention_kernel_TC(
...
@@ -419,12 +397,17 @@ __device__ void paged_attention_kernel_TC(
__syncthreads
();
__syncthreads
();
}
}
}
}
// Write the final output.
if
(
warp_idx
==
0
)
{
if
(
warp_idx
==
0
)
{
scalar_t
*
out_ptr
=
scalar_t
*
out_ptr
;
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
int
out_offset
;
(
head_idx
+
reuse_kv_idx
)
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
if
(
USE_PARTITIONING
){
#pragma unroll
out_offset
=
max_num_partitions
*
HEAD_SIZE
;
out_ptr
=
out_tmp
+
seq_idx
*
num_heads
*
out_offset
+
head_idx
*
out_offset
+
partition_idx
*
HEAD_SIZE
;
}
else
{
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
+
i
*
WARP_SIZE
;
const
int
row_idx
=
lane
+
i
*
WARP_SIZE
;
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
reuse_kv_idx
][
i
]);
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
reuse_kv_idx
][
i
]);
...
@@ -472,7 +455,7 @@ __device__ void paged_attention_kernel_TC(
...
@@ -472,7 +455,7 @@ __device__ void paged_attention_kernel_TC(
}
}
}
}
float4_t
out_vec
=
{
0
,
0
,
0
,
0
};
float4_t
out_vec
=
{
0
,
0
,
0
,
0
};
builtin_amdgcn_mmac
<
is_half
,
use_vmac
>
(
v_vec
,
logits_vec
,
out_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
v_vec
,
logits_vec
,
out_vec
);
for
(
int
g
=
0
;
g
<
reuse_group
;
g
++
){
for
(
int
g
=
0
;
g
<
reuse_group
;
g
++
){
accs
[
g
*
4
+
k
][
i
]
+=
out_vec
[
g
];
accs
[
g
*
4
+
k
][
i
]
+=
out_vec
[
g
];
}
}
...
@@ -508,12 +491,20 @@ __device__ void paged_attention_kernel_TC(
...
@@ -508,12 +491,20 @@ __device__ void paged_attention_kernel_TC(
}
}
}
}
if
(
warp_idx
==
0
)
{
if
(
warp_idx
==
0
)
{
scalar_t
*
out_ptr_base
;
int
out_offset
;
if
(
USE_PARTITIONING
){
out_offset
=
max_num_partitions
*
HEAD_SIZE
;
out_ptr_base
=
out_tmp
+
seq_idx
*
num_heads
*
out_offset
+
head_idx
*
out_offset
+
partition_idx
*
HEAD_SIZE
;
}
else
{
out_offset
=
HEAD_SIZE
;
out_ptr_base
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
}
for
(
int
g
=
0
;
g
<
reuse_group
;
g
++
){
for
(
int
g
=
0
;
g
<
reuse_group
;
g
++
){
int
reusekvid
=
g
*
4
+
rows
;
int
reusekvid
=
g
*
4
+
rows
;
if
(
reusekvid
<
q_boundary
){
if
(
reusekvid
<
q_boundary
){
scalar_t
*
out_ptr
=
scalar_t
*
out_ptr
=
out_ptr_base
+
reusekvid
*
out_offset
;
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
(
head_idx
+
reusekvid
)
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
k
=
0
;
k
<
4
;
k
++
){
for
(
int
k
=
0
;
k
<
4
;
k
++
){
...
@@ -564,7 +555,7 @@ __device__ void paged_attention_kernel_TC(
...
@@ -564,7 +555,7 @@ __device__ void paged_attention_kernel_TC(
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
}
}
}
}
builtin_amdgcn_mmac
<
is_half
,
use_vmac
>
(
v_vec
,
logits_vec
,
accs
[
k
][
i
]);
builtin_amdgcn_mmac
<
is_half
>
(
v_vec
,
logits_vec
,
accs
[
k
][
i
]);
}
}
}
}
}
}
...
@@ -600,12 +591,20 @@ __device__ void paged_attention_kernel_TC(
...
@@ -600,12 +591,20 @@ __device__ void paged_attention_kernel_TC(
}
}
}
}
if
(
warp_idx
==
0
)
{
if
(
warp_idx
==
0
)
{
scalar_t
*
out_ptr_base
;
int
out_offset
;
if
(
USE_PARTITIONING
){
out_offset
=
max_num_partitions
*
HEAD_SIZE
;
out_ptr_base
=
out_tmp
+
seq_idx
*
num_heads
*
out_offset
+
head_idx
*
out_offset
+
partition_idx
*
HEAD_SIZE
;
}
else
{
out_offset
=
HEAD_SIZE
;
out_ptr_base
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
}
for
(
int
g
=
0
;
g
<
reuse_group
;
g
++
){
for
(
int
g
=
0
;
g
<
reuse_group
;
g
++
){
int
reusekvid
=
g
*
4
+
rows
;
int
reusekvid
=
g
*
4
+
rows
;
if
(
reusekvid
<
q_boundary
){
if
(
reusekvid
<
q_boundary
){
scalar_t
*
out_ptr
=
scalar_t
*
out_ptr
=
out_ptr_base
+
reusekvid
*
out_offset
;
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
(
head_idx
+
reusekvid
)
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
k
=
0
;
k
<
4
;
k
++
){
for
(
int
k
=
0
;
k
<
4
;
k
++
){
...
@@ -618,84 +617,19 @@ __device__ void paged_attention_kernel_TC(
...
@@ -618,84 +617,19 @@ __device__ void paged_attention_kernel_TC(
}
}
}
}
#endif
#endif
}
if
(
USE_PARTITIONING
&&
thread_idx
<
q_boundary
){
int
offset
=
seq_idx
*
num_heads
*
max_num_partitions
+
(
head_idx
+
thread_idx
)
*
max_num_partitions
+
partition_idx
;
*
(
max_logits
+
offset
)
=
max_out
[
thread_idx
];
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
*
(
exp_sums
+
offset
)
=
expsum_out
[
thread_idx
];
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
}
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
,
bool
use_vmac
>
#endif
__global__
void
paged_attention_v1_kernel_TC
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_heads
,
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
#if defined(__gfx936__) || defined(__gfx928__)
paged_attention_kernel_TC
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
use_vmac
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
#endif
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
,
bool
use_vmac
,
int
PARTITION_SIZE
,
bool
odd_nheads
=
false
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_kernel_TC
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
// max_num_partitions, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_kv_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
#if defined(__gfx936__) || defined(__gfx928__)
paged_attention_kernel_TC
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
use_vmac
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
#endif
}
}
// Grid: (num_heads, num_seqs).
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_reduce_kernel_opt_tc
(
__global__
__launch_bounds__
(
NUM_THREADS
,
1
)
void
paged_attention_v2_reduce_kernel_opt_tc
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
...
@@ -704,431 +638,249 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
...
@@ -704,431 +638,249 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
// max_num_partitions, head_size]
// max_num_partitions, head_size]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_partitions
)
{
const
int
max_num_partitions
,
int
PARTITION_SIZE
=
512
)
{
const
int
num_heads
=
gridDim
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
seq_len
=
seq_lens
[
seq_idx
];
const
int
seq_len
=
seq_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
seq_len
,
PARTITION_SIZE
);
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
seq_len
,
PARTITION_SIZE
);
if
(
num_partitions
==
1
)
{
if
(
num_partitions
==
1
)
return
;
// No need to reduce. Only copy tmp_out to out.
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
const
scalar_t
*
tmp_out_ptr
=
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
blockDim
.
x
)
{
out_ptr
[
i
]
=
tmp_out_ptr
[
i
];
}
// Terminate the thread block.
return
;
}
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
warp_idx
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
thread_idx
=
threadIdx
.
x
;
const
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warp_idx
=
__builtin_amdgcn_readfirstlane
(
thread_idx
/
WARP_SIZE
);
const
int
lane
=
thread_idx
%
WARP_SIZE
;
// Size: 2 * num_partitions.
extern
__shared__
char
shared_mem
[];
// Workspace for reduction.
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
// Load max logits to shared memory.
int
offset
=
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
float
*
shared_max_logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
const
float
*
max_logits_ptr
=
max_logits
+
offset
;
const
float
*
max_logits_ptr
=
max_logits
+
const
float
*
exp_sums_ptr
=
exp_sums
+
offset
;
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
float
max_logit
=
-
FLT_MAX
;
float
max_logit
=
-
FLT_MAX
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
float
global_max_logit
=
-
FLT_MAX
;
const
float
l
=
max_logits_ptr
[
i
];
shared_max_logits
[
i
]
=
l
;
max_logit
=
fmaxf
(
max_logit
,
l
);
}
__syncthreads
();
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
}
if
(
lane
==
0
)
{
red_smem
[
warp_idx
]
=
max_logit
;
}
__syncthreads
();
// Reduce across warps.
max_logit
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
}
// Broadcast the max value to all threads.
max_logit
=
VLLM_SHFL_SYNC
(
max_logit
,
0
);
// Load rescaled exp sums to shared memory.
float
*
shared_exp_sums
=
reinterpret_cast
<
float
*>
(
shared_mem
+
sizeof
(
float
)
*
num_partitions
);
const
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
float
global_exp_sum
=
0.0
f
;
float
global_exp_sum
=
0.0
f
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
if
constexpr
(
NUM_THREADS
==
64
&&
HEAD_SIZE
==
128
){
float
l
=
shared_max_logits
[
i
];
__shared__
float
shared_exp_sums
[
64
];
float
rescaled_exp_sum
=
exp_sums_ptr
[
i
]
*
expf
(
l
-
max_logit
);
if
(
thread_idx
<
num_partitions
){
global_exp_sum
+=
rescaled_exp_sum
;
max_logit
=
max_logits_ptr
[
thread_idx
];
shared_exp_sums
[
i
]
=
rescaled_exp_sum
;
global_exp_sum
=
exp_sums_ptr
[
thread_idx
];
global_max_logit
=
max_logit
;
}
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
global_max_logit
=
fmaxf
(
global_max_logit
,
VLLM_SHFL_XOR_SYNC
(
global_max_logit
,
mask
));
}
if
(
thread_idx
<
num_partitions
){
global_exp_sum
=
global_exp_sum
*
__expf
(
max_logit
-
global_max_logit
);
shared_exp_sums
[
thread_idx
]
=
global_exp_sum
;
}
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
global_exp_sum
+=
VLLM_SHFL_XOR_SYNC
(
global_exp_sum
,
mask
);
}
const
float
inv_global_exp_sum
=
__fdividef
(
1.0
f
,
global_exp_sum
+
1e-6
f
);
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
const
scalar_t
*
tmp_out_ptr
=
tmp_out
+
offset
*
HEAD_SIZE
;
using
half2_t
=
vec2data
<
scalar_t
>
;
float2_t
acc
=
{
0.0
f
,
0.0
f
};
half2_t
acc_half
;
for
(
int
j
=
0
;
j
<
num_partitions
;
++
j
)
{
half2_t
tout
=
*
(
half2_t
*
)(
tmp_out_ptr
+
j
*
HEAD_SIZE
+
thread_idx
*
2
);
float
temp_sum
=
shared_exp_sums
[
j
]
*
inv_global_exp_sum
;
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
){
acc
[
i
]
+=
to_float
(
tout
.
data
[
i
])
*
temp_sum
;
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
){
from_float
(
acc_half
.
data
[
i
],
acc
[
i
]);
}
*
(
half2_t
*
)(
out_ptr
+
thread_idx
*
2
)
=
acc_half
;
}
}
__syncthreads
();
else
{
global_exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
global_exp_sum
);
// Size: 2 * num_partitions.
const
float
inv_global_exp_sum
=
__fdividef
(
1.0
f
,
global_exp_sum
+
1e-6
f
);
extern
__shared__
char
shared_mem
[];
// Workspace for reduction.
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
// Load max logits to shared memory.
float
*
shared_max_logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
const
float
l
=
max_logits_ptr
[
i
];
shared_max_logits
[
i
]
=
l
;
max_logit
=
fmaxf
(
max_logit
,
l
);
}
__syncthreads
();
// Aggregate tmp_out to out.
// Get the global max logit.
const
scalar_t
*
tmp_out_ptr
=
// Reduce within the warp.
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
NUM_THREADS
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
float
acc
=
0.0
f
;
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
for
(
int
j
=
0
;
j
<
num_partitions
;
++
j
)
{
}
acc
+=
to_float
(
tmp_out_ptr
[
j
*
HEAD_SIZE
+
i
])
*
shared_exp_sums
[
j
]
*
if
(
lane
==
0
)
{
inv_global_exp_sum
;
red_smem
[
warp_idx
]
=
max_logit
;
}
__syncthreads
();
// Reduce across warps.
max_logit
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
}
// Broadcast the max value to all threads.
max_logit
=
VLLM_SHFL_SYNC
(
max_logit
,
0
);
// Load rescaled exp sums to shared memory.
float
*
shared_exp_sums
=
reinterpret_cast
<
float
*>
(
shared_mem
+
sizeof
(
float
)
*
num_partitions
);
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
float
l
=
shared_max_logits
[
i
];
float
rescaled_exp_sum
=
exp_sums_ptr
[
i
]
*
expf
(
l
-
max_logit
);
global_exp_sum
+=
rescaled_exp_sum
;
shared_exp_sums
[
i
]
=
rescaled_exp_sum
;
}
__syncthreads
();
global_exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
global_exp_sum
);
const
float
inv_global_exp_sum
=
__fdividef
(
1.0
f
,
global_exp_sum
+
1e-6
f
);
// Aggregate tmp_out to out.
const
scalar_t
*
tmp_out_ptr
=
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
NUM_THREADS
)
{
float
acc
=
0.0
f
;
for
(
int
j
=
0
;
j
<
num_partitions
;
++
j
)
{
acc
+=
to_float
(
tmp_out_ptr
[
j
*
HEAD_SIZE
+
i
])
*
shared_exp_sums
[
j
]
*
inv_global_exp_sum
;
}
from_float
(
out_ptr
[
i
],
acc
);
}
}
from_float
(
out_ptr
[
i
],
acc
);
}
}
}
}
}
// namespace vllm
}
// namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE) \
#define LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
hipLaunchKernelGGL( \
((void*)vllm::paged_attention_v1_kernel_TC<T, CACHE_T, HEAD_SIZE, \
(vllm::paged_attention_kernel_TC_with_mask< \
BLOCK_SIZE, NUM_THREADS, \
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
KV_DTYPE, IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac>), \
IS_BLOCK_SPARSE, REUSE_KV_TIMES>), \
shared_mem_size); \
dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \
vllm::paged_attention_v1_kernel_TC<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
max_logits_ptr,out_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr,\
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac> \
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
<<<grid, block, shared_mem_size, stream>>>( \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \
kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
blocksparse_vert_stride, blocksparse_block_size, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
blocksparse_head_sliding_step,PARTITION_SIZE);\
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
if (max_num_partitions<=64&&max_num_partitions>1){ \
blocksparse_vert_stride, blocksparse_block_size, \
hipLaunchKernelGGL( \
blocksparse_head_sliding_step);
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, 64>), \
dim3(reduce_grid), dim3(64), 0, stream, out_ptr, \
void
get_numberthread_and_reuse_kv_v1
(
int
&
num_thread
,
int
&
reusekv
,
int
batchsize
,
int
seq
,
int
qheads
,
int
kvheads
){
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
//mha
max_num_partitions,PARTITION_SIZE); \
}else if(max_num_partitions>64){ \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, 128>), \
dim3(reduce_grid), dim3(128), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions,PARTITION_SIZE);}
void
get_numberthread_and_reuse_kv_v2
(
int
&
num_thread
,
int
&
reusekv
,
int
&
PARTITION_SIZE
,
int
&
max_num_partitions
,
int
batchsize
,
int
max_seq_len
,
int
qheads
,
int
kvheads
,
int
num_blocks
)
{
reusekv
=
1
;
reusekv
=
1
;
num_thread
=
256
;
num_thread
=
256
;
if
(
device_name
==
"gfx936"
){
//bw
PARTITION_SIZE
=
512
;
if
(
qheads
==
kvheads
){
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
if
(
seq
<
16
){
num_thread
=
64
;
return
;}
if
(
max_seq_len
==
8192
&&
num_blocks
==
1024
){
//ali test
if
(
batchsize
>=
32
&&
seq
>=
1000
)
return
;
if
(
batchsize
==
1
&&
qheads
==
16
&&
kvheads
==
16
){
num_thread
=
128
;
return
;}
if
(
batchsize
*
qheads
>=
512
)
num_thread
=
64
;
if
(
batchsize
==
1
&&
qheads
==
32
&&
kvheads
==
32
){
num_thread
=
64
;
return
;}
return
;
if
(
batchsize
==
1
){
}
if
(
qheads
==
52
){
reusekv
=
8
;
return
;}
if
(
seq
<=
16
){
if
(
qheads
==
13
){
reusekv
=
2
;
return
;}
num_thread
=
64
;
reusekv
=
4
;
return
;
if
(
qheads
*
batchsize
>
1000
)
reusekv
=
4
;
return
;
}
}
if
(
seq
<=
64
){
if
(
batchsize
==
64
){
if
(
qheads
*
batchsize
>
1000
)
reusekv
=
4
;
if
(
qheads
==
13
){
PARTITION_SIZE
=
256
;
num_thread
=
128
;
reusekv
=
8
;}
else
if
(
qheads
==
32
){
PARTITION_SIZE
=
1024
;
reusekv
=
8
;}
else
if
(
qheads
==
52
||
qheads
==
26
){
reusekv
=
16
;}
else
reusekv
=
8
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
return
;
return
;
}
}
if
(
seq
<=
200
){
if
(
qheads
*
batchsize
>
400
)
reusekv
=
4
;
return
;
}
if
(
seq
<=
500
){
if
(
qheads
*
batchsize
>
200
)
reusekv
=
4
;
return
;
}
if
(((
qheads
-
1
)
/
16
+
1
)
*
batchsize
>=
64
&&
qheads
/
kvheads
>
4
&&
seq
<
7800
)
reusekv
=
8
;
else
if
(
qheads
*
batchsize
>
100
)
reusekv
=
4
;
return
;
}
}
if
(
qheads
==
kvheads
){
if
(
qheads
==
kvheads
){
//llama 7B ,其他模型未可知
if
(
max_seq_len
<=
8192
){
if
(
seq
<=
16
||
batchsize
>=
32
)
num_thread
=
64
;
if
(
batchsize
*
qheads
>=
512
){
else
if
(
batchsize
<=
2
)
num_thread
=
256
;
max_num_partitions
=
1
;
else
if
(
batchsize
<
8
)
num_thread
=
128
;
num_thread
=
64
;
else
num_thread
=
64
;
}
if
(
qheads
==
32
&&
max_seq_len
<=
1024
)
max_num_partitions
=
1
;
}
return
;
return
;
}
}
// mqa
if
(
max_seq_len
<
800
)
max_num_partitions
=
1
;
if
(
qheads
>
kvheads
*
4
){
if
(
qheads
>
kvheads
*
4
){
if
(
seq
<
64
){
if
(
max_seq_len
<=
1000
||
if
(
batchsize
<=
64
){
reusekv
=
1
;
num_thread
=
64
;}
max_seq_len
<
1500
&&
(
batchsize
>=
8
&&
qheads
>=
8
||
batchsize
>=
64
)
||
else
if
(
batchsize
<
128
){
reusekv
=
2
;
num_thread
=
64
;}
max_seq_len
<
1900
&&
batchsize
>=
8
&&
qheads
==
28
else
{
reusekv
=
4
;
num_thread
=
64
;}
)
}
max_num_partitions
=
1
;
else
if
(
seq
<=
400
){
int
blocks
=
max_num_partitions
*
batchsize
*
qheads
;
if
(
batchsize
<
16
){
reusekv
=
1
;
num_thread
=
256
;}
if
(
device_name
==
"gfx928"
){
else
if
(
batchsize
<
64
){
reusekv
=
2
;
num_thread
=
256
;}
if
(
batchsize
*
qheads
>
1024
&&
max_seq_len
>=
2000
){
else
if
(
batchsize
<=
128
){
max_num_partitions
=
1
;
reusekv
=
4
;
if
(
max_seq_len
<
3900
)
reusekv
=
8
;
if
(
qheads
%
7
==
0
)
num_thread
=
64
;
//qwen7b
else
if
(
max_seq_len
<
7800
)
reusekv
=
4
;
else
num_thread
=
256
;
//llama70b
else
{
PARTITION_SIZE
=
2048
;
reusekv
=
8
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
}
}
else
{
reusekv
=
8
;
num_thread
=
64
;}
return
;
}
else
if
(
seq
<=
1000
){
if
(
batchsize
<
16
){
reusekv
=
1
;
num_thread
=
256
;}
else
if
(
qheads
%
7
==
0
&&
batchsize
<=
128
){
//qwen7b
if
(
batchsize
<
64
){
reusekv
=
4
;
num_thread
=
256
;}
else
{
reusekv
=
4
;
num_thread
=
64
;}
}
}
else
if
(
batchsize
<=
64
){
reusekv
=
4
;
num_thread
=
256
;}
else
{
reusekv
=
8
;
num_thread
=
128
;}
}
}
else
if
(
seq
<
3900
)
{
reusekv
=
8
;
num_thread
=
256
;}
if
(
max_num_partitions
==
1
){
else
if
(
seq
<
7800
)
{
reusekv
=
4
;
num_thread
=
256
;}
if
(
max_seq_len
<
512
){
else
{
reusekv
=
2
;
num_thread
=
256
;}
int
bytes
=
max_seq_len
*
qheads
*
batchsize
;
return
;
if
(
bytes
<
51200
)
reusekv
=
1
;
}
else
if
(
bytes
<
256000
)
reusekv
=
4
;
else
reusekv
=
8
;
if
(
qheads
/
kvheads
>
4
&&
seq
<
3900
)
reusekv
=
8
;
return
;
else
if
(
qheads
/
kvheads
>
2
&&
seq
<
7800
)
reusekv
=
4
;
}
else
if
(
qheads
/
kvheads
>=
2
&&
seq
<
15600
)
reusekv
=
2
;
if
(
batchsize
<
4
||
batchsize
==
4
&&
qheads
==
8
)
reusekv
=
1
;
else
if
(
batchsize
<
32
||
batchsize
<=
64
&&
qheads
==
8
)
reusekv
=
4
;
if
(
seq
<=
64
){
else
reusekv
=
8
;
num_thread
=
64
;
return
;
if
(
batchsize
<=
64
)
reusekv
=
1
;
}
else
num_thread
=
256
;
}
// TODO(woosuk): Tune NUM_THREADS.
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
void
paged_attention_v1_launcher_opt_tc
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
max_num_blocks_per_seq
=
block_tables
.
size
(
1
);
int
q_stride
=
query
.
stride
(
0
);
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
int
num_threads
=
128
;
// printf("paged_attention_v1\n");
if
(
num_heads
!=
num_kv_heads
)
{
num_threads
=
256
;
}
[[
maybe_unused
]]
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
assert
(
head_size
%
thread_group_size
==
0
);
// NOTE: alibi_slopes is optional.
const
float
*
alibi_slopes_ptr
=
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
int
padded_max_seq_len
=
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
&&
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
){
constexpr
int
HEAD_SIZE
=
128
;
constexpr
static
int
use_vmac
=
false
;
int
reusekv
,
num_thread
;
get_numberthread_and_reuse_kv_v1
(
num_thread
,
reusekv
,
num_seqs
,
padded_max_seq_len
,
num_heads
,
num_kv_heads
);
if
(
PA_REUSE_KV_TIMES
!=
0
&&
num_heads
>
num_kv_heads
)
reusekv
=
PA_REUSE_KV_TIMES
;
if
(
PA_BLOCK_SIZE
!=
0
)
num_thread
=
PA_BLOCK_SIZE
;
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
padded_max_seq_len
*
2
;
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
if
(
NUM_WARPS
==
64
)
outputs_size
=
0
;
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
dim3
grid
(
1
,(
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
,
num_seqs
);
dim3
block
(
NUM_THREADS
);
if
(
PA_PRINT_PARAM
)
printf
(
"reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d
\n
"
,
reusekv
,
num_thread
,
grid
.
x
,
grid
.
y
,
grid
.
z
,
num_heads
,
num_kv_heads
,
max_seq_len
,
num_seqs
);
LAUNCH_PAGED_ATTENTION_V1_TC
(
HEAD_SIZE
);
});
});
}
}
}
if
(
blocks
<
150
)
return
;
if
(
blocks
<
600
||
qheads
<=
kvheads
*
4
){
reusekv
=
4
;
return
;}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
reusekv
=
8
;
return
;
paged_attention_v1_launcher_opt_tc<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
if (is_block_sparse) { \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
} else { \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
}
if
(
device_name
==
"gfx928"
){
void
paged_attention_v1
(
if
(
batchsize
*
qheads
>
1024
&&
max_seq_len
>=
2000
){
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
max_num_partitions
=
1
;
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
if
(
max_seq_len
<
7800
)
reusekv
=
4
;
torch
::
Tensor
&
else
{
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
PARTITION_SIZE
=
2048
;
torch
::
Tensor
&
reusekv
=
4
;
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
int64_t
num_kv_heads
,
// [num_heads]
double
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v1_opt_tc
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
int64_t
num_kv_heads
,
// [num_heads]
double
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
device_name
!=
"gfx928"
&&
device_name
!=
"gfx936"
)){
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
else
{
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
CALL_V1_LAUNCHER_BLOCK_SIZE
)
}
}
#define LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE) \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_kernel_TC< \
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
IS_BLOCK_SPARSE, REUSE_KV_TIMES,use_vmac, PARTITION_SIZE>), \
dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \
max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \
dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
void
get_numberthread_and_reuse_kv_v2
(
int
&
num_thread
,
int
&
reusekv
,
int
batchsize
,
int
max_num_partitions
,
int
qheads
,
int
kvheads
,
int
num_blocks
){
reusekv
=
1
;
num_thread
=
256
;
if
(
device_name
==
"gfx936"
){
//bw
if
(
max_num_partitions
==
16
&&
num_blocks
==
1024
){
//ali test
if
(
batchsize
==
1
&&
qheads
==
16
&&
kvheads
==
16
){
num_thread
=
128
;
return
;}
if
(
batchsize
==
1
&&
qheads
==
32
&&
kvheads
==
32
){
num_thread
=
64
;
return
;}
if
(
batchsize
==
1
){
if
(
qheads
==
52
){
reusekv
=
8
;
return
;}
if
(
qheads
==
13
){
reusekv
=
2
;
return
;}
reusekv
=
4
;
return
;
}
if
(
batchsize
==
64
){
if
(
qheads
==
13
||
qheads
==
32
){
num_thread
=
128
;
reusekv
=
8
;
return
;}
if
(
qheads
==
52
){
reusekv
=
16
;
return
;}
reusekv
=
8
;
return
;
}
}
return
;
}
}
if
(
qheads
==
kvheads
)
return
;
int
bp
=
max_num_partitions
*
batchsize
;
if
(
qheads
/
kvheads
>
4
){
if
(
qheads
==
16
&&
bp
>
96
||
qheads
<
16
&&
bp
>=
192
||
qheads
>
16
&&
bp
>
24
){
reusekv
=
8
;
return
;}
}
if
(
qheads
/
4
*
bp
>=
32
)
reusekv
=
4
;
return
;
}
int
blocks
=
batchsize
*
qheads
*
max_num_partitions
;
if
(
qheads
==
kvheads
){
if
(
blocks
<=
80
||
blocks
>
8000
){
num_thread
=
256
;}
else
if
(
blocks
<=
160
){
num_thread
=
128
;}
else
num_thread
=
64
;
return
;
}
if
(
qheads
/
kvheads
>
8
&&
blocks
>
4000
){
reusekv
=
16
;
if
(
blocks
>
40000
)
num_thread
=
64
;
else
num_thread
=
128
;
}
else
if
(
qheads
/
kvheads
==
5
||
qheads
/
kvheads
==
7
){
if
(
blocks
<=
160
){
reusekv
=
1
;
num_thread
=
256
;}
else
if
(
blocks
<
640
/
5
*
qheads
/
kvheads
){
reusekv
=
4
;
num_thread
=
256
;}
else
if
(
blocks
<
1920
){
reusekv
=
8
;
num_thread
=
128
;}
else
{
reusekv
=
8
;
num_thread
=
64
;}
}
else
if
(
qheads
>
kvheads
*
4
){
if
(
blocks
<=
128
){
reusekv
=
1
;
num_thread
=
256
;}
else
if
(
blocks
<
1536
){
reusekv
=
4
;
num_thread
=
256
;}
else
if
(
blocks
<
6144
){
reusekv
=
8
;
num_thread
=
128
;}
else
{
reusekv
=
8
;
num_thread
=
64
;}
}
else
{
if
(
blocks
<=
128
){
reusekv
=
1
;
num_thread
=
256
;}
else
if
(
blocks
<
3000
){
reusekv
=
4
;
num_thread
=
256
;}
else
{
reusekv
=
4
;
num_thread
=
64
;}
}
}
}
if
(
max_seq_len
<=
1000
||
max_seq_len
<=
1500
&&
(
qheads
>
4
&&
batchsize
>=
16
||
batchsize
>=
64
))
max_num_partitions
=
1
;
int
blocks
=
max_num_partitions
*
batchsize
*
qheads
;
if
(
blocks
>=
150
||
batchsize
>=
16
||
qheads
>=
8
&&
(
batchsize
>=
4
||
max_seq_len
>=
2000
))
reusekv
=
4
;
}
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
=
512
>
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
void
paged_attention_v2_launcher_opt_tc
(
void
paged_attention_v2_launcher_opt_tc
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
...
@@ -1146,45 +898,54 @@ void paged_attention_v2_launcher_opt_tc(
...
@@ -1146,45 +898,54 @@ void paged_attention_v2_launcher_opt_tc(
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
int
num_blocks
=
key_cache
.
size
(
0
);
int
num_blocks
=
key_cache
.
size
(
0
);
// printf("paged_attention_v2\n");
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
assert
(
head_size
%
thread_group_size
==
0
);
// NOTE: alibi_slopes is optional.
// NOTE: alibi_slopes is optional.
const
float
*
alibi_slopes_ptr
=
const
float
*
alibi_slopes_ptr
=
alibi_slopes
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
:
nullptr
;
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
float
*
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
//
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float
*
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
//
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T
*
tmp_out_ptr
=
reinterpret_cast
<
T
*>
(
tmp_out
.
data_ptr
());
//
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
static
float
*
exp_sums_ptr
=
nullptr
;
static
float
*
max_logits_ptr
=
nullptr
;
static
T
*
tmp_out_ptr
=
nullptr
;
if
(
exp_sums_ptr
==
nullptr
){
hipMalloc
(
&
exp_sums_ptr
,
1000000
);
// 1m
hipMalloc
(
&
max_logits_ptr
,
1000000
);
// 1m
hipMalloc
(
&
tmp_out_ptr
,
100000000
);
// 100m
}
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
reduce_grid
(
num_heads
,
num_seqs
);
dim3
reduce_grid
(
num_heads
,
num_seqs
);
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
int
reduce_shared_mem_size
=
2
*
max_num_partitions
*
sizeof
(
float
);
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
&&
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
){
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
&&
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
){
constexpr
int
HEAD_SIZE
=
128
;
constexpr
int
HEAD_SIZE
=
128
;
constexpr
static
int
use_vmac
=
false
;
int
reusekv
,
num_thread
,
max_num_partitions
,
PARTITION_SIZE
;
int
reusekv
,
num_thread
;
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
PARTITION_SIZE
,
max_num_partitions
,
num_seqs
,
max_seq_len
,
num_heads
,
num_kv_heads
,
num_blocks
);
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
num_seqs
,
max_num_partitions
,
num_heads
,
num_kv_heads
,
num_blocks
);
if
(
PA_PARTITION_SIZE
!=
0
){
PARTITION_SIZE
=
PA_PARTITION_SIZE
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
}
if
(
PA_REUSE_KV_TIMES
!=
0
&&
num_heads
>
num_kv_heads
)
reusekv
=
PA_REUSE_KV_TIMES
;
if
(
PA_REUSE_KV_TIMES
!=
0
&&
num_heads
>
num_kv_heads
)
reusekv
=
PA_REUSE_KV_TIMES
;
if
(
PA_BLOCK_SIZE
!=
0
)
num_thread
=
PA_BLOCK_SIZE
;
if
(
PA_BLOCK_SIZE
!=
0
)
num_thread
=
PA_BLOCK_SIZE
;
if
(
PA_USE_V1
!=
0
)
max_num_partitions
=
1
;
if
(
max_num_partitions
==
1
)
PARTITION_SIZE
=
max_seq_len
;
assert
(
num_seqs
*
num_heads
*
max_num_partitions
*
head_size
<=
100000000
);
int
reduce_shared_mem_size
=
2
*
max_num_partitions
*
sizeof
(
float
);
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
2
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
2
;
if
(
max_num_partitions
==
1
)
PARTITION_SIZE
=
0
;
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
dim3
grid
;
dim3
grid
;
grid
.
y
=
(
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
;
grid
.
y
=
(
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
;
...
@@ -1291,6 +1052,57 @@ void paged_attention_v2_opt_tc(
...
@@ -1291,6 +1052,57 @@ void paged_attention_v2_opt_tc(
}
}
}
}
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
int64_t
num_kv_heads
,
// [num_heads]
double
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v1_opt_tc
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
int64_t
num_kv_heads
,
// [num_heads]
double
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
device_name
!=
"gfx928"
&&
device_name
!=
"gfx936"
)){
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
else
{
paged_attention_v2_opt_tc
(
out
,
out
,
out
,
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
}
#undef WARP_SIZE
#undef WARP_SIZE
#undef MAX
#undef MAX
#undef MIN
#undef MIN
...
...
csrc/attention/attention_with_mask_kernels_opt_tc.cu
View file @
fbde1e5a
...
@@ -316,13 +316,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -316,13 +316,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
}
__syncthreads
();
__syncthreads
();
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
WARP_SIZE
);
//2
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
WARP_SIZE
);
//2
if
(
q_boundary
<=
2
){
if
constexpr
(
REUSE_KV_TIMES
<=
2
){
constexpr
int
acc_size
=
REUSE_KV_TIMES
==
1
?
1
:
2
;
float
accs
[
REUSE_KV_TIMES
][
NUM_ROWS_PER_THREAD
];
float
accs
[
acc_size
][
NUM_ROWS_PER_THREAD
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
acc_size
;
k
++
)
for
(
int
k
=
0
;
k
<
REUSE_KV_TIMES
;
k
++
)
{
{
accs
[
k
][
i
]
=
0.
f
;
accs
[
k
][
i
]
=
0.
f
;
}
}
...
@@ -356,7 +355,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -356,7 +355,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
float4_t
out_vec
=
{
0
,
0
,
0
,
0
};
float4_t
out_vec
=
{
0
,
0
,
0
,
0
};
builtin_amdgcn_mmac
<
is_half
>
(
v_vec
,
logits_vec
,
out_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
v_vec
,
logits_vec
,
out_vec
);
if
(
rows
==
k
){
if
(
rows
==
k
){
for
(
int
resuseid
=
0
;
resuseid
<
acc_size
;
resuseid
++
){
for
(
int
resuseid
=
0
;
resuseid
<
REUSE_KV_TIMES
;
resuseid
++
){
accs
[
resuseid
][
i
]
+=
out_vec
[
resuseid
];
accs
[
resuseid
][
i
]
+=
out_vec
[
resuseid
];
}
}
}
}
...
@@ -366,8 +365,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -366,8 +365,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
__syncthreads
();
__syncthreads
();
using
floatV_t
=
__attribute__
(
(
__vector_size__
(
NUM_ROWS_PER_THREAD
*
sizeof
(
float
))
))
float
;
using
floatV_t
=
__attribute__
(
(
__vector_size__
(
NUM_ROWS_PER_THREAD
*
sizeof
(
float
))
))
float
;
// Perform reduction across warps.
// Perform reduction across warps.
#pragma unroll
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
q_boundary
;
reuse_kv_idx
++
)
{
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
acc_size
;
reuse_kv_idx
++
)
{
if
constexpr
(
NUM_THREADS
>
64
){
if
constexpr
(
NUM_THREADS
>
64
){
floatV_t
*
out_smem
=
reinterpret_cast
<
floatV_t
*>
(
shared_mem
);
floatV_t
*
out_smem
=
reinterpret_cast
<
floatV_t
*>
(
shared_mem
);
#pragma unroll
#pragma unroll
...
@@ -780,97 +778,9 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern
...
@@ -780,97 +778,9 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern
max_num_partitions,PARTITION_SIZE);}
max_num_partitions,PARTITION_SIZE);}
static
void
get_numberthread_and_reuse_kv_v2
(
int
&
num_thread
,
int
&
reusekv
,
int
&
PARTITION_SIZE
,
int
&
max_num_partitions
,
void
get_numberthread_and_reuse_kv_v2
(
int
&
num_thread
,
int
&
reusekv
,
int
&
PARTITION_SIZE
,
int
&
max_num_partitions
,
int
batchsize
,
int
max_seq_len
,
int
qheads
,
int
kvheads
,
int
num_blocks
)
int
batchsize
,
int
max_seq_len
,
int
qheads
,
int
kvheads
,
int
num_blocks
);
{
reusekv
=
1
;
num_thread
=
256
;
PARTITION_SIZE
=
512
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
if
(
max_seq_len
==
8192
&&
num_blocks
==
1024
){
//ali test
if
(
batchsize
==
1
&&
qheads
==
16
&&
kvheads
==
16
){
num_thread
=
128
;
return
;}
if
(
batchsize
==
1
&&
qheads
==
32
&&
kvheads
==
32
){
num_thread
=
64
;
return
;}
if
(
batchsize
==
1
){
if
(
qheads
==
52
){
reusekv
=
8
;
return
;}
if
(
qheads
==
13
){
reusekv
=
2
;
return
;}
reusekv
=
4
;
return
;
}
if
(
batchsize
==
64
){
if
(
qheads
==
13
){
PARTITION_SIZE
=
256
;
num_thread
=
128
;
reusekv
=
8
;}
else
if
(
qheads
==
32
){
PARTITION_SIZE
=
1024
;
reusekv
=
8
;}
else
if
(
qheads
==
52
||
qheads
==
26
){
reusekv
=
16
;}
else
reusekv
=
8
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
return
;
}
}
if
(
qheads
==
kvheads
){
if
(
max_seq_len
<=
8192
){
if
(
batchsize
*
qheads
>=
512
){
max_num_partitions
=
1
;
num_thread
=
64
;
}
if
(
qheads
==
32
&&
max_seq_len
<=
1024
)
max_num_partitions
=
1
;
}
return
;
}
if
(
max_seq_len
<
800
)
max_num_partitions
=
1
;
if
(
qheads
>
kvheads
*
4
){
if
(
max_seq_len
<=
1000
||
max_seq_len
<
1500
&&
(
batchsize
>=
8
&&
qheads
>=
8
||
batchsize
>=
64
)
||
max_seq_len
<
1900
&&
batchsize
>=
8
&&
qheads
==
28
)
max_num_partitions
=
1
;
int
blocks
=
max_num_partitions
*
batchsize
*
qheads
;
if
(
device_name
==
"gfx928"
){
if
(
batchsize
*
qheads
>
1024
&&
max_seq_len
>=
2000
){
max_num_partitions
=
1
;
if
(
max_seq_len
<
3900
)
reusekv
=
8
;
else
if
(
max_seq_len
<
7800
)
reusekv
=
4
;
else
{
PARTITION_SIZE
=
2048
;
reusekv
=
8
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
}
return
;
}
}
if
(
max_num_partitions
==
1
){
if
(
max_seq_len
<
512
){
int
bytes
=
max_seq_len
*
qheads
*
batchsize
;
if
(
bytes
<
51200
)
reusekv
=
1
;
else
if
(
bytes
<
256000
)
reusekv
=
4
;
else
reusekv
=
8
;
return
;
}
if
(
batchsize
<
4
||
batchsize
==
4
&&
qheads
==
8
)
reusekv
=
1
;
else
if
(
batchsize
<
32
||
batchsize
<=
64
&&
qheads
==
8
)
reusekv
=
4
;
else
reusekv
=
8
;
return
;
}
if
(
blocks
<
150
)
return
;
if
(
blocks
<
600
||
qheads
<=
kvheads
*
4
){
reusekv
=
4
;
return
;}
reusekv
=
8
;
return
;
}
if
(
device_name
==
"gfx928"
){
if
(
batchsize
*
qheads
>
1024
&&
max_seq_len
>=
2000
){
max_num_partitions
=
1
;
if
(
max_seq_len
<
7800
)
reusekv
=
4
;
else
{
PARTITION_SIZE
=
2048
;
reusekv
=
4
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
}
return
;
}
}
if
(
max_seq_len
<=
1000
||
max_seq_len
<=
1500
&&
(
qheads
>
4
&&
batchsize
>=
16
||
batchsize
>=
64
))
max_num_partitions
=
1
;
int
blocks
=
max_num_partitions
*
batchsize
*
qheads
;
if
(
blocks
>=
150
||
batchsize
>=
16
||
qheads
>=
8
&&
(
batchsize
>=
4
||
max_seq_len
>=
2000
))
reusekv
=
4
;
}
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
void
paged_attention_v2_launcher_opt_tc_with_mask
(
void
paged_attention_v2_launcher_opt_tc_with_mask
(
...
...
vllm/attention/ops/paged_attn.py
View file @
fbde1e5a
...
@@ -14,7 +14,8 @@ if HAS_TRITON:
...
@@ -14,7 +14,8 @@ if HAS_TRITON:
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
_PARTITION_SIZE
=
512
gpuname
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
name
support_tc
=
gpuname
.
startswith
(
'K100_AI'
)
or
gpuname
.
startswith
(
'BW'
)
@
dataclass
@
dataclass
class
PagedAttentionMetadata
:
class
PagedAttentionMetadata
:
...
@@ -128,12 +129,62 @@ class PagedAttention:
...
@@ -128,12 +129,62 @@ class PagedAttention:
# to parallelize.
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
use_v1
=
(
max_seq_len
<
8192
and
(
max_seq_len
<
(
1024
if
num_kv_heads
==
num_heads
else
600
)
or
num_seqs
*
num_heads
>
(
1024
if
num_kv_heads
<
num_heads
else
512
)))
if
envs
.
VLLM_USE_TC_PAGED_ATTN
and
support_tc
:
else
:
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
use_v1
=
(
max_seq_len
<=
8192
print
(
"PA V1 SIZE:"
)
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
attn_masks
is
None
:
ops
.
paged_attention_v1_opt_tc
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
else
:
ops
.
paged_attention_v1_opt_tc_with_mask
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
return
output
use_v1
=
(
max_seq_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
if
use_v1
:
if
use_v1
:
# Run PagedAttention V1.
# Run PagedAttention V1.
...
@@ -143,100 +194,52 @@ class PagedAttention:
...
@@ -143,100 +194,52 @@ class PagedAttention:
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
if
attn_masks
is
None
:
if
attn_masks
is
None
:
ops
.
paged_attention_v1_opt
(
ops
.
paged_attention_v1_opt_tc
(
output
,
output
,
query
,
query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
num_kv_heads
,
num_kv_heads
,
scale
,
scale
,
block_tables
,
block_tables
,
seq_lens
,
seq_lens
,
block_size
,
block_size
,
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
k_scale
,
k_scale
,
v_scale
,
v_scale
,
tp_rank
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
blocksparse_head_sliding_step
)
)
else
:
ops
.
paged_attention_v1_opt_tc_with_mask
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
else
:
if
attn_masks
is
None
:
ops
.
paged_attention_v1_opt_with_mask
(
ops
.
paged_attention_v1_opt
(
output
,
output
,
query
,
query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
num_kv_heads
,
num_kv_heads
,
scale
,
scale
,
block_tables
,
block_tables
,
seq_lens
,
seq_lens
,
block_size
,
block_size
,
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
k_scale
,
k_scale
,
v_scale
,
v_scale
,
tp_rank
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
blocksparse_head_sliding_step
attn_masks
,
)
attn_masks_stride
else
:
)
ops
.
paged_attention_v1_opt_with_mask
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
else
:
if
attn_masks
is
None
:
if
attn_masks
is
None
:
ops
.
paged_attention_v1
(
ops
.
paged_attention_v1
(
...
@@ -306,112 +309,58 @@ class PagedAttention:
...
@@ -306,112 +309,58 @@ class PagedAttention:
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
if
attn_masks
is
None
:
if
attn_masks
is
None
:
ops
.
paged_attention_v2_opt
(
ops
.
paged_attention_v2_opt_tc
(
output
,
output
,
exp_sums
,
exp_sums
,
max_logits
,
max_logits
,
tmp_output
,
tmp_output
,
query
,
query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
num_kv_heads
,
num_kv_heads
,
scale
,
scale
,
block_tables
,
block_tables
,
seq_lens
,
seq_lens
,
block_size
,
block_size
,
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
k_scale
,
k_scale
,
v_scale
,
v_scale
,
tp_rank
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
blocksparse_head_sliding_step
)
)
else
:
ops
.
paged_attention_v2_opt_tc_with_mask
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
else
:
if
attn_masks
is
None
:
ops
.
paged_attention_v2_opt_with_mask
(
ops
.
paged_attention_v2_opt
(
output
,
output
,
exp_sums
,
exp_sums
,
max_logits
,
max_logits
,
tmp_output
,
tmp_output
,
query
,
query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
num_kv_heads
,
num_kv_heads
,
scale
,
scale
,
block_tables
,
block_tables
,
seq_lens
,
seq_lens
,
block_size
,
block_size
,
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
k_scale
,
k_scale
,
v_scale
,
v_scale
,
tp_rank
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
blocksparse_head_sliding_step
attn_masks
,
)
attn_masks_stride
else
:
)
ops
.
paged_attention_v2_opt_with_mask
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
else
:
if
attn_masks
is
None
:
if
attn_masks
is
None
:
ops
.
paged_attention_v2
(
ops
.
paged_attention_v2
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment