Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
44e3ca68
Commit
44e3ca68
authored
Dec 11, 2024
by
王敏
Browse files
[feat]优化medusa代码,通过VLLM_TREE_DECODING环境变量控制是否采用tree-style解码,计算逻辑主干隔离
parent
54b92ba4
Changes
38
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3801 additions
and
312 deletions
+3801
-312
CMakeLists.txt
CMakeLists.txt
+4
-1
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+13
-45
csrc/attention/attention_kernels_opt.cu
csrc/attention/attention_kernels_opt.cu
+13
-51
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+17
-55
csrc/attention/attention_with_mask_kernels.cu
csrc/attention/attention_with_mask_kernels.cu
+1040
-0
csrc/attention/attention_with_mask_kernels_opt.cu
csrc/attention/attention_with_mask_kernels_opt.cu
+1117
-0
csrc/attention/attention_with_mask_kernels_opt_tc.cu
csrc/attention/attention_with_mask_kernels_opt_tc.cu
+1201
-0
csrc/ops.h
csrc/ops.h
+71
-5
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+96
-11
examples/medusa/README.md
examples/medusa/README.md
+33
-11
examples/medusa/medusa_benchmark_throughput.py
examples/medusa/medusa_benchmark_throughput.py
+2
-11
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+2
-2
tests/worker/test_model_input.py
tests/worker/test_model_input.py
+0
-10
vllm/_custom_ops.py
vllm/_custom_ops.py
+191
-9
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+1
-13
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+0
-48
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+0
-10
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+0
-10
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+0
-10
vllm/attention/backends/openvino.py
vllm/attention/backends/openvino.py
+0
-10
No files found.
CMakeLists.txt
View file @
44e3ca68
...
@@ -198,7 +198,10 @@ set(VLLM_EXT_SRC
...
@@ -198,7 +198,10 @@ set(VLLM_EXT_SRC
"csrc/cuda_utils_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp"
)
"csrc/torch_bindings.cpp"
"csrc/attention/attention_with_mask_kernels.cu"
"csrc/attention/attention_with_mask_kernels_opt.cu"
"csrc/attention/attention_with_mask_kernels_opt_tc.cu"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
SET
(
CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL
"Enable only the header library"
)
SET
(
CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL
"Enable only the header library"
)
...
...
csrc/attention/attention_kernels.cu
View file @
44e3ca68
...
@@ -107,8 +107,7 @@ __device__ void paged_attention_kernel(
...
@@ -107,8 +107,7 @@ __device__ void paged_attention_kernel(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
const
int
seq_idx
=
blockIdx
.
y
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
partition_idx
=
blockIdx
.
z
;
const
int
partition_idx
=
blockIdx
.
z
;
const
int
max_num_partitions
=
gridDim
.
z
;
const
int
max_num_partitions
=
gridDim
.
z
;
...
@@ -297,14 +296,6 @@ __device__ void paged_attention_kernel(
...
@@ -297,14 +296,6 @@ __device__ void paged_attention_kernel(
// Add the ALiBi bias if slopes are given.
// Add the ALiBi bias if slopes are given.
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq_len
+
1
)
:
0
;
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq_len
+
1
)
:
0
;
// used for tree-style attention
if
(
attn_masks
!=
nullptr
)
{
const
int
*
attn_masks_ptr
=
attn_masks
+
seq_idx
*
attn_masks_stride
;
if
(
attn_masks_ptr
[
token_idx
]
==
0
)
{
qk
=
-
FLT_MAX
;
}
}
if
(
thread_group_offset
==
0
)
{
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
// NOTE(woosuk): It is required to zero out the masked logits.
...
@@ -524,8 +515,7 @@ __global__ void paged_attention_v1_kernel(
...
@@ -524,8 +515,7 @@ __global__ void paged_attention_v1_kernel(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
>
(
KV_DTYPE
,
IS_BLOCK_SPARSE
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
...
@@ -533,7 +523,7 @@ __global__ void paged_attention_v1_kernel(
...
@@ -533,7 +523,7 @@ __global__ void paged_attention_v1_kernel(
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
blocksparse_head_sliding_step
);
}
}
// Grid: (num_heads, num_seqs, max_num_partitions).
// Grid: (num_heads, num_seqs, max_num_partitions).
...
@@ -561,15 +551,14 @@ __global__ void paged_attention_v2_kernel(
...
@@ -561,15 +551,14 @@ __global__ void paged_attention_v2_kernel(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
PARTITION_SIZE
>
(
KV_DTYPE
,
IS_BLOCK_SPARSE
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
blocksparse_head_sliding_step
);
}
}
// Grid: (num_heads, num_seqs).
// Grid: (num_heads, num_seqs).
...
@@ -695,8 +684,7 @@ __global__ void paged_attention_v2_reduce_kernel(
...
@@ -695,8 +684,7 @@ __global__ void paged_attention_v2_reduce_kernel(
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \
blocksparse_head_sliding_step);
attn_masks_stride);
// TODO(woosuk): Tune NUM_THREADS.
// TODO(woosuk): Tune NUM_THREADS.
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
...
@@ -709,9 +697,7 @@ void paged_attention_v1_launcher(
...
@@ -709,9 +697,7 @@ void paged_attention_v1_launcher(
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_head_sliding_step
)
{
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
head_size
=
query
.
size
(
2
);
...
@@ -736,12 +722,6 @@ void paged_attention_v1_launcher(
...
@@ -736,12 +722,6 @@ void paged_attention_v1_launcher(
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
// NOTE: attn_masks is optional.
const
int
*
attn_masks_ptr
=
attn_masks
?
attn_masks
.
value
().
data_ptr
<
int
>
()
:
nullptr
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
padded_max_seq_len
=
int
padded_max_seq_len
=
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
...
@@ -798,8 +778,7 @@ void paged_attention_v1_launcher(
...
@@ -798,8 +778,7 @@ void paged_attention_v1_launcher(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \
blocksparse_block_size, blocksparse_head_sliding_step);
attn_masks, attn_masks_stride);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
switch (is_block_sparse) { \
...
@@ -845,9 +824,7 @@ void paged_attention_v1(
...
@@ -845,9 +824,7 @@ void paged_attention_v1(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
...
@@ -864,8 +841,7 @@ void paged_attention_v1(
...
@@ -864,8 +841,7 @@ void paged_attention_v1(
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \
blocksparse_block_size, blocksparse_head_sliding_step); \
attn_masks_ptr, attn_masks_stride); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE> \
PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
...
@@ -883,9 +859,7 @@ void paged_attention_v2_launcher(
...
@@ -883,9 +859,7 @@ void paged_attention_v2_launcher(
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_head_sliding_step
)
{
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int
attn_masks_stride
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
head_size
=
query
.
size
(
2
);
...
@@ -913,10 +887,6 @@ void paged_attention_v2_launcher(
...
@@ -913,10 +887,6 @@ void paged_attention_v2_launcher(
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
// NOTE: attn_masks is optional.
const
int
*
attn_masks_ptr
=
attn_masks
?
attn_masks
.
value
().
data_ptr
<
int
>
()
:
nullptr
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
int
logits_size
=
PARTITION_SIZE
*
sizeof
(
float
);
int
logits_size
=
PARTITION_SIZE
*
sizeof
(
float
);
...
@@ -976,7 +946,7 @@ void paged_attention_v2_launcher(
...
@@ -976,7 +946,7 @@ void paged_attention_v2_launcher(
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step
, attn_masks, attn_masks_stride
);
blocksparse_head_sliding_step);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
switch (is_block_sparse) { \
...
@@ -1026,9 +996,7 @@ void paged_attention_v2(
...
@@ -1026,9 +996,7 @@ void paged_attention_v2(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
CALL_V2_LAUNCHER_BLOCK_SIZE
)
CALL_V2_LAUNCHER_BLOCK_SIZE
)
...
...
csrc/attention/attention_kernels_opt.cu
View file @
44e3ca68
...
@@ -94,8 +94,7 @@ __device__ void paged_attention_kernel_opt(
...
@@ -94,8 +94,7 @@ __device__ void paged_attention_kernel_opt(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
const
int
seq_idx
=
blockIdx
.
z
;
const
int
seq_idx
=
blockIdx
.
z
;
const
int
partition_idx
=
blockIdx
.
y
;
const
int
partition_idx
=
blockIdx
.
y
;
const
int
max_num_partitions
=
gridDim
.
y
;
const
int
max_num_partitions
=
gridDim
.
y
;
...
@@ -328,25 +327,11 @@ __device__ void paged_attention_kernel_opt(
...
@@ -328,25 +327,11 @@ __device__ void paged_attention_kernel_opt(
// Add the ALiBi bias if slopes are given.
// Add the ALiBi bias if slopes are given.
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq_len
+
1
)
:
0
;
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq_len
+
1
)
:
0
;
// used for tree-style attention
if
(
attn_masks
!=
nullptr
)
{
const
int
*
attn_masks_ptr
=
attn_masks
+
seq_idx
*
attn_masks_stride
;
if
(
attn_masks_ptr
[
token_idx
]
==
0
)
{
qk
=
-
FLT_MAX
;
}
}
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
if
(
thread_group_offset
==
0
)
{
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
// NOTE(woosuk): It is required to zero out the masked logits.
const
bool
mask
=
token_idx
>=
seq_len
;
const
bool
mask
=
token_idx
>=
seq_len
;
// used for tree-style attention
/*if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
mask |= attn_masks_ptr[token_idx] == 0;
}*/
logits
[(
reuse_kv_idx
*
partition_size
)
+
(
token_idx
-
start_token_idx
)]
=
mask
?
0.
f
:
qk
;
logits
[(
reuse_kv_idx
*
partition_size
)
+
(
token_idx
-
start_token_idx
)]
=
mask
?
0.
f
:
qk
;
// Update the max value.
// Update the max value.
qk_max
[
reuse_kv_idx
]
=
mask
?
qk_max
[
reuse_kv_idx
]
:
fmaxf
(
qk_max
[
reuse_kv_idx
],
qk
);
qk_max
[
reuse_kv_idx
]
=
mask
?
qk_max
[
reuse_kv_idx
]
:
fmaxf
(
qk_max
[
reuse_kv_idx
],
qk
);
...
@@ -627,8 +612,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
...
@@ -627,8 +612,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
paged_attention_kernel_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
>
(
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
...
@@ -636,7 +620,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
...
@@ -636,7 +620,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
blocksparse_head_sliding_step
);
}
}
// Grid: (num_heads, num_seqs, max_num_partitions).
// Grid: (num_heads, num_seqs, max_num_partitions).
...
@@ -668,15 +652,14 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt(
...
@@ -668,15 +652,14 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt(
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
paged_attention_kernel_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
,
PARTITION_SIZE
>
(
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
blocksparse_head_sliding_step
);
}
}
// Grid: (num_heads, num_seqs).
// Grid: (num_heads, num_seqs).
...
@@ -802,8 +785,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
...
@@ -802,8 +785,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \
blocksparse_head_sliding_step);
attn_masks_stride);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
...
@@ -826,9 +808,7 @@ void paged_attention_v1_launcher(
...
@@ -826,9 +808,7 @@ void paged_attention_v1_launcher(
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_head_sliding_step
)
{
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
head_size
=
query
.
size
(
2
);
...
@@ -857,12 +837,6 @@ void paged_attention_v1_launcher(
...
@@ -857,12 +837,6 @@ void paged_attention_v1_launcher(
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
// NOTE: attn_masks is optional.
const
int
*
attn_masks_ptr
=
attn_masks
?
attn_masks
.
value
().
data_ptr
<
int
>
()
:
nullptr
;
int
padded_max_seq_len
=
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
padded_max_seq_len
=
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
REUSEKV_SWITCH_V1
(
num_heads
*
num_seqs
,
[
&
]
{
REUSEKV_SWITCH_V1
(
num_heads
*
num_seqs
,
[
&
]
{
BOOL_SWITCH
((
num_heads
/
num_kv_heads
%
REUSE_KV_TIMES
!=
0
),
odd_nheads
,
[
&
]
{
BOOL_SWITCH
((
num_heads
/
num_kv_heads
%
REUSE_KV_TIMES
!=
0
),
odd_nheads
,
[
&
]
{
...
@@ -896,8 +870,7 @@ void paged_attention_v1_launcher(
...
@@ -896,8 +870,7 @@ void paged_attention_v1_launcher(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \
blocksparse_block_size, blocksparse_head_sliding_step);
attn_masks, attn_masks_stride);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
switch (is_block_sparse) { \
...
@@ -943,9 +916,7 @@ void paged_attention_v1_opt(
...
@@ -943,9 +916,7 @@ void paged_attention_v1_opt(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
...
@@ -962,8 +933,7 @@ void paged_attention_v1_opt(
...
@@ -962,8 +933,7 @@ void paged_attention_v1_opt(
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \
blocksparse_block_size, blocksparse_head_sliding_step); \
attn_masks_ptr, attn_masks_stride); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>) \
PARTITION_SIZE>) \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
...
@@ -981,9 +951,7 @@ void paged_attention_v2_launcher(
...
@@ -981,9 +951,7 @@ void paged_attention_v2_launcher(
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_head_sliding_step
)
{
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int
attn_masks_stride
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
head_size
=
query
.
size
(
2
);
...
@@ -1011,10 +979,6 @@ void paged_attention_v2_launcher(
...
@@ -1011,10 +979,6 @@ void paged_attention_v2_launcher(
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
// NOTE: attn_masks is optional.
const
int
*
attn_masks_ptr
=
attn_masks
?
attn_masks
.
value
().
data_ptr
<
int
>
()
:
nullptr
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
REUSEKV_SWITCH
(
num_heads
*
max_num_partitions
*
num_seqs
,
[
&
]
{
REUSEKV_SWITCH
(
num_heads
*
max_num_partitions
*
num_seqs
,
[
&
]
{
...
@@ -1053,7 +1017,7 @@ void paged_attention_v2_launcher(
...
@@ -1053,7 +1017,7 @@ void paged_attention_v2_launcher(
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step
, attn_masks, attn_masks_stride
);
blocksparse_head_sliding_step);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
switch (is_block_sparse) { \
...
@@ -1103,9 +1067,7 @@ void paged_attention_v2_opt(
...
@@ -1103,9 +1067,7 @@ void paged_attention_v2_opt(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
CALL_V2_LAUNCHER_BLOCK_SIZE
)
CALL_V2_LAUNCHER_BLOCK_SIZE
)
...
...
csrc/attention/attention_kernels_opt_tc.cu
View file @
44e3ca68
...
@@ -168,8 +168,7 @@ __device__ void paged_attention_kernel_TC(
...
@@ -168,8 +168,7 @@ __device__ void paged_attention_kernel_TC(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
const
int
seq_idx
=
blockIdx
.
z
;
const
int
seq_idx
=
blockIdx
.
z
;
const
int
partition_idx
=
blockIdx
.
y
;
const
int
partition_idx
=
blockIdx
.
y
;
const
int
max_num_partitions
=
gridDim
.
y
;
const
int
max_num_partitions
=
gridDim
.
y
;
...
@@ -293,14 +292,6 @@ __device__ void paged_attention_kernel_TC(
...
@@ -293,14 +292,6 @@ __device__ void paged_attention_kernel_TC(
qk_vec
[
i
]
=
alibi
;
qk_vec
[
i
]
=
alibi
;
}
}
// used for tree-style attention
if
(
attn_masks
!=
nullptr
)
{
const
int
*
attn_masks_ptr
=
attn_masks
+
seq_idx
*
attn_masks_stride
;
if
(
attn_masks_ptr
[
token_idx
]
==
0
)
{
qk_vec
[
i
]
=
-
FLT_MAX
;
}
}
const
bool
mask
=
(
token_idx
>=
seq_len
);
const
bool
mask
=
(
token_idx
>=
seq_len
);
if
(
mask
){
if
(
mask
){
from_float
(
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
,
0.
f
);
from_float
(
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
,
0.
f
);
...
@@ -565,8 +556,7 @@ __global__ void paged_attention_v1_kernel_TC(
...
@@ -565,8 +556,7 @@ __global__ void paged_attention_v1_kernel_TC(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
#if defined(__gfx936__) || defined(__gfx928__)
#if defined(__gfx936__) || defined(__gfx928__)
paged_attention_kernel_TC
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel_TC
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
use_vmac
>
(
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
use_vmac
>
(
...
@@ -575,7 +565,7 @@ __global__ void paged_attention_v1_kernel_TC(
...
@@ -575,7 +565,7 @@ __global__ void paged_attention_v1_kernel_TC(
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
blocksparse_head_sliding_step
);
#endif
#endif
}
}
...
@@ -605,8 +595,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
...
@@ -605,8 +595,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
#if defined(__gfx936__) || defined(__gfx928__)
#if defined(__gfx936__) || defined(__gfx928__)
paged_attention_kernel_TC
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel_TC
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
use_vmac
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
use_vmac
,
...
@@ -615,7 +604,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
...
@@ -615,7 +604,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
blocksparse_head_sliding_step
);
#endif
#endif
}
}
...
@@ -742,8 +731,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
...
@@ -742,8 +731,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \
blocksparse_head_sliding_step);
attn_masks_stride);
void
get_numberthread_and_reuse_kv_v1
(
int
&
num_thread
,
int
&
reusekv
,
int
batchsize
,
int
seq
,
int
qheads
,
int
kvheads
){
void
get_numberthread_and_reuse_kv_v1
(
int
&
num_thread
,
int
&
reusekv
,
int
batchsize
,
int
seq
,
int
qheads
,
int
kvheads
){
//mha
//mha
...
@@ -809,9 +797,7 @@ void paged_attention_v1_launcher_opt_tc(
...
@@ -809,9 +797,7 @@ void paged_attention_v1_launcher_opt_tc(
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_head_sliding_step
)
{
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
head_size
=
query
.
size
(
2
);
...
@@ -840,12 +826,6 @@ void paged_attention_v1_launcher_opt_tc(
...
@@ -840,12 +826,6 @@ void paged_attention_v1_launcher_opt_tc(
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
// NOTE: attn_masks is optional.
const
int
*
attn_masks_ptr
=
attn_masks
?
attn_masks
.
value
().
data_ptr
<
int
>
()
:
nullptr
;
int
padded_max_seq_len
=
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
padded_max_seq_len
=
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
@@ -880,8 +860,7 @@ void paged_attention_v1_launcher_opt_tc(
...
@@ -880,8 +860,7 @@ void paged_attention_v1_launcher_opt_tc(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \
blocksparse_block_size, blocksparse_head_sliding_step);
attn_masks, attn_masks_stride);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
switch (is_block_sparse) { \
...
@@ -927,9 +906,7 @@ void paged_attention_v1_opt(
...
@@ -927,9 +906,7 @@ void paged_attention_v1_opt(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
);
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
);
void
paged_attention_v1_opt_tc
(
void
paged_attention_v1_opt_tc
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
...
@@ -947,17 +924,14 @@ void paged_attention_v1_opt_tc(
...
@@ -947,17 +924,14 @@ void paged_attention_v1_opt_tc(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
get_device_name
()
!=
"gfx928"
&&
get_device_name
()
!=
"gfx936"
)){
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
get_device_name
()
!=
"gfx928"
&&
get_device_name
()
!=
"gfx936"
)){
paged_attention_v1_opt
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
paged_attention_v1_opt
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
attn_masks
,
attn_masks_stride
);
}
}
else
{
else
{
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
...
@@ -976,8 +950,7 @@ void paged_attention_v1_opt_tc(
...
@@ -976,8 +950,7 @@ void paged_attention_v1_opt_tc(
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, \
blocksparse_head_sliding_step); \
attn_masks_ptr, attn_masks_stride); \
hipLaunchKernelGGL( \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS, \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \
PARTITION_SIZE>), \
...
@@ -1028,9 +1001,7 @@ void paged_attention_v2_launcher_opt_tc(
...
@@ -1028,9 +1001,7 @@ void paged_attention_v2_launcher_opt_tc(
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_head_sliding_step
)
{
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int
attn_masks_stride
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
head_size
=
query
.
size
(
2
);
...
@@ -1058,10 +1029,6 @@ void paged_attention_v2_launcher_opt_tc(
...
@@ -1058,10 +1029,6 @@ void paged_attention_v2_launcher_opt_tc(
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
// NOTE: attn_masks is optional.
const
int
*
attn_masks_ptr
=
attn_masks
?
attn_masks
.
value
().
data_ptr
<
int
>
()
:
nullptr
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
reduce_grid
(
num_heads
,
num_seqs
);
dim3
reduce_grid
(
num_heads
,
num_seqs
);
...
@@ -1103,7 +1070,7 @@ void paged_attention_v2_launcher_opt_tc(
...
@@ -1103,7 +1070,7 @@ void paged_attention_v2_launcher_opt_tc(
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step
, attn_masks, attn_masks_stride
);
blocksparse_head_sliding_step);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
switch (is_block_sparse) { \
...
@@ -1153,9 +1120,7 @@ void paged_attention_v2_opt(
...
@@ -1153,9 +1120,7 @@ void paged_attention_v2_opt(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
);
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
);
void
paged_attention_v2_opt_tc
(
void
paged_attention_v2_opt_tc
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
...
@@ -1177,17 +1142,14 @@ void paged_attention_v2_opt_tc(
...
@@ -1177,17 +1142,14 @@ void paged_attention_v2_opt_tc(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
get_device_name
()
!=
"gfx928"
&&
get_device_name
()
!=
"gfx936"
)){
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
get_device_name
()
!=
"gfx928"
&&
get_device_name
()
!=
"gfx936"
)){
paged_attention_v2_opt
(
out
,
exp_sums
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
paged_attention_v2_opt
(
out
,
exp_sums
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
attn_masks_stride
);
}
}
else
{
else
{
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
...
...
csrc/attention/attention_with_mask_kernels.cu
0 → 100644
View file @
44e3ca68
This diff is collapsed.
Click to expand it.
csrc/attention/attention_with_mask_kernels_opt.cu
0 → 100644
View file @
44e3ca68
This diff is collapsed.
Click to expand it.
csrc/attention/attention_with_mask_kernels_opt_tc.cu
0 → 100644
View file @
44e3ca68
This diff is collapsed.
Click to expand it.
csrc/ops.h
View file @
44e3ca68
...
@@ -6,6 +6,71 @@
...
@@ -6,6 +6,71 @@
#include "core/scalar_type.hpp"
#include "core/scalar_type.hpp"
void
paged_attention_v1
(
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v1_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v2_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v1_opt_tc
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v2_opt_tc
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
// paged_attention with attn_masks
void
paged_attention_v1_with_mask
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
...
@@ -17,7 +82,7 @@ void paged_attention_v1(
...
@@ -17,7 +82,7 @@ void paged_attention_v1(
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
);
const
int64_t
attn_masks_stride
=
0
);
void
paged_attention_v2
(
void
paged_attention_v2
_with_mask
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
...
@@ -30,7 +95,7 @@ void paged_attention_v2(
...
@@ -30,7 +95,7 @@ void paged_attention_v2(
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
);
const
int64_t
attn_masks_stride
=
0
);
void
paged_attention_v1_opt
(
void
paged_attention_v1_opt
_with_mask
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
...
@@ -42,7 +107,7 @@ void paged_attention_v1_opt(
...
@@ -42,7 +107,7 @@ void paged_attention_v1_opt(
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
);
const
int64_t
attn_masks_stride
=
0
);
void
paged_attention_v2_opt
(
void
paged_attention_v2_opt
_with_mask
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
...
@@ -55,7 +120,7 @@ void paged_attention_v2_opt(
...
@@ -55,7 +120,7 @@ void paged_attention_v2_opt(
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
);
const
int64_t
attn_masks_stride
=
0
);
void
paged_attention_v1_opt_tc
(
void
paged_attention_v1_opt_tc
_with_mask
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
...
@@ -67,7 +132,7 @@ void paged_attention_v1_opt_tc(
...
@@ -67,7 +132,7 @@ void paged_attention_v1_opt_tc(
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
);
const
int64_t
attn_masks_stride
=
0
);
void
paged_attention_v2_opt_tc
(
void
paged_attention_v2_opt_tc
_with_mask
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
...
@@ -80,6 +145,7 @@ void paged_attention_v2_opt_tc(
...
@@ -80,6 +145,7 @@ void paged_attention_v2_opt_tc(
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
);
const
int64_t
attn_masks_stride
=
0
);
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
double
epsilon
);
double
epsilon
);
...
...
csrc/torch_bindings.cpp
View file @
44e3ca68
...
@@ -30,14 +30,98 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -30,14 +30,98 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" str kv_cache_dtype, float k_scale, float v_scale,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v1"
,
torch
::
kCUDA
,
&
paged_attention_v1
);
// PagedAttention V2.
ops
.
def
(
"paged_attention_v2("
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops
.
def
(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v1_opt"
,
torch
::
kCUDA
,
&
paged_attention_v1_opt
);
// PagedAttention V2 (opt).
ops
.
def
(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2_opt"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt
);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops
.
def
(
"paged_attention_v1_opt_tc("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v1_opt_tc"
,
torch
::
kCUDA
,
&
paged_attention_v1_opt_tc
);
// PagedAttention V2 (opt).
ops
.
def
(
"paged_attention_v2_opt_tc("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2_opt_tc"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt_tc
);
// paged_attention with atth_masks
ops
.
def
(
"paged_attention_v1_with_mask("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()"
);
" int attn_masks_stride) -> ()"
);
ops
.
impl
(
"paged_attention_v1"
,
torch
::
kCUDA
,
&
paged_attention_v1
);
ops
.
impl
(
"paged_attention_v1
_with_mask
"
,
torch
::
kCUDA
,
&
paged_attention_v1
_with_mask
);
// PagedAttention V2.
// PagedAttention V2.
ops
.
def
(
ops
.
def
(
"paged_attention_v2("
"paged_attention_v2
_with_mask
("
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
...
@@ -49,12 +133,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -49,12 +133,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step,"
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()"
);
" int attn_masks_stride) -> ()"
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
ops
.
impl
(
"paged_attention_v2
_with_mask
"
,
torch
::
kCUDA
,
&
paged_attention_v2
_with_mask
);
// Compute the attention between an input query and the cached
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
// keys/values using PagedAttention. (opt)
ops
.
def
(
ops
.
def
(
"paged_attention_v1_opt("
"paged_attention_v1_opt
_with_mask
("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
...
@@ -65,11 +149,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -65,11 +149,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step,"
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()"
);
" int attn_masks_stride) -> ()"
);
ops
.
impl
(
"paged_attention_v1_opt"
,
torch
::
kCUDA
,
&
paged_attention_v1_opt
);
ops
.
impl
(
"paged_attention_v1_opt
_with_mask
"
,
torch
::
kCUDA
,
&
paged_attention_v1_opt
_with_mask
);
// PagedAttention V2 (opt).
// PagedAttention V2 (opt).
ops
.
def
(
ops
.
def
(
"paged_attention_v2_opt("
"paged_attention_v2_opt
_with_mask
("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
...
@@ -81,12 +165,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -81,12 +165,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step,"
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()"
);
" int attn_masks_stride) -> ()"
);
ops
.
impl
(
"paged_attention_v2_opt"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt
);
ops
.
impl
(
"paged_attention_v2_opt
_with_mask
"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt
_with_mask
);
// Compute the attention between an input query and the cached
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
// keys/values using PagedAttention. (opt)
ops
.
def
(
ops
.
def
(
"paged_attention_v1_opt_tc("
"paged_attention_v1_opt_tc
_with_mask
("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
...
@@ -97,11 +181,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -97,11 +181,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step,"
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()"
);
" int attn_masks_stride) -> ()"
);
ops
.
impl
(
"paged_attention_v1_opt_tc"
,
torch
::
kCUDA
,
&
paged_attention_v1_opt_tc
);
ops
.
impl
(
"paged_attention_v1_opt_tc
_with_mask
"
,
torch
::
kCUDA
,
&
paged_attention_v1_opt_tc
_with_mask
);
// PagedAttention V2 (opt).
// PagedAttention V2 (opt).
ops
.
def
(
ops
.
def
(
"paged_attention_v2_opt_tc("
"paged_attention_v2_opt_tc
_with_mask
("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
...
@@ -113,7 +197,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -113,7 +197,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step,"
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()"
);
" int attn_masks_stride) -> ()"
);
ops
.
impl
(
"paged_attention_v2_opt_tc"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt_tc
);
ops
.
impl
(
"paged_attention_v2_opt_tc_with_mask"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt_tc_with_mask
);
// Activation ops
// Activation ops
// Activation function used in SwiGLU.
// Activation function used in SwiGLU.
...
...
examples/medusa/README.md
View file @
44e3ca68
# Medusa Decoding
# Medusa Decoding
本文说明如何使用vllm构建和运行medusa模型
本文说明如何使用vllm构建和运行medusa模型,目前medusa支持tree-style generation,target model和draft model均可多卡推理
## Overview
## Overview
Medusa是一种大模型并行解码算法,除了支持官方提供的Top1-proposer,我们还支持tree-style并行解码,target model和draft model均可多卡推理
与其他模型不同,medusa解码需要一个base model和若干Medusa heads.
与其他模型不同,medusa解码需要一个base model和若干Medusa heads.
Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
...
@@ -19,28 +20,43 @@ Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
...
@@ -19,28 +20,43 @@ Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
```
bash
```
bash
python medusa_weight_converter.py
--medusa_num_heads
4
--medusa_num_layers
1
--medusa_model_path
/work/model.bin
--vocab_size
152064
--hidden_size
8192
--output_dir
/work/medusa/vllm-medusa-qwen2-72b-head-4
--medusa_choices
=
"[(0), (0, 0), (0, 0, 0), (0, 1), (1), (1, 0), (0, 0, 0, 0), (0, 0, 1), (0, 2), (0, 1, 0), (2), (0, 0, 2), (0, 3), (1, 0, 0), (2, 0), (0, 2, 0), (0, 4), (0, 0, 3), (3), (0, 0, 0, 1), (0, 5), (0, 0, 1, 0), (0, 0, 4)]"
python medusa_weight_converter.py
--medusa_num_heads
4
--medusa_num_layers
1
--medusa_model_path
/work/model.bin
--vocab_size
152064
--hidden_size
8192
--output_dir
/work/medusa/vllm-medusa-qwen2-72b-head-4
--medusa_choices
=
"[(0), (0, 0), (0, 0, 0), (0, 1), (1), (1, 0), (0, 0, 0, 0), (0, 0, 1), (0, 2), (0, 1, 0), (2), (0, 0, 2), (0, 3), (1, 0, 0), (2, 0), (0, 2, 0), (0, 4), (0, 0, 3), (3), (0, 0, 0, 1), (0, 5), (0, 0, 1, 0), (0, 0, 4)]"
```
```
此处model.bin是训练后保存的medusa head权重
此处model.bin是训练后保存的medusa head权重
,如果希望采用Top1-proposer,medusa_choices可以不设置
### Run
### Run
tree-style generation server
```
bash
```
bash
python3
-m
vllm.entrypoints.openai.api_server
\
VLLM_TREE_DECODING
=
1
python3
-m
vllm.entrypoints.openai.api_server
\
--served-model-name
qwen_medusa
\
--served-model-name
qwen_medusa
\
--model
/models/Qwen2-72B-Instruct/
-tp
4
\
--model
/models/Qwen2-72B-Instruct/
-tp
4
\
--max-model-len
1024
--max-num-seqs
8
--gpu-memory-utilization
0.8
\
--max-model-len
1024
--max-num-seqs
8
--gpu-memory-utilization
0.8
\
--speculative-model
/work/medusa/vllm-medusa-qwen2-72b-head-4
\
--speculative-model
/work/medusa/vllm-medusa-qwen2-72b-head-4
\
--speculative-draft-tensor-parallel-size
4
\
--speculative-draft-tensor-parallel-size
4
\
--speculative-disable-by-batch-size
4
\
--speculative-disable-by-batch-size
9
\
--use-v2-block-manager
\
--use-v2-block-manager
\
--spec-decoding-acceptance-method
typical_acceptance_sampler
\
--spec-decoding-acceptance-method
typical_acceptance_sampler
\
--dtype
float16
--trust-remote-code
--port
8086
\
--dtype
float16
--trust-remote-code
--port
8086
\
--tree-style-spec-decoding
True
\
--num-speculative-heads
4
--num-speculative-tokens
24
--num-speculative-heads
4
--num-speculative-tokens
24
```
```
注意:
num_speculative_tokens = len(medusa_choices) + 1
medusa_choices个数不能太多,否则多batch下会降低推理速度
speculative-disable-by-batch-size要大于max-num-seqs,否则当batch等于max-num-seqs时,不会走并行解码
merge-lora可以将lora权重和base model权重融合,提升整体推理速度,若对精度有严格要求,可不设置此参数
### Run Top1-proposer server
num-speculative-tokens和medusa choices的个数相关,num_speculative_tokens = len(medusa_choices) + 1
python3 -m vllm.entrypoints.openai.api_server
\
--served-model-name qwen_medusa
\
--model /models/Qwen2-72B-Instruct/ -tp 4
\
--max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.8
\
--speculative-model /work/medusa/vllm-medusa-qwen2-72b-head-4
\
--speculative-draft-tensor-parallel-size 4
\
--speculative-disable-by-batch-size 9
\
--use-v2-block-manager
\
--spec-decoding-acceptance-method typical_acceptance_sampler
\
--dtype float16 --trust-remote-code --port 8086
\
--num-speculative-tokens 4
注意:
使用Top1-proposer时,num-speculative-tokens就是medusa head的个数
# do request
# do request
```
bash
```
bash
...
@@ -54,8 +70,14 @@ curl http://localhost:8086/v1/completions \
...
@@ -54,8 +70,14 @@ curl http://localhost:8086/v1/completions \
}'
}'
```
```
### benchmark
### Run tree-style benchmark
python medusa_benchmark_throughput.py --model /data/llm-models/qwen2/Qwen2-72B-Instruct/ -tp 4 --dtype float16 --trust-remote-code --max-num-seqs 1 --dataset /work/test/medusa_benchmark_data.json --max-model-len 4096 --gpu-memory-utilization 0.9
```
bash
VLLM_TREE_DECODING
=
1 python /work/test/medusa_benchmark_throughput.py
--model
/models/Qwen2-72B-Instruct/
-tp
4
--dtype
float16
--trust-remote-code
--max-num-seqs
4
--speculative-model
/work/medusa/vllm-medusa1-qwen2-72b-head-4
--speculative-draft-tensor-parallel-size
4
--speculative-disable-by-batch-size
9
--use-v2-block-manager
--spec-decoding-acceptance-method
typical_acceptance_sampler
--max-model-len
1024
--dataset
/work/medusa_benchmark_data.json
--num-speculative-heads
4
--num-speculative-tokens
24
--gpu-memory-utilization
0.95
```
### Run Top1-proposer benchmark
```
bash
python /work/test/medusa_benchmark_throughput.py
--model
/models/Qwen2-72B-Instruct/
-tp
4
--dtype
float16
--trust-remote-code
--max-num-seqs
4
--speculative-model
/work/medusa/vllm-medusa1-qwen2-72b-head-4
--speculative-draft-tensor-parallel-size
4
--speculative-disable-by-batch-size
9
--use-v2-block-manager
--spec-decoding-acceptance-method
typical_acceptance_sampler
--max-model-len
1024
--dataset
/work/medusa_benchmark_data.json
--num-speculative-tokens
4
--gpu-memory-utilization
0.95
```
可设置max-num-seqs对不同的batch进行性能测试
可设置max-num-seqs对不同的batch进行性能测试
examples/medusa/medusa_benchmark_throughput.py
View file @
44e3ca68
...
@@ -98,7 +98,6 @@ def run_vllm(
...
@@ -98,7 +98,6 @@ def run_vllm(
merge_lora
:
bool
=
False
,
merge_lora
:
bool
=
False
,
lora_extra_vocab_size
:
int
=
0
,
lora_extra_vocab_size
:
int
=
0
,
lora_target_modules
:
List
[
str
]
=
None
,
lora_target_modules
:
List
[
str
]
=
None
,
tree_style_spec_decoding
:
bool
=
False
,
num_speculative_heads
:
int
=
5
,
num_speculative_heads
:
int
=
5
,
num_speculative_tokens
:
int
=
64
,
num_speculative_tokens
:
int
=
64
,
use_new_beam_search_impl
:
bool
=
False
,
use_new_beam_search_impl
:
bool
=
False
,
...
@@ -138,7 +137,6 @@ def run_vllm(
...
@@ -138,7 +137,6 @@ def run_vllm(
merge_lora
=
merge_lora
,
merge_lora
=
merge_lora
,
lora_extra_vocab_size
=
lora_extra_vocab_size
,
lora_extra_vocab_size
=
lora_extra_vocab_size
,
lora_target_modules
=
lora_target_modules
,
lora_target_modules
=
lora_target_modules
,
tree_style_spec_decoding
=
tree_style_spec_decoding
,
num_speculative_heads
=
num_speculative_heads
,
num_speculative_heads
=
num_speculative_heads
,
num_speculative_tokens
=
num_speculative_tokens
num_speculative_tokens
=
num_speculative_tokens
)
)
...
@@ -234,7 +232,6 @@ async def run_vllm_async(
...
@@ -234,7 +232,6 @@ async def run_vllm_async(
merge_lora
:
bool
=
False
,
merge_lora
:
bool
=
False
,
lora_extra_vocab_size
:
int
=
0
,
lora_extra_vocab_size
:
int
=
0
,
lora_target_modules
:
List
[
str
]
=
None
,
lora_target_modules
:
List
[
str
]
=
None
,
tree_style_spec_decoding
:
bool
=
False
,
num_speculative_heads
:
int
=
5
,
num_speculative_heads
:
int
=
5
,
num_speculative_tokens
:
int
=
64
,
num_speculative_tokens
:
int
=
64
,
use_new_beam_search_impl
:
bool
=
False
,
use_new_beam_search_impl
:
bool
=
False
,
...
@@ -276,7 +273,6 @@ async def run_vllm_async(
...
@@ -276,7 +273,6 @@ async def run_vllm_async(
merge_lora
=
merge_lora
,
merge_lora
=
merge_lora
,
lora_extra_vocab_size
=
lora_extra_vocab_size
,
lora_extra_vocab_size
=
lora_extra_vocab_size
,
lora_target_modules
=
lora_target_modules
,
lora_target_modules
=
lora_target_modules
,
tree_style_spec_decoding
=
tree_style_spec_decoding
,
num_speculative_heads
=
num_speculative_heads
,
num_speculative_heads
=
num_speculative_heads
,
num_speculative_tokens
=
num_speculative_tokens
num_speculative_tokens
=
num_speculative_tokens
)
)
...
@@ -350,7 +346,7 @@ def main(args: argparse.Namespace):
...
@@ -350,7 +346,7 @@ def main(args: argparse.Namespace):
args
.
speculative_model
,
args
.
speculative_draft_tensor_parallel_size
,
args
.
speculative_model
,
args
.
speculative_draft_tensor_parallel_size
,
args
.
speculative_disable_by_batch_size
,
args
.
spec_decoding_acceptance_method
,
args
.
speculative_disable_by_batch_size
,
args
.
spec_decoding_acceptance_method
,
args
.
enable_lora
,
args
.
max_lora_rank
,
args
.
merge_lora
,
args
.
lora_extra_vocab_size
,
args
.
enable_lora
,
args
.
max_lora_rank
,
args
.
merge_lora
,
args
.
lora_extra_vocab_size
,
args
.
lora_target_modules
,
args
.
tree_style_spec_decoding
,
args
.
num_speculative_heads
,
args
.
lora_target_modules
,
args
.
num_speculative_heads
,
args
.
num_speculative_tokens
args
.
num_speculative_tokens
]
]
else
:
else
:
...
@@ -368,7 +364,7 @@ def main(args: argparse.Namespace):
...
@@ -368,7 +364,7 @@ def main(args: argparse.Namespace):
args
.
speculative_model
,
args
.
speculative_draft_tensor_parallel_size
,
args
.
speculative_model
,
args
.
speculative_draft_tensor_parallel_size
,
args
.
speculative_disable_by_batch_size
,
args
.
spec_decoding_acceptance_method
,
args
.
speculative_disable_by_batch_size
,
args
.
spec_decoding_acceptance_method
,
args
.
enable_lora
,
args
.
max_lora_rank
,
args
.
merge_lora
,
args
.
lora_extra_vocab_size
,
args
.
enable_lora
,
args
.
max_lora_rank
,
args
.
merge_lora
,
args
.
lora_extra_vocab_size
,
args
.
lora_target_modules
,
args
.
tree_style_spec_decoding
,
args
.
num_speculative_heads
,
args
.
lora_target_modules
,
args
.
num_speculative_heads
,
args
.
num_speculative_tokens
args
.
num_speculative_tokens
]
]
...
@@ -625,11 +621,6 @@ if __name__ == "__main__":
...
@@ -625,11 +621,6 @@ if __name__ == "__main__":
default
=
None
,
default
=
None
,
help
=
'List of lora module name, If not specified, modules will be chosen according to the model architecture.'
)
help
=
'List of lora module name, If not specified, modules will be chosen according to the model architecture.'
)
parser
.
add_argument
(
'--tree-style-spec-decoding'
,
type
=
bool
,
default
=
False
,
help
=
'If set to True, tree-style generation will be activated.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--num-speculative-heads'
,
'--num-speculative-heads'
,
type
=
int
,
type
=
int
,
...
...
tests/kernels/test_attention.py
View file @
44e3ca68
...
@@ -225,7 +225,7 @@ def test_paged_attention(
...
@@ -225,7 +225,7 @@ def test_paged_attention(
opcheck
(
torch
.
ops
.
_C
.
paged_attention_v1
,
opcheck
(
torch
.
ops
.
_C
.
paged_attention_v1
,
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
,
None
,
0
),
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
and
block_size
==
BLOCK_SIZES
[
0
]))
and
block_size
==
BLOCK_SIZES
[
0
]))
else
:
else
:
...
@@ -291,7 +291,7 @@ def test_paged_attention(
...
@@ -291,7 +291,7 @@ def test_paged_attention(
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
,
None
,
0
),
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
and
block_size
==
BLOCK_SIZES
[
0
]))
and
block_size
==
BLOCK_SIZES
[
0
]))
else
:
else
:
...
...
tests/worker/test_model_input.py
View file @
44e3ca68
...
@@ -60,16 +60,6 @@ class MockAttentionBackend(AttentionBackend):
...
@@ -60,16 +60,6 @@ class MockAttentionBackend(AttentionBackend):
)
->
None
:
)
->
None
:
pass
pass
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
NotImplementedError
def
test_model_runner_input
():
def
test_model_runner_input
():
sampling_metadata
=
SamplingMetadata
(
sampling_metadata
=
SamplingMetadata
(
...
...
vllm/_custom_ops.py
View file @
44e3ca68
...
@@ -89,6 +89,66 @@ def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
...
@@ -89,6 +89,66 @@ def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
# page attention ops
# page attention ops
def
paged_attention_v1
(
def
paged_attention_v1
(
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_v2
(
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
tmp_out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v2
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_v1_with_mask
(
out
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
...
@@ -111,7 +171,7 @@ def paged_attention_v1(
...
@@ -111,7 +171,7 @@ def paged_attention_v1(
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks_stride
:
int
=
0
,
attn_masks_stride
:
int
=
0
,
)
->
None
:
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v1
(
torch
.
ops
.
_C
.
paged_attention_v1
_with_mask
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
...
@@ -120,7 +180,7 @@ def paged_attention_v1(
...
@@ -120,7 +180,7 @@ def paged_attention_v1(
attn_masks_stride
)
attn_masks_stride
)
def
paged_attention_v2
(
def
paged_attention_v2
_with_mask
(
out
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
...
@@ -146,7 +206,7 @@ def paged_attention_v2(
...
@@ -146,7 +206,7 @@ def paged_attention_v2(
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks_stride
:
int
=
0
,
attn_masks_stride
:
int
=
0
,
)
->
None
:
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v2
(
torch
.
ops
.
_C
.
paged_attention_v2
_with_mask
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
...
@@ -157,6 +217,67 @@ def paged_attention_v2(
...
@@ -157,6 +217,67 @@ def paged_attention_v2(
# page attention ops (opt)
# page attention ops (opt)
def
paged_attention_v1_opt
(
def
paged_attention_v1_opt
(
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v1_opt
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_v2_opt
(
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
tmp_out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v2_opt
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_v1_opt_with_mask
(
out
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
...
@@ -179,7 +300,7 @@ def paged_attention_v1_opt(
...
@@ -179,7 +300,7 @@ def paged_attention_v1_opt(
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks_stride
:
int
=
0
,
attn_masks_stride
:
int
=
0
,
)
->
None
:
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v1_opt
(
torch
.
ops
.
_C
.
paged_attention_v1_opt
_with_mask
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
...
@@ -188,7 +309,7 @@ def paged_attention_v1_opt(
...
@@ -188,7 +309,7 @@ def paged_attention_v1_opt(
attn_masks_stride
)
attn_masks_stride
)
def
paged_attention_v2_opt
(
def
paged_attention_v2_opt
_with_mask
(
out
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
...
@@ -214,7 +335,7 @@ def paged_attention_v2_opt(
...
@@ -214,7 +335,7 @@ def paged_attention_v2_opt(
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks_stride
:
int
=
0
,
attn_masks_stride
:
int
=
0
,
)
->
None
:
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v2_opt
(
torch
.
ops
.
_C
.
paged_attention_v2_opt
_with_mask
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
...
@@ -225,6 +346,67 @@ def paged_attention_v2_opt(
...
@@ -225,6 +346,67 @@ def paged_attention_v2_opt(
# page attention ops (opt)
# page attention ops (opt)
def
paged_attention_v1_opt_tc
(
def
paged_attention_v1_opt_tc
(
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v1_opt_tc
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_v2_opt_tc
(
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
tmp_out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v2_opt_tc
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
# page attention ops (opt)
def
paged_attention_v1_opt_tc_with_mask
(
out
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
...
@@ -247,7 +429,7 @@ def paged_attention_v1_opt_tc(
...
@@ -247,7 +429,7 @@ def paged_attention_v1_opt_tc(
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks_stride
:
int
=
0
,
attn_masks_stride
:
int
=
0
,
)
->
None
:
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v1_opt_tc
(
torch
.
ops
.
_C
.
paged_attention_v1_opt_tc
_with_mask
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
...
@@ -255,7 +437,7 @@ def paged_attention_v1_opt_tc(
...
@@ -255,7 +437,7 @@ def paged_attention_v1_opt_tc(
attn_masks
,
attn_masks_stride
)
attn_masks
,
attn_masks_stride
)
def
paged_attention_v2_opt_tc
(
def
paged_attention_v2_opt_tc
_with_mask
(
out
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
...
@@ -281,7 +463,7 @@ def paged_attention_v2_opt_tc(
...
@@ -281,7 +463,7 @@ def paged_attention_v2_opt_tc(
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks_stride
:
int
=
0
,
attn_masks_stride
:
int
=
0
,
)
->
None
:
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v2_opt_tc
(
torch
.
ops
.
_C
.
paged_attention_v2_opt_tc
_with_mask
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
...
...
vllm/attention/backends/abstract.py
View file @
44e3ca68
...
@@ -82,17 +82,6 @@ class AttentionBackend(ABC):
...
@@ -82,17 +82,6 @@ class AttentionBackend(ABC):
src_to_dists
:
torch
.
Tensor
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
raise
NotImplementedError
def
advance_step
(
self
,
model_input
:
"ModelRunnerInputBase"
,
def
advance_step
(
self
,
model_input
:
"ModelRunnerInputBase"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
...
@@ -206,8 +195,7 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
...
@@ -206,8 +195,7 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
@
abstractmethod
@
abstractmethod
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
,
cuda_graph_pad_size
:
int
,
batch_size
:
int
)
->
T
:
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
)
->
T
:
"""Build attention metadata with on-device tensors."""
"""Build attention metadata with on-device tensors."""
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
44e3ca68
...
@@ -129,50 +129,6 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
...
@@ -129,50 +129,6 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
)
->
None
:
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
key_caches
=
[]
value_caches
=
[]
num_layers
=
len
(
kv_caches
)
token_num
=
src_to_dists
.
shape
[
0
]
tmp_store_kv
=
torch
.
empty
(
(
2
,
num_layers
,
token_num
,
num_kv_heads
,
head_size
),
dtype
=
kv_caches
[
0
].
dtype
,
device
=
kv_caches
[
0
].
device
)
keys
=
tmp_store_kv
[
0
].
contiguous
()
values
=
tmp_store_kv
[
1
].
contiguous
()
for
kv_cache
in
kv_caches
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
num_kv_heads
,
head_size
)
key_caches
.
append
(
key_cache
)
value_caches
.
append
(
value_cache
)
ops
.
read_cache
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
0
].
contiguous
(),
kv_cache_dtype
)
ops
.
write_cache_multi_layers
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
1
].
contiguous
(),
kv_cache_dtype
)
@
dataclass
@
dataclass
class
BlocksparseFlashAttentionMetadata
(
AttentionMetadata
):
class
BlocksparseFlashAttentionMetadata
(
AttentionMetadata
):
...
@@ -235,8 +191,6 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
...
@@ -235,8 +191,6 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
"BlocksparseFlashAttentionMetadata"
]
=
None
"BlocksparseFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
_cached_decode_metadata
:
Optional
[
"BlocksparseFlashAttentionMetadata"
]
=
None
"BlocksparseFlashAttentionMetadata"
]
=
None
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
block_tables_list
:
Optional
[
List
[
int
]]
=
None
block_tables_list
:
Optional
[
List
[
int
]]
=
None
...
@@ -271,7 +225,6 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
...
@@ -271,7 +225,6 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
use_cuda_graph
=
False
,
use_cuda_graph
=
False
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables_list
block_tables_list
=
self
.
block_tables_list
)
)
return
self
.
_cached_prefill_metadata
return
self
.
_cached_prefill_metadata
...
@@ -301,7 +254,6 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
...
@@ -301,7 +254,6 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
use_cuda_graph
=
self
.
use_cuda_graph
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables_list
block_tables_list
=
self
.
block_tables_list
)
)
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
...
...
vllm/attention/backends/flash_attn.py
View file @
44e3ca68
...
@@ -221,16 +221,6 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -221,16 +221,6 @@ class FlashAttentionBackend(AttentionBackend):
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
NotImplementedError
@
dataclass
@
dataclass
class
FlashAttentionMetadata
(
AttentionMetadata
):
class
FlashAttentionMetadata
(
AttentionMetadata
):
...
...
vllm/attention/backends/flashinfer.py
View file @
44e3ca68
...
@@ -92,16 +92,6 @@ class FlashInferBackend(AttentionBackend):
...
@@ -92,16 +92,6 @@ class FlashInferBackend(AttentionBackend):
return
torch
.
float8_e5m2
return
torch
.
float8_e5m2
else
:
else
:
raise
ValueError
(
f
"Unrecognized FP8 dtype:
{
kv_cache_dtype
}
"
)
raise
ValueError
(
f
"Unrecognized FP8 dtype:
{
kv_cache_dtype
}
"
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
NotImplementedError
class
FlashInferState
(
AttentionState
):
class
FlashInferState
(
AttentionState
):
...
...
vllm/attention/backends/ipex_attn.py
View file @
44e3ca68
...
@@ -62,16 +62,6 @@ class IpexAttnBackend(AttentionBackend):
...
@@ -62,16 +62,6 @@ class IpexAttnBackend(AttentionBackend):
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
NotImplementedError
@
dataclass
@
dataclass
class
IpexAttnMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
class
IpexAttnMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
...
...
vllm/attention/backends/openvino.py
View file @
44e3ca68
...
@@ -62,16 +62,6 @@ class OpenVINOAttentionBackend(AttentionBackend):
...
@@ -62,16 +62,6 @@ class OpenVINOAttentionBackend(AttentionBackend):
key_cache
.
data
[
dst
,
:]
=
key_cache
.
data
[
src
,
:]
key_cache
.
data
[
dst
,
:]
=
key_cache
.
data
[
src
,
:]
value_cache
.
data
[
dst
,
:]
=
value_cache
.
data
[
src
,
:]
value_cache
.
data
[
dst
,
:]
=
value_cache
.
data
[
src
,
:]
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
NotImplementedError
@
dataclass
@
dataclass
class
OpenVINOAttentionMetadata
:
class
OpenVINOAttentionMetadata
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment