Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ac811e51
Commit
ac811e51
authored
Mar 18, 2025
by
zhuwenwen
Browse files
解决PA部分size计算错误的问题
优化bf16精度 解决bf16精度问题,解决cudagraph精度问题
parent
a5b976df
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
395 additions
and
785 deletions
+395
-785
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+374
-617
csrc/attention/attention_with_mask_kernels_opt_tc.cu
csrc/attention/attention_with_mask_kernels_opt_tc.cu
+21
-168
No files found.
csrc/attention/attention_kernels_opt_tc.cu
View file @
ac811e51
This diff is collapsed.
Click to expand it.
csrc/attention/attention_with_mask_kernels_opt_tc.cu
View file @
ac811e51
...
@@ -107,11 +107,11 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4
...
@@ -107,11 +107,11 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4
{
{
if
constexpr
(
is_half
){
if
constexpr
(
is_half
){
asm
volatile
(
"v_mmac_f32_16x16x16_f16 %0, %1, %2, %0"
:
asm
volatile
(
"
\n
s_nop 1
\n
v_mmac_f32_16x16x16_f16 %0, %1, %2, %0"
:
"=v"
(
reg_c
)
:
"v"
(
reg_a
),
"v"
(
reg_b
),
"0"
(
reg_c
));
"=v"
(
reg_c
)
:
"v"
(
reg_a
),
"v"
(
reg_b
),
"0"
(
reg_c
));
}
}
else
{
else
{
asm
volatile
(
"v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0"
:
asm
volatile
(
"
\n
s_nop 1
\n
v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0"
:
"=v"
(
reg_c
)
:
"v"
(
reg_a
),
"v"
(
reg_b
),
"0"
(
reg_c
));
"=v"
(
reg_c
)
:
"v"
(
reg_a
),
"v"
(
reg_b
),
"0"
(
reg_c
));
}
}
}
}
...
@@ -159,7 +159,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -159,7 +159,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
const
int
seq_len
=
__builtin_amdgcn_readfirstlane
(
seq_lens
[
seq_idx
]);
const
int
seq_len
=
__builtin_amdgcn_readfirstlane
(
seq_lens
[
seq_idx
]);
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
bool
USE_PARTITIONING
=
PARTITION_SIZE
<
num_seq_blocks
*
BLOCK_SIZE
&&
PARTITION_SIZE
>
0
;
const
bool
USE_PARTITIONING
=
PARTITION_SIZE
<
num_seq_blocks
*
BLOCK_SIZE
&&
PARTITION_SIZE
>
0
;
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
return
;
if
(
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
return
;
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
...
@@ -209,10 +209,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -209,10 +209,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
for
(
int
i
=
0
;
i
<
q_boundary
;
i
++
){
for
(
int
i
=
0
;
i
<
q_boundary
;
i
++
){
if
(
thread_idx
<
16
){
if
(
thread_idx
<
16
){
half4x2
temp
=
*
reinterpret_cast
<
const
half4x2
*>
(
q_ptr
+
i
*
HEAD_SIZE
+
thread_idx
*
8
);
half4x2
temp
=
*
reinterpret_cast
<
const
half4x2
*>
(
q_ptr
+
i
*
HEAD_SIZE
+
thread_idx
*
8
);
if
constexpr
(
is_half
){
scalar_t
*
t
=
reinterpret_cast
<
scalar_t
*>
(
&
temp
);
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
){
for
(
int
k
=
0
;
k
<
8
;
k
++
){
temp
.
data
[
0
][
k
]
=
((
float
)
temp
.
data
[
0
]
[
k
])
*
scale
;
from_float
(
t
[
k
],
to_float
(
t
[
k
])
*
scale
)
;
temp
.
data
[
1
][
k
]
=
((
float
)
temp
.
data
[
1
][
k
])
*
scale
;
}
}
}
q_vecs
[
i
][
thread_idx
]
=
temp
;
q_vecs
[
i
][
thread_idx
]
=
temp
;
}
}
...
@@ -249,6 +251,9 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -249,6 +251,9 @@ __global__ void paged_attention_kernel_TC_with_mask(
int
reuse_kv_idx
=
rows
+
i
*
4
;
int
reuse_kv_idx
=
rows
+
i
*
4
;
if
(
reuse_kv_idx
<
REUSE_KV_TIMES
){
if
(
reuse_kv_idx
<
REUSE_KV_TIMES
){
if
(
reuse_kv_idx
>=
q_boundary
)
qk_vec
[
i
]
=
0
;
if
(
reuse_kv_idx
>=
q_boundary
)
qk_vec
[
i
]
=
0
;
else
{
if
constexpr
(
!
is_half
)
qk_vec
[
i
]
*=
scale
;
}
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rowid
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rowid
;
if
(
alibi_slope
[
i
]
!=
0
){
if
(
alibi_slope
[
i
]
!=
0
){
float
alibi
=
alibi_slope
[
i
]
*
(
token_idx
-
seq_len
+
1
);
float
alibi
=
alibi_slope
[
i
]
*
(
token_idx
-
seq_len
+
1
);
...
@@ -316,13 +321,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -316,13 +321,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
}
__syncthreads
();
__syncthreads
();
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
WARP_SIZE
);
//2
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
WARP_SIZE
);
//2
if
(
q_boundary
<=
2
){
if
constexpr
(
REUSE_KV_TIMES
<=
2
){
constexpr
int
acc_size
=
REUSE_KV_TIMES
==
1
?
1
:
2
;
float
accs
[
REUSE_KV_TIMES
][
NUM_ROWS_PER_THREAD
];
float
accs
[
acc_size
][
NUM_ROWS_PER_THREAD
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
acc_size
;
k
++
)
for
(
int
k
=
0
;
k
<
REUSE_KV_TIMES
;
k
++
)
{
{
accs
[
k
][
i
]
=
0.
f
;
accs
[
k
][
i
]
=
0.
f
;
}
}
...
@@ -356,7 +360,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -356,7 +360,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
float4_t
out_vec
=
{
0
,
0
,
0
,
0
};
float4_t
out_vec
=
{
0
,
0
,
0
,
0
};
builtin_amdgcn_mmac
<
is_half
>
(
v_vec
,
logits_vec
,
out_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
v_vec
,
logits_vec
,
out_vec
);
if
(
rows
==
k
){
if
(
rows
==
k
){
for
(
int
resuseid
=
0
;
resuseid
<
acc_size
;
resuseid
++
){
for
(
int
resuseid
=
0
;
resuseid
<
REUSE_KV_TIMES
;
resuseid
++
){
accs
[
resuseid
][
i
]
+=
out_vec
[
resuseid
];
accs
[
resuseid
][
i
]
+=
out_vec
[
resuseid
];
}
}
}
}
...
@@ -366,8 +370,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -366,8 +370,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
__syncthreads
();
__syncthreads
();
using
floatV_t
=
__attribute__
(
(
__vector_size__
(
NUM_ROWS_PER_THREAD
*
sizeof
(
float
))
))
float
;
using
floatV_t
=
__attribute__
(
(
__vector_size__
(
NUM_ROWS_PER_THREAD
*
sizeof
(
float
))
))
float
;
// Perform reduction across warps.
// Perform reduction across warps.
#pragma unroll
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
q_boundary
;
reuse_kv_idx
++
)
{
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
acc_size
;
reuse_kv_idx
++
)
{
if
constexpr
(
NUM_THREADS
>
64
){
if
constexpr
(
NUM_THREADS
>
64
){
floatV_t
*
out_smem
=
reinterpret_cast
<
floatV_t
*>
(
shared_mem
);
floatV_t
*
out_smem
=
reinterpret_cast
<
floatV_t
*>
(
shared_mem
);
#pragma unroll
#pragma unroll
...
@@ -780,97 +783,9 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern
...
@@ -780,97 +783,9 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern
max_num_partitions,PARTITION_SIZE);}
max_num_partitions,PARTITION_SIZE);}
static
void
get_numberthread_and_reuse_kv_v2
(
int
&
num_thread
,
int
&
reusekv
,
int
&
PARTITION_SIZE
,
int
&
max_num_partitions
,
void
get_numberthread_and_reuse_kv_v2
(
int
&
num_thread
,
int
&
reusekv
,
int
&
PARTITION_SIZE
,
int
&
max_num_partitions
,
int
batchsize
,
int
max_seq_len
,
int
qheads
,
int
kvheads
,
int
num_blocks
)
int
batchsize
,
int
max_seq_len
,
int
qheads
,
int
kvheads
,
int
num_blocks
);
{
reusekv
=
1
;
num_thread
=
256
;
PARTITION_SIZE
=
512
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
if
(
max_seq_len
==
8192
&&
num_blocks
==
1024
){
//ali test
if
(
batchsize
==
1
&&
qheads
==
16
&&
kvheads
==
16
){
num_thread
=
128
;
return
;}
if
(
batchsize
==
1
&&
qheads
==
32
&&
kvheads
==
32
){
num_thread
=
64
;
return
;}
if
(
batchsize
==
1
){
if
(
qheads
==
52
){
reusekv
=
8
;
return
;}
if
(
qheads
==
13
){
reusekv
=
2
;
return
;}
reusekv
=
4
;
return
;
}
if
(
batchsize
==
64
){
if
(
qheads
==
13
){
PARTITION_SIZE
=
256
;
num_thread
=
128
;
reusekv
=
8
;}
else
if
(
qheads
==
32
){
PARTITION_SIZE
=
1024
;
reusekv
=
8
;}
else
if
(
qheads
==
52
||
qheads
==
26
){
reusekv
=
16
;}
else
reusekv
=
8
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
return
;
}
}
if
(
qheads
==
kvheads
){
if
(
max_seq_len
<=
8192
){
if
(
batchsize
*
qheads
>=
512
){
max_num_partitions
=
1
;
num_thread
=
64
;
}
if
(
qheads
==
32
&&
max_seq_len
<=
1024
)
max_num_partitions
=
1
;
}
return
;
}
if
(
max_seq_len
<
800
)
max_num_partitions
=
1
;
if
(
qheads
>
kvheads
*
4
){
if
(
max_seq_len
<=
1000
||
max_seq_len
<
1500
&&
(
batchsize
>=
8
&&
qheads
>=
8
||
batchsize
>=
64
)
||
max_seq_len
<
1900
&&
batchsize
>=
8
&&
qheads
==
28
)
max_num_partitions
=
1
;
int
blocks
=
max_num_partitions
*
batchsize
*
qheads
;
if
(
device_name
==
"gfx928"
){
if
(
batchsize
*
qheads
>
1024
&&
max_seq_len
>=
2000
){
max_num_partitions
=
1
;
if
(
max_seq_len
<
3900
)
reusekv
=
8
;
else
if
(
max_seq_len
<
7800
)
reusekv
=
4
;
else
{
PARTITION_SIZE
=
2048
;
reusekv
=
8
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
}
return
;
}
}
if
(
max_num_partitions
==
1
){
if
(
max_seq_len
<
512
){
int
bytes
=
max_seq_len
*
qheads
*
batchsize
;
if
(
bytes
<
51200
)
reusekv
=
1
;
else
if
(
bytes
<
256000
)
reusekv
=
4
;
else
reusekv
=
8
;
return
;
}
if
(
batchsize
<
4
||
batchsize
==
4
&&
qheads
==
8
)
reusekv
=
1
;
else
if
(
batchsize
<
32
||
batchsize
<=
64
&&
qheads
==
8
)
reusekv
=
4
;
else
reusekv
=
8
;
return
;
}
if
(
blocks
<
150
)
return
;
if
(
blocks
<
600
||
qheads
<=
kvheads
*
4
){
reusekv
=
4
;
return
;}
reusekv
=
8
;
return
;
}
if
(
device_name
==
"gfx928"
){
if
(
batchsize
*
qheads
>
1024
&&
max_seq_len
>=
2000
){
max_num_partitions
=
1
;
if
(
max_seq_len
<
7800
)
reusekv
=
4
;
else
{
PARTITION_SIZE
=
2048
;
reusekv
=
4
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
}
return
;
}
}
if
(
max_seq_len
<=
1000
||
max_seq_len
<=
1500
&&
(
qheads
>
4
&&
batchsize
>=
16
||
batchsize
>=
64
))
max_num_partitions
=
1
;
int
blocks
=
max_num_partitions
*
batchsize
*
qheads
;
if
(
blocks
>=
150
||
batchsize
>=
16
||
qheads
>=
8
&&
(
batchsize
>=
4
||
max_seq_len
>=
2000
))
reusekv
=
4
;
}
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
void
paged_attention_v2_launcher_opt_tc_with_mask
(
void
paged_attention_v2_launcher_opt_tc_with_mask
(
...
@@ -995,30 +910,6 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
...
@@ -995,30 +910,6 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
break; \
break; \
}
}
void
paged_attention_v2_with_mask
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
int64_t
num_kv_heads
,
// [num_heads]
double
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
// [num_seqs, max_seq_len]
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
);
void
paged_attention_v2_opt_tc_with_mask
(
void
paged_attention_v2_opt_tc_with_mask
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
...
@@ -1043,38 +934,10 @@ void paged_attention_v2_opt_tc_with_mask(
...
@@ -1043,38 +934,10 @@ void paged_attention_v2_opt_tc_with_mask(
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
device_name
!=
"gfx928"
&&
device_name
!=
"gfx936"
)){
paged_attention_v2_with_mask
(
out
,
exp_sums
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
}
else
{
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
CALL_V2_LAUNCHER_BLOCK_SIZE
)
CALL_V2_LAUNCHER_BLOCK_SIZE
)
}
}
}
void
paged_attention_v1_with_mask
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
int64_t
num_kv_heads
,
// [num_heads]
double
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
);
void
paged_attention_v1_opt_tc_with_mask
(
void
paged_attention_v1_opt_tc_with_mask
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
...
@@ -1095,20 +958,10 @@ void paged_attention_v1_opt_tc_with_mask(
...
@@ -1095,20 +958,10 @@ void paged_attention_v1_opt_tc_with_mask(
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
device_name
!=
"gfx928"
&&
device_name
!=
"gfx936"
)){
paged_attention_v1_with_mask
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
}
else
{
paged_attention_v2_opt_tc_with_mask
(
out
,
out
,
out
,
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
paged_attention_v2_opt_tc_with_mask
(
out
,
out
,
out
,
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
}
}
}
#undef WARP_SIZE
#undef WARP_SIZE
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment