Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e96edbbe
Commit
e96edbbe
authored
Oct 25, 2024
by
zhuwenwen
Browse files
Merge remote-tracking branch 'origin/v0.6.2-dev-medusa' into v0.6.2-dev
parents
3bda0405
19bc93d9
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1123 additions
and
69 deletions
+1123
-69
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+45
-14
csrc/attention/attention_kernels_opt.cu
csrc/attention/attention_kernels_opt.cu
+44
-14
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+56
-18
csrc/cache.h
csrc/cache.h
+16
-0
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+272
-0
csrc/cpu/cache.cpp
csrc/cpu/cache.cpp
+21
-0
csrc/ops.h
csrc/ops.h
+18
-6
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+34
-6
examples/medusa/medusa_weight_converter.py
examples/medusa/medusa_weight_converter.py
+393
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+49
-6
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+13
-1
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+49
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+3
-1
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+13
-1
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+10
-0
vllm/attention/backends/openvino.py
vllm/attention/backends/openvino.py
+10
-0
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+10
-0
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+52
-0
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+11
-0
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+4
-2
No files found.
csrc/attention/attention_kernels.cu
View file @
e96edbbe
...
...
@@ -107,7 +107,8 @@ __device__ void paged_attention_kernel(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
const
int
seq_idx
=
blockIdx
.
y
;
const
int
partition_idx
=
blockIdx
.
z
;
const
int
max_num_partitions
=
gridDim
.
z
;
...
...
@@ -299,7 +300,14 @@ __device__ void paged_attention_kernel(
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const
bool
mask
=
token_idx
>=
seq_len
;
bool
mask
=
token_idx
>=
seq_len
;
// used for tree-style attention
if
(
attn_masks
!=
nullptr
)
{
const
int
*
attn_masks_ptr
=
attn_masks
+
seq_idx
*
attn_masks_stride
;
mask
|=
attn_masks_ptr
[
token_idx
]
==
0
;
}
logits
[
token_idx
-
start_token_idx
]
=
mask
?
0.
f
:
qk
;
// Update the max value.
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
...
...
@@ -515,7 +523,8 @@ __global__ void paged_attention_v1_kernel(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
...
...
@@ -523,7 +532,7 @@ __global__ void paged_attention_v1_kernel(
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
...
...
@@ -551,14 +560,15 @@ __global__ void paged_attention_v2_kernel(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
}
// Grid: (num_heads, num_seqs).
...
...
@@ -684,7 +694,8 @@ __global__ void paged_attention_v2_reduce_kernel(
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride);
// TODO(woosuk): Tune NUM_THREADS.
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
...
...
@@ -697,7 +708,9 @@ void paged_attention_v1_launcher(
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
...
...
@@ -722,6 +735,12 @@ void paged_attention_v1_launcher(
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
// NOTE: attn_masks is optional.
const
int
*
attn_masks_ptr
=
attn_masks
?
attn_masks
.
value
().
data_ptr
<
int
>
()
:
nullptr
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
padded_max_seq_len
=
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
...
...
@@ -778,7 +797,8 @@ void paged_attention_v1_launcher(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks, attn_masks_stride);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
...
...
@@ -824,7 +844,9 @@ void paged_attention_v1(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
...
...
@@ -841,7 +863,8 @@ void paged_attention_v1(
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
...
...
@@ -859,7 +882,9 @@ void paged_attention_v2_launcher(
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int
attn_masks_stride
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
...
...
@@ -887,6 +912,10 @@ void paged_attention_v2_launcher(
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
// NOTE: attn_masks is optional.
const
int
*
attn_masks_ptr
=
attn_masks
?
attn_masks
.
value
().
data_ptr
<
int
>
()
:
nullptr
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
int
logits_size
=
PARTITION_SIZE
*
sizeof
(
float
);
...
...
@@ -946,7 +975,7 @@ void paged_attention_v2_launcher(
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step
, attn_masks, attn_masks_stride
);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
...
...
@@ -996,7 +1025,9 @@ void paged_attention_v2(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
CALL_V2_LAUNCHER_BLOCK_SIZE
)
...
...
csrc/attention/attention_kernels_opt.cu
View file @
e96edbbe
...
...
@@ -94,7 +94,8 @@ __device__ void paged_attention_kernel_opt(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
const
int
seq_idx
=
blockIdx
.
z
;
const
int
partition_idx
=
blockIdx
.
y
;
const
int
max_num_partitions
=
gridDim
.
y
;
...
...
@@ -330,7 +331,13 @@ __device__ void paged_attention_kernel_opt(
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const
bool
mask
=
token_idx
>=
seq_len
;
bool
mask
=
token_idx
>=
seq_len
;
// used for tree-style attention
if
(
attn_masks
!=
nullptr
)
{
const
int
*
attn_masks_ptr
=
attn_masks
+
seq_idx
*
attn_masks_stride
;
mask
|=
attn_masks_ptr
[
token_idx
]
==
0
;
}
logits
[(
reuse_kv_idx
*
partition_size
)
+
(
token_idx
-
start_token_idx
)]
=
mask
?
0.
f
:
qk
;
// Update the max value.
qk_max
[
reuse_kv_idx
]
=
mask
?
qk_max
[
reuse_kv_idx
]
:
fmaxf
(
qk_max
[
reuse_kv_idx
],
qk
);
...
...
@@ -611,7 +618,8 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
paged_attention_kernel_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
...
...
@@ -619,7 +627,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
...
...
@@ -651,14 +659,15 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt(
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
paged_attention_kernel_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
}
// Grid: (num_heads, num_seqs).
...
...
@@ -784,7 +793,8 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
...
...
@@ -807,7 +817,9 @@ void paged_attention_v1_launcher(
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
...
...
@@ -836,6 +848,12 @@ void paged_attention_v1_launcher(
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
// NOTE: attn_masks is optional.
const
int
*
attn_masks_ptr
=
attn_masks
?
attn_masks
.
value
().
data_ptr
<
int
>
()
:
nullptr
;
int
padded_max_seq_len
=
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
REUSEKV_SWITCH_V1
(
num_heads
*
num_seqs
,
[
&
]
{
BOOL_SWITCH
((
num_heads
/
num_kv_heads
%
REUSE_KV_TIMES
!=
0
),
odd_nheads
,
[
&
]
{
...
...
@@ -869,7 +887,8 @@ void paged_attention_v1_launcher(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks, attn_masks_stride);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
...
...
@@ -915,7 +934,9 @@ void paged_attention_v1_opt(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
...
...
@@ -932,7 +953,8 @@ void paged_attention_v1_opt(
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>) \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
...
...
@@ -950,7 +972,9 @@ void paged_attention_v2_launcher(
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int
attn_masks_stride
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
...
...
@@ -978,6 +1002,10 @@ void paged_attention_v2_launcher(
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
// NOTE: attn_masks is optional.
const
int
*
attn_masks_ptr
=
attn_masks
?
attn_masks
.
value
().
data_ptr
<
int
>
()
:
nullptr
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
REUSEKV_SWITCH
(
num_heads
*
max_num_partitions
*
num_seqs
,
[
&
]
{
...
...
@@ -1016,7 +1044,7 @@ void paged_attention_v2_launcher(
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step
, attn_masks, attn_masks_stride
);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
...
...
@@ -1066,7 +1094,9 @@ void paged_attention_v2_opt(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
CALL_V2_LAUNCHER_BLOCK_SIZE
)
...
...
csrc/attention/attention_kernels_opt_tc.cu
View file @
e96edbbe
...
...
@@ -168,7 +168,8 @@ __device__ void paged_attention_kernel_TC(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
const
int
seq_idx
=
blockIdx
.
z
;
const
int
partition_idx
=
blockIdx
.
y
;
const
int
max_num_partitions
=
gridDim
.
y
;
...
...
@@ -291,7 +292,13 @@ __device__ void paged_attention_kernel_TC(
float
alibi
=
alibi_slope
[
i
]
*
(
token_idx
-
seq_len
+
1
);
qk_vec
[
i
]
+=
alibi
;
}
const
bool
mask
=
(
token_idx
>=
seq_len
);
bool
mask
=
(
token_idx
>=
seq_len
);
// used for tree-style attention
if
(
attn_masks
!=
nullptr
)
{
const
int
*
attn_masks_ptr
=
attn_masks
+
seq_idx
*
attn_masks_stride
;
mask
|=
attn_masks_ptr
[
token_idx
]
==
0
;
}
if
(
mask
){
from_float
(
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
,
0.
f
);
}
...
...
@@ -555,7 +562,8 @@ __global__ void paged_attention_v1_kernel_TC(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
#ifdef __gfx928__
paged_attention_kernel_TC
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
use_vmac
>
(
...
...
@@ -564,7 +572,7 @@ __global__ void paged_attention_v1_kernel_TC(
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
#endif
}
...
...
@@ -594,7 +602,8 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
#ifdef __gfx928__
paged_attention_kernel_TC
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
use_vmac
,
...
...
@@ -603,7 +612,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
#endif
}
...
...
@@ -730,7 +739,8 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride);
void
get_numberthread_and_reuse_kv_v1
(
int
&
num_thread
,
int
&
reusekv
,
int
batchsize
,
int
seq
,
int
qheads
,
int
kvheads
){
//mha
...
...
@@ -796,7 +806,9 @@ void paged_attention_v1_launcher_opt_tc(
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
...
...
@@ -824,6 +836,13 @@ void paged_attention_v1_launcher_opt_tc(
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
// NOTE: attn_masks is optional.
const
int
*
attn_masks_ptr
=
attn_masks
?
attn_masks
.
value
().
data_ptr
<
int
>
()
:
nullptr
;
int
padded_max_seq_len
=
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
@@ -858,7 +877,8 @@ void paged_attention_v1_launcher_opt_tc(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks, attn_masks_stride);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
...
...
@@ -904,7 +924,9 @@ void paged_attention_v1_opt(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
);
void
paged_attention_v1_opt_tc
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
...
...
@@ -922,14 +944,17 @@ void paged_attention_v1_opt_tc(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
get_device_name
()
!=
"gfx928"
){
paged_attention_v1_opt
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
}
else
{
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
...
...
@@ -948,7 +973,8 @@ void paged_attention_v1_opt_tc(
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); \
blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \
...
...
@@ -999,7 +1025,9 @@ void paged_attention_v2_launcher_opt_tc(
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int
attn_masks_stride
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
...
...
@@ -1026,6 +1054,11 @@ void paged_attention_v2_launcher_opt_tc(
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
// NOTE: attn_masks is optional.
const
int
*
attn_masks_ptr
=
attn_masks
?
attn_masks
.
value
().
data_ptr
<
int
>
()
:
nullptr
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
reduce_grid
(
num_heads
,
num_seqs
);
...
...
@@ -1067,7 +1100,7 @@ void paged_attention_v2_launcher_opt_tc(
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step
, attn_masks, attn_masks_stride
);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
...
...
@@ -1117,7 +1150,9 @@ void paged_attention_v2_opt(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
);
void
paged_attention_v2_opt_tc
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
...
...
@@ -1139,14 +1174,17 @@ void paged_attention_v2_opt_tc(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
get_device_name
()
!=
"gfx928"
){
paged_attention_v2_opt
(
out
,
exp_sums
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
}
else
{
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
...
...
csrc/cache.h
View file @
e96edbbe
...
...
@@ -31,3 +31,19 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
// Just for unittest
void
convert_fp8
(
torch
::
Tensor
&
dst_cache
,
torch
::
Tensor
&
src_cache
,
const
double
scale
,
const
std
::
string
&
kv_cache_dtype
);
void
read_cache
(
torch
::
Tensor
&
keys
,
torch
::
Tensor
&
values
,
std
::
vector
<
torch
::
Tensor
>
const
&
key_caches
,
std
::
vector
<
torch
::
Tensor
>
const
&
value_caches
,
torch
::
Tensor
&
slot_mapping
,
const
std
::
string
&
kv_cache_dtype
);
void
write_cache_multi_layers
(
torch
::
Tensor
&
keys
,
torch
::
Tensor
&
values
,
std
::
vector
<
torch
::
Tensor
>
const
&
key_caches
,
std
::
vector
<
torch
::
Tensor
>
const
&
value_caches
,
torch
::
Tensor
&
slot_mapping
,
const
std
::
string
&
kv_cache_dtype
);
csrc/cache_kernels.cu
View file @
e96edbbe
...
...
@@ -245,6 +245,133 @@ __global__ void reshape_and_cache_flash_kernel(
}
}
}
template
<
typename
scalar_t
,
typename
cache_t
,
Fp8KVCacheDataType
kv_dt
>
__global__
void
write_cache_multi_layers_kernel
(
scalar_t
*
__restrict__
keys
,
// [num_layers, num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
values
,
// [num_layers, num_tokens, num_heads, head_size]
int64_t
*
key_cache_ptrs
,
// [num_blocks, num_heads, head_size/x,
// block_size, x]
int64_t
*
value_cache_ptrs
,
// [num_blocks, num_heads, head_size,
// block_size]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
const
int
x
,
const
int
num_tokens
)
{
const
int
layer_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
y
;
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
if
(
slot_idx
<
0
)
{
// Padding token that should be ignored.
return
;
}
cache_t
*
key_cache
=
reinterpret_cast
<
cache_t
*>
(
key_cache_ptrs
[
layer_idx
]);
cache_t
*
value_cache
=
reinterpret_cast
<
cache_t
*>
(
value_cache_ptrs
[
layer_idx
]);
scalar_t
*
key
=
keys
+
layer_idx
*
num_tokens
*
key_stride
;
scalar_t
*
value
=
values
+
layer_idx
*
num_tokens
*
value_stride
;
const
int64_t
block_idx
=
slot_idx
/
block_size
;
const
int64_t
block_offset
=
slot_idx
%
block_size
;
const
int
n
=
num_heads
*
head_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int
x_idx
=
head_offset
/
x
;
const
int
x_offset
=
head_offset
%
x
;
const
int64_t
tgt_key_idx
=
block_idx
*
num_heads
*
(
head_size
/
x
)
*
block_size
*
x
+
head_idx
*
(
head_size
/
x
)
*
block_size
*
x
+
x_idx
*
block_size
*
x
+
block_offset
*
x
+
x_offset
;
const
int64_t
tgt_value_idx
=
block_idx
*
num_heads
*
head_size
*
block_size
+
head_idx
*
head_size
*
block_size
+
head_offset
*
block_size
+
block_offset
;
const
int64_t
src_key_idx
=
token_idx
*
key_stride
+
i
;
const
int64_t
src_value_idx
=
token_idx
*
value_stride
+
i
;
scalar_t
tgt_key
=
key
[
src_key_idx
];
scalar_t
tgt_value
=
value
[
src_value_idx
];
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
key_cache
[
tgt_key_idx
]
=
tgt_key
;
value_cache
[
tgt_value_idx
]
=
tgt_value
;
}
else
{
key_cache
[
tgt_key_idx
]
=
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_key
,
1.0
);
value_cache
[
tgt_value_idx
]
=
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_value
,
1.0
);
}
}
}
template
<
typename
scalar_t
,
typename
cache_t
,
Fp8KVCacheDataType
kv_dt
>
__global__
void
read_cache_kernel
(
scalar_t
*
__restrict__
keys
,
// [num_layers, num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
values
,
// [num_layers, num_tokens, num_heads, head_size]
int64_t
*
key_cache_ptrs
,
// [num_blocks, num_heads, head_size/x,
// block_size, x]
int64_t
*
value_cache_ptrs
,
// [num_blocks, num_heads, head_size,
// block_size]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
const
int
x
,
const
int
num_tokens
)
{
const
int
layer_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
y
;
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
if
(
slot_idx
<
0
)
{
// Padding token that should be ignored.
return
;
}
cache_t
*
key_cache
=
reinterpret_cast
<
cache_t
*>
(
key_cache_ptrs
[
layer_idx
]);
cache_t
*
value_cache
=
reinterpret_cast
<
cache_t
*>
(
value_cache_ptrs
[
layer_idx
]);
scalar_t
*
key
=
keys
+
layer_idx
*
num_tokens
*
key_stride
;
scalar_t
*
value
=
values
+
layer_idx
*
num_tokens
*
value_stride
;
const
int64_t
block_idx
=
slot_idx
/
block_size
;
const
int64_t
block_offset
=
slot_idx
%
block_size
;
const
int
n
=
num_heads
*
head_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int
x_idx
=
head_offset
/
x
;
const
int
x_offset
=
head_offset
%
x
;
const
int64_t
src_key_idx
=
block_idx
*
num_heads
*
(
head_size
/
x
)
*
block_size
*
x
+
head_idx
*
(
head_size
/
x
)
*
block_size
*
x
+
x_idx
*
block_size
*
x
+
block_offset
*
x
+
x_offset
;
const
int64_t
src_value_idx
=
block_idx
*
num_heads
*
head_size
*
block_size
+
head_idx
*
head_size
*
block_size
+
head_offset
*
block_size
+
block_offset
;
const
int64_t
tgt_key_idx
=
token_idx
*
key_stride
+
i
;
const
int64_t
tgt_value_idx
=
token_idx
*
value_stride
+
i
;
cache_t
tgt_key
=
key_cache
[
src_key_idx
];
cache_t
tgt_value
=
value_cache
[
src_value_idx
];
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
key
[
tgt_key_idx
]
=
tgt_key
;
value
[
tgt_value_idx
]
=
tgt_value
;
}
else
{
key
[
tgt_key_idx
]
=
fp8
::
scaled_convert
<
scalar_t
,
cache_t
,
kv_dt
>
(
tgt_key
,
1.0
);
value
[
tgt_value_idx
]
=
fp8
::
scaled_convert
<
scalar_t
,
cache_t
,
kv_dt
>
(
tgt_value
,
1.0
);
}
}
}
}
// namespace vllm
// KV_T is the stored data type of kv-cache.
...
...
@@ -329,6 +456,151 @@ void reshape_and_cache_flash(
CALL_RESHAPE_AND_CACHE_FLASH
);
}
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_READ_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::read_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(keys.data_ptr()), \
reinterpret_cast<KV_T*>(values.data_ptr()), \
key_cache_ptrs_tensor.data_ptr<int64_t>(), \
value_cache_ptrs_tensor.data_ptr<int64_t>(), \
slot_mapping.data_ptr<int64_t>(), \
key_stride, value_stride, \
num_heads, head_size, block_size, x, num_tokens);
void
read_cache
(
torch
::
Tensor
&
keys
,
// [num_layers, seq_len, num_heads, head_size]
torch
::
Tensor
&
values
,
// [num_layers, seq_len, num_heads, head_size]
std
::
vector
<
torch
::
Tensor
>
const
&
key_caches
,
// [num_blocks, num_heads, head_size/x, block_size, x]
std
::
vector
<
torch
::
Tensor
>
const
&
value_caches
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
const
std
::
string
&
kv_cache_dtype
)
{
int
num_layers
=
key_caches
.
size
();
TORCH_CHECK
(
num_layers
==
value_caches
.
size
());
if
(
num_layers
==
0
)
{
return
;
}
torch
::
Device
cache_device
=
key_caches
[
0
].
device
();
TORCH_CHECK
(
cache_device
.
is_cuda
());
// Create data structures for the kernel.
// Create an array of pointers to the key and value and caches.
int64_t
key_cache_ptrs
[
num_layers
];
int64_t
value_cache_ptrs
[
num_layers
];
for
(
int
layer_idx
=
0
;
layer_idx
<
num_layers
;
++
layer_idx
)
{
key_cache_ptrs
[
layer_idx
]
=
reinterpret_cast
<
int64_t
>
(
key_caches
[
layer_idx
].
data_ptr
());
value_cache_ptrs
[
layer_idx
]
=
reinterpret_cast
<
int64_t
>
(
value_caches
[
layer_idx
].
data_ptr
());
}
int
num_tokens
=
keys
.
size
(
1
);
auto
kv_dtype
=
keys
.
dtype
();
torch
::
Tensor
key_cache
=
key_caches
[
0
];
torch
::
Tensor
value_cache
=
value_caches
[
0
];
int
key_stride
=
keys
.
stride
(
1
);
int
value_stride
=
values
.
stride
(
1
);
int
num_heads
=
value_cache
.
size
(
1
);
int
head_size
=
value_cache
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
3
);
int
x
=
key_cache
.
size
(
4
);
// Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU.
torch
::
Tensor
key_cache_ptrs_tensor
=
torch
::
from_blob
(
key_cache_ptrs
,
{
num_layers
},
torch
::
kInt64
)
.
to
(
cache_device
);
torch
::
Tensor
value_cache_ptrs_tensor
=
torch
::
from_blob
(
value_cache_ptrs
,
{
num_layers
},
torch
::
kInt64
)
.
to
(
cache_device
);
dim3
grid
(
num_layers
,
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
slot_mapping
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
DISPATCH_BY_KV_CACHE_DTYPE
(
kv_dtype
,
kv_cache_dtype
,
CALL_READ_CACHE
);
}
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_WRITE_CACHE_MULTI_LAYERS(KV_T, CACHE_T, KV_DTYPE) \
vllm::write_cache_multi_layers_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(keys.data_ptr()), \
reinterpret_cast<KV_T*>(values.data_ptr()), \
key_cache_ptrs_tensor.data_ptr<int64_t>(), \
value_cache_ptrs_tensor.data_ptr<int64_t>(), \
slot_mapping.data_ptr<int64_t>(), \
key_stride, value_stride, \
num_heads, head_size, block_size, x, num_tokens);
void
write_cache_multi_layers
(
torch
::
Tensor
&
keys
,
// [num_layers, seq_len, num_heads, head_size]
torch
::
Tensor
&
values
,
// [num_layers, seq_len, num_heads, head_size]
std
::
vector
<
torch
::
Tensor
>
const
&
key_caches
,
// [num_blocks, num_heads, head_size/x, block_size, x]
std
::
vector
<
torch
::
Tensor
>
const
&
value_caches
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
const
std
::
string
&
kv_cache_dtype
)
{
int
num_layers
=
key_caches
.
size
();
TORCH_CHECK
(
num_layers
==
value_caches
.
size
());
if
(
num_layers
==
0
)
{
return
;
}
torch
::
Device
cache_device
=
key_caches
[
0
].
device
();
TORCH_CHECK
(
cache_device
.
is_cuda
());
// Create data structures for the kernel.
// Create an array of pointers to the key and value and caches.
int64_t
key_cache_ptrs
[
num_layers
];
int64_t
value_cache_ptrs
[
num_layers
];
for
(
int
layer_idx
=
0
;
layer_idx
<
num_layers
;
++
layer_idx
)
{
key_cache_ptrs
[
layer_idx
]
=
reinterpret_cast
<
int64_t
>
(
key_caches
[
layer_idx
].
data_ptr
());
value_cache_ptrs
[
layer_idx
]
=
reinterpret_cast
<
int64_t
>
(
value_caches
[
layer_idx
].
data_ptr
());
}
auto
kv_dtype
=
keys
.
dtype
();
int
num_tokens
=
keys
.
size
(
1
);
torch
::
Tensor
key_cache
=
key_caches
[
0
];
torch
::
Tensor
value_cache
=
value_caches
[
0
];
int
key_stride
=
keys
.
stride
(
1
);
int
value_stride
=
values
.
stride
(
1
);
int
num_heads
=
value_cache
.
size
(
1
);
int
head_size
=
value_cache
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
3
);
int
x
=
key_cache
.
size
(
4
);
// Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU.
torch
::
Tensor
key_cache_ptrs_tensor
=
torch
::
from_blob
(
key_cache_ptrs
,
{
num_layers
},
torch
::
kInt64
)
.
to
(
cache_device
);
torch
::
Tensor
value_cache_ptrs_tensor
=
torch
::
from_blob
(
value_cache_ptrs
,
{
num_layers
},
torch
::
kInt64
)
.
to
(
cache_device
);
dim3
grid
(
num_layers
,
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
slot_mapping
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
DISPATCH_BY_KV_CACHE_DTYPE
(
kv_dtype
,
kv_cache_dtype
,
CALL_WRITE_CACHE_MULTI_LAYERS
);
}
namespace
vllm
{
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
...
...
csrc/cpu/cache.cpp
View file @
e96edbbe
...
...
@@ -136,3 +136,24 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const
torch
::
Tensor
&
block_mapping
)
{
TORCH_CHECK
(
false
,
"swap_blocks is unsupported on CPU."
)
}
void
read_cache
(
std
::
vector
<
torch
::
Tensor
>
const
&
keys
,
std
::
vector
<
torch
::
Tensor
>
const
&
values
,
std
::
vector
<
torch
::
Tensor
>
const
&
key_caches
,
std
::
vector
<
torch
::
Tensor
>
const
&
value_caches
,
torch
::
Tensor
&
slot_mapping
,
const
std
::
string
&
kv_cache_dtype
)
{
TORCH_CHECK
(
false
,
"read_cache is unsupported on CPU."
)
}
void
write_cache_multi_layers
(
std
::
vector
<
torch
::
Tensor
>
const
&
keys
,
std
::
vector
<
torch
::
Tensor
>
const
&
values
,
std
::
vector
<
torch
::
Tensor
>
const
&
key_caches
,
std
::
vector
<
torch
::
Tensor
>
const
&
value_caches
,
torch
::
Tensor
&
slot_mapping
,
const
std
::
string
&
kv_cache_dtype
)
{
TORCH_CHECK
(
false
,
"write_cache_multi_layers is unsupported on CPU."
)
}
csrc/ops.h
View file @
e96edbbe
...
...
@@ -13,7 +13,9 @@ void paged_attention_v1(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
);
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
...
...
@@ -24,7 +26,9 @@ void paged_attention_v2(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
);
void
paged_attention_v1_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
...
...
@@ -34,7 +38,9 @@ void paged_attention_v1_opt(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
);
void
paged_attention_v2_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
...
...
@@ -45,7 +51,9 @@ void paged_attention_v2_opt(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
);
void
paged_attention_v1_opt_tc
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
...
...
@@ -55,7 +63,9 @@ void paged_attention_v1_opt_tc(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
);
void
paged_attention_v2_opt_tc
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
...
...
@@ -66,7 +76,9 @@ void paged_attention_v2_opt_tc(
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
);
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
double
epsilon
);
...
...
csrc/torch_bindings.cpp
View file @
e96edbbe
...
...
@@ -30,7 +30,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()"
);
ops
.
impl
(
"paged_attention_v1"
,
torch
::
kCUDA
,
&
paged_attention_v1
);
// PagedAttention V2.
...
...
@@ -44,7 +46,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()"
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
// Compute the attention between an input query and the cached
...
...
@@ -58,7 +62,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()"
);
ops
.
impl
(
"paged_attention_v1_opt"
,
torch
::
kCUDA
,
&
paged_attention_v1_opt
);
// PagedAttention V2 (opt).
...
...
@@ -72,7 +78,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()"
);
ops
.
impl
(
"paged_attention_v2_opt"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt
);
// Compute the attention between an input query and the cached
...
...
@@ -86,7 +94,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()"
);
ops
.
impl
(
"paged_attention_v1_opt_tc"
,
torch
::
kCUDA
,
&
paged_attention_v1_opt_tc
);
// PagedAttention V2 (opt).
...
...
@@ -100,7 +110,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()"
);
ops
.
impl
(
"paged_attention_v2_opt_tc"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt_tc
);
// Activation ops
...
...
@@ -479,6 +491,22 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops
.
impl
(
"reshape_and_cache_flash"
,
torch
::
kCUDA
,
&
reshape_and_cache_flash
);
// read key and value form kv cache
cache_ops
.
def
(
"read_cache(Tensor keys, Tensor values,"
" Tensor[]! key_caches, Tensor[]! value_caches,"
" Tensor slot_mapping,"
" str kv_cache_dtype) -> ()"
);
cache_ops
.
impl
(
"read_cache"
,
torch
::
kCUDA
,
&
read_cache
);
// write multi-layers key and value to kv cache
cache_ops
.
def
(
"write_cache_multi_layers(Tensor keys, Tensor values,"
" Tensor[]! key_caches, Tensor[]! value_caches,"
" Tensor slot_mapping,"
" str kv_cache_dtype) -> ()"
);
cache_ops
.
impl
(
"write_cache_multi_layers"
,
torch
::
kCUDA
,
&
write_cache_multi_layers
);
// Convert the key and value cache to fp8 data type.
cache_ops
.
def
(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
...
...
examples/medusa/medusa_weight_converter.py
0 → 100644
View file @
e96edbbe
import
os
from
pathlib
import
Path
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
addict
import
Dict
import
yaml
import
argparse
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
transformers
import
PretrainedConfig
from
safetensors.torch
import
save_model
,
safe_open
from
vllm.model_executor.layers.linear
import
UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE
=
64
TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE
=
'base_model.model.medusa_head.{}.{}.linear.weight'
TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE
=
'base_model.model.medusa_head.{}.1.weight'
VLLM_BLOCK_WEIGHT_NAME_TEMPLATE
=
'blocks.{}.layers.{}.weight'
VLLM_MEDUSA_HEADS_WEIGHT_NAME_TEMPLATE
=
'lm_heads.{}.weight'
MEDUSA_CHOICES
=
[(
0
,),
(
0
,
0
),
(
0
,
0
,
0
),
(
1
,),
(
0
,
1
),
(
1
,
0
),
(
0
,
2
),
(
2
,),
(
0
,
0
,
0
,
0
),
(
0
,
0
,
1
),
(
0
,
1
,
0
),
(
0
,
3
),
(
2
,
0
),
(
1
,
0
,
0
),
(
3
,),
(
0
,
0
,
2
),
(
0
,
4
),
(
0
,
2
,
0
),
(
0
,
5
),
(
4
,),
(
1
,
1
),
(
0
,
0
,
3
),
(
3
,
0
),
(
0
,
6
),
(
0
,
0
,
0
,
1
),
(
0
,
3
,
0
),
(
0
,
0
,
4
),
(
0
,
0
,
1
,
0
),
(
2
,
0
,
0
),
(
5
,),
(
0
,
1
,
0
,
0
),
(
0
,
7
)]
def
default_weight_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
"""Default weight loader."""
assert
param
.
size
()
==
loaded_weight
.
size
()
param
.
data
.
copy_
(
loaded_weight
)
def
pad_vocab_size
(
vocab_size
:
int
,
pad_to
:
int
=
DEFAULT_VOCAB_PADDING_SIZE
)
->
int
:
"""Pad the vocab size to the given value."""
return
((
vocab_size
+
pad_to
-
1
)
//
pad_to
)
*
pad_to
class
MedusaConfig
(
PretrainedConfig
):
model_type
=
"medusa"
def
__init__
(
self
,
hidden_size
:
int
=
4096
,
vocab_size
:
int
=
32001
,
num_heads
:
int
=
5
,
num_hidden_layers
:
int
=
1
,
max_paths
:
int
=
64
,
topk
:
int
=
10
,
truncated_vocab_size
:
Optional
[
int
]
=
None
,
**
kwargs
):
self
.
hidden_size
=
hidden_size
self
.
vocab_size
=
vocab_size
self
.
num_heads
=
num_heads
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_paths
=
max_paths
self
.
topk
=
topk
self
.
max_seq_len
=
int
(
2
**
20
)
self
.
truncated_vocab_size
=
vocab_size
if
truncated_vocab_size
is
None
\
else
truncated_vocab_size
if
"architectures"
not
in
kwargs
:
kwargs
[
"architectures"
]
=
[
"MedusaModel"
]
super
().
__init__
(
**
kwargs
)
@
property
def
num_attention_heads
(
self
):
return
0
@
property
def
num_lookahead_tokens
(
self
):
return
self
.
num_heads
@
num_lookahead_tokens
.
setter
def
num_lookahead_tokens
(
self
,
num_lookahead_tokens
:
int
):
self
.
num_heads
=
num_lookahead_tokens
class
VocabParallelEmbedding
(
torch
.
nn
.
Module
):
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
In order to support various loading methods, we ensure that LoRA-added
embeddings are always at the end of TP-sharded tensors. In other words,
we shard base embeddings and LoRA embeddings separately (both padded),
and place them in the same tensor.
In this example, we will have the original vocab size = 1010,
added vocab size = 16 and padding to 64. Therefore, the total
vocab size with padding will be 1088 (because we first pad 1010 to
1024, add 16, and then pad to 1088).
Therefore, the tensor format looks like the following:
TP1, rank 0 (no sharding):
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
TP2, rank 0:
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
TP2, rank 1:
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
quant_config: quant config for the layer
prefix: full name of the layer in the state dict
"""
# noqa: E501
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
org_num_embeddings
:
Optional
[
int
]
=
None
,
padding_size
:
int
=
DEFAULT_VOCAB_PADDING_SIZE
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
num_embeddings
=
num_embeddings
self
.
padding_size
=
padding_size
self
.
org_vocab_size
=
org_num_embeddings
or
num_embeddings
num_added_embeddings
=
num_embeddings
-
self
.
org_vocab_size
self
.
org_vocab_size_padded
=
pad_vocab_size
(
self
.
org_vocab_size
,
self
.
padding_size
)
self
.
num_embeddings_padded
=
pad_vocab_size
(
self
.
org_vocab_size_padded
+
num_added_embeddings
,
self
.
padding_size
)
assert
self
.
org_vocab_size_padded
<=
self
.
num_embeddings_padded
self
.
embedding_dim
=
embedding_dim
linear_method
=
None
if
quant_config
is
not
None
:
linear_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
if
linear_method
is
None
:
linear_method
=
UnquantizedLinearMethod
()
self
.
linear_method
:
QuantizeMethodBase
=
linear_method
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
linear_method
.
create_weights
(
self
,
self
.
embedding_dim
,
[
self
.
num_embeddings_padded
],
self
.
embedding_dim
,
self
.
num_embeddings_padded
,
params_dtype
=
params_dtype
,
weight_loader
=
self
.
weight_loader
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
assert
param
.
data
.
shape
==
loaded_weight
.
shape
param
.
data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
input_
):
masked_input
=
input_
# Get the embeddings.
output
=
F
.
embedding
(
masked_input
.
long
(),
self
.
weight
)
return
output
class
ParallelLMHead
(
VocabParallelEmbedding
):
"""Parallelized LM head.
Output logits weight matrices used in the Sampler. The weight and bias
tensors are padded to make sure they are divisible by the number of
model parallel GPUs.
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
bias: whether to use bias.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
bias
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
org_num_embeddings
:
Optional
[
int
]
=
None
,
padding_size
:
int
=
DEFAULT_VOCAB_PADDING_SIZE
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
params_dtype
,
org_num_embeddings
,
padding_size
,
quant_config
,
prefix
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
dtype
=
params_dtype
))
set_weight_attrs
(
self
.
bias
,
{
"output_dim"
:
0
,
"weight_loader"
:
self
.
weight_loader
,
})
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
input_
):
del
input_
raise
RuntimeError
(
"LMHead's weights should be used in the sampler."
)
class
ResidualBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_layers
:
int
)
->
None
:
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
([
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
)
for
_
in
range
(
num_layers
)
])
self
.
act
=
nn
.
SiLU
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
layer
in
self
.
layers
:
x
=
x
+
self
.
act
(
layer
(
x
))
return
x
class
Medusa
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MedusaConfig
,
**
_
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
blocks
=
nn
.
ModuleList
([
ResidualBlock
(
hidden_size
=
self
.
config
.
hidden_size
,
num_layers
=
self
.
config
.
num_hidden_layers
)
for
_
in
range
(
self
.
config
.
num_heads
)
])
self
.
orig_vocab_size
=
config
.
vocab_size
self
.
truncated_vocab_size
=
config
.
truncated_vocab_size
self
.
unpadded_vocab_size
=
self
.
truncated_vocab_size
self
.
lm_heads
=
nn
.
ModuleList
([
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
self
.
truncated_vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
)
for
_
in
range
(
self
.
config
.
num_heads
)
])
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
token_map
=
None
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
List
[
torch
.
Tensor
]:
return
[
block
(
hidden_states
)
for
block
in
self
.
blocks
]
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
())
weights_map
=
{}
for
name
,
loaded_weight
in
weights
:
name
=
name
.
replace
(
"medusa_heads."
,
""
)
if
name
==
"token_map"
:
if
self
.
truncated_vocab_size
<
self
.
orig_vocab_size
:
self
.
token_map
=
nn
.
Parameter
(
loaded_weight
,
requires_grad
=
False
)
elif
name
in
params_dict
:
weights_map
[
name
]
=
loaded_weight
for
name
,
loaded_weight
in
weights_map
.
items
():
if
"lm_head"
in
name
and
self
.
token_map
is
not
None
and
\
loaded_weight
.
shape
[
0
]
>
self
.
token_map
.
shape
[
0
]:
loaded_weight
=
loaded_weight
[
self
.
token_map
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
token_map
is
not
None
:
self
.
token_map
.
to
(
device
=
self
.
lm_heads
[
0
].
weight
.
device
)
assert
(
self
.
truncated_vocab_size
==
self
.
orig_vocab_size
)
or
(
self
.
token_map
is
not
None
)
class
CustomMedusaConfig
(
PretrainedConfig
):
model_type
=
"medusa"
def
__init__
(
self
,
name_or_path
:
str
=
"sugon/vllm-medusa-qwen1.5-7b-chat"
,
architectures
:
list
[
str
]
=
[
"MedusaModel"
],
hidden_size
:
int
=
4096
,
model_type
:
str
=
"medusa"
,
num_heads
:
int
=
5
,
num_hidden_layers
:
int
=
1
,
transformers_version
:
str
=
"4.41.2"
,
truncated_vocab_size
:
Optional
[
int
]
=
None
,
vocab_size
:
int
=
151936
,
medusa_choices
:
List
[
List
[
int
]]
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_name_or_path
=
name_or_path
self
.
architectures
=
architectures
self
.
hidden_size
=
hidden_size
self
.
model_type
=
model_type
self
.
num_heads
=
num_heads
self
.
num_hidden_layers
=
num_hidden_layers
self
.
transformers_version
=
transformers_version
self
.
truncated_vocab_size
=
truncated_vocab_size
self
.
vocab_size
=
vocab_size
self
.
medusa_choices
=
medusa_choices
def
main
(
args
):
# load the medusa config from the yaml file
medusa_config_path
=
args
.
medusa_config_path
with
open
(
medusa_config_path
,
encoding
=
"utf-8"
)
as
file
:
medusa_cfg
:
Dict
=
Dict
(
yaml
.
safe_load
(
file
))
medusa_head_num
=
medusa_cfg
.
medusa_num_heads
medusa_num_layers
=
medusa_cfg
.
medusa_num_layers
config
=
MedusaConfig
(
hidden_size
=
args
.
hidden_size
,
vocab_size
=
args
.
vocab_size
,
num_heads
=
medusa_head_num
)
medusa_model
=
Medusa
(
config
)
params_dict
=
dict
(
medusa_model
.
named_parameters
())
trained_medusa_model
=
torch
.
load
(
args
.
medusa_model_path
)
for
i
in
range
(
medusa_head_num
):
vllm_medusa_head_weight_name
=
VLLM_MEDUSA_HEADS_WEIGHT_NAME_TEMPLATE
.
format
(
i
)
trained_medusa_head_weight_name
=
TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE
.
format
(
i
)
vllm_medusa_head_param
=
params_dict
[
vllm_medusa_head_weight_name
]
trained_medusa_head_param
=
trained_medusa_model
[
trained_medusa_head_weight_name
]
weight_loader
=
getattr
(
vllm_medusa_head_param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
vllm_medusa_head_param
,
trained_medusa_head_param
)
for
i
in
range
(
medusa_head_num
):
for
j
in
range
(
medusa_num_layers
):
vllm_medusa_block_weight_name
=
VLLM_BLOCK_WEIGHT_NAME_TEMPLATE
.
format
(
i
,
j
)
trained_medusa_block_weight_name
=
TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE
.
format
(
i
,
j
)
vllm_medusa_block_param
=
params_dict
[
vllm_medusa_block_weight_name
]
trained_medusa_block_param
=
trained_medusa_model
[
trained_medusa_block_weight_name
]
weight_loader
=
getattr
(
vllm_medusa_block_param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
vllm_medusa_block_param
,
trained_medusa_block_param
)
if
not
Path
(
args
.
output_dir
).
is_dir
():
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
save_model
(
medusa_model
,
os
.
path
.
join
(
args
.
output_dir
,
"model.safetensors"
))
to_save_config
=
CustomMedusaConfig
(
name_or_path
=
os
.
path
.
join
(
args
.
output_dir
,
"config.json"
),
hidden_size
=
args
.
hidden_size
,
num_heads
=
medusa_head_num
,
num_hidden_layers
=
medusa_num_layers
,
vocab_size
=
args
.
vocab_size
,
medusa_choices
=
MEDUSA_CHOICES
)
to_save_config
.
save_pretrained
(
args
.
output_dir
)
# validate weight
# with safe_open("model.safetensors", framework="pt") as f:
# param = f.get_tensor(VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(0, 0))
# trained_param = trained_medusa_model[TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(0, 0)]
# mse_value = torch.nn.functional.mse_loss(param.cpu(), trained_param.cpu())
# print("weight mes:", mse_value)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Medusa Model Evaluator"
)
parser
.
add_argument
(
"--medusa_config_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to the medusa config file."
)
parser
.
add_argument
(
"--medusa_model_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to the medusa model file."
)
parser
.
add_argument
(
"--vocab_size"
,
type
=
int
,
required
=
True
,
help
=
"Vocab size"
)
parser
.
add_argument
(
"--hidden_size"
,
type
=
int
,
required
=
True
,
help
=
"Hidden size"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
required
=
True
,
help
=
"Output dir"
)
args
=
parser
.
parse_args
()
main
(
args
)
vllm/_custom_ops.py
View file @
e96edbbe
...
...
@@ -108,13 +108,16 @@ def paged_attention_v1(
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks_stride
:
int
=
0
,
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
def
paged_attention_v2
(
...
...
@@ -140,13 +143,16 @@ def paged_attention_v2(
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks_stride
:
int
=
0
,
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v2
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
# page attention ops (opt)
...
...
@@ -170,13 +176,16 @@ def paged_attention_v1_opt(
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks_stride
:
int
=
0
,
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v1_opt
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
def
paged_attention_v2_opt
(
...
...
@@ -202,13 +211,16 @@ def paged_attention_v2_opt(
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks_stride
:
int
=
0
,
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v2_opt
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
# page attention ops (opt)
...
...
@@ -232,12 +244,15 @@ def paged_attention_v1_opt_tc(
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks_stride
:
int
=
0
,
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v1_opt_tc
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
def
paged_attention_v2_opt_tc
(
...
...
@@ -263,13 +278,16 @@ def paged_attention_v2_opt_tc(
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks_stride
:
int
=
0
,
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v2_opt_tc
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
def
paged_attention_rocm
(
...
...
@@ -1142,6 +1160,31 @@ def register_graph_buffers(fa: int, handles: List[str],
offsets
:
List
[
List
[
int
]])
->
None
:
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
read_cache
(
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
key_caches
:
List
[
torch
.
Tensor
],
value_caches
:
List
[
torch
.
Tensor
],
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
read_cache
(
keys
,
values
,
key_caches
,
value_caches
,
slot_mapping
,
kv_cache_dtype
)
def
write_cache_multi_layers
(
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
key_caches
:
List
[
torch
.
Tensor
],
value_caches
:
List
[
torch
.
Tensor
],
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
write_cache_multi_layers
(
keys
,
values
,
key_caches
,
value_caches
,
slot_mapping
,
kv_cache_dtype
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
...
...
vllm/attention/backends/abstract.py
View file @
e96edbbe
...
...
@@ -82,6 +82,17 @@ class AttentionBackend(ABC):
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
raise
NotImplementedError
def
advance_step
(
self
,
model_input
:
"ModelRunnerInputBase"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
...
...
@@ -195,7 +206,8 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
@
abstractmethod
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
)
->
T
:
cuda_graph_pad_size
:
int
,
batch_size
:
int
,
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
)
->
T
:
"""Build attention metadata with on-device tensors."""
raise
NotImplementedError
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
e96edbbe
...
...
@@ -10,6 +10,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
from
vllm.attention.ops.blocksparse_attention.interface
import
(
LocalStridedBlockSparseAttn
,
get_head_sliding_step
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
...
...
@@ -128,6 +129,50 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
key_caches
=
[]
value_caches
=
[]
num_layers
=
len
(
kv_caches
)
token_num
=
src_to_dists
.
shape
[
0
]
tmp_store_kv
=
torch
.
empty
(
(
2
,
num_layers
,
token_num
,
num_kv_heads
,
head_size
),
dtype
=
kv_caches
[
0
].
dtype
,
device
=
kv_caches
[
0
].
device
)
keys
=
tmp_store_kv
[
0
].
contiguous
()
values
=
tmp_store_kv
[
1
].
contiguous
()
for
kv_cache
in
kv_caches
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
num_kv_heads
,
head_size
)
key_caches
.
append
(
key_cache
)
value_caches
.
append
(
value_cache
)
ops
.
read_cache
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
0
].
contiguous
(),
kv_cache_dtype
)
ops
.
write_cache_multi_layers
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
1
].
contiguous
(),
kv_cache_dtype
)
@
dataclass
class
BlocksparseFlashAttentionMetadata
(
AttentionMetadata
):
...
...
@@ -190,6 +235,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
"BlocksparseFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"BlocksparseFlashAttentionMetadata"
]
=
None
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
prefill_metadata
(
...
...
@@ -222,6 +269,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
use_cuda_graph
=
False
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
)
return
self
.
_cached_prefill_metadata
...
...
@@ -250,6 +298,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
)
return
self
.
_cached_decode_metadata
...
...
vllm/attention/backends/flash_attn.py
View file @
e96edbbe
...
...
@@ -475,7 +475,8 @@ class FlashAttentionMetadataBuilder(
self
.
block_size
,
inter_data
.
block_tables
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
cuda_graph_pad_size
:
int
,
batch_size
:
int
,
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
):
"""Build attention metadata with on-device tensors.
Args:
...
...
@@ -484,6 +485,7 @@ class FlashAttentionMetadataBuilder(
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
tree_attention_masks_tensor: attention mask used in tree style attention.
"""
prefix_cache_hit
=
any
([
inter_data
.
prefix_cache_hit
...
...
vllm/attention/backends/flashinfer.py
View file @
e96edbbe
...
...
@@ -92,6 +92,16 @@ class FlashInferBackend(AttentionBackend):
return
torch
.
float8_e5m2
else
:
raise
ValueError
(
f
"Unrecognized FP8 dtype:
{
kv_cache_dtype
}
"
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
NotImplementedError
class
FlashInferState
(
AttentionState
):
...
...
@@ -574,7 +584,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
paged_kv_last_page_len
.
append
(
last_page_len
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
cuda_graph_pad_size
:
int
,
batch_size
:
int
,
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
):
"""Build attention metadata with on-device tensors.
Args:
...
...
@@ -583,6 +594,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
tree_attention_masks_tensor: attention mask used in tree style attention.
"""
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
...
...
vllm/attention/backends/ipex_attn.py
View file @
e96edbbe
...
...
@@ -62,6 +62,16 @@ class IpexAttnBackend(AttentionBackend):
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
NotImplementedError
@
dataclass
class
IpexAttnMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
...
...
vllm/attention/backends/openvino.py
View file @
e96edbbe
...
...
@@ -62,6 +62,16 @@ class OpenVINOAttentionBackend(AttentionBackend):
key_cache
.
data
[
dst
,
:]
=
key_cache
.
data
[
src
,
:]
value_cache
.
data
[
dst
,
:]
=
value_cache
.
data
[
src
,
:]
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
NotImplementedError
@
dataclass
class
OpenVINOAttentionMetadata
:
...
...
vllm/attention/backends/pallas.py
View file @
e96edbbe
...
...
@@ -53,6 +53,16 @@ class PallasAttentionBackend(AttentionBackend):
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
v_cache
,
True
)
v_cache
[:,
dst_indices
]
=
v_cache
[:,
src_indices
]
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
NotImplementedError
@
dataclass
class
PallasMetadata
(
AttentionMetadata
):
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
e96edbbe
...
...
@@ -13,6 +13,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
if
TYPE_CHECKING
:
...
...
@@ -71,6 +72,50 @@ class ROCmFlashAttentionBackend(AttentionBackend):
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
key_caches
=
[]
value_caches
=
[]
num_layers
=
len
(
kv_caches
)
token_num
=
src_to_dists
.
shape
[
0
]
tmp_store_kv
=
torch
.
empty
(
(
2
,
num_layers
,
token_num
,
num_kv_heads
,
head_size
),
dtype
=
kv_caches
[
0
].
dtype
,
device
=
kv_caches
[
0
].
device
)
keys
=
tmp_store_kv
[
0
].
contiguous
()
values
=
tmp_store_kv
[
1
].
contiguous
()
for
kv_cache
in
kv_caches
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
num_kv_heads
,
head_size
)
key_caches
.
append
(
key_cache
)
value_caches
.
append
(
value_cache
)
ops
.
read_cache
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
0
].
contiguous
(),
kv_cache_dtype
)
ops
.
write_cache_multi_layers
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
1
].
contiguous
(),
kv_cache_dtype
)
@
dataclass
class
ROCmFlashAttentionMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
...
...
@@ -122,6 +167,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
_cached_prefill_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"ROCmFlashAttentionMetadata"
]:
if
self
.
num_prefills
==
0
:
...
...
@@ -152,6 +199,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
use_cuda_graph
=
False
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
)
return
self
.
_cached_prefill_metadata
...
...
@@ -180,6 +228,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
)
return
self
.
_cached_decode_metadata
...
...
@@ -613,6 +662,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
v_scale
,
)
else
:
tree_attention_masks_tensor
=
decode_meta
.
tree_attention_masks_tensor
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
decode_query
,
key_cache
,
...
...
@@ -626,6 +676,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
alibi_slopes
,
k_scale
,
v_scale
,
attn_masks
=
tree_attention_masks_tensor
,
attn_masks_stride
=
tree_attention_masks_tensor
.
stride
(
0
)
if
tree_attention_masks_tensor
is
not
None
else
0
)
# Reshape the output tensor.
...
...
vllm/attention/backends/torch_sdpa.py
View file @
e96edbbe
...
...
@@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.paged_attn
import
PagedAttentionMetadata
from
vllm.utils
import
is_cpu
from
vllm
import
_custom_ops
as
ops
if
is_cpu
():
try
:
...
...
@@ -64,6 +65,16 @@ class TorchSDPABackend(AttentionBackend):
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
NotImplementedError
@
dataclass
class
TorchSDPAMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
...
...
vllm/attention/backends/utils.py
View file @
e96edbbe
"""Attention backend utils"""
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Type
,
TypeVar
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Type
,
TypeVar
,
Union
,
Optional
import
numpy
as
np
import
torch
...
...
@@ -188,7 +188,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self
.
block_size
,
inter_data
.
block_tables
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
cuda_graph_pad_size
:
int
,
batch_size
:
int
,
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
):
"""Build attention metadata with on-device tensors.
Args:
...
...
@@ -271,6 +272,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
tree_attention_masks_tensor
=
tree_attention_masks_tensor
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment