Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8af5263f
"tests/vscode:/vscode.git/clone" did not exist on "6f1229f91ddbc8c12b77929b4eae4006d3a46ca5"
Commit
8af5263f
authored
Dec 18, 2024
by
zhangshao
Browse files
增加bw pa tc优化
parent
c56b26cd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
188 additions
and
51 deletions
+188
-51
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+188
-51
No files found.
csrc/attention/attention_kernels_opt_tc.cu
View file @
8af5263f
...
...
@@ -6,26 +6,19 @@
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef
__hip_bfloat16
__nv_bfloat16
;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define WARP_SIZE 64
#include "static_switch_tc.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
inline
std
::
string
get_device_name
()
std
::
string
get_device_name
()
{
hipDeviceProp_t
props
{};
int
device
;
...
...
@@ -43,6 +36,9 @@ inline std::string get_device_name()
const
std
::
string
raw_name
(
props
.
gcnArchName
);
return
raw_name
.
substr
(
0
,
raw_name
.
find
(
':'
));
// str.substr(0, npos) returns str.
}
static
const
std
::
string
device_name
=
get_device_name
();
static
inline
int
get_env_
(
const
char
*
env_var
)
{
if
(
char
*
value
=
std
::
getenv
(
env_var
))
{
return
atoi
(
value
);
...
...
@@ -170,8 +166,8 @@ __device__ void paged_attention_kernel_TC(
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
seq_idx
=
blockIdx
.
z
;
const
int
partition_idx
=
blockIdx
.
y
;
const
int
max_num_partitions
=
gridDim
.
y
;
const
int
partition_idx
=
blockIdx
.
x
;
const
int
max_num_partitions
=
gridDim
.
x
;
constexpr
bool
USE_PARTITIONING
=
PARTITION_SIZE
>
0
;
const
int
seq_len
=
__builtin_amdgcn_readfirstlane
(
seq_lens
[
seq_idx
]);
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
{
...
...
@@ -203,14 +199,10 @@ __device__ void paged_attention_kernel_TC(
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
int
num_blocks_per_kv
=
((
num_queries_per_kv
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
);
const
int
odd_tg_round
=
(((
blockIdx
.
z
*
gridDim
.
y
*
gridDim
.
x
)
+
blockIdx
.
y
*
gridDim
.
x
)
/
128
)
%
2
;
const
int
mid_x
=
gridDim
.
x
/
2
;
const
int
blockIdx_shift
=
(
odd_tg_round
|
(
gridDim
.
x
&
1
))
?
blockIdx
.
x
:
(
blockIdx
.
x
<
mid_x
?
(
blockIdx
.
x
+
mid_x
)
:
(
blockIdx
.
x
-
mid_x
));
const
int
head_idx
=
(
blockIdx_shift
/
num_blocks_per_kv
)
*
num_queries_per_kv
+
(
blockIdx_shift
%
num_blocks_per_kv
)
*
REUSE_KV_TIMES
;
//const int head_idx=(blockIdx.x / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES;
const
int
head_idx
=
(
blockIdx
.
y
/
num_blocks_per_kv
)
*
num_queries_per_kv
+
(
blockIdx
.
y
%
num_blocks_per_kv
)
*
REUSE_KV_TIMES
;
int
q_boundary
=
REUSE_KV_TIMES
;
if
(
num_heads
<
REUSE_KV_TIMES
*
gridDim
.
x
&&
(
num_blocks_per_kv
-
1
)
*
REUSE_KV_TIMES
==
head_idx
%
num_queries_per_kv
)
if
(
num_heads
<
REUSE_KV_TIMES
*
gridDim
.
y
&&
(
num_blocks_per_kv
-
1
)
*
REUSE_KV_TIMES
==
head_idx
%
num_queries_per_kv
)
q_boundary
=
num_queries_per_kv
-
(
num_blocks_per_kv
-
1
)
*
REUSE_KV_TIMES
;
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
constexpr
int
reuse_group
=
(
REUSE_KV_TIMES
-
1
)
/
4
+
1
;
...
...
@@ -233,7 +225,7 @@ __device__ void paged_attention_kernel_TC(
q_vec
.
data
[
1
]
=
{
0
,
0
,
0
,
0
};
__shared__
half4x2
q_vecs
[
REUSE_KV_TIMES
][
16
];
//if(thread_idx==0)printf("blockIdx.
x
==%d,q_boundary=%d,head_idx=%d,kv_head_idx=%d\n",blockIdx.
x
,q_boundary,head_idx,kv_head_idx);
//if(thread_idx==0)printf("blockIdx.
y
==%d,q_boundary=%d,head_idx=%d,kv_head_idx=%d\n",blockIdx.
y
,q_boundary,head_idx,kv_head_idx);
for
(
int
i
=
0
;
i
<
q_boundary
;
i
++
){
if
(
thread_idx
<
16
){
q_vecs
[
i
][
thread_idx
]
=*
reinterpret_cast
<
const
half4x2
*>
(
q_ptr
+
i
*
HEAD_SIZE
+
thread_idx
*
8
);
...
...
@@ -303,7 +295,7 @@ __device__ void paged_attention_kernel_TC(
}
}
}
// if(blockIdx.
x
==0)printf("%d,qkmax=%f\n",threadIdx.x,qk_max[0]);
// if(blockIdx.
y
==0)printf("%d,qkmax=%f\n",threadIdx.x,qk_max[0]);
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
...
...
@@ -353,7 +345,6 @@ __device__ void paged_attention_kernel_TC(
*
exp_sums_ptr
=
exp_sum
;
}
}
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
WARP_SIZE
);
//2
if
constexpr
(
REUSE_KV_TIMES
<=
2
){
float
accs
[
REUSE_KV_TIMES
][
NUM_ROWS_PER_THREAD
];
...
...
@@ -441,6 +432,7 @@ __device__ void paged_attention_kernel_TC(
}
}
}
#if defined __gfx928__
else
{
constexpr
int
GROUPS
=
reuse_group
*
4
;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
...
...
@@ -533,6 +525,99 @@ __device__ void paged_attention_kernel_TC(
}
}
}
#else
else
{
constexpr
int
GROUPS
=
reuse_group
*
4
;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float4_t
accs
[
4
][
NUM_ROWS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
accs
[
k
][
i
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
}
}
scalar_t
zero_value
;
zero
(
zero_value
);
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
if
(
rowid
<
q_boundary
){
logits_vec
=*
reinterpret_cast
<
half4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
}
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
rows
*
4
+
rowid
*
16
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
){
int
offset
=
i
*
1024
+
k
*
256
;
half4_t
v_vec
=*
reinterpret_cast
<
const
half4_t
*>
(
v_ptr
+
offset
);
if
(
block_idx
==
num_seq_blocks
-
1
)
{
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
}
}
builtin_amdgcn_mmac
<
is_half
,
use_vmac
>
(
v_vec
,
logits_vec
,
accs
[
k
][
i
]);
}
}
}
if
constexpr
(
NUM_THREADS
>
64
){
__syncthreads
();
using
floatV_t
=
__attribute__
(
(
__vector_size__
(
reuse_group
*
sizeof
(
float
))
))
float
;
// Perform reduction across warps.
for
(
int
m
=
0
;
m
<
4
;
m
++
)
{
floatV_t
*
out_smem
=
reinterpret_cast
<
floatV_t
*>
(
shared_mem
);
#pragma unroll
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
int
mid
=
i
/
2
;
// Upper warps write to shared memory.
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
for
(
int
k
=
0
;
k
<
NUM_ROWS_PER_THREAD
;
k
++
){
out_smem
[((
warp_idx
-
mid
)
*
64
+
lane
)
*
NUM_ROWS_PER_THREAD
+
k
]
=*
(
floatV_t
*
)(
&
(
accs
[
m
][
k
]));
}
}
__syncthreads
();
// Lower warps update the output.
if
(
warp_idx
<
mid
)
{
for
(
int
k
=
0
;
k
<
NUM_ROWS_PER_THREAD
;
k
++
){
floatV_t
tmp
=
out_smem
[
thread_idx
*
NUM_ROWS_PER_THREAD
+
k
];
#pragma unroll
for
(
int
i
=
0
;
i
<
reuse_group
;
i
++
)
{
accs
[
m
][
k
][
i
]
+=
tmp
[
i
];
}
}
}
__syncthreads
();
}
}
}
if
(
warp_idx
==
0
)
{
for
(
int
g
=
0
;
g
<
reuse_group
;
g
++
){
int
reusekvid
=
g
*
4
+
rows
;
if
(
reusekvid
<
q_boundary
){
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
(
head_idx
+
reusekvid
)
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
k
=
0
;
k
<
4
;
k
++
){
const
int
row_idx
=
rowid
+
16
*
k
+
i
*
WARP_SIZE
;
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
k
][
i
][
g
]);
}
}
}
}
}
}
#endif
}
...
...
@@ -736,6 +821,35 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
void
get_numberthread_and_reuse_kv_v1
(
int
&
num_thread
,
int
&
reusekv
,
int
batchsize
,
int
seq
,
int
qheads
,
int
kvheads
){
//mha
reusekv
=
1
;
num_thread
=
256
;
if
(
device_name
==
"gfx936"
){
//bw
if
(
qheads
==
kvheads
){
if
(
seq
<
16
){
num_thread
=
64
;
return
;}
if
(
batchsize
>=
32
&&
seq
>=
1000
)
return
;
if
(
batchsize
*
qheads
>=
512
)
num_thread
=
64
;
return
;
}
if
(
seq
<=
16
){
num_thread
=
64
;
if
(
qheads
*
batchsize
>
1000
)
reusekv
=
4
;
return
;
}
if
(
seq
<=
64
){
if
(
qheads
*
batchsize
>
1000
)
reusekv
=
4
;
return
;
}
if
(
seq
<=
200
){
if
(
qheads
*
batchsize
>
400
)
reusekv
=
4
;
return
;
}
if
(
seq
<=
500
){
if
(
qheads
*
batchsize
>
200
)
reusekv
=
4
;
return
;
}
if
(((
qheads
-
1
)
/
16
+
1
)
*
batchsize
>=
64
&&
qheads
/
kvheads
>
4
&&
seq
<
7800
)
reusekv
=
8
;
else
if
(
qheads
*
batchsize
>
100
)
reusekv
=
4
;
return
;
}
if
(
qheads
==
kvheads
){
//llama 7B ,其他模型未可知
if
(
seq
<=
16
||
batchsize
>=
32
)
num_thread
=
64
;
...
...
@@ -844,7 +958,7 @@ void paged_attention_v1_launcher_opt_tc(
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
if
(
NUM_WARPS
==
64
)
outputs_size
=
0
;
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
dim3
grid
((
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
,
1
,
num_seqs
);
dim3
grid
(
1
,
(
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
,
num_seqs
);
dim3
block
(
NUM_THREADS
);
if
(
PA_PRINT_PARAM
)
printf
(
"reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d
\n
"
,
reusekv
,
num_thread
,
grid
.
x
,
grid
.
y
,
grid
.
z
,
num_heads
,
num_kv_heads
,
max_seq_len
,
num_seqs
);
...
...
@@ -927,7 +1041,7 @@ void paged_attention_v1_opt_tc(
const
int64_t
blocksparse_head_sliding_step
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
get_
device_name
()
!=
"gfx928"
&&
get_
device_name
()
!=
"gfx936"
)){
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
device_name
!=
"gfx928"
&&
device_name
!=
"gfx936"
)){
paged_attention_v1_opt
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
...
...
@@ -958,8 +1072,32 @@ void paged_attention_v1_opt_tc(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
void
get_numberthread_and_reuse_kv_v2
(
int
&
num_thread
,
int
&
reusekv
,
int
batchsize
,
int
max_num_partitions
,
int
qheads
,
int
kvheads
){
void
get_numberthread_and_reuse_kv_v2
(
int
&
num_thread
,
int
&
reusekv
,
int
batchsize
,
int
max_num_partitions
,
int
qheads
,
int
kvheads
,
int
num_blocks
){
reusekv
=
1
;
num_thread
=
256
;
if
(
device_name
==
"gfx936"
){
//bw
if
(
max_num_partitions
==
16
&&
num_blocks
==
1024
){
//ali test
if
(
batchsize
==
1
&&
qheads
==
16
&&
kvheads
==
16
){
num_thread
=
128
;
return
;}
if
(
batchsize
==
1
&&
qheads
==
32
&&
kvheads
==
32
){
num_thread
=
64
;
return
;}
if
(
batchsize
==
1
){
if
(
qheads
==
52
){
reusekv
=
8
;
return
;}
if
(
qheads
==
13
){
reusekv
=
2
;
return
;}
reusekv
=
4
;
return
;
}
if
(
batchsize
==
64
){
if
(
qheads
==
13
||
qheads
==
32
){
num_thread
=
128
;
reusekv
=
8
;
return
;}
if
(
qheads
==
52
){
reusekv
=
16
;
return
;}
reusekv
=
8
;
return
;
}
}
if
(
qheads
==
kvheads
)
return
;
int
bp
=
max_num_partitions
*
batchsize
;
if
(
qheads
/
kvheads
>
4
){
if
(
qheads
==
16
&&
bp
>
96
||
qheads
<
16
&&
bp
>=
192
||
qheads
>
16
&&
bp
>
24
){
reusekv
=
8
;
return
;}
}
if
(
qheads
/
4
*
bp
>=
32
)
reusekv
=
4
;
return
;
}
int
blocks
=
batchsize
*
qheads
*
max_num_partitions
;
if
(
qheads
==
kvheads
){
if
(
blocks
<=
80
||
blocks
>
8000
){
num_thread
=
256
;}
...
...
@@ -1009,6 +1147,7 @@ void paged_attention_v2_launcher_opt_tc(
int
q_stride
=
query
.
stride
(
0
);
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
int
num_blocks
=
key_cache
.
size
(
0
);
// printf("paged_attention_v2\n");
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
assert
(
head_size
%
thread_group_size
==
0
);
...
...
@@ -1036,31 +1175,29 @@ void paged_attention_v2_launcher_opt_tc(
int
reduce_shared_mem_size
=
2
*
max_num_partitions
*
sizeof
(
float
);
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
&&
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
){
//if(head_size==128&&get_device_name()=="gfx928"){
constexpr
int
HEAD_SIZE
=
128
;
constexpr
static
int
use_vmac
=
false
;
int
reusekv
,
num_thread
;
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
num_seqs
,
max_num_partitions
,
num_heads
,
num_kv_heads
);
if
(
PA_REUSE_KV_TIMES
!=
0
&&
num_heads
>
num_kv_heads
)
reusekv
=
PA_REUSE_KV_TIMES
;
if
(
PA_BLOCK_SIZE
!=
0
)
num_thread
=
PA_BLOCK_SIZE
;
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
2
;
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
dim3
grid
;
grid
.
x
=
(
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
;
grid
.
y
=
max_num_partitions
;
grid
.
z
=
num_seqs
;
dim3
block
(
NUM_THREADS
);
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
if
(
PA_PRINT_PARAM
)
printf
(
"reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d
\n
"
,
reusekv
,
num_thread
,
grid
.
x
,
grid
.
y
,
grid
.
z
,
num_heads
,
num_kv_heads
,
max_seq_len
,
num_seqs
);
LAUNCH_PAGED_ATTENTION_V2_TC
(
HEAD_SIZE
);
});
constexpr
int
HEAD_SIZE
=
128
;
constexpr
static
int
use_vmac
=
false
;
int
reusekv
,
num_thread
;
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
num_seqs
,
max_num_partitions
,
num_heads
,
num_kv_heads
,
num_blocks
);
if
(
PA_REUSE_KV_TIMES
!=
0
&&
num_heads
>
num_kv_heads
)
reusekv
=
PA_REUSE_KV_TIMES
;
if
(
PA_BLOCK_SIZE
!=
0
)
num_thread
=
PA_BLOCK_SIZE
;
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
2
;
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
dim3
grid
;
grid
.
y
=
(
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
;
grid
.
x
=
max_num_partitions
;
grid
.
z
=
num_seqs
;
dim3
block
(
NUM_THREADS
);
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
if
(
PA_PRINT_PARAM
)
printf
(
"reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d
\n
"
,
reusekv
,
num_thread
,
grid
.
x
,
grid
.
y
,
grid
.
z
,
num_heads
,
num_kv_heads
,
max_seq_len
,
num_seqs
);
LAUNCH_PAGED_ATTENTION_V2_TC
(
HEAD_SIZE
);
});
}
//
}
}
);
}
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
...
...
@@ -1145,7 +1282,7 @@ void paged_attention_v2_opt_tc(
const
int64_t
blocksparse_head_sliding_step
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
get_
device_name
()
!=
"gfx928"
&&
get_
device_name
()
!=
"gfx936"
)){
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
device_name
!=
"gfx928"
&&
device_name
!=
"gfx936"
)){
paged_attention_v2_opt
(
out
,
exp_sums
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment