Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0794e744
Unverified
Commit
0794e744
authored
Jan 14, 2025
by
Elfie Guo
Committed by
GitHub
Jan 15, 2025
Browse files
[Misc] Add multipstep chunked-prefill support for FlashInfer (#10467)
parent
b7ee940a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
169 additions
and
109 deletions
+169
-109
csrc/prepare_inputs/advance_step.cu
csrc/prepare_inputs/advance_step.cu
+10
-0
tests/multi_step/test_correctness_llm.py
tests/multi_step/test_correctness_llm.py
+16
-1
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+24
-5
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+118
-102
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+1
-1
No files found.
csrc/prepare_inputs/advance_step.cu
View file @
0794e744
...
...
@@ -95,6 +95,16 @@ __global__ void advance_step_flashinfer_kernel(
long
*
input_positions_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
,
int
*
paged_kv_last_page_len_ptr
,
int
*
block_table_bound_ptr
)
{
int
const
n_pad
=
num_seqs
-
num_queries
;
if
(
n_pad
&&
blockIdx
.
x
==
0
)
{
// Handle cuda graph padding
int
const
offset
=
num_queries
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n_pad
;
i
+=
blockDim
.
x
)
{
input_tokens_ptr
[
offset
+
i
]
=
0
;
input_positions_ptr
[
offset
+
i
]
=
0
;
slot_mapping_ptr
[
offset
+
i
]
=
-
1
;
}
}
int
num_query_blocks
=
div_ceil
(
num_queries
,
num_threads
);
if
(
blockIdx
.
x
<
num_query_blocks
)
{
...
...
tests/multi_step/test_correctness_llm.py
View file @
0794e744
...
...
@@ -5,6 +5,8 @@ from typing import Optional
import
pytest
from
tests.kernels.utils
import
override_backend_env_variable
from
..models.utils
import
check_logprobs_close
,
check_outputs_equal
MODELS
=
[
...
...
@@ -19,10 +21,11 @@ NUM_PROMPTS = [10]
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"enable_chunked_prefill"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
None
,
5
])
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
def
test_multi_step_llm
(
hf_runner
,
vllm_runner
,
...
...
@@ -36,6 +39,8 @@ def test_multi_step_llm(
num_scheduler_steps
:
int
,
num_prompts
:
int
,
num_logprobs
:
Optional
[
int
],
attention_backend
:
str
,
monkeypatch
,
)
->
None
:
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
...
...
@@ -63,6 +68,7 @@ def test_multi_step_llm(
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> 1 logprob returned.
"""
override_backend_env_variable
(
monkeypatch
,
attention_backend
)
prompts
=
example_prompts
if
len
(
prompts
)
<
num_prompts
:
...
...
@@ -114,6 +120,7 @@ def test_multi_step_llm(
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_logprobs,num_prompt_logprobs"
,
[(
5
,
5
)])
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"FLASH_ATTN"
])
def
test_multi_step_llm_w_prompt_logprobs
(
vllm_runner
,
example_prompts
,
...
...
@@ -126,6 +133,8 @@ def test_multi_step_llm_w_prompt_logprobs(
num_prompts
:
int
,
num_logprobs
:
Optional
[
int
],
num_prompt_logprobs
:
Optional
[
int
],
attention_backend
:
str
,
monkeypatch
,
)
->
None
:
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
...
...
@@ -155,6 +164,7 @@ def test_multi_step_llm_w_prompt_logprobs(
note that this argument is not supported by the
OpenAI completions endpoint.
"""
override_backend_env_variable
(
monkeypatch
,
attention_backend
)
prompts
=
example_prompts
if
len
(
prompts
)
<
num_prompts
:
...
...
@@ -205,6 +215,7 @@ def test_multi_step_llm_w_prompt_logprobs(
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
None
,
5
])
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"FLASH_ATTN"
])
def
test_multi_step_llm_chunked_prefill_prefix_cache
(
vllm_runner
,
example_prompts
,
...
...
@@ -216,6 +227,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
num_scheduler_steps
:
int
,
num_prompts
:
int
,
num_logprobs
:
Optional
[
int
],
attention_backend
:
str
,
monkeypatch
,
)
->
None
:
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
...
...
@@ -278,6 +291,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
#
# The Incorrect scheduling behavior - if it occurs - will cause an exception
# in the model runner resulting from `do_sample=False`.
override_backend_env_variable
(
monkeypatch
,
attention_backend
)
assert
len
(
example_prompts
)
>=
2
challenge_prompts
=
copy
.
deepcopy
(
example_prompts
)
challenge_prompts
[
0
]
=
(
'vLLM is a high-throughput and memory-efficient '
...
...
vllm/attention/backends/flashinfer.py
View file @
0794e744
...
...
@@ -256,7 +256,12 @@ class FlashInferState(AttentionState):
def
begin_forward
(
self
,
model_input
):
assert
not
self
.
_is_graph_capturing
state
=
self
if
model_input
.
attn_metadata
.
use_cuda_graph
:
use_cuda_graph
=
model_input
.
attn_metadata
.
use_cuda_graph
is_decode
=
model_input
.
attn_metadata
.
num_prefills
==
0
# In case of multistep chunked-prefill, there might be prefill requests
# scheduled while CUDA graph mode is enabled. We don't run graph in that
# case.
if
use_cuda_graph
and
is_decode
:
batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
state
=
(
self
.
runner
.
graph_runners
[
model_input
.
virtual_engine
]
[
batch_size
].
attn_state
)
...
...
@@ -429,10 +434,24 @@ class FlashInferMetadata(AttentionMetadata):
Update metadata in-place to advance one decode step.
"""
assert
not
turn_prefills_into_decodes
,
\
(
"Chunked prefill is not supported with flashinfer yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter."
)
if
turn_prefills_into_decodes
:
# When Multi-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert
self
.
num_decode_tokens
+
self
.
num_prefills
==
num_seqs
# Flashinfer doesn't support speculative decoding + chunked-prefill
# + multi-step scheduling yet.
assert
self
.
decode_query_len
==
1
self
.
num_decode_tokens
+=
self
.
num_prefills
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
max_prefill_seq_len
=
0
self
.
max_query_len
=
1
self
.
slot_mapping
=
self
.
slot_mapping
[:
num_seqs
]
else
:
assert
self
.
seq_lens_tensor
is
not
None
assert
num_seqs
>
0
assert
num_queries
>
0
...
...
vllm/worker/model_runner.py
View file @
0794e744
...
...
@@ -5,6 +5,7 @@ import itertools
import
time
import
warnings
import
weakref
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
,
Union
)
...
...
@@ -1028,6 +1029,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
has_inner_state
=
model_config
.
has_inner_state
self
.
in_profile_run
=
False
# When using CUDA graph, the input block tables must be padded to
# max_seq_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
...
...
@@ -1228,110 +1231,123 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
return
builder
.
build
()
# type: ignore
@
contextmanager
def
set_in_profile_run
(
self
):
self
.
in_profile_run
=
True
try
:
yield
finally
:
self
.
in_profile_run
=
False
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params
=
SamplingParams
(
top_p
=
0.99
,
top_k
=
self
.
vocab_size
-
1
)
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
# passed in, which contains a lora from the lora warmup path.
dummy_lora_requests
:
List
[
LoRARequest
]
=
[]
dummy_lora_requests_per_seq
:
List
[
LoRARequest
]
=
[]
if
self
.
lora_config
:
assert
self
.
lora_manager
is
not
None
with
self
.
lora_manager
.
dummy_lora_cache
():
for
idx
in
range
(
self
.
lora_config
.
max_loras
):
lora_id
=
idx
+
1
dummy_lora_request
=
LoRARequest
(
lora_name
=
f
"warmup_
{
lora_id
}
"
,
lora_int_id
=
lora_id
,
lora_path
=
"/not/a/real/path"
,
)
self
.
lora_manager
.
add_dummy_lora
(
dummy_lora_request
,
rank
=
LORA_WARMUP_RANK
)
dummy_lora_requests
.
append
(
dummy_lora_request
)
dummy_lora_requests_per_seq
=
[
dummy_lora_requests
[
idx
%
len
(
dummy_lora_requests
)]
for
idx
in
range
(
max_num_seqs
)
]
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
# Additional GPU memory may be needed for multi-modal encoding, which
# needs to be accounted for when calculating the GPU blocks for
# vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
max_mm_tokens
=
self
.
mm_registry
.
get_max_multimodal_tokens
(
self
.
model_config
)
if
max_mm_tokens
>
0
:
max_num_seqs_orig
=
max_num_seqs
max_num_seqs
=
min
(
max_num_seqs
,
max_num_batched_tokens
//
max_mm_tokens
)
if
max_num_seqs
<
1
:
expr
=
(
f
"min(
{
max_num_seqs_orig
}
, "
f
"
{
max_num_batched_tokens
}
//
{
max_mm_tokens
}
)"
)
logger
.
warning
(
"Computed max_num_seqs (%s) to be less than 1. "
"Setting it to the minimum value of 1."
,
expr
)
max_num_seqs
=
1
batch_size
=
0
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
batch_size
+=
seq_len
dummy_data
=
self
.
input_registry
\
.
dummy_data_for_profiling
(
self
.
model_config
,
seq_len
,
self
.
mm_registry
)
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
is_prompt
=
True
,
seq_data
=
{
group_id
:
dummy_data
.
seq_data
},
sampling_params
=
sampling_params
,
block_tables
=
None
,
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
multi_modal_data
=
dummy_data
.
multi_modal_data
,
multi_modal_placeholders
=
dummy_data
.
multi_modal_placeholders
,
)
seqs
.
append
(
seq
)
# Run the model with the dummy inputs.
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
# it is important to create tensors inside the loop, rather than
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
kv_caches
=
[
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
for
_
in
range
(
num_layers
)
]
finished_requests_ids
=
[
seq
.
request_id
for
seq
in
seqs
]
model_input
=
self
.
prepare_model_input
(
seqs
,
finished_requests_ids
=
finished_requests_ids
)
intermediate_tensors
=
None
if
not
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
self
.
model
.
make_empty_intermediate_tensors
(
batch_size
=
batch_size
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
self
.
execute_model
(
model_input
,
kv_caches
,
intermediate_tensors
)
torch
.
cuda
.
synchronize
()
return
with
self
.
set_in_profile_run
():
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params
=
\
SamplingParams
(
top_p
=
0.99
,
top_k
=
self
.
vocab_size
-
1
)
max_num_batched_tokens
=
\
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
# passed in, which contains a lora from the lora warmup path.
dummy_lora_requests
:
List
[
LoRARequest
]
=
[]
dummy_lora_requests_per_seq
:
List
[
LoRARequest
]
=
[]
if
self
.
lora_config
:
assert
self
.
lora_manager
is
not
None
with
self
.
lora_manager
.
dummy_lora_cache
():
for
idx
in
range
(
self
.
lora_config
.
max_loras
):
lora_id
=
idx
+
1
dummy_lora_request
=
LoRARequest
(
lora_name
=
f
"warmup_
{
lora_id
}
"
,
lora_int_id
=
lora_id
,
lora_path
=
"/not/a/real/path"
,
)
self
.
lora_manager
.
add_dummy_lora
(
dummy_lora_request
,
rank
=
LORA_WARMUP_RANK
)
dummy_lora_requests
.
append
(
dummy_lora_request
)
dummy_lora_requests_per_seq
=
[
dummy_lora_requests
[
idx
%
len
(
dummy_lora_requests
)]
for
idx
in
range
(
max_num_seqs
)
]
# Profile memory usage with max_num_sequences sequences and the
# total number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
# Additional GPU memory may be needed for multi-modal encoding,
# which needs to be accounted for when calculating the GPU blocks
# for vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
max_mm_tokens
=
self
.
mm_registry
.
get_max_multimodal_tokens
(
self
.
model_config
)
if
max_mm_tokens
>
0
:
max_num_seqs_orig
=
max_num_seqs
max_num_seqs
=
min
(
max_num_seqs
,
max_num_batched_tokens
//
max_mm_tokens
)
if
max_num_seqs
<
1
:
expr
=
(
f
"min(
{
max_num_seqs_orig
}
, "
f
"
{
max_num_batched_tokens
}
//
{
max_mm_tokens
}
)"
)
logger
.
warning
(
"Computed max_num_seqs (%s) to be less than 1. "
"Setting it to the minimum value of 1."
,
expr
)
max_num_seqs
=
1
batch_size
=
0
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
batch_size
+=
seq_len
dummy_data
=
self
.
input_registry
\
.
dummy_data_for_profiling
(
self
.
model_config
,
seq_len
,
self
.
mm_registry
)
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
is_prompt
=
True
,
seq_data
=
{
group_id
:
dummy_data
.
seq_data
},
sampling_params
=
sampling_params
,
block_tables
=
None
,
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
multi_modal_data
=
dummy_data
.
multi_modal_data
,
multi_modal_placeholders
=
dummy_data
.
multi_modal_placeholders
,
)
seqs
.
append
(
seq
)
# Run the model with the dummy inputs.
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
# it is important to create tensors inside the loop, rather than
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
kv_caches
=
[
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
for
_
in
range
(
num_layers
)
]
finished_requests_ids
=
[
seq
.
request_id
for
seq
in
seqs
]
model_input
=
self
.
prepare_model_input
(
seqs
,
finished_requests_ids
=
finished_requests_ids
)
intermediate_tensors
=
None
if
not
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
\
self
.
model
.
make_empty_intermediate_tensors
(
batch_size
=
batch_size
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
self
.
execute_model
(
model_input
,
kv_caches
,
intermediate_tensors
)
torch
.
cuda
.
synchronize
()
return
def
remove_all_loras
(
self
):
if
not
self
.
lora_manager
:
...
...
vllm/worker/multi_step_model_runner.py
View file @
0794e744
...
...
@@ -32,7 +32,7 @@ logger = init_logger(__name__)
MULTI_STEP_ATTENTION_BACKENDS
=
[
"FLASH_ATTN"
,
"ROCM_FLASH"
,
"FLASHINFER"
,
"NO_ATTENTION"
]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS
=
[
"FLASH_ATTN"
]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS
=
[
"FLASH_ATTN"
,
"FLASHINFER"
]
def
_get_supported_attention_backends
(
chunked_prefill_enabled
:
bool
)
\
->
List
[
str
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment