Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d365012a
Commit
d365012a
authored
Feb 26, 2025
by
zhangshao
Browse files
解决cudagraph计算错误的问题
parent
4d4c6fe3
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
665 additions
and
750 deletions
+665
-750
csrc/attention/attention_kernels_opt.cu
csrc/attention/attention_kernels_opt.cu
+19
-27
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+21
-18
csrc/attention/attention_with_mask_kernels.cu
csrc/attention/attention_with_mask_kernels.cu
+18
-16
csrc/attention/attention_with_mask_kernels_opt.cu
csrc/attention/attention_with_mask_kernels_opt.cu
+18
-16
csrc/attention/attention_with_mask_kernels_opt_tc.cu
csrc/attention/attention_with_mask_kernels_opt_tc.cu
+549
-633
csrc/ops.h
csrc/ops.h
+10
-10
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+10
-10
vllm/_custom_ops.py
vllm/_custom_ops.py
+20
-20
No files found.
csrc/attention/attention_kernels_opt.cu
View file @
d365012a
...
@@ -92,7 +92,7 @@ __device__ void paged_attention_kernel_opt(
...
@@ -92,7 +92,7 @@ __device__ void paged_attention_kernel_opt(
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
_ptr
,
const
float
*
v_scale
_ptr
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
seq_idx
=
blockIdx
.
z
;
const
int
seq_idx
=
blockIdx
.
z
;
...
@@ -316,7 +316,7 @@ __device__ void paged_attention_kernel_opt(
...
@@ -316,7 +316,7 @@ __device__ void paged_attention_kernel_opt(
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vec_quant
,
k_scale
);
k_vec_quant
,
*
k_scale
_ptr
);
}
}
}
}
}
}
...
@@ -483,7 +483,7 @@ __device__ void paged_attention_kernel_opt(
...
@@ -483,7 +483,7 @@ __device__ void paged_attention_kernel_opt(
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
// Vector conversion from V_quant_vec to V_vec.
// Vector conversion from V_quant_vec to V_vec.
v_vec
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
v_vec
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
v_scale
);
*
v_scale
_ptr
);
}
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
if
(
block_idx
==
num_seq_blocks
-
1
)
{
// NOTE(woosuk): When v_vec contains the tokens that are out of the
// NOTE(woosuk): When v_vec contains the tokens that are out of the
...
@@ -610,7 +610,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
...
@@ -610,7 +610,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
...
@@ -650,7 +650,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt(
...
@@ -650,7 +650,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt(
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
...
@@ -770,7 +770,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
...
@@ -770,7 +770,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
}
}
}
// namespace vllm
}
// namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel_opt<T, CACHE_T, HEAD_SIZE, \
((void*)vllm::paged_attention_v1_kernel_opt<T, CACHE_T, HEAD_SIZE, \
...
@@ -783,20 +783,10 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
...
@@ -783,20 +783,10 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale
_ptr
, v_scale
_ptr
, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
// NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads> \
// <<<dim3(grid), dim3(block)>>>( \
// out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
// scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
// alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
// kv_scale, tp_rank, blocksparse_local_blocks, \
// blocksparse_vert_stride, blocksparse_block_size, \
// blocksparse_head_sliding_step);
// TODO(woosuk): Tune NUM_THREADS.
// TODO(woosuk): Tune NUM_THREADS.
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
...
@@ -805,8 +795,8 @@ void paged_attention_v1_launcher(
...
@@ -805,8 +795,8 @@ void paged_attention_v1_launcher(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
torch
::
Tensor
&
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
...
@@ -829,7 +819,8 @@ void paged_attention_v1_launcher(
...
@@ -829,7 +819,8 @@ void paged_attention_v1_launcher(
alibi_slopes
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
:
nullptr
;
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
...
@@ -910,7 +901,7 @@ void paged_attention_v1_opt(
...
@@ -910,7 +901,7 @@ void paged_attention_v1_opt(
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
int64_t
blocksparse_head_sliding_step
)
{
...
@@ -928,7 +919,7 @@ void paged_attention_v1_opt(
...
@@ -928,7 +919,7 @@ void paged_attention_v1_opt(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
kv_block_stride, kv_head_stride, k_scale
_ptr
, v_scale
_ptr
, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
blocksparse_block_size, blocksparse_head_sliding_step); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
...
@@ -945,8 +936,8 @@ void paged_attention_v2_launcher(
...
@@ -945,8 +936,8 @@ void paged_attention_v2_launcher(
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
torch
::
Tensor
&
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
...
@@ -965,7 +956,8 @@ void paged_attention_v2_launcher(
...
@@ -965,7 +956,8 @@ void paged_attention_v2_launcher(
alibi_slopes
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
:
nullptr
;
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
float
*
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
float
*
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
float
*
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
float
*
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
...
@@ -1058,7 +1050,7 @@ void paged_attention_v2_opt(
...
@@ -1058,7 +1050,7 @@ void paged_attention_v2_opt(
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
int64_t
blocksparse_head_sliding_step
)
{
...
@@ -1070,4 +1062,4 @@ void paged_attention_v2_opt(
...
@@ -1070,4 +1062,4 @@ void paged_attention_v2_opt(
#undef WARP_SIZE
#undef WARP_SIZE
#undef MAX
#undef MAX
#undef MIN
#undef MIN
#undef DIVIDE_ROUND_UP
#undef DIVIDE_ROUND_UP
\ No newline at end of file
csrc/attention/attention_kernels_opt_tc.cu
View file @
d365012a
...
@@ -162,7 +162,7 @@ __device__ void paged_attention_kernel_TC(
...
@@ -162,7 +162,7 @@ __device__ void paged_attention_kernel_TC(
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
seq_idx
=
blockIdx
.
z
;
const
int
seq_idx
=
blockIdx
.
z
;
...
@@ -639,7 +639,7 @@ __global__ void paged_attention_v1_kernel_TC(
...
@@ -639,7 +639,7 @@ __global__ void paged_attention_v1_kernel_TC(
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
#if defined(__gfx936__) || defined(__gfx928__)
#if defined(__gfx936__) || defined(__gfx928__)
...
@@ -678,7 +678,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
...
@@ -678,7 +678,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
#if defined(__gfx936__) || defined(__gfx928__)
#if defined(__gfx936__) || defined(__gfx928__)
...
@@ -814,7 +814,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
...
@@ -814,7 +814,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale
_ptr
, v_scale
_ptr
, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step);
...
@@ -908,8 +908,8 @@ void paged_attention_v1_launcher_opt_tc(
...
@@ -908,8 +908,8 @@ void paged_attention_v1_launcher_opt_tc(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
torch
::
Tensor
&
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
...
@@ -932,7 +932,8 @@ void paged_attention_v1_launcher_opt_tc(
...
@@ -932,7 +932,8 @@ void paged_attention_v1_launcher_opt_tc(
alibi_slopes
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
:
nullptr
;
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
...
@@ -1001,7 +1002,7 @@ void paged_attention_v1_launcher_opt_tc(
...
@@ -1001,7 +1002,7 @@ void paged_attention_v1_launcher_opt_tc(
break; \
break; \
}
}
void
paged_attention_v1
_opt
(
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
torch
::
Tensor
&
...
@@ -1014,7 +1015,7 @@ void paged_attention_v1_opt(
...
@@ -1014,7 +1015,7 @@ void paged_attention_v1_opt(
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
);
...
@@ -1032,14 +1033,14 @@ void paged_attention_v1_opt_tc(
...
@@ -1032,14 +1033,14 @@ void paged_attention_v1_opt_tc(
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
int64_t
blocksparse_head_sliding_step
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
device_name
!=
"gfx928"
&&
device_name
!=
"gfx936"
)){
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
device_name
!=
"gfx928"
&&
device_name
!=
"gfx936"
)){
paged_attention_v1
_opt
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_block_size
,
blocksparse_head_sliding_step
);
...
@@ -1059,7 +1060,7 @@ void paged_attention_v1_opt_tc(
...
@@ -1059,7 +1060,7 @@ void paged_attention_v1_opt_tc(
max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \
max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
kv_head_stride, k_scale
_ptr
, v_scale
_ptr
, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); \
blocksparse_head_sliding_step); \
hipLaunchKernelGGL( \
hipLaunchKernelGGL( \
...
@@ -1133,8 +1134,8 @@ void paged_attention_v2_launcher_opt_tc(
...
@@ -1133,8 +1134,8 @@ void paged_attention_v2_launcher_opt_tc(
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
torch
::
Tensor
&
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
...
@@ -1156,6 +1157,8 @@ void paged_attention_v2_launcher_opt_tc(
...
@@ -1156,6 +1157,8 @@ void paged_attention_v2_launcher_opt_tc(
:
nullptr
;
:
nullptr
;
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
float
*
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
float
*
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
float
*
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
float
*
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
T
*
tmp_out_ptr
=
reinterpret_cast
<
T
*>
(
tmp_out
.
data_ptr
());
T
*
tmp_out_ptr
=
reinterpret_cast
<
T
*>
(
tmp_out
.
data_ptr
());
...
@@ -1231,7 +1234,7 @@ void paged_attention_v2_launcher_opt_tc(
...
@@ -1231,7 +1234,7 @@ void paged_attention_v2_launcher_opt_tc(
break; \
break; \
}
}
void
paged_attention_v2
_opt
(
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
...
@@ -1248,7 +1251,7 @@ void paged_attention_v2_opt(
...
@@ -1248,7 +1251,7 @@ void paged_attention_v2_opt(
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
);
...
@@ -1270,14 +1273,14 @@ void paged_attention_v2_opt_tc(
...
@@ -1270,14 +1273,14 @@ void paged_attention_v2_opt_tc(
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
int64_t
blocksparse_head_sliding_step
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
device_name
!=
"gfx928"
&&
device_name
!=
"gfx936"
)){
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
device_name
!=
"gfx928"
&&
device_name
!=
"gfx936"
)){
paged_attention_v2
_opt
(
out
,
exp_sums
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
paged_attention_v2
(
out
,
exp_sums
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_block_size
,
blocksparse_head_sliding_step
);
...
...
csrc/attention/attention_with_mask_kernels.cu
View file @
d365012a
...
@@ -105,7 +105,7 @@ __device__ void paged_attention_with_mask_kernel(
...
@@ -105,7 +105,7 @@ __device__ void paged_attention_with_mask_kernel(
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
_ptr
,
const
float
*
v_scale
_ptr
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
...
@@ -286,7 +286,7 @@ __device__ void paged_attention_with_mask_kernel(
...
@@ -286,7 +286,7 @@ __device__ void paged_attention_with_mask_kernel(
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vec_quant
,
k_scale
);
k_vec_quant
,
*
k_scale
_ptr
);
}
}
}
}
...
@@ -424,7 +424,7 @@ __device__ void paged_attention_with_mask_kernel(
...
@@ -424,7 +424,7 @@ __device__ void paged_attention_with_mask_kernel(
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
// Vector conversion from V_quant_vec to V_vec.
// Vector conversion from V_quant_vec to V_vec.
v_vec
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
v_vec
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
v_scale
);
*
v_scale
_ptr
);
}
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
if
(
block_idx
==
num_seq_blocks
-
1
)
{
// NOTE(woosuk): When v_vec contains the tokens that are out of the
// NOTE(woosuk): When v_vec contains the tokens that are out of the
...
@@ -522,7 +522,7 @@ __global__ void paged_attention_v1_with_mask_kernel(
...
@@ -522,7 +522,7 @@ __global__ void paged_attention_v1_with_mask_kernel(
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
...
@@ -559,7 +559,7 @@ __global__ void paged_attention_v2_with_mask_kernel(
...
@@ -559,7 +559,7 @@ __global__ void paged_attention_v2_with_mask_kernel(
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
...
@@ -693,7 +693,7 @@ __global__ void paged_attention_v2_reduce_kernel(
...
@@ -693,7 +693,7 @@ __global__ void paged_attention_v2_reduce_kernel(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale
_ptr
, v_scale
_ptr
, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \
blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride);
attn_masks_stride);
...
@@ -706,8 +706,8 @@ void paged_attention_v1_launcher(
...
@@ -706,8 +706,8 @@ void paged_attention_v1_launcher(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
torch
::
Tensor
&
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
...
@@ -728,7 +728,8 @@ void paged_attention_v1_launcher(
...
@@ -728,7 +728,8 @@ void paged_attention_v1_launcher(
alibi_slopes
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
:
nullptr
;
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
...
@@ -842,7 +843,7 @@ void paged_attention_v1_with_mask(
...
@@ -842,7 +843,7 @@ void paged_attention_v1_with_mask(
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
,
...
@@ -862,7 +863,7 @@ void paged_attention_v1_with_mask(
...
@@ -862,7 +863,7 @@ void paged_attention_v1_with_mask(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
kv_block_stride, kv_head_stride, k_scale
_ptr
, v_scale
_ptr
, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \
attn_masks_ptr, attn_masks_stride); \
...
@@ -880,8 +881,8 @@ void paged_attention_v2_launcher(
...
@@ -880,8 +881,8 @@ void paged_attention_v2_launcher(
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
torch
::
Tensor
&
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
...
@@ -902,7 +903,8 @@ void paged_attention_v2_launcher(
...
@@ -902,7 +903,8 @@ void paged_attention_v2_launcher(
alibi_slopes
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
:
nullptr
;
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
float
*
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
float
*
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
float
*
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
float
*
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
...
@@ -1023,7 +1025,7 @@ void paged_attention_v2_with_mask(
...
@@ -1023,7 +1025,7 @@ void paged_attention_v2_with_mask(
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
,
...
@@ -1037,4 +1039,4 @@ void paged_attention_v2_with_mask(
...
@@ -1037,4 +1039,4 @@ void paged_attention_v2_with_mask(
#undef WARP_SIZE
#undef WARP_SIZE
#undef MAX
#undef MAX
#undef MIN
#undef MIN
#undef DIVIDE_ROUND_UP
#undef DIVIDE_ROUND_UP
\ No newline at end of file
csrc/attention/attention_with_mask_kernels_opt.cu
View file @
d365012a
...
@@ -92,7 +92,7 @@ __device__ void paged_attention_with_mask_kernel_opt(
...
@@ -92,7 +92,7 @@ __device__ void paged_attention_with_mask_kernel_opt(
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
_ptr
,
const
float
*
v_scale
_ptr
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
...
@@ -317,7 +317,7 @@ __device__ void paged_attention_with_mask_kernel_opt(
...
@@ -317,7 +317,7 @@ __device__ void paged_attention_with_mask_kernel_opt(
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vec_quant
,
k_scale
);
k_vec_quant
,
*
k_scale
_ptr
);
}
}
}
}
}
}
...
@@ -498,7 +498,7 @@ __device__ void paged_attention_with_mask_kernel_opt(
...
@@ -498,7 +498,7 @@ __device__ void paged_attention_with_mask_kernel_opt(
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
// Vector conversion from V_quant_vec to V_vec.
// Vector conversion from V_quant_vec to V_vec.
v_vec
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
v_vec
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
v_scale
);
*
v_scale
_ptr
);
}
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
if
(
block_idx
==
num_seq_blocks
-
1
)
{
// NOTE(woosuk): When v_vec contains the tokens that are out of the
// NOTE(woosuk): When v_vec contains the tokens that are out of the
...
@@ -625,7 +625,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_with_mask_kernel_opt
...
@@ -625,7 +625,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_with_mask_kernel_opt
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
...
@@ -666,7 +666,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_with_mask_kernel_opt
...
@@ -666,7 +666,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_with_mask_kernel_opt
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
...
@@ -800,7 +800,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
...
@@ -800,7 +800,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale
_ptr
, v_scale
_ptr
, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \
blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride);
attn_masks_stride);
...
@@ -823,8 +823,8 @@ void paged_attention_v1_launcher(
...
@@ -823,8 +823,8 @@ void paged_attention_v1_launcher(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
torch
::
Tensor
&
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
...
@@ -849,7 +849,8 @@ void paged_attention_v1_launcher(
...
@@ -849,7 +849,8 @@ void paged_attention_v1_launcher(
alibi_slopes
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
:
nullptr
;
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
...
@@ -940,7 +941,7 @@ void paged_attention_v1_opt_with_mask(
...
@@ -940,7 +941,7 @@ void paged_attention_v1_opt_with_mask(
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
,
...
@@ -960,7 +961,7 @@ void paged_attention_v1_opt_with_mask(
...
@@ -960,7 +961,7 @@ void paged_attention_v1_opt_with_mask(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
kv_block_stride, kv_head_stride, k_scale
_ptr
, v_scale
_ptr
, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \
attn_masks_ptr, attn_masks_stride); \
...
@@ -978,8 +979,8 @@ void paged_attention_v2_launcher(
...
@@ -978,8 +979,8 @@ void paged_attention_v2_launcher(
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
torch
::
Tensor
&
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
...
@@ -1000,7 +1001,8 @@ void paged_attention_v2_launcher(
...
@@ -1000,7 +1001,8 @@ void paged_attention_v2_launcher(
alibi_slopes
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
:
nullptr
;
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
float
*
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
float
*
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
float
*
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
float
*
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
...
@@ -1100,7 +1102,7 @@ void paged_attention_v2_opt_with_mask(
...
@@ -1100,7 +1102,7 @@ void paged_attention_v2_opt_with_mask(
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
,
...
@@ -1114,4 +1116,4 @@ void paged_attention_v2_opt_with_mask(
...
@@ -1114,4 +1116,4 @@ void paged_attention_v2_opt_with_mask(
#undef WARP_SIZE
#undef WARP_SIZE
#undef MAX
#undef MAX
#undef MIN
#undef MIN
#undef DIVIDE_ROUND_UP
#undef DIVIDE_ROUND_UP
\ No newline at end of file
csrc/attention/attention_with_mask_kernels_opt_tc.cu
View file @
d365012a
This diff is collapsed.
Click to expand it.
csrc/ops.h
View file @
d365012a
...
@@ -57,7 +57,7 @@ void paged_attention_v1_opt(
...
@@ -57,7 +57,7 @@ void paged_attention_v1_opt(
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
);
...
@@ -68,7 +68,7 @@ void paged_attention_v2_opt(
...
@@ -68,7 +68,7 @@ void paged_attention_v2_opt(
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
);
...
@@ -78,7 +78,7 @@ void paged_attention_v1_opt_tc(
...
@@ -78,7 +78,7 @@ void paged_attention_v1_opt_tc(
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
);
...
@@ -89,7 +89,7 @@ void paged_attention_v2_opt_tc(
...
@@ -89,7 +89,7 @@ void paged_attention_v2_opt_tc(
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
);
...
@@ -101,7 +101,7 @@ void paged_attention_v1_with_mask(
...
@@ -101,7 +101,7 @@ void paged_attention_v1_with_mask(
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
,
...
@@ -114,7 +114,7 @@ void paged_attention_v2_with_mask(
...
@@ -114,7 +114,7 @@ void paged_attention_v2_with_mask(
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
,
...
@@ -126,7 +126,7 @@ void paged_attention_v1_opt_with_mask(
...
@@ -126,7 +126,7 @@ void paged_attention_v1_opt_with_mask(
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
,
...
@@ -139,7 +139,7 @@ void paged_attention_v2_opt_with_mask(
...
@@ -139,7 +139,7 @@ void paged_attention_v2_opt_with_mask(
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
,
...
@@ -151,7 +151,7 @@ void paged_attention_v1_opt_tc_with_mask(
...
@@ -151,7 +151,7 @@ void paged_attention_v1_opt_tc_with_mask(
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
,
...
@@ -164,7 +164,7 @@ void paged_attention_v2_opt_tc_with_mask(
...
@@ -164,7 +164,7 @@ void paged_attention_v2_opt_tc_with_mask(
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
,
...
...
csrc/torch_bindings.cpp
View file @
d365012a
...
@@ -58,7 +58,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -58,7 +58,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype,
float
k_scale,
float
v_scale,"
" str kv_cache_dtype,
Tensor
k_scale,
Tensor
v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step) -> ()"
);
...
@@ -72,7 +72,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -72,7 +72,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype,
float
k_scale,
float
v_scale,"
" str kv_cache_dtype,
Tensor
k_scale,
Tensor
v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step) -> ()"
);
...
@@ -86,7 +86,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -86,7 +86,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype,
float
k_scale,
float
v_scale,"
" str kv_cache_dtype,
Tensor
k_scale,
Tensor
v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step) -> ()"
);
...
@@ -100,7 +100,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -100,7 +100,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype,
float
k_scale,
float
v_scale,"
" str kv_cache_dtype,
Tensor
k_scale,
Tensor
v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step) -> ()"
);
...
@@ -114,7 +114,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -114,7 +114,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype,
float
k_scale,
float
v_scale,"
" str kv_cache_dtype,
Tensor
k_scale,
Tensor
v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
" int blocksparse_head_sliding_step,"
...
@@ -130,7 +130,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -130,7 +130,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype,
float
k_scale,
float
v_scale,"
" str kv_cache_dtype,
Tensor
k_scale,
Tensor
v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
" int blocksparse_head_sliding_step,"
...
@@ -146,7 +146,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -146,7 +146,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype,
float
k_scale,
float
v_scale,"
" str kv_cache_dtype,
Tensor
k_scale,
Tensor
v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
" int blocksparse_head_sliding_step,"
...
@@ -162,7 +162,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -162,7 +162,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype,
float
k_scale,
float
v_scale,"
" str kv_cache_dtype,
Tensor
k_scale,
Tensor
v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
" int blocksparse_head_sliding_step,"
...
@@ -178,7 +178,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -178,7 +178,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype,
float
k_scale,
float
v_scale,"
" str kv_cache_dtype,
Tensor
k_scale,
Tensor
v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
" int blocksparse_head_sliding_step,"
...
@@ -194,7 +194,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -194,7 +194,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype,
float
k_scale,
float
v_scale,"
" str kv_cache_dtype,
Tensor
k_scale,
Tensor
v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
" int blocksparse_head_sliding_step,"
...
...
vllm/_custom_ops.py
View file @
d365012a
...
@@ -117,8 +117,8 @@ def paged_attention_v1_with_mask(
...
@@ -117,8 +117,8 @@ def paged_attention_v1_with_mask(
max_seq_len
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
k_scale
:
torch
.
Tensor
,
v_scale
:
float
,
v_scale
:
torch
.
Tensor
,
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
@@ -152,8 +152,8 @@ def paged_attention_v2_with_mask(
...
@@ -152,8 +152,8 @@ def paged_attention_v2_with_mask(
max_seq_len
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
k_scale
:
torch
.
Tensor
,
v_scale
:
float
,
v_scale
:
torch
.
Tensor
,
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
@@ -185,8 +185,8 @@ def paged_attention_v1_opt(
...
@@ -185,8 +185,8 @@ def paged_attention_v1_opt(
max_seq_len
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
k_scale
:
torch
.
Tensor
,
v_scale
:
float
,
v_scale
:
torch
.
Tensor
,
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
@@ -217,8 +217,8 @@ def paged_attention_v2_opt(
...
@@ -217,8 +217,8 @@ def paged_attention_v2_opt(
max_seq_len
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
k_scale
:
torch
.
Tensor
,
v_scale
:
float
,
v_scale
:
torch
.
Tensor
,
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
@@ -246,8 +246,8 @@ def paged_attention_v1_opt_with_mask(
...
@@ -246,8 +246,8 @@ def paged_attention_v1_opt_with_mask(
max_seq_len
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
k_scale
:
torch
.
Tensor
,
v_scale
:
float
,
v_scale
:
torch
.
Tensor
,
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
@@ -281,8 +281,8 @@ def paged_attention_v2_opt_with_mask(
...
@@ -281,8 +281,8 @@ def paged_attention_v2_opt_with_mask(
max_seq_len
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
k_scale
:
torch
.
Tensor
,
v_scale
:
float
,
v_scale
:
torch
.
Tensor
,
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
@@ -314,8 +314,8 @@ def paged_attention_v1_opt_tc(
...
@@ -314,8 +314,8 @@ def paged_attention_v1_opt_tc(
max_seq_len
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
k_scale
:
torch
.
Tensor
,
v_scale
:
float
,
v_scale
:
torch
.
Tensor
,
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
@@ -345,8 +345,8 @@ def paged_attention_v2_opt_tc(
...
@@ -345,8 +345,8 @@ def paged_attention_v2_opt_tc(
max_seq_len
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
k_scale
:
torch
.
Tensor
,
v_scale
:
float
,
v_scale
:
torch
.
Tensor
,
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
@@ -375,8 +375,8 @@ def paged_attention_v1_opt_tc_with_mask(
...
@@ -375,8 +375,8 @@ def paged_attention_v1_opt_tc_with_mask(
max_seq_len
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
k_scale
:
torch
.
Tensor
,
v_scale
:
float
,
v_scale
:
torch
.
Tensor
,
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
@@ -409,8 +409,8 @@ def paged_attention_v2_opt_tc_with_mask(
...
@@ -409,8 +409,8 @@ def paged_attention_v2_opt_tc_with_mask(
max_seq_len
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
k_scale
:
torch
.
Tensor
,
v_scale
:
float
,
v_scale
:
torch
.
Tensor
,
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment