Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3521ba4f
Unverified
Commit
3521ba4f
authored
May 04, 2024
by
SangBin Cho
Committed by
GitHub
May 03, 2024
Browse files
[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)
parent
2d7bce9c
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
373 additions
and
361 deletions
+373
-361
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+12
-13
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+38
-38
csrc/cpu/attention.cpp
csrc/cpu/attention.cpp
+46
-46
csrc/ops.h
csrc/ops.h
+4
-4
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+17
-18
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+8
-8
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+17
-17
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+2
-2
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+14
-10
tests/spec_decode/test_ngram_worker.py
tests/spec_decode/test_ngram_worker.py
+15
-9
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+4
-4
tests/test_logits_processor.py
tests/test_logits_processor.py
+4
-4
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+48
-51
vllm/_custom_ops.py
vllm/_custom_ops.py
+9
-9
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+22
-22
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+30
-30
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+18
-18
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+32
-33
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+17
-18
vllm/config.py
vllm/config.py
+16
-7
No files found.
benchmarks/kernels/benchmark_paged_attention.py
View file @
3521ba4f
...
...
@@ -16,7 +16,7 @@ PARTITION_SIZE = 512
def
main
(
version
:
str
,
num_seqs
:
int
,
context
_len
:
int
,
seq
_len
:
int
,
num_query_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
...
...
@@ -48,12 +48,12 @@ def main(
dtype
=
torch
.
float
,
device
=
device
)
context
_lens
=
[
context
_len
for
_
in
range
(
num_seqs
)]
max_
context
_len
=
max
(
context
_lens
)
context
_lens
=
torch
.
tensor
(
context
_lens
,
dtype
=
torch
.
int
,
device
=
device
)
seq
_lens
=
[
seq
_len
for
_
in
range
(
num_seqs
)]
max_
seq
_len
=
max
(
seq
_lens
)
seq
_lens
=
torch
.
tensor
(
seq
_lens
,
dtype
=
torch
.
int
,
device
=
device
)
# Create the block tables.
max_num_blocks_per_seq
=
(
max_
context
_len
+
block_size
-
1
)
//
block_size
max_num_blocks_per_seq
=
(
max_
seq
_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
...
...
@@ -77,8 +77,7 @@ def main(
# Prepare for the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
if
version
==
"v2"
:
num_partitions
=
((
max_context_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
...
...
@@ -110,9 +109,9 @@ def main(
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
seq
_lens
,
block_size
,
max_
context
_len
,
max_
seq
_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
...
...
@@ -129,9 +128,9 @@ def main(
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
seq
_lens
,
block_size
,
max_
context
_len
,
max_
seq
_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
...
...
@@ -166,7 +165,7 @@ if __name__ == '__main__':
choices
=
[
"v1"
,
"v2"
],
default
=
"v2"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--
context-
len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--
seq_
len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--num-query-heads"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
...
...
@@ -199,7 +198,7 @@ if __name__ == '__main__':
main
(
version
=
args
.
version
,
num_seqs
=
args
.
batch_size
,
context
_len
=
args
.
context
_len
,
seq
_len
=
args
.
seq
_len
,
num_query_heads
=
args
.
num_query_heads
,
num_kv_heads
=
args
.
num_kv_heads
,
head_size
=
args
.
head_size
,
...
...
csrc/attention/attention_kernels.cu
View file @
3521ba4f
...
...
@@ -104,7 +104,7 @@ __device__ void paged_attention_kernel(
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
...
...
@@ -115,23 +115,23 @@ __device__ void paged_attention_kernel(
const
int
partition_idx
=
blockIdx
.
z
;
const
int
max_num_partitions
=
gridDim
.
z
;
constexpr
bool
USE_PARTITIONING
=
PARTITION_SIZE
>
0
;
const
int
context_len
=
context
_lens
[
seq_idx
];
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
context
_len
)
{
const
int
seq_len
=
seq
_lens
[
seq_idx
];
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq
_len
)
{
// No work to do. Terminate the thread block.
return
;
}
const
int
num_
context
_blocks
=
DIVIDE_ROUND_UP
(
context
_len
,
BLOCK_SIZE
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_
context
_blocks
;
const
int
num_
seq
_blocks
=
DIVIDE_ROUND_UP
(
seq
_len
,
BLOCK_SIZE
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_
seq
_blocks
;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const
int
start_block_idx
=
USE_PARTITIONING
?
partition_idx
*
num_blocks_per_partition
:
0
;
const
int
end_block_idx
=
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_
context
_blocks
);
const
int
end_block_idx
=
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_
seq
_blocks
);
const
int
num_blocks
=
end_block_idx
-
start_block_idx
;
// [start_token_idx, end_token_idx) is the range of tokens to process.
const
int
start_token_idx
=
start_block_idx
*
BLOCK_SIZE
;
const
int
end_token_idx
=
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
context
_len
);
const
int
end_token_idx
=
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
seq
_len
);
const
int
num_tokens
=
end_token_idx
-
start_token_idx
;
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
...
...
@@ -245,12 +245,12 @@ __device__ void paged_attention_kernel(
// This includes a reduction across the threads in the same thread group.
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
[
thread_group_offset
],
k_vecs
);
// Add the ALiBi bias if slopes are given.
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
context
_len
+
1
)
:
0
;
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq
_len
+
1
)
:
0
;
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const
bool
mask
=
token_idx
>=
context
_len
;
const
bool
mask
=
token_idx
>=
seq
_len
;
logits
[
token_idx
-
start_token_idx
]
=
mask
?
0.
f
:
qk
;
// Update the max value.
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
...
...
@@ -364,14 +364,14 @@ __device__ void paged_attention_kernel(
}
else
{
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
}
if
(
block_idx
==
num_
context
_blocks
-
1
)
{
if
(
block_idx
==
num_
seq
_blocks
-
1
)
{
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs.
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
V_VEC_SIZE
;
j
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
context
_len
?
v_vec_ptr
[
j
]
:
zero_value
;
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq
_len
?
v_vec_ptr
[
j
]
:
zero_value
;
}
}
accs
[
i
]
+=
dot
(
logits_vec
,
v_vec
);
...
...
@@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel(
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
...
...
@@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel(
const
float
kv_scale
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
IS_FP8_KV_CACHE
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq
_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
kv_scale
);
}
...
...
@@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel(
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
...
...
@@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel(
const
float
kv_scale
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
IS_FP8_KV_CACHE
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
block_tables
,
seq
_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
kv_scale
);
}
...
...
@@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel(
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_partitions
)
{
const
int
num_heads
=
gridDim
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
context
_len
,
PARTITION_SIZE
);
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
seq
_len
,
PARTITION_SIZE
);
if
(
num_partitions
==
1
)
{
// No need to reduce. Only copy tmp_out to out.
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
...
...
@@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel(
num_kv_heads, \
scale, \
block_tables_ptr, \
context
_lens_ptr, \
seq
_lens_ptr,
\
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
...
...
@@ -639,8 +639,8 @@ void paged_attention_v1_launcher(
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
int
max_
context
_len
,
torch
::
Tensor
&
seq
_lens
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
kv_scale
)
{
int
num_seqs
=
query
.
size
(
0
);
...
...
@@ -664,11 +664,11 @@ void paged_attention_v1_launcher(
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context
_lens_ptr
=
context
_lens
.
data_ptr
<
int
>
();
int
*
seq
_lens_ptr
=
seq
_lens
.
data_ptr
<
int
>
();
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
padded_max_
context
_len
=
DIVIDE_ROUND_UP
(
max_
context
_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
logits_size
=
padded_max_
context
_len
*
sizeof
(
float
);
int
padded_max_
seq
_len
=
DIVIDE_ROUND_UP
(
max_
seq
_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
logits_size
=
padded_max_
seq
_len
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
...
...
@@ -715,8 +715,8 @@ void paged_attention_v1_launcher(
num_kv_heads, \
scale, \
block_tables, \
context
_lens, \
max_
context
_len, \
seq
_lens, \
max_
seq
_len, \
alibi_slopes, \
kv_scale);
...
...
@@ -746,9 +746,9 @@ void paged_attention_v1(
int
num_kv_heads
,
// [num_heads]
float
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context
_lens
,
// [num_seqs]
torch
::
Tensor
&
seq
_lens
,
// [num_seqs]
int
block_size
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
...
...
@@ -790,7 +790,7 @@ void paged_attention_v1(
num_kv_heads, \
scale, \
block_tables_ptr, \
context
_lens_ptr, \
seq
_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
...
...
@@ -803,7 +803,7 @@ void paged_attention_v1(
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
context
_lens_ptr, \
seq
_lens_ptr, \
max_num_partitions);
template
<
...
...
@@ -824,8 +824,8 @@ void paged_attention_v2_launcher(
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
int
max_
context
_len
,
torch
::
Tensor
&
seq
_lens
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
kv_scale
)
{
int
num_seqs
=
query
.
size
(
0
);
...
...
@@ -852,10 +852,10 @@ void paged_attention_v2_launcher(
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context
_lens_ptr
=
context
_lens
.
data_ptr
<
int
>
();
int
*
seq
_lens_ptr
=
seq
_lens
.
data_ptr
<
int
>
();
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_
context
_len
,
PARTITION_SIZE
);
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_
seq
_len
,
PARTITION_SIZE
);
int
logits_size
=
PARTITION_SIZE
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
...
...
@@ -909,8 +909,8 @@ void paged_attention_v2_launcher(
num_kv_heads, \
scale, \
block_tables, \
context
_lens, \
max_
context
_len, \
seq
_lens, \
max_
seq
_len, \
alibi_slopes, \
kv_scale);
...
...
@@ -943,9 +943,9 @@ void paged_attention_v2(
int
num_kv_heads
,
// [num_heads]
float
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context
_lens
,
// [num_seqs]
torch
::
Tensor
&
seq
_lens
,
// [num_seqs]
int
block_size
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
...
...
csrc/cpu/attention.cpp
View file @
3521ba4f
...
...
@@ -70,11 +70,11 @@ template <typename T>
FORCE_INLINE
std
::
pair
<
T
,
T
>
reduceSoftmaxAlibi
(
T
*
data
,
const
int
size
,
const
int
capacity
,
const
float
alibi_slope
,
const
int
start_index
,
const
int
context
_len
)
{
data
[
0
]
+=
alibi_slope
*
(
start_index
-
context
_len
+
1
);
const
int
seq
_len
)
{
data
[
0
]
+=
alibi_slope
*
(
start_index
-
seq
_len
+
1
);
T
max
=
data
[
0
];
for
(
int
i
=
1
;
i
<
size
;
++
i
)
{
T
qk
=
data
[
i
]
+
alibi_slope
*
(
start_index
+
i
-
context
_len
+
1
);
T
qk
=
data
[
i
]
+
alibi_slope
*
(
start_index
+
i
-
seq
_len
+
1
);
data
[
i
]
=
qk
;
max
=
max
>=
qk
?
max
:
qk
;
}
...
...
@@ -225,7 +225,7 @@ struct paged_attention_v1_impl {
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
...
...
@@ -235,32 +235,32 @@ struct paged_attention_v1_impl {
static_assert
(
BLOCK_SIZE
==
16
);
int
max_
context
_len
=
max_num_blocks_per_seq
*
BLOCK_SIZE
;
int
max_
context
_len_padded
=
(
max_
context
_len
+
15
)
&
0xFFFFFFF0
;
TORCH_CHECK
((
max_
context
_len_padded
*
sizeof
(
float
))
%
64
==
0
);
int
max_
seq
_len
=
max_num_blocks_per_seq
*
BLOCK_SIZE
;
int
max_
seq
_len_padded
=
(
max_
seq
_len
+
15
)
&
0xFFFFFFF0
;
TORCH_CHECK
((
max_
seq
_len_padded
*
sizeof
(
float
))
%
64
==
0
);
const
int
parallel_work_item_num
=
omp_get_max_threads
();
size_t
logits_bytes
=
parallel_work_item_num
*
max_
context
_len_padded
*
sizeof
(
float
);
parallel_work_item_num
*
max_
seq
_len_padded
*
sizeof
(
float
);
float
*
logits
=
(
float
*
)
std
::
aligned_alloc
(
64
,
logits_bytes
);
// Cacheline alignment for each context token.
// [parallel_work_item_num, max_
context
_len_padded]
// [parallel_work_item_num, max_
seq
_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
for
(
int
seq_idx
=
0
;
seq_idx
<
num_seqs
;
++
seq_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
int
context_len
=
context
_lens
[
seq_idx
];
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
*
seq_block_table
=
block_tables
+
max_num_blocks_per_seq
*
seq_idx
;
const
int
block_num
=
(
context
_len
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
const
int
block_num
=
(
seq
_len
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
const
int64_t
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
scalar_t
*
__restrict__
q_vec_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
const
int
last_block_token_num
=
context
_len
-
(
block_num
-
1
)
*
BLOCK_SIZE
;
seq
_len
-
(
block_num
-
1
)
*
BLOCK_SIZE
;
float
*
__restrict__
thread_block_logits
=
logits
+
omp_get_thread_num
()
*
max_
context
_len_padded
;
logits
+
omp_get_thread_num
()
*
max_
seq
_len_padded
;
// Compute logits
for
(
int
block_idx
=
0
;
block_idx
<
block_num
;
++
block_idx
)
{
...
...
@@ -278,11 +278,11 @@ struct paged_attention_v1_impl {
// Compute softmax
if
(
alibi_slopes
)
{
reduceSoftmaxAlibi
(
thread_block_logits
,
context
_len
,
reduceSoftmaxAlibi
(
thread_block_logits
,
seq
_len
,
block_num
*
BLOCK_SIZE
,
alibi_slopes
[
head_idx
],
0
,
context
_len
);
seq
_len
);
}
else
{
reduceSoftmax
(
thread_block_logits
,
context
_len
,
reduceSoftmax
(
thread_block_logits
,
seq
_len
,
block_num
*
BLOCK_SIZE
);
}
...
...
@@ -340,7 +340,7 @@ struct paged_attention_v1_impl {
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr,
context
_lens_ptr, max_num_blocks_per_seq, \
block_tables_ptr,
seq
_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads);
...
...
@@ -348,8 +348,8 @@ template <typename T, int BLOCK_SIZE>
void
paged_attention_v1_impl_launcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
int
max_
context
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
)
{
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq
_lens
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
...
...
@@ -369,7 +369,7 @@ void paged_attention_v1_impl_launcher(
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
T
*
value_cache_ptr
=
reinterpret_cast
<
T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context
_lens_ptr
=
context
_lens
.
data_ptr
<
int
>
();
int
*
seq
_lens_ptr
=
seq
_lens
.
data_ptr
<
int
>
();
switch
(
head_size
)
{
case
64
:
...
...
@@ -399,7 +399,7 @@ void paged_attention_v1_impl_launcher(
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
context
_lens, max_
context
_len, alibi_slopes);
seq
_lens, max_
seq
_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
...
...
@@ -416,8 +416,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
int
block_size
,
int
max_
context
_len
,
torch
::
Tensor
&
seq
_lens
,
int
block_size
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
TORCH_CHECK
(
kv_scale
==
1.0
f
);
...
...
@@ -448,7 +448,7 @@ struct paged_attention_v2_impl {
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
...
...
@@ -465,22 +465,22 @@ struct paged_attention_v2_impl {
for
(
int
partition_idx
=
0
;
partition_idx
<
max_num_partitions
;
++
partition_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
start_token_idx
=
partition_idx
*
PARTITION_SIZE
;
if
(
start_token_idx
>=
context
_len
)
if
(
start_token_idx
>=
seq
_len
)
continue
;
const
int
partition_num
=
(
context
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
(
seq
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
const
bool
no_reduce
=
(
partition_num
==
1
);
const
int
context_
token_num
=
(
std
::
min
(
context
_len
,
start_token_idx
+
PARTITION_SIZE
)
-
const
int
token_num
=
(
std
::
min
(
seq
_len
,
start_token_idx
+
PARTITION_SIZE
)
-
start_token_idx
);
const
int
block_num
=
(
context_
token_num
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
(
token_num
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
const
int
last_block_token_num
=
context_
token_num
-
(
block_num
-
1
)
*
BLOCK_SIZE
;
token_num
-
(
block_num
-
1
)
*
BLOCK_SIZE
;
const
int
*
seq_block_table
=
block_tables
+
max_num_blocks_per_seq
*
seq_idx
+
start_token_idx
/
BLOCK_SIZE
;
...
...
@@ -507,10 +507,10 @@ struct paged_attention_v2_impl {
std
::
pair
<
float
,
float
>
max_and_sum
;
if
(
alibi_slopes
)
{
max_and_sum
=
reduceSoftmaxAlibi
(
logits
,
context_
token_num
,
block_num
*
BLOCK_SIZE
,
alibi_slopes
[
head_idx
],
start_token_idx
,
context
_len
);
logits
,
token_num
,
block_num
*
BLOCK_SIZE
,
alibi_slopes
[
head_idx
],
start_token_idx
,
seq
_len
);
}
else
{
max_and_sum
=
reduceSoftmax
(
logits
,
context_
token_num
,
max_and_sum
=
reduceSoftmax
(
logits
,
token_num
,
block_num
*
BLOCK_SIZE
);
}
...
...
@@ -583,9 +583,9 @@ struct paged_attention_v2_impl {
#pragma omp parallel for collapse(2) schedule(static, 1)
for
(
int
seq_idx
=
0
;
seq_idx
<
num_seqs
;
++
seq_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
partition_num
=
(
context
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
(
seq
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
if
(
partition_num
==
1
)
continue
;
...
...
@@ -612,9 +612,9 @@ struct paged_attention_v2_impl {
for
(
int
seq_idx
=
0
;
seq_idx
<
num_seqs
;
++
seq_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
for
(
int
group_idx
=
0
;
group_idx
<
head_group_num
;
++
group_idx
)
{
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
partition_num
=
(
context
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
(
seq
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
if
(
partition_num
==
1
)
continue
;
...
...
@@ -649,7 +649,7 @@ struct paged_attention_v2_impl {
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
context
_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
seq
_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions);
...
...
@@ -658,8 +658,8 @@ void paged_attention_v2_impl_launcher(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
int
block_size
,
int
max_
context
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
)
{
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq
_lens
,
int
block_size
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
...
...
@@ -683,7 +683,7 @@ void paged_attention_v2_impl_launcher(
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
T
*
value_cache_ptr
=
reinterpret_cast
<
T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context
_lens_ptr
=
context
_lens
.
data_ptr
<
int
>
();
int
*
seq
_lens_ptr
=
seq
_lens
.
data_ptr
<
int
>
();
switch
(
head_size
)
{
case
64
:
...
...
@@ -713,8 +713,8 @@ void paged_attention_v2_impl_launcher(
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables,
context
_lens, block_size, \
max_
context
_len, alibi_slopes);
num_kv_heads, scale, block_tables,
seq
_lens, block_size, \
max_
seq
_len, alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
...
...
@@ -732,8 +732,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
int
block_size
,
int
max_
context
_len
,
torch
::
Tensor
&
seq
_lens
,
int
block_size
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
TORCH_CHECK
(
kv_scale
==
1.0
f
);
...
...
csrc/ops.h
View file @
3521ba4f
...
...
@@ -10,9 +10,9 @@ void paged_attention_v1(
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
torch
::
Tensor
&
seq
_lens
,
int
block_size
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
...
...
@@ -28,9 +28,9 @@ void paged_attention_v2(
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
torch
::
Tensor
&
seq
_lens
,
int
block_size
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
...
...
tests/kernels/test_attention.py
View file @
3521ba4f
...
...
@@ -61,7 +61,7 @@ def ref_single_query_cached_kv_attention(
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context
_lens
:
torch
.
Tensor
,
seq
_lens
:
torch
.
Tensor
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
)
->
None
:
...
...
@@ -72,15 +72,15 @@ def ref_single_query_cached_kv_attention(
num_seqs
=
query
.
shape
[
0
]
block_tables
=
block_tables
.
cpu
().
tolist
()
context
_lens
=
context
_lens
.
cpu
().
tolist
()
seq
_lens
=
seq
_lens
.
cpu
().
tolist
()
for
i
in
range
(
num_seqs
):
q
=
query
[
i
].
unsqueeze
(
0
)
block_table
=
block_tables
[
i
]
context
_len
=
int
(
context
_lens
[
i
])
seq
_len
=
int
(
seq
_lens
[
i
])
keys
=
[]
values
=
[]
for
j
in
range
(
context
_len
):
for
j
in
range
(
seq
_len
):
block_number
=
int
(
block_table
[
j
//
block_size
])
block_offset
=
j
%
block_size
...
...
@@ -100,8 +100,8 @@ def ref_single_query_cached_kv_attention(
alibi_bias
=
None
if
alibi_slopes
is
not
None
:
# Create the ALiBi bias used in the paged attention kernel.
position_ids
=
torch
.
arange
(
context
_len
).
int
()
alibi_bias
=
(
position_ids
-
context
_len
+
1
).
float
()
position_ids
=
torch
.
arange
(
seq
_len
).
int
()
alibi_bias
=
(
position_ids
-
seq
_len
+
1
).
float
()
alibi_bias
=
alibi_slopes
.
view
(
-
1
,
1
,
1
)
*
alibi_bias
.
view
(
1
,
1
,
-
1
)
...
...
@@ -149,13 +149,13 @@ def test_paged_attention(
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
)
context
_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_seqs
)]
context
_lens
[
-
1
]
=
MAX_SEQ_LEN
max_
context
_len
=
max
(
context
_lens
)
context
_lens
=
torch
.
tensor
(
context
_lens
,
dtype
=
torch
.
int
)
seq
_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_seqs
)]
seq
_lens
[
-
1
]
=
MAX_SEQ_LEN
max_
seq
_len
=
max
(
seq
_lens
)
seq
_lens
=
torch
.
tensor
(
seq
_lens
,
dtype
=
torch
.
int
)
# Create the block tables.
max_num_blocks_per_seq
=
(
max_
context
_len
+
block_size
-
1
)
//
block_size
max_num_blocks_per_seq
=
(
max_
seq
_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
...
...
@@ -186,16 +186,15 @@ def test_paged_attention(
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
seq
_lens
,
block_size
,
max_
context
_len
,
max_
seq
_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
)
elif
version
==
"v2"
:
num_partitions
=
((
max_context_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
assert
PARTITION_SIZE
%
block_size
==
0
num_seqs
,
num_heads
,
head_size
=
output
.
shape
tmp_output
=
torch
.
empty
(
...
...
@@ -218,9 +217,9 @@ def test_paged_attention(
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
seq
_lens
,
block_size
,
max_
context
_len
,
max_
seq
_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
...
...
@@ -255,7 +254,7 @@ def test_paged_attention(
key_cache
,
value_cache
,
block_tables
,
context
_lens
,
seq
_lens
,
scale
,
alibi_slopes
,
)
...
...
tests/kernels/test_prefix_prefill.py
View file @
3521ba4f
...
...
@@ -51,12 +51,12 @@ def test_contexted_kv_attention(
cache_size
=
640
block_size
=
32
max_block_per_request
=
64
sub
query_lens
=
[
random
.
randint
(
16
,
MAX_SEQ_LEN
)
for
_
in
range
(
BS
)]
query_lens
=
[
random
.
randint
(
16
,
MAX_SEQ_LEN
)
for
_
in
range
(
BS
)]
ctx_lens
=
[
random
.
randint
(
16
,
MAX_CTX_LEN
)
for
_
in
range
(
BS
)]
seq_lens
=
[
a
+
b
for
a
,
b
in
zip
(
sub
query_lens
,
ctx_lens
)]
seq_lens
=
[
a
+
b
for
a
,
b
in
zip
(
query_lens
,
ctx_lens
)]
num_kv_heads
=
num_heads
//
num_queries_per_kv
num_tokens
=
sum
(
sub
query_lens
)
num_tokens
=
sum
(
query_lens
)
query
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
)
query
.
uniform_
(
-
1e-3
,
1e-3
)
output
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
)
...
...
@@ -75,15 +75,15 @@ def test_contexted_kv_attention(
num_kv_heads
,
head_size
,
dtype
=
dtype
)
k
=
torch
.
zeros
(
sum
(
sub
query_lens
),
num_kv_heads
,
head_size
,
dtype
=
dtype
)
v
=
torch
.
zeros
(
sum
(
sub
query_lens
),
num_kv_heads
,
head_size
,
dtype
=
dtype
)
k
=
torch
.
zeros
(
sum
(
query_lens
),
num_kv_heads
,
head_size
,
dtype
=
dtype
)
v
=
torch
.
zeros
(
sum
(
query_lens
),
num_kv_heads
,
head_size
,
dtype
=
dtype
)
values
=
torch
.
arange
(
0
,
cache_size
,
dtype
=
torch
.
long
)
values
=
values
[
torch
.
randperm
(
cache_size
)]
block_table
=
values
[:
BS
*
max_block_per_request
].
view
(
BS
,
max_block_per_request
)
b_seq_len
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
long
)
b_ctx_len
=
torch
.
tensor
(
ctx_lens
,
dtype
=
torch
.
long
)
b_start_loc
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
sub
query_lens
[:
-
1
],
b_start_loc
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
query_lens
[:
-
1
],
dtype
=
torch
.
long
),
dim
=
0
)
max_input_len
=
MAX_SEQ_LEN
...
...
@@ -92,7 +92,7 @@ def test_contexted_kv_attention(
dtype
=
torch
.
long
),
dim
=
0
)
for
i
in
range
(
BS
):
for
j
in
range
(
sub
query_lens
[
i
]):
for
j
in
range
(
query_lens
[
i
]):
k
[
b_start_loc
[
i
]
+
j
].
copy_
(
key
[
b_seq_start_loc
[
i
]
+
b_ctx_len
[
i
]
+
j
])
v
[
b_start_loc
[
i
]
+
j
].
copy_
(
value
[
b_seq_start_loc
[
i
]
+
...
...
@@ -178,7 +178,7 @@ def test_contexted_kv_attention(
value
=
value
.
unsqueeze
(
0
)
attn_bias
=
BlockDiagonalCausalFromBottomRightMask
.
from_seqlens
(
sub
query_lens
,
seq_lens
)
query_lens
,
seq_lens
)
if
sliding_window
>
0
:
attn_bias
=
attn_bias
.
make_local_attention_from_bottomright
(
sliding_window
)
...
...
tests/samplers/test_sampler.py
View file @
3521ba4f
...
...
@@ -58,7 +58,7 @@ def _do_sample(
device
:
str
,
):
seq_group_metadata_list
=
[]
prompt
_lens
=
[]
seq
_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
...
...
@@ -68,12 +68,12 @@ def _do_sample(
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
))
prompt
_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
seq
_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
prompt
_lens
,
sub
query_lens
=
prompt
_lens
,
seq
_lens
,
query_lens
=
seq
_lens
,
device
=
device
,
pin_memory
=
model_runner
.
pin_memory
)
return
sampler
(
logits
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
...
...
@@ -421,7 +421,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
"Invalid test case, need seq_group_metadata_list"
batch_size
=
0
prompt
_lens
=
[]
seq
_lens
=
[]
sampling_params_per_row
=
[]
for
sgm
in
seq_group_metadata_list
:
sampling_params
=
sgm
.
sampling_params
...
...
@@ -431,7 +431,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
# a prompt seq_group has only one sequence
seq_data
=
next
(
iter
(
sgm
.
seq_data
.
values
()))
prompt_len
=
seq_data
.
get_prompt_len
()
prompt
_lens
.
append
(
prompt_len
)
seq
_lens
.
append
(
prompt_len
)
if
sgm
.
sampling_params
.
prompt_logprobs
:
# with prompt_logprobs each token in the prompt has a row in
...
...
@@ -451,8 +451,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
_
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
prompt_lens
=
prompt
_lens
if
prompt
_lens
else
None
,
sub
query_lens
=
prompt
_lens
if
prompt
_lens
else
None
,
seq_lens
=
seq
_lens
if
seq
_lens
else
None
,
query_lens
=
seq
_lens
if
seq
_lens
else
None
,
device
=
device
,
pin_memory
=
model_runner
.
pin_memory
)
# the logits tensor is modified in-place by the sampler
...
...
@@ -497,7 +497,7 @@ def test_sampler_mixed(seed: int, device: str):
seq_group_metadata_list
=
[]
expected_tokens
:
List
[
Optional
[
List
[
int
]]]
=
[]
prompt
_lens
=
[]
seq
_lens
=
[]
for
i
in
range
(
batch_size
):
expected
:
Optional
[
List
[
int
]]
=
None
sampling_type
=
random
.
randint
(
0
,
3
)
...
...
@@ -532,13 +532,13 @@ def test_sampler_mixed(seed: int, device: str):
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
))
prompt
_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
seq
_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
def
test_sampling
(
model_runner
:
ModelRunner
):
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
prompt
_lens
,
sub
query_lens
=
prompt
_lens
,
seq
_lens
,
query_lens
=
seq
_lens
,
device
=
device
,
pin_memory
=
model_runner
.
pin_memory
)
sampler_output
=
sampler
(
logits
=
fake_logits
,
...
...
@@ -575,7 +575,7 @@ def test_sampler_mixed(seed: int, device: str):
# Shuffle the batch and resample
target_index
=
list
(
range
(
batch_size
))
for
list_to_shuffle
in
(
target_index
,
seq_group_metadata_list
,
expected_tokens
,
prompt
_lens
):
expected_tokens
,
seq
_lens
):
random
.
Random
(
seed
).
shuffle
(
list_to_shuffle
)
target_index
=
torch
.
tensor
(
target_index
)
input_tensor
.
data
=
input_tensor
.
index_select
(
0
,
target_index
)
...
...
@@ -620,7 +620,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
assert
len
(
warpers
)
==
2
# top_p and top_k
seq_group_metadata_list
=
[]
prompt
_lens
=
[]
seq
_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
...
...
@@ -634,12 +634,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
),
block_tables
=
{
0
:
[
1
]},
))
prompt
_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
seq
_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
prompt
_lens
,
sub
query_lens
=
prompt
_lens
,
seq
_lens
,
query_lens
=
seq
_lens
,
device
=
device
,
pin_memory
=
model_runner
.
pin_memory
)
...
...
tests/spec_decode/e2e/conftest.py
View file @
3521ba4f
...
...
@@ -45,7 +45,7 @@ class AsyncLLM:
gpu_memory_utilization
:
float
=
0.9
,
swap_space
:
int
=
4
,
enforce_eager
:
bool
=
False
,
max_
context
_len_to_capture
:
int
=
8192
,
max_
seq
_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
**
kwargs
,
)
->
None
:
...
...
@@ -66,7 +66,7 @@ class AsyncLLM:
gpu_memory_utilization
=
gpu_memory_utilization
,
swap_space
=
swap_space
,
enforce_eager
=
enforce_eager
,
max_
context
_len_to_capture
=
max_
context
_len_to_capture
,
max_
seq
_len_to_capture
=
max_
seq
_len_to_capture
,
engine_use_ray
=
True
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
**
kwargs
,
...
...
tests/spec_decode/test_multi_step_worker.py
View file @
3521ba4f
...
...
@@ -34,7 +34,7 @@ def test_assert_enough_kv_space(num_steps: int):
list
(
range
(
block_size
*
2
)),
]
final_
seq
_lens
=
[
final_
prompt
_lens
=
[
len
(
prompt
+
output
)
+
num_steps
for
prompt
,
output
in
zip
(
prompts
,
prev_output_tokens
)
]
...
...
@@ -43,7 +43,7 @@ def test_assert_enough_kv_space(num_steps: int):
prompts
,
num_gpu_blocks
,
block_size
,
final_
seq
_lens
,
final_
prompt
_lens
,
continuations
=
prev_output_tokens
)
assert_enough_kv_space
=
MultiStepWorker
.
_assert_enough_kv_space
# pylint: disable=protected-access
...
...
@@ -103,17 +103,21 @@ def test_same_output_for_single_step():
[
6
,
7
,
8
,
9
,
10
],
]
final_
seq
_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
final_
prompt
_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
multi_step_execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_seq_lens
=
final_seq_lens
))
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
))
single_step_execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_seq_lens
=
final_seq_lens
))
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
))
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
...
...
@@ -181,7 +185,7 @@ def test_same_output_for_multi_step():
random
.
randint
(
0
,
1000
)
for
_
in
range
(
random
.
randint
(
10
,
20
))
]
for
_
in
range
(
10
)]
final_
seq
_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
final_
prompt
_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
rand_seeds
=
list
(
random
.
randint
(
0
,
100
)
for
_
in
range
(
num_steps
))
multi_step_worker
.
execute_model
=
patch_execute_model_with_seeds
(
...
...
@@ -195,7 +199,7 @@ def test_same_output_for_multi_step():
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_
seq
_lens
=
final_
seq
_lens
),
)
final_
prompt
_lens
=
final_
prompt
_lens
),
)
# Run multi-step.
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
...
...
@@ -217,7 +221,7 @@ def test_same_output_for_multi_step():
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_
seq
_lens
=
final_
seq
_lens
))
final_
prompt
_lens
=
final_
prompt
_lens
))
single_step_output
.
extend
(
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
))
...
...
tests/spec_decode/test_ngram_worker.py
View file @
3521ba4f
...
...
@@ -43,11 +43,13 @@ def test_ngram_algo_correctness_for_single_no_match():
]
proposal_len
=
5
final_
seq
_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
final_
prompt
_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
ngram_sampler_output_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_seq_lens
=
final_seq_lens
))
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
))
proposals
=
proposer
.
get_proposals
(
**
ngram_sampler_output_data
.
to_dict
(),
...
...
@@ -110,11 +112,13 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
]
proposal_len
=
5
final_
seq
_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
final_
prompt
_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
ngram_sampler_output_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_seq_lens
=
final_seq_lens
))
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
))
proposals
=
proposer
.
get_proposals
(
**
ngram_sampler_output_data
.
to_dict
(),
...
...
@@ -180,11 +184,13 @@ def test_ngram_algo_correctness_for_batches_match_all():
]
proposal_len
=
5
final_
seq
_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
final_
prompt
_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
ngram_sampler_output_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_seq_lens
=
final_seq_lens
))
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
))
proposals
=
proposer
.
get_proposals
(
**
ngram_sampler_output_data
.
to_dict
(),
...
...
tests/spec_decode/utils.py
View file @
3521ba4f
...
...
@@ -144,7 +144,7 @@ def create_seq_group_metadata_from_prompts(
prompts
:
List
[
List
[
int
]],
num_gpu_blocks
:
int
,
block_size
:
int
,
final_
seq
_lens
:
List
[
int
],
final_
prompt
_lens
:
List
[
int
],
continuations
:
Optional
[
List
[
List
[
int
]]]
=
None
,
seq_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
List
[
SequenceGroupMetadata
]:
...
...
@@ -162,7 +162,7 @@ def create_seq_group_metadata_from_prompts(
free_gpu_blocks
.
pop
()
for
_
in
range
(
round_up_to_next_block
(
final_len
,
block_size
))
]
for
i
,
final_len
in
enumerate
(
final_
seq
_lens
)
for
i
,
final_len
in
enumerate
(
final_
prompt
_lens
)
}
return
[
...
...
@@ -251,13 +251,13 @@ def create_batch(batch_size,
prev_output_tokens
=
[[
next
(
iterator
)
for
_
in
range
(
prev_output_token_len
)
]
for
_
in
range
(
batch_size
)]
final_
seq
_lens
=
[
final_
prompt
_lens
=
[
len
(
prompt
)
+
len
(
prev_output_token
)
+
k
+
1
for
prompt
,
prev_output_token
in
zip
(
prompts
,
prev_output_tokens
)
]
execute_model_data
=
create_execute_model_data
(
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_
seq
_lens
,
block_size
,
final_
prompt
_lens
,
prev_output_tokens
,
seq_ids
),
)
return
execute_model_data
,
prompts
,
prev_output_tokens
tests/test_logits_processor.py
View file @
3521ba4f
...
...
@@ -70,7 +70,7 @@ def test_logits_processors(seed: int, device: str):
return
logits
seq_group_metadata_list
=
[]
prompt
_lens
=
[]
seq
_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
...
...
@@ -81,12 +81,12 @@ def test_logits_processors(seed: int, device: str):
logits_processors
=
[
pick_ith
]),
block_tables
=
{
0
:
[
1
]},
))
prompt
_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
seq
_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
prompt
_lens
,
sub
query_lens
=
prompt
_lens
,
seq
_lens
,
query_lens
=
seq
_lens
,
device
=
model_runner
.
device
,
pin_memory
=
model_runner
.
pin_memory
)
logits_processor_output
=
logits_processor
(
...
...
tests/worker/test_model_runner.py
View file @
3521ba4f
...
...
@@ -23,14 +23,14 @@ def test_prepare_prompt(batch_size):
lora_config
=
None
)
model_runner
.
set_block_size
(
16
)
prompt
_lens
=
[]
seq
_lens
=
[]
seq_group_metadata_list
=
[]
block_tables
=
{
0
:
[
1
]}
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
prompt
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt
_lens
.
append
(
prompt
_len
)
seq_data
=
SequenceData
(
list
(
range
(
prompt
_len
)))
seq
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq
_lens
.
append
(
seq
_len
)
seq_data
=
SequenceData
(
list
(
range
(
seq
_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
...
...
@@ -43,29 +43,29 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
for
prompt
_len
in
prompt
_lens
:
for
seq
_len
in
seq
_lens
:
expected_selected_token_indices
.
append
(
selected_token_start_idx
+
prompt_len
-
1
)
selected_token_start_idx
+=
prompt_len
(
input_tokens
,
input_positions
,
attn_metadata
,
return_prompt_lens
,
_
,
_
,
_
,
_
,
_
,
slot_mapping
)
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
assert
return_prompt_lens
==
prompt_lens
seq_len
-
1
)
selected_token_start_idx
+=
seq_len
(
input_tokens
,
input_positions
,
attn_metadata
,
return_seq_lens
,
_
,
_
,
_
,
_
,
_
,
slot_mapping
)
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
assert
return_seq_lens
==
seq_lens
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
# Verify input metadata is correct for prompts.
device
=
model_runner
.
device
assert
attn_metadata
.
is_prompt
is
True
assert
torch
.
allclose
(
attn_metadata
.
prompt_lens_tensor
,
torch
.
tensor
(
prompt_lens
,
device
=
device
))
assert
attn_metadata
.
prompt_lens
==
prompt_lens
assert
attn_metadata
.
max_prompt_len
==
max
(
prompt_lens
)
assert
torch
.
allclose
(
attn_metadata
.
seq_lens_tensor
,
torch
.
tensor
(
seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
seq_lens
==
seq_lens
assert
attn_metadata
.
max_seq_len
==
max
(
seq_lens
)
# Test subquery start locs.
start_idx
=
0
start_loc
=
[
start_idx
]
for
prompt
_len
in
prompt
_lens
:
start_idx
+=
prompt
_len
for
seq
_len
in
seq
_lens
:
start_idx
+=
seq
_len
start_loc
.
append
(
start_idx
)
assert
torch
.
allclose
(
attn_metadata
.
subquery_start_loc
,
...
...
@@ -75,17 +75,16 @@ def test_prepare_prompt(batch_size):
# equivalent to subquery_start_loc.
start_idx
=
0
seq_start_loc
=
[
start_idx
]
for
prompt
_len
in
prompt
_lens
:
start_idx
+=
prompt
_len
for
seq
_len
in
seq
_lens
:
start_idx
+=
seq
_len
seq_start_loc
.
append
(
start_idx
)
assert
torch
.
allclose
(
attn_metadata
.
seq_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
assert
attn_metadata
.
max_context_len
is
None
assert
torch
.
allclose
(
attn_metadata
.
context_lens
,
torch
.
zeros
(
attn_metadata
.
context_lens
.
shape
[
0
],
attn_metadata
.
context_lens
_tensor
,
torch
.
zeros
(
attn_metadata
.
context_lens
_tensor
.
shape
[
0
],
dtype
=
torch
.
int
,
device
=
device
))
...
...
@@ -96,18 +95,18 @@ def test_prepare_prompt(batch_size):
# Cuda graph should not be used for prerill.
assert
attn_metadata
.
use_cuda_graph
is
False
assert
len
(
input_tokens
)
==
sum
(
prompt
_lens
)
assert
len
(
input_positions
)
==
sum
(
prompt
_lens
)
assert
len
(
input_tokens
)
==
sum
(
seq
_lens
)
assert
len
(
input_positions
)
==
sum
(
seq
_lens
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
prompt
_lens
,
sub
query_lens
=
prompt
_lens
,
seq
_lens
,
query_lens
=
seq
_lens
,
device
=
model_runner
.
device
,
pin_memory
=
model_runner
.
pin_memory
)
assert
len
(
input_tokens
)
==
sum
(
prompt
_lens
)
assert
len
(
input_positions
)
==
sum
(
prompt
_lens
)
assert
len
(
input_tokens
)
==
sum
(
seq
_lens
)
assert
len
(
input_positions
)
==
sum
(
seq
_lens
)
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
...
...
@@ -146,13 +145,13 @@ def test_prepare_decode_cuda_graph(batch_size):
lora_config
=
None
)
model_runner
.
set_block_size
(
16
)
prompt
_lens
=
[]
seq
_lens
=
[]
seq_group_metadata_list
=
[]
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
prompt
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt
_lens
.
append
(
prompt
_len
)
seq_data
=
list
(
range
(
prompt
_len
))
seq
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq
_lens
.
append
(
seq
_len
)
seq_data
=
list
(
range
(
seq
_len
))
seq_data
=
SequenceData
(
seq_data
)
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
...
...
@@ -172,14 +171,13 @@ def test_prepare_decode_cuda_graph(batch_size):
# Verify input metadata is correct for prompts.
device
=
model_runner
.
device
assert
attn_metadata
.
is_prompt
is
False
assert
attn_metadata
.
prompt_lens
is
None
assert
attn_metadata
.
max_prompt_len
is
None
assert
attn_metadata
.
seq_lens
is
None
assert
attn_metadata
.
subquery_start_loc
is
None
assert
attn_metadata
.
seq_start_loc
is
None
assert
attn_metadata
.
max_
context
_len
==
max
(
prompt
_lens
)
assert
attn_metadata
.
max_
seq
_len
==
max
(
seq
_lens
)
assert
torch
.
allclose
(
attn_metadata
.
context_lens
[:
len
(
prompt
_lens
)],
torch
.
tensor
(
prompt
_lens
,
dtype
=
torch
.
int
,
device
=
device
))
attn_metadata
.
seq_lens_tensor
[:
len
(
seq
_lens
)],
torch
.
tensor
(
seq
_lens
,
dtype
=
torch
.
int
,
device
=
device
))
# block table's first index corresponds to each batch, meaning in
# decoding it is each token.
...
...
@@ -198,13 +196,13 @@ def test_prepare_decode_cuda_graph(batch_size):
# Verify Sampling
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
for
prompt
_len
in
prompt
_lens
:
for
seq
_len
in
seq
_lens
:
expected_selected_token_indices
.
append
(
selected_token_start_idx
)
selected_token_start_idx
+=
1
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
prompt
_lens
,
sub
query_lens
=
prompt
_lens
,
seq
_lens
,
query_lens
=
seq
_lens
,
device
=
model_runner
.
device
,
pin_memory
=
model_runner
.
pin_memory
)
actual
=
sampling_metadata
.
selected_token_indices
...
...
@@ -241,14 +239,13 @@ def test_empty_seq_group():
assert
attn_metadata
is
None
assert
len
(
slot_mapping
)
==
0
(
input_tokens
,
input_positions
,
attn_metadata
,
return_prompt_lens
,
_
,
_
,
_
,
_
,
_
,
slot_mapping
)
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
(
input_tokens
,
input_positions
,
attn_metadata
,
return_seq_lens
,
_
,
_
,
_
,
_
,
_
,
slot_mapping
)
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
assert
len
(
input_tokens
)
==
0
assert
len
(
input_positions
)
==
0
assert
attn_metadata
is
None
assert
len
(
slot_mapping
)
==
0
assert
len
(
return_
prompt
_lens
)
==
0
assert
len
(
return_
seq
_lens
)
==
0
@
pytest
.
fixture
...
...
@@ -288,7 +285,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
model_runner
.
set_block_size
(
16
)
# Add prefill requests.
prompt
_lens
=
[]
seq
_lens
=
[]
seq_group_metadata_list
=
[]
prefill_metadata_list
=
[]
decode_metadata_list
=
[]
...
...
@@ -297,9 +294,9 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
decode_batch_size
=
batch_size
-
prefill_batch_size
for
i
in
range
(
prefill_batch_size
):
# make sure all tokens fit into one block
prompt
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt
_lens
.
append
(
prompt
_len
)
seq_data
=
SequenceData
(
list
(
range
(
prompt
_len
)))
seq
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq
_lens
.
append
(
seq
_len
)
seq_data
=
SequenceData
(
list
(
range
(
seq
_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
...
...
@@ -314,8 +311,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# Add decode requests
for
i
in
range
(
prefill_batch_size
,
batch_size
):
# make sure all tokens fit into one block
prompt
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_toks
=
list
(
range
(
prompt
_len
))
seq
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_toks
=
list
(
range
(
seq
_len
))
seq_data
=
SequenceData
(
prompt_toks
)
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
...
...
@@ -343,7 +340,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
else
:
assert
attn_metadata
.
num_decode_tokens
==
_get_graph_batch_size
(
decode_batch_size
)
assert
attn_metadata
.
num_prefill_tokens
==
sum
(
prompt
_lens
)
assert
attn_metadata
.
num_prefill_tokens
==
sum
(
seq
_lens
)
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
...
...
vllm/_custom_ops.py
View file @
3521ba4f
...
...
@@ -39,17 +39,17 @@ def paged_attention_v1(
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
context
_lens
:
torch
.
Tensor
,
seq
_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_
context
_len
:
int
,
max_
seq
_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_scale
:
float
,
)
->
None
:
vllm_ops
.
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_
context_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
)
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_
seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
)
def
paged_attention_v2
(
...
...
@@ -63,17 +63,17 @@ def paged_attention_v2(
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
context
_lens
:
torch
.
Tensor
,
seq
_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_
context
_len
:
int
,
max_
seq
_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_scale
:
float
,
)
->
None
:
vllm_ops
.
paged_attention_v2
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
block_size
,
max_
context
_len
,
alibi_slopes
,
kv_cache_dtype
,
block_tables
,
seq
_lens
,
block_size
,
max_
seq
_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
)
...
...
vllm/attention/backends/flash_attn.py
View file @
3521ba4f
...
...
@@ -66,27 +66,24 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens
:
Optional
[
List
[
int
]]
# prompt_lens stored as a tensor.
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len,
sub
query_len, and seqlen.
# NOTE(sang): Definition of context_len, query_len, and seq
_
len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |-
sub
query_len -|
# |-------------------- seq
_
len ----------------------|
# |-
-
query_len
--
-|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len
:
Optional
[
int
]
# Maximum prompt length in the batch.
max_prompt_len
:
Optional
[
int
]
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
# Maximum sequence length in the batch.
max_seq_len
:
Optional
[
int
]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
...
...
@@ -95,6 +92,9 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
...
...
@@ -223,8 +223,8 @@ class FlashAttentionImpl(AttentionImpl):
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_
prompt
_len
,
max_seqlen_k
=
prefill_meta
.
max_
prompt
_len
,
max_seqlen_q
=
prefill_meta
.
max_
seq
_len
,
max_seqlen_k
=
prefill_meta
.
max_
seq
_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
...
...
@@ -245,9 +245,9 @@ class FlashAttentionImpl(AttentionImpl):
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
subquery_start_loc
,
prefill_meta
.
prompt
_lens_tensor
,
prefill_meta
.
context_lens
,
prefill_meta
.
max_
sub
query_len
,
prefill_meta
.
seq
_lens_tensor
,
prefill_meta
.
context_lens
_tensor
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
)
...
...
@@ -258,8 +258,8 @@ class FlashAttentionImpl(AttentionImpl):
key_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
context_lens
,
decode_meta
.
max_
context
_len
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_
seq
_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
3521ba4f
...
...
@@ -64,27 +64,24 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens
:
Optional
[
List
[
int
]]
# prompt_lens stored as a tensor.
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len,
sub
query_len, and seqlen.
# NOTE(sang): Definition of context_len, query_len, and seq
_
len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |-
sub
query_len -|
# |-------------------- seq
_
len ----------------------|
# |-
-
query_len
--
-|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len
:
Optional
[
int
]
# Maximum prompt length in the batch.
max_prompt_len
:
Optional
[
int
]
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
# Maximum sequence length in the batch.
max_seq_len
:
Optional
[
int
]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
...
...
@@ -98,6 +95,9 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
class
ROCmFlashAttentionImpl
(
AttentionImpl
):
...
...
@@ -247,7 +247,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
assert
prefill_meta
.
prompt
_lens
is
not
None
assert
prefill_meta
.
seq
_lens
is
not
None
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
# triton attention
# When block_tables are not filled, it means q and k are the
...
...
@@ -260,8 +260,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
None
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
max_
prompt
_len
,
prefill_meta
.
max_
prompt
_len
,
prefill_meta
.
max_
seq
_len
,
prefill_meta
.
max_
seq
_len
,
True
,
self
.
scale
,
)
...
...
@@ -274,7 +274,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query
,
key
,
value
,
prefill_meta
.
prompt
_lens
,
prefill_meta
.
seq
_lens
,
self
.
scale
,
)
else
:
...
...
@@ -284,8 +284,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_
prompt
_len
,
max_seqlen_k
=
prefill_meta
.
max_
prompt
_len
,
max_seqlen_q
=
prefill_meta
.
max_
seq
_len
,
max_seqlen_k
=
prefill_meta
.
max_
seq
_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
...
...
@@ -303,9 +303,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
subquery_start_loc
,
prefill_meta
.
prompt
_lens_tensor
,
prefill_meta
.
context_lens
,
prefill_meta
.
max_
sub
query_len
,
prefill_meta
.
seq
_lens_tensor
,
prefill_meta
.
context_lens
_tensor
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
)
...
...
@@ -317,8 +317,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
context_lens
,
decode_meta
.
max_
context
_len
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_
seq
_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
...
...
@@ -334,13 +334,13 @@ def _naive_attention(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
prompt
_lens
:
List
[
int
],
seq
_lens
:
List
[
int
],
scale
:
float
,
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
start
=
0
for
_
,
prompt
_len
in
enumerate
(
prompt
_lens
):
end
=
start
+
prompt
_len
for
_
,
seq
_len
in
enumerate
(
seq
_lens
):
end
=
start
+
seq
_len
out
=
_naive_masked_attention
(
query
[
start
:
end
],
key
[
start
:
end
],
...
...
@@ -349,7 +349,7 @@ def _naive_attention(
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
)
start
+=
prompt
_len
start
+=
seq
_len
return
output
...
...
vllm/attention/backends/torch_sdpa.py
View file @
3521ba4f
...
...
@@ -58,7 +58,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
slot_mapping
:
torch
.
Tensor
prompt
_lens
:
Optional
[
List
[
int
]]
seq
_lens
:
Optional
[
List
[
int
]]
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
...
...
@@ -136,7 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
kv_scale
)
if
attn_metadata
.
is_prompt
:
assert
attn_metadata
.
prompt
_lens
is
not
None
assert
attn_metadata
.
seq
_lens
is
not
None
if
(
kv_cache
is
None
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
...
...
@@ -147,13 +147,13 @@ class TorchSDPABackendImpl(AttentionImpl):
if
self
.
alibi_slopes
is
not
None
:
att_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
prompt
_lens
)
# type: ignore
attn_metadata
.
seq
_lens
)
# type: ignore
elif
self
.
sliding_window
is
not
None
:
att_masks
=
_make_sliding_window_bias
(
attn_metadata
.
prompt
_lens
,
self
.
sliding_window
,
attn_metadata
.
seq
_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
else
:
att_masks
=
[
None
]
*
len
(
attn_metadata
.
prompt
_lens
)
att_masks
=
[
None
]
*
len
(
attn_metadata
.
seq
_lens
)
attn_metadata
.
attn_bias
=
att_masks
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
...
...
@@ -164,9 +164,9 @@ class TorchSDPABackendImpl(AttentionImpl):
output
=
torch
.
empty
(
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
),
dtype
=
query
.
dtype
)
for
prompt
_len
,
mask
in
zip
(
attn_metadata
.
prompt
_lens
,
attn_metadata
.
attn_bias
):
end
=
start
+
prompt
_len
for
seq
_len
,
mask
in
zip
(
attn_metadata
.
seq
_lens
,
attn_metadata
.
attn_bias
):
end
=
start
+
seq
_len
sub_out
=
scaled_dot_product_attention
(
query
[:,
start
:
end
,
:],
key
[:,
start
:
end
,
:],
...
...
@@ -189,8 +189,8 @@ class TorchSDPABackendImpl(AttentionImpl):
key_cache
,
value_cache
,
attn_metadata
.
block_tables
,
attn_metadata
.
context_lens
,
attn_metadata
.
max_
context
_len
,
attn_metadata
.
seq_lens_tensor
,
attn_metadata
.
max_
seq
_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
...
...
@@ -205,13 +205,13 @@ class TorchSDPABackendImpl(AttentionImpl):
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
prompt
_lens
:
List
[
int
],
seq
_lens
:
List
[
int
],
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
for
prompt
_len
in
prompt
_lens
:
bias
=
torch
.
arange
(
prompt
_len
,
dtype
=
dtype
)
for
seq
_len
in
seq
_lens
:
bias
=
torch
.
arange
(
seq
_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(
prompt
_len, 1)`
# `bias = bias[None, :].repeat(
seq
_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
...
...
@@ -221,7 +221,7 @@ def _make_alibi_bias(
bias
=
bias
[
None
,
:].
repeat
((
num_heads
,
1
,
1
))
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
inf_mask
=
torch
.
empty
(
(
1
,
prompt_len
,
prompt
_len
),
(
1
,
seq_len
,
seq
_len
),
dtype
=
bias
.
dtype
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
)
attn_biases
.
append
((
bias
+
inf_mask
).
to
(
dtype
))
...
...
@@ -229,14 +229,14 @@ def _make_alibi_bias(
def
_make_sliding_window_bias
(
prompt
_lens
:
List
[
int
],
seq
_lens
:
List
[
int
],
window_size
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
for
prompt
_len
in
prompt
_lens
:
for
seq
_len
in
seq
_lens
:
tensor
=
torch
.
full
(
(
1
,
prompt_len
,
prompt
_len
),
(
1
,
seq_len
,
seq
_len
),
dtype
=
dtype
,
fill_value
=
1
,
)
...
...
vllm/attention/backends/xformers.py
View file @
3521ba4f
...
...
@@ -66,28 +66,24 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens
:
Optional
[
List
[
int
]]
# prompt_lens stored as a tensor.
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |-
sub
query_len -|
# |-------------------- seq
_
len ----------------------|
# |-
-
query_len
--
-|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len
:
Optional
[
int
]
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
# FIXME: It is for flash attn.
# Maximum
prompt
length in the batch.
max_
prompt
_len
:
Optional
[
int
]
# Maximum
sequence
length in the batch.
max_
seq
_len
:
Optional
[
int
]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
...
...
@@ -97,6 +93,9 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
...
...
@@ -242,9 +241,9 @@ class XFormersImpl(AttentionImpl):
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
subquery_start_loc
,
prefill_meta
.
prompt
_lens_tensor
,
prefill_meta
.
context_lens
,
prefill_meta
.
max_
sub
query_len
,
prefill_meta
.
seq
_lens_tensor
,
prefill_meta
.
context_lens
_tensor
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
sliding_window
,
)
...
...
@@ -257,8 +256,8 @@ class XFormersImpl(AttentionImpl):
key_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
context_lens
,
decode_meta
.
max_
context
_len
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_
seq
_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
...
...
@@ -289,7 +288,7 @@ class XFormersImpl(AttentionImpl):
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
"""
assert
attn_metadata
.
prompt
_lens
is
not
None
assert
attn_metadata
.
seq
_lens
is
not
None
original_query
=
query
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# GQA/MQA requires the shape [B, M, G, H, K].
...
...
@@ -310,7 +309,7 @@ class XFormersImpl(AttentionImpl):
if
attn_metadata
.
attn_bias
is
None
:
if
self
.
alibi_slopes
is
None
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
attn_metadata
.
prompt
_lens
)
attn_metadata
.
seq
_lens
)
if
self
.
sliding_window
is
not
None
:
attn_bias
=
attn_bias
.
make_local_attention
(
self
.
sliding_window
)
...
...
@@ -318,7 +317,7 @@ class XFormersImpl(AttentionImpl):
else
:
attn_metadata
.
attn_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
query
.
dtype
,
attn_metadata
.
prompt
_lens
)
attn_metadata
.
seq
_lens
)
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
...
...
@@ -343,8 +342,8 @@ class XFormersImpl(AttentionImpl):
# one. This is inefficient, especially when we have many short prompts.
output
=
torch
.
empty_like
(
original_query
)
start
=
0
for
i
,
prompt
_len
in
enumerate
(
attn_metadata
.
prompt
_lens
):
end
=
start
+
prompt
_len
for
i
,
seq
_len
in
enumerate
(
attn_metadata
.
seq
_lens
):
end
=
start
+
seq
_len
out
=
xops
.
memory_efficient_attention_forward
(
query
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
...
...
@@ -354,7 +353,7 @@ class XFormersImpl(AttentionImpl):
scale
=
self
.
scale
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
.
view_as
(
original_query
[
start
:
end
]))
start
+=
prompt
_len
start
+=
seq
_len
return
output
...
...
@@ -362,13 +361,13 @@ def _make_alibi_bias(
alibi_slopes
:
torch
.
Tensor
,
num_kv_heads
:
int
,
dtype
:
torch
.
dtype
,
prompt
_lens
:
List
[
int
],
seq
_lens
:
List
[
int
],
)
->
LowerTriangularMaskWithTensorBias
:
attn_biases
=
[]
for
prompt
_len
in
prompt
_lens
:
bias
=
torch
.
arange
(
prompt
_len
,
dtype
=
dtype
)
for
seq
_len
in
seq
_lens
:
bias
=
torch
.
arange
(
seq
_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(
prompt
_len, 1)`
# `bias = bias[None, :].repeat(
seq
_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
...
...
@@ -376,16 +375,16 @@ def _make_alibi_bias(
# element.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
padded_len
=
(
prompt
_len
+
7
)
//
8
*
8
padded_len
=
(
seq
_len
+
7
)
//
8
*
8
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
torch
.
empty
(
1
,
# batch size
num_heads
,
prompt
_len
,
seq
_len
,
padded_len
,
device
=
alibi_slopes
.
device
,
dtype
=
dtype
,
)[:,
:,
:,
:
prompt
_len
].
copy_
(
bias
)
)[:,
:,
:,
:
seq
_len
].
copy_
(
bias
)
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
if
num_heads
!=
num_kv_heads
:
bias
=
bias
.
unflatten
(
1
,
(
num_kv_heads
,
num_heads
//
num_kv_heads
))
...
...
vllm/attention/ops/paged_attn.py
View file @
3521ba4f
...
...
@@ -13,12 +13,11 @@ _PARTITION_SIZE = 512
@
dataclass
class
PagedAttentionMetadata
:
"""Metadata for PagedAttention."""
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
context_lens
:
Optional
[
torch
.
Tensor
]
# Maximum context length in the batch.
max_context_len
:
Optional
[
int
]
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Maximum sequence length in the batch.
max_seq_len
:
Optional
[
int
]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
...
...
@@ -85,8 +84,8 @@ class PagedAttention:
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context
_lens
:
torch
.
Tensor
,
max_
context
_len
:
int
,
seq
_lens
:
torch
.
Tensor
,
max_
seq
_len
:
int
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
scale
:
float
,
...
...
@@ -97,7 +96,7 @@ class PagedAttention:
block_size
=
value_cache
.
shape
[
3
]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
((
max_
context
_len
+
_PARTITION_SIZE
-
1
)
//
max_num_partitions
=
((
max_
seq
_len
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
...
...
@@ -106,7 +105,7 @@ class PagedAttention:
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1
=
(
max_
context
_len
<=
8192
use_v1
=
(
max_
seq
_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
if
use_v1
:
# Run PagedAttention V1.
...
...
@@ -118,9 +117,9 @@ class PagedAttention:
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
seq
_lens
,
block_size
,
max_
context
_len
,
max_
seq
_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
...
...
@@ -150,9 +149,9 @@ class PagedAttention:
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
seq
_lens
,
block_size
,
max_
context
_len
,
max_
seq
_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
...
...
@@ -168,9 +167,9 @@ class PagedAttention:
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
subquery_start_loc
:
torch
.
Tensor
,
prompt
_lens_tensor
:
torch
.
Tensor
,
seq
_lens_tensor
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_
sub
query_len
:
int
,
max_query_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
sliding_window
:
Optional
[
int
],
)
->
torch
.
Tensor
:
...
...
@@ -185,9 +184,9 @@ class PagedAttention:
block_tables
,
# subquery_start_loc is (batch_size + 1,)
subquery_start_loc
[:
-
1
],
prompt
_lens_tensor
,
seq
_lens_tensor
,
context_lens
,
max_
sub
query_len
,
max_query_len
,
alibi_slopes
,
sliding_window
,
)
...
...
vllm/config.py
View file @
3521ba4f
...
...
@@ -63,7 +63,10 @@ class ModelConfig:
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
"""
...
...
@@ -84,6 +87,7 @@ class ModelConfig:
quantization_param_path
:
Optional
[
str
]
=
None
,
enforce_eager
:
bool
=
False
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
Optional
[
int
]
=
None
,
max_logprobs
:
int
=
5
,
skip_tokenizer_init
:
bool
=
False
,
)
->
None
:
...
...
@@ -99,6 +103,11 @@ class ModelConfig:
self
.
quantization_param_path
=
quantization_param_path
self
.
enforce_eager
=
enforce_eager
self
.
max_context_len_to_capture
=
max_context_len_to_capture
if
self
.
max_context_len_to_capture
is
not
None
:
raise
ValueError
(
"`max_context_len_to_capture` is deprecated. "
"Use `max_seq_len_to_capture` instead."
)
self
.
max_seq_len_to_capture
=
(
max_seq_len_to_capture
or
max_context_len_to_capture
)
self
.
max_logprobs
=
max_logprobs
self
.
skip_tokenizer_init
=
skip_tokenizer_init
...
...
@@ -190,10 +199,10 @@ class ModelConfig:
"non-quantized models."
,
self
.
quantization
)
def
_verify_cuda_graph
(
self
)
->
None
:
if
self
.
max_
context
_len_to_capture
is
None
:
self
.
max_
context
_len_to_capture
=
self
.
max_model_len
self
.
max_
context
_len_to_capture
=
min
(
self
.
max_
context
_len_to_capture
,
self
.
max_model_len
)
if
self
.
max_
seq
_len_to_capture
is
None
:
self
.
max_
seq
_len_to_capture
=
self
.
max_model_len
self
.
max_
seq
_len_to_capture
=
min
(
self
.
max_
seq
_len_to_capture
,
self
.
max_model_len
)
def
verify_with_parallel_config
(
self
,
...
...
@@ -772,8 +781,8 @@ class SpeculativeConfig:
max_model_len
=
None
,
quantization
=
draft_quantization
,
enforce_eager
=
target_model_config
.
enforce_eager
,
max_
context
_len_to_capture
=
target_model_config
.
max_
context
_len_to_capture
,
max_
seq
_len_to_capture
=
target_model_config
.
max_
seq
_len_to_capture
,
max_logprobs
=
target_model_config
.
max_logprobs
,
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment