Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0794e744
Unverified
Commit
0794e744
authored
Jan 14, 2025
by
Elfie Guo
Committed by
GitHub
Jan 15, 2025
Browse files
[Misc] Add multipstep chunked-prefill support for FlashInfer (#10467)
parent
b7ee940a
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
169 additions
and
109 deletions
+169
-109
csrc/prepare_inputs/advance_step.cu
csrc/prepare_inputs/advance_step.cu
+10
-0
tests/multi_step/test_correctness_llm.py
tests/multi_step/test_correctness_llm.py
+16
-1
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+24
-5
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+118
-102
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+1
-1
No files found.
csrc/prepare_inputs/advance_step.cu
View file @
0794e744
...
@@ -95,6 +95,16 @@ __global__ void advance_step_flashinfer_kernel(
...
@@ -95,6 +95,16 @@ __global__ void advance_step_flashinfer_kernel(
long
*
input_positions_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
long
*
input_positions_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
,
int
*
paged_kv_last_page_len_ptr
,
int
*
block_table_bound_ptr
)
{
int
*
paged_kv_last_page_len_ptr
,
int
*
block_table_bound_ptr
)
{
int
const
n_pad
=
num_seqs
-
num_queries
;
if
(
n_pad
&&
blockIdx
.
x
==
0
)
{
// Handle cuda graph padding
int
const
offset
=
num_queries
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n_pad
;
i
+=
blockDim
.
x
)
{
input_tokens_ptr
[
offset
+
i
]
=
0
;
input_positions_ptr
[
offset
+
i
]
=
0
;
slot_mapping_ptr
[
offset
+
i
]
=
-
1
;
}
}
int
num_query_blocks
=
div_ceil
(
num_queries
,
num_threads
);
int
num_query_blocks
=
div_ceil
(
num_queries
,
num_threads
);
if
(
blockIdx
.
x
<
num_query_blocks
)
{
if
(
blockIdx
.
x
<
num_query_blocks
)
{
...
...
tests/multi_step/test_correctness_llm.py
View file @
0794e744
...
@@ -5,6 +5,8 @@ from typing import Optional
...
@@ -5,6 +5,8 @@ from typing import Optional
import
pytest
import
pytest
from
tests.kernels.utils
import
override_backend_env_variable
from
..models.utils
import
check_logprobs_close
,
check_outputs_equal
from
..models.utils
import
check_logprobs_close
,
check_outputs_equal
MODELS
=
[
MODELS
=
[
...
@@ -19,10 +21,11 @@ NUM_PROMPTS = [10]
...
@@ -19,10 +21,11 @@ NUM_PROMPTS = [10]
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"enable_chunked_prefill"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"enable_chunked_prefill"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
None
,
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
None
,
5
])
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
def
test_multi_step_llm
(
def
test_multi_step_llm
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
...
@@ -36,6 +39,8 @@ def test_multi_step_llm(
...
@@ -36,6 +39,8 @@ def test_multi_step_llm(
num_scheduler_steps
:
int
,
num_scheduler_steps
:
int
,
num_prompts
:
int
,
num_prompts
:
int
,
num_logprobs
:
Optional
[
int
],
num_logprobs
:
Optional
[
int
],
attention_backend
:
str
,
monkeypatch
,
)
->
None
:
)
->
None
:
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
...
@@ -63,6 +68,7 @@ def test_multi_step_llm(
...
@@ -63,6 +68,7 @@ def test_multi_step_llm(
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> 1 logprob returned.
completions endpoint; `None` -> 1 logprob returned.
"""
"""
override_backend_env_variable
(
monkeypatch
,
attention_backend
)
prompts
=
example_prompts
prompts
=
example_prompts
if
len
(
prompts
)
<
num_prompts
:
if
len
(
prompts
)
<
num_prompts
:
...
@@ -114,6 +120,7 @@ def test_multi_step_llm(
...
@@ -114,6 +120,7 @@ def test_multi_step_llm(
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_logprobs,num_prompt_logprobs"
,
[(
5
,
5
)])
@
pytest
.
mark
.
parametrize
(
"num_logprobs,num_prompt_logprobs"
,
[(
5
,
5
)])
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"FLASH_ATTN"
])
def
test_multi_step_llm_w_prompt_logprobs
(
def
test_multi_step_llm_w_prompt_logprobs
(
vllm_runner
,
vllm_runner
,
example_prompts
,
example_prompts
,
...
@@ -126,6 +133,8 @@ def test_multi_step_llm_w_prompt_logprobs(
...
@@ -126,6 +133,8 @@ def test_multi_step_llm_w_prompt_logprobs(
num_prompts
:
int
,
num_prompts
:
int
,
num_logprobs
:
Optional
[
int
],
num_logprobs
:
Optional
[
int
],
num_prompt_logprobs
:
Optional
[
int
],
num_prompt_logprobs
:
Optional
[
int
],
attention_backend
:
str
,
monkeypatch
,
)
->
None
:
)
->
None
:
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
...
@@ -155,6 +164,7 @@ def test_multi_step_llm_w_prompt_logprobs(
...
@@ -155,6 +164,7 @@ def test_multi_step_llm_w_prompt_logprobs(
note that this argument is not supported by the
note that this argument is not supported by the
OpenAI completions endpoint.
OpenAI completions endpoint.
"""
"""
override_backend_env_variable
(
monkeypatch
,
attention_backend
)
prompts
=
example_prompts
prompts
=
example_prompts
if
len
(
prompts
)
<
num_prompts
:
if
len
(
prompts
)
<
num_prompts
:
...
@@ -205,6 +215,7 @@ def test_multi_step_llm_w_prompt_logprobs(
...
@@ -205,6 +215,7 @@ def test_multi_step_llm_w_prompt_logprobs(
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
None
,
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
None
,
5
])
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"FLASH_ATTN"
])
def
test_multi_step_llm_chunked_prefill_prefix_cache
(
def
test_multi_step_llm_chunked_prefill_prefix_cache
(
vllm_runner
,
vllm_runner
,
example_prompts
,
example_prompts
,
...
@@ -216,6 +227,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
...
@@ -216,6 +227,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
num_scheduler_steps
:
int
,
num_scheduler_steps
:
int
,
num_prompts
:
int
,
num_prompts
:
int
,
num_logprobs
:
Optional
[
int
],
num_logprobs
:
Optional
[
int
],
attention_backend
:
str
,
monkeypatch
,
)
->
None
:
)
->
None
:
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
...
@@ -278,6 +291,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
...
@@ -278,6 +291,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
#
#
# The Incorrect scheduling behavior - if it occurs - will cause an exception
# The Incorrect scheduling behavior - if it occurs - will cause an exception
# in the model runner resulting from `do_sample=False`.
# in the model runner resulting from `do_sample=False`.
override_backend_env_variable
(
monkeypatch
,
attention_backend
)
assert
len
(
example_prompts
)
>=
2
assert
len
(
example_prompts
)
>=
2
challenge_prompts
=
copy
.
deepcopy
(
example_prompts
)
challenge_prompts
=
copy
.
deepcopy
(
example_prompts
)
challenge_prompts
[
0
]
=
(
'vLLM is a high-throughput and memory-efficient '
challenge_prompts
[
0
]
=
(
'vLLM is a high-throughput and memory-efficient '
...
...
vllm/attention/backends/flashinfer.py
View file @
0794e744
...
@@ -256,7 +256,12 @@ class FlashInferState(AttentionState):
...
@@ -256,7 +256,12 @@ class FlashInferState(AttentionState):
def
begin_forward
(
self
,
model_input
):
def
begin_forward
(
self
,
model_input
):
assert
not
self
.
_is_graph_capturing
assert
not
self
.
_is_graph_capturing
state
=
self
state
=
self
if
model_input
.
attn_metadata
.
use_cuda_graph
:
use_cuda_graph
=
model_input
.
attn_metadata
.
use_cuda_graph
is_decode
=
model_input
.
attn_metadata
.
num_prefills
==
0
# In case of multistep chunked-prefill, there might be prefill requests
# scheduled while CUDA graph mode is enabled. We don't run graph in that
# case.
if
use_cuda_graph
and
is_decode
:
batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
state
=
(
self
.
runner
.
graph_runners
[
model_input
.
virtual_engine
]
state
=
(
self
.
runner
.
graph_runners
[
model_input
.
virtual_engine
]
[
batch_size
].
attn_state
)
[
batch_size
].
attn_state
)
...
@@ -429,10 +434,24 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -429,10 +434,24 @@ class FlashInferMetadata(AttentionMetadata):
Update metadata in-place to advance one decode step.
Update metadata in-place to advance one decode step.
"""
"""
assert
not
turn_prefills_into_decodes
,
\
if
turn_prefills_into_decodes
:
(
"Chunked prefill is not supported with flashinfer yet."
# When Multi-Step is enabled with Chunked-Prefill, prefills and
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
# decodes are scheduled together. In the first step, all the
"specific parameter."
)
# prefills turn into decodes. This update reflects that
# conversion.
assert
self
.
num_decode_tokens
+
self
.
num_prefills
==
num_seqs
# Flashinfer doesn't support speculative decoding + chunked-prefill
# + multi-step scheduling yet.
assert
self
.
decode_query_len
==
1
self
.
num_decode_tokens
+=
self
.
num_prefills
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
max_prefill_seq_len
=
0
self
.
max_query_len
=
1
self
.
slot_mapping
=
self
.
slot_mapping
[:
num_seqs
]
else
:
assert
self
.
seq_lens_tensor
is
not
None
assert
num_seqs
>
0
assert
num_seqs
>
0
assert
num_queries
>
0
assert
num_queries
>
0
...
...
vllm/worker/model_runner.py
View file @
0794e744
...
@@ -5,6 +5,7 @@ import itertools
...
@@ -5,6 +5,7 @@ import itertools
import
time
import
time
import
warnings
import
warnings
import
weakref
import
weakref
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
,
Union
)
Tuple
,
Type
,
TypeVar
,
Union
)
...
@@ -1028,6 +1029,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1028,6 +1029,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
has_inner_state
=
model_config
.
has_inner_state
self
.
has_inner_state
=
model_config
.
has_inner_state
self
.
in_profile_run
=
False
# When using CUDA graph, the input block tables must be padded to
# When using CUDA graph, the input block tables must be padded to
# max_seq_len_to_capture. However, creating the block table in
# max_seq_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# Python can be expensive. To optimize this, we cache the block table
...
@@ -1228,11 +1231,22 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1228,11 +1231,22 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
return
builder
.
build
()
# type: ignore
return
builder
.
build
()
# type: ignore
@
contextmanager
def
set_in_profile_run
(
self
):
self
.
in_profile_run
=
True
try
:
yield
finally
:
self
.
in_profile_run
=
False
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
def
profile_run
(
self
)
->
None
:
with
self
.
set_in_profile_run
():
# Enable top-k sampling to reflect the accurate memory usage.
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params
=
SamplingParams
(
top_p
=
0.99
,
top_k
=
self
.
vocab_size
-
1
)
sampling_params
=
\
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
SamplingParams
(
top_p
=
0.99
,
top_k
=
self
.
vocab_size
-
1
)
max_num_batched_tokens
=
\
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
# This represents the maximum number of different requests
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# that will have unique loras, an therefore the max amount of memory
...
@@ -1258,12 +1272,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1258,12 +1272,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
for
idx
in
range
(
max_num_seqs
)
for
idx
in
range
(
max_num_seqs
)
]
]
# Profile memory usage with max_num_sequences sequences and the
total
# Profile memory usage with max_num_sequences sequences and the
#
number of tokens equal to max_num_batched_tokens.
# total
number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
# Additional GPU memory may be needed for multi-modal encoding,
which
# Additional GPU memory may be needed for multi-modal encoding,
#
needs to be accounted for when calculating the GPU blocks
for
# which
needs to be accounted for when calculating the GPU blocks
#
vLLM blocker manager.
# for
vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
# of images processed.
...
@@ -1302,7 +1316,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1302,7 +1316,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
if
dummy_lora_requests_per_seq
else
None
,
multi_modal_data
=
dummy_data
.
multi_modal_data
,
multi_modal_data
=
dummy_data
.
multi_modal_data
,
multi_modal_placeholders
=
dummy_data
.
multi_modal_placeholders
,
multi_modal_placeholders
=
dummy_data
.
multi_modal_placeholders
,
)
)
seqs
.
append
(
seq
)
seqs
.
append
(
seq
)
...
@@ -1324,7 +1339,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1324,7 +1339,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
seqs
,
finished_requests_ids
=
finished_requests_ids
)
seqs
,
finished_requests_ids
=
finished_requests_ids
)
intermediate_tensors
=
None
intermediate_tensors
=
None
if
not
get_pp_group
().
is_first_rank
:
if
not
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
self
.
model
.
make_empty_intermediate_tensors
(
intermediate_tensors
=
\
self
.
model
.
make_empty_intermediate_tensors
(
batch_size
=
batch_size
,
batch_size
=
batch_size
,
dtype
=
self
.
model_config
.
dtype
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
device
=
self
.
device
)
...
...
vllm/worker/multi_step_model_runner.py
View file @
0794e744
...
@@ -32,7 +32,7 @@ logger = init_logger(__name__)
...
@@ -32,7 +32,7 @@ logger = init_logger(__name__)
MULTI_STEP_ATTENTION_BACKENDS
=
[
MULTI_STEP_ATTENTION_BACKENDS
=
[
"FLASH_ATTN"
,
"ROCM_FLASH"
,
"FLASHINFER"
,
"NO_ATTENTION"
"FLASH_ATTN"
,
"ROCM_FLASH"
,
"FLASHINFER"
,
"NO_ATTENTION"
]
]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS
=
[
"FLASH_ATTN"
]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS
=
[
"FLASH_ATTN"
,
"FLASHINFER"
]
def
_get_supported_attention_backends
(
chunked_prefill_enabled
:
bool
)
\
def
_get_supported_attention_backends
(
chunked_prefill_enabled
:
bool
)
\
->
List
[
str
]:
->
List
[
str
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment