Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9ff4511e
Unverified
Commit
9ff4511e
authored
Oct 30, 2024
by
Elfie Guo
Committed by
GitHub
Oct 30, 2024
Browse files
[Misc] Add chunked-prefill support on FlashInfer. (#9781)
parent
81f09cfd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
28 deletions
+72
-28
tests/basic_correctness/test_chunked_prefill.py
tests/basic_correctness/test_chunked_prefill.py
+12
-0
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+60
-28
No files found.
tests/basic_correctness/test_chunked_prefill.py
View file @
9ff4511e
...
...
@@ -11,6 +11,8 @@ from contextlib import nullcontext
import
pytest
from
tests.kernels.utils
import
override_backend_env_variable
from
..models.utils
import
check_logprobs_close
,
check_outputs_equal
from
..utils
import
multi_gpu_test
...
...
@@ -28,6 +30,7 @@ MODELS = [
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"FLASHINFER"
,
"FLASH_ATTN"
])
def
test_models
(
hf_runner
,
vllm_runner
,
...
...
@@ -38,11 +41,15 @@ def test_models(
chunked_prefill_token_size
:
int
,
enforce_eager
:
bool
,
tensor_parallel_size
:
int
,
attention_backend
:
str
,
monkeypatch
,
)
->
None
:
"""
Checks exact match decode between huggingface model and vllm runner with
chunked prefill.
"""
override_backend_env_variable
(
monkeypatch
,
attention_backend
)
max_num_seqs
=
chunked_prefill_token_size
max_num_batched_tokens
=
chunked_prefill_token_size
...
...
@@ -71,13 +78,18 @@ def test_models(
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"distributed_executor_backend"
,
[
"ray"
,
"mp"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"FLASHINFER"
,
"FLASH_ATTN"
])
def
test_models_distributed
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
distributed_executor_backend
:
str
,
attention_backend
:
str
,
monkeypatch
,
)
->
None
:
override_backend_env_variable
(
monkeypatch
,
attention_backend
)
if
(
model
==
"meta-llama/Llama-2-7b-hf"
and
distributed_executor_backend
==
"ray"
):
# test ray adag
...
...
vllm/attention/backends/flashinfer.py
View file @
9ff4511e
...
...
@@ -268,6 +268,11 @@ class FlashInferMetadata(AttentionMetadata):
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len
:
int
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len
:
Optional
[
int
]
=
1
use_cuda_graph
:
bool
=
True
...
...
@@ -335,6 +340,7 @@ class FlashInferMetadata(AttentionMetadata):
assert
self
.
paged_kv_last_page_len
is
not
None
assert
self
.
block_table_bound
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
self
.
query_start_loc
=
self
.
query_start_loc
[:
self
.
num_prefills
+
1
]
batch_size
=
self
.
query_start_loc
.
shape
[
0
]
-
1
assert
batch_size
>=
0
# We will use flash attention for profiling to
...
...
@@ -349,11 +355,13 @@ class FlashInferMetadata(AttentionMetadata):
self
.
paged_kv_indices
=
self
.
paged_kv_indices
.
to
(
self
.
device
)
self
.
prefill_wrapper
.
end_forward
()
self
.
prefill_wrapper
.
begin_forward
(
self
.
query_start_loc
,
self
.
paged_kv_indptr
,
self
.
paged_kv_indices
,
self
.
paged_kv_last_page_len
,
self
.
query_start_loc
,
self
.
paged_kv_indptr
[:
self
.
num_prefills
+
1
],
self
.
paged_kv_indices
,
self
.
paged_kv_last_page_len
[:
self
.
num_prefills
],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
)
else
:
if
self
.
num_decode_tokens
>
0
:
assert
self
.
paged_kv_indices
is
not
None
assert
self
.
paged_kv_indptr
is
not
None
assert
self
.
paged_kv_last_page_len
is
not
None
...
...
@@ -370,9 +378,9 @@ class FlashInferMetadata(AttentionMetadata):
assert
self
.
decode_wrapper
is
not
None
self
.
decode_wrapper
.
end_forward
()
self
.
decode_wrapper
.
begin_forward
(
self
.
paged_kv_indptr
,
self
.
paged_kv_indptr
[
self
.
num_prefills
:]
,
self
.
paged_kv_indices
,
self
.
paged_kv_last_page_len
,
self
.
paged_kv_last_page_len
[
self
.
num_prefills
:]
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
...
...
@@ -397,21 +405,14 @@ class FlashInferMetadata(AttentionMetadata):
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"FlashInferMetadata"
]:
# Currently chunked prefill is not supported
if
self
.
num_decode_tokens
==
0
:
assert
self
.
num_prefills
>
0
return
self
if
self
.
num_prefills
==
0
:
return
None
return
self
@
property
def
decode_metadata
(
self
)
->
Optional
[
"FlashInferMetadata"
]:
# Currently chunked prefill is not supported
if
self
.
num_prefills
>
0
:
assert
self
.
num_decode_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
if
self
.
num_decode_tokens
==
0
:
return
None
return
self
def
advance_step
(
self
,
...
...
@@ -599,11 +600,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
decode_query_len
=
max
(
query_lens
[
self
.
num_prefills
:],
default
=
1
)
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
num_decode_tokens
=
batch_size
-
self
.
num_prefill_tokens
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
...
...
@@ -689,6 +691,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
runner
.
kv_cache_dtype
,
self
.
runner
.
model_config
.
dtype
)
return
FlashInferMetadata
(
decode_query_len
=
decode_query_len
,
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
...
...
@@ -811,12 +814,6 @@ def unified_flash_infer(
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
if
attn_metadata
.
num_prefill_tokens
>
0
:
assert
attn_metadata
.
num_decode_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
if
attn_metadata
.
num_decode_tokens
>
0
:
assert
attn_metadata
.
num_prefill_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
if
kv_cache
.
numel
()
>
0
:
# Use the same reshape and cache kernel as flash attention.
ops
.
reshape_and_cache_flash
(
...
...
@@ -836,14 +833,33 @@ def unified_flash_infer(
kv_cache_dtype
)
kv_cache
=
kv_cache
.
view
(
torch_dtype
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
,
\
f
"key :
{
key
.
shape
}
: #prefill tokens
{
num_prefill_tokens
}
: #decode tokens
{
num_decode_tokens
}
"
# noqa
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
,
\
f
"value :
{
value
.
shape
}
: #prefill toks
{
num_prefill_tokens
}
: #decode toks
{
num_decode_tokens
}
"
# noqa
query
=
query
.
contiguous
()
# Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached.
# QKV for prefill.
decode_query
=
query
[
num_prefill_tokens
:]
query
=
query
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if
kv_cache
.
numel
()
==
0
:
output
=
flash_attn_varlen_func
(
prefill_
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
...
...
@@ -859,18 +875,34 @@ def unified_flash_infer(
else
:
assert
prefill_meta
is
not
None
assert
prefill_meta
.
prefill_wrapper
is
not
None
output
=
prefill_meta
.
prefill_wrapper
.
forward
(
prefill_
output
=
prefill_meta
.
prefill_wrapper
.
forward
(
query
,
kv_cache
,
logits_soft_cap
=
logits_soft_cap
,
causal
=
True
)
else
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
output
=
attn_metadata
.
decode_metadata
.
decode_wrapper
.
forward
(
query
,
decode_
output
=
attn_metadata
.
decode_metadata
.
decode_wrapper
.
forward
(
decode_
query
,
kv_cache
,
sm_scale
=
softmax_scale
,
logits_soft_cap
=
logits_soft_cap
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
if
prefill_output
is
None
and
decode_output
is
not
None
:
# Decode only batch.
output
,
num_tokens
=
decode_output
,
num_decode_tokens
elif
decode_output
is
None
and
prefill_output
is
not
None
:
# Prefill only batch.
output
,
num_tokens
=
prefill_output
,
num_prefill_tokens
else
:
# Chunked prefill batch does not work with speculative decoding in
# FlashInfer backend, so the query length for decode should be 1.
assert
prefill_output
is
not
None
assert
decode_output
is
not
None
assert
decode_meta
is
not
None
assert
decode_meta
.
decode_query_len
==
1
decode_output
=
decode_output
.
squeeze
(
1
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
return
output
.
view
(
num_tokens
,
hidden_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment