Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
67b4221a
Unverified
Commit
67b4221a
authored
Apr 11, 2024
by
SangBin Cho
Committed by
GitHub
Apr 10, 2024
Browse files
[Core][5/N] Fully working chunked prefill e2e (#3884)
parent
63e7176f
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
261 additions
and
90 deletions
+261
-90
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+9
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+2
-3
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+4
-1
vllm/lora/layers.py
vllm/lora/layers.py
+3
-2
vllm/sequence.py
vllm/sequence.py
+2
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+241
-82
No files found.
vllm/distributed/communication_op.py
View file @
67b4221a
...
@@ -173,10 +173,18 @@ def broadcast_tensor_dict(
...
@@ -173,10 +173,18 @@ def broadcast_tensor_dict(
torch
.
distributed
.
broadcast_object_list
([
metadata_list
],
torch
.
distributed
.
broadcast_object_list
([
metadata_list
],
src
=
src
,
src
=
src
,
group
=
group
)
group
=
group
)
async_handles
=
[]
for
key
,
value
in
metadata_list
:
for
key
,
value
in
metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
tensor_dict
[
key
]
tensor
=
tensor_dict
[
key
]
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
group
=
group
)
async_handles
.
append
(
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
group
=
group
,
async_op
=
True
))
for
async_handle
in
async_handles
:
async_handle
.
wait
()
else
:
else
:
recv_metadata_list
=
[
None
]
recv_metadata_list
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
...
...
vllm/engine/arg_utils.py
View file @
67b4221a
...
@@ -386,9 +386,8 @@ class EngineArgs:
...
@@ -386,9 +386,8 @@ class EngineArgs:
'prompt latency) before scheduling next prompt.'
)
'prompt latency) before scheduling next prompt.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--enable-chunked-prefill'
,
'--enable-chunked-prefill'
,
type
=
bool
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, the prefill requests can be chunked based on the '
help
=
'If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens'
)
'max_num_batched_tokens'
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
vllm/engine/llm_engine.py
View file @
67b4221a
...
@@ -633,7 +633,10 @@ class LLMEngine:
...
@@ -633,7 +633,10 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
update_num_computed_tokens
(
seq_group
.
update_num_computed_tokens
(
scheduled_seq_group
.
token_chunk_size
)
scheduled_seq_group
.
token_chunk_size
)
self
.
_process_sequence_group_outputs
(
seq_group
,
outputs
)
# If uncomputed tokens > 0, it means prefill is chunked.
# We don't need to process outputs in that case.
if
seq_group
.
get_num_uncomputed_tokens
()
==
0
:
self
.
_process_sequence_group_outputs
(
seq_group
,
outputs
)
# Free the finished sequence groups.
# Free the finished sequence groups.
self
.
scheduler
.
free_finished_seq_groups
()
self
.
scheduler
.
free_finished_seq_groups
()
...
...
vllm/lora/layers.py
View file @
67b4221a
...
@@ -267,12 +267,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -267,12 +267,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
added_tokens_mask
=
x
>
self
.
base_layer
.
org_vocab_size
-
1
added_tokens_mask
=
x
>
self
.
base_layer
.
org_vocab_size
-
1
indices
=
self
.
embeddings_indices
[
1
][:
self
.
indices_len
[
3
]].
view_as
(
x
)
embedding_len
=
self
.
indices_len
[
3
]
indices
=
self
.
embeddings_indices
[
1
][:
embedding_len
].
view_as
(
x
)
full_lora_a_embeddings
=
F
.
embedding
(
full_lora_a_embeddings
=
F
.
embedding
(
x
+
indices
,
x
+
indices
,
self
.
lora_a_stacked_2d
,
self
.
lora_a_stacked_2d
,
)
)
indices
=
self
.
embeddings_indices
[
0
][:
self
.
indices
_len
[
3
]
].
view_as
(
x
)
indices
=
self
.
embeddings_indices
[
0
][:
embedding
_len
].
view_as
(
x
)
full_output
=
self
.
base_layer
.
forward
(
full_output
=
self
.
base_layer
.
forward
(
x
.
add_
(
indices
*
added_tokens_mask
))
x
.
add_
(
indices
*
added_tokens_mask
))
...
...
vllm/sequence.py
View file @
67b4221a
...
@@ -500,7 +500,8 @@ class SequenceGroup:
...
@@ -500,7 +500,8 @@ class SequenceGroup:
def
get_num_uncomputed_tokens
(
self
)
->
int
:
def
get_num_uncomputed_tokens
(
self
)
->
int
:
num_uncomputed_tokens
=
0
num_uncomputed_tokens
=
0
for
seq
in
self
.
get_seqs
():
for
seq
in
self
.
get_seqs
():
num_uncomputed_tokens
+=
seq
.
data
.
get_num_uncomputed_tokens
()
if
not
seq
.
is_finished
():
num_uncomputed_tokens
+=
seq
.
data
.
get_num_uncomputed_tokens
()
return
num_uncomputed_tokens
return
num_uncomputed_tokens
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
...
...
vllm/worker/model_runner.py
View file @
67b4221a
import
contextlib
import
contextlib
import
time
import
time
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
enum
import
IntEnum
from
typing
import
Dict
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataPerStage
,
get_attn_backend
)
from
vllm.config
import
(
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
,
with_pynccl_for_all_reduce
from
vllm.distributed
import
broadcast_tensor_dict
,
with_pynccl_for_all_reduce
...
@@ -37,6 +39,66 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
...
@@ -37,6 +39,66 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
]
]
class
PreparePromptMetadata
(
NamedTuple
):
input_tokens
:
List
[
int
]
input_positions
:
List
[
int
]
attn_metadata
:
Optional
[
AttentionMetadataPerStage
]
prompt_lens
:
List
[
int
]
subquery_lens
:
List
[
int
]
lora_index_mapping
:
List
[
int
]
lora_prompt_mapping
:
List
[
int
]
lora_requests
:
Set
[
LoRARequest
]
multi_modal_input
:
Optional
[
torch
.
Tensor
]
slot_mapping
:
List
[
int
]
@
classmethod
def
empty
(
cls
):
return
PreparePromptMetadata
(
input_tokens
=
[],
input_positions
=
[],
attn_metadata
=
None
,
prompt_lens
=
[],
subquery_lens
=
[],
lora_index_mapping
=
[],
lora_prompt_mapping
=
[],
lora_requests
=
set
(),
multi_modal_input
=
None
,
slot_mapping
=
[],
)
class
PrepareDecodeMetadata
(
NamedTuple
):
input_tokens
:
List
[
int
]
input_positions
:
List
[
int
]
attn_metadata
:
Optional
[
AttentionMetadata
]
lora_index_mapping
:
List
[
int
]
lora_prompt_mapping
:
List
[
int
]
lora_requests
:
Set
[
LoRARequest
]
slot_mapping
:
List
[
int
]
@
classmethod
def
empty
(
cls
):
return
PrepareDecodeMetadata
(
input_tokens
=
[],
input_positions
=
[],
attn_metadata
=
None
,
lora_index_mapping
=
[],
lora_prompt_mapping
=
[],
lora_requests
=
set
(),
slot_mapping
=
[],
)
# How batches are constructed.
class
BatchType
(
IntEnum
):
# Every batch is prefill.
PREFILL
=
0
# Every batch is decode.
DECODE
=
1
# Batch is a mixture of prefill and decode.
MIXED
=
2
class
ModelRunner
:
class
ModelRunner
:
def
__init__
(
def
__init__
(
...
@@ -152,10 +214,7 @@ class ModelRunner:
...
@@ -152,10 +214,7 @@ class ModelRunner:
def
_prepare_prompt
(
def
_prepare_prompt
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
)
->
PreparePromptMetadata
:
List
[
int
],
List
[
int
],
List
[
int
],
Set
[
LoRARequest
],
torch
.
Tensor
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
...
@@ -169,6 +228,9 @@ class ModelRunner:
...
@@ -169,6 +228,9 @@ class ModelRunner:
prefix_block_tables
:
List
[
List
[
int
]]
=
[]
prefix_block_tables
:
List
[
List
[
int
]]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
if
len
(
seq_group_metadata_list
)
==
0
:
return
PreparePromptMetadata
.
empty
()
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
...
@@ -178,7 +240,8 @@ class ModelRunner:
...
@@ -178,7 +240,8 @@ class ModelRunner:
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
if
(
self
.
scheduler_config
is
not
None
if
(
self
.
scheduler_config
is
not
None
and
self
.
scheduler_config
.
chunked_prefill_enabled
and
self
.
scheduler_config
.
chunked_prefill_enabled
and
computed_block_nums
is
not
None
):
and
not
(
computed_block_nums
is
None
or
computed_block_nums
==
[])):
raise
RuntimeError
(
raise
RuntimeError
(
"chunked prefill cannot be used with prefix caching "
"chunked prefill cannot be used with prefix caching "
"now."
)
"now."
)
...
@@ -190,13 +253,8 @@ class ModelRunner:
...
@@ -190,13 +253,8 @@ class ModelRunner:
# it contains output tokens.
# it contains output tokens.
prefill_end
=
min
(
seq_data
.
get_len
(),
prefill_end
=
min
(
seq_data
.
get_len
(),
computed_len
+
token_chunk_size
)
computed_len
+
token_chunk_size
)
# TODO(sang): Rename it after chunked prefill is introduced.
prompt_tokens
=
seq_data
.
get_token_ids
()[
computed_len
:
prefill_end
]
prompt_tokens
=
seq_data
.
get_token_ids
()[
computed_len
:
prefill_end
]
prompt_len
=
len
(
prompt_tokens
)
prompt_len
=
prefill_end
# Right now, the prefill_end is always same as the length of
# sequence. However, once chunked prefill is introduced, this
# assumption can be changed.
assert
prefill_end
==
seq_data
.
get_len
()
prompt_lens
.
append
(
prompt_len
)
prompt_lens
.
append
(
prompt_len
)
# NOTE: This only works for oooooooxxx style attention.
# NOTE: This only works for oooooooxxx style attention.
...
@@ -206,6 +264,14 @@ class ModelRunner:
...
@@ -206,6 +264,14 @@ class ModelRunner:
computed_len
=
len
(
computed_block_nums
)
*
self
.
block_size
computed_len
=
len
(
computed_block_nums
)
*
self
.
block_size
prompt_tokens
=
prompt_tokens
[
computed_len
:]
prompt_tokens
=
prompt_tokens
[
computed_len
:]
prefix_block_tables
.
append
(
computed_block_nums
)
prefix_block_tables
.
append
(
computed_block_nums
)
elif
self
.
scheduler_config
.
chunked_prefill_enabled
:
if
seq_group_metadata
.
block_tables
is
not
None
:
# Prefill has chunked before.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
prefix_block_tables
.
append
(
block_table
)
else
:
# The first prefill.
prefix_block_tables
.
append
([])
else
:
else
:
prefix_block_tables
.
append
([])
prefix_block_tables
.
append
([])
# Right now, prefill start is always 0. However, this
# Right now, prefill start is always 0. However, this
...
@@ -267,20 +333,8 @@ class ModelRunner:
...
@@ -267,20 +333,8 @@ class ModelRunner:
max_subquery_len
=
max
(
subquery_lens
)
max_subquery_len
=
max
(
subquery_lens
)
max_prompt_len
=
max
(
prompt_lens
)
max_prompt_len
=
max
(
prompt_lens
)
num_prompt_tokens
=
len
(
input_tokens
)
assert
max_subquery_len
>
0
assert
max_subquery_len
>
0
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
lora_index_mapping
=
lora_index_mapping
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
device
=
self
.
device
)
...
@@ -332,11 +386,8 @@ class ModelRunner:
...
@@ -332,11 +386,8 @@ class ModelRunner:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
is_prompt
=
True
,
slot_mapping
=
slot_mapping
,
prompt_lens
=
prompt_lens
,
prompt_lens
=
prompt_lens
,
prompt_lens_tensor
=
prompt_lens_tensor
,
prompt_lens_tensor
=
prompt_lens_tensor
,
num_prompt_tokens
=
num_prompt_tokens
,
num_generation_tokens
=
0
,
max_subquery_len
=
max_subquery_len
,
max_subquery_len
=
max_subquery_len
,
max_context_len
=
None
,
max_context_len
=
None
,
max_prompt_len
=
max_prompt_len
,
max_prompt_len
=
max_prompt_len
,
...
@@ -345,18 +396,25 @@ class ModelRunner:
...
@@ -345,18 +396,25 @@ class ModelRunner:
context_lens
=
context_lens_tensor
,
context_lens
=
context_lens_tensor
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
use_cuda_graph
=
False
,
use_cuda_graph
=
False
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
,
subquery_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
return
PreparePromptMetadata
(
lora_requests
,
multi_modal_input
)
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
attn_metadata
=
attn_metadata
,
prompt_lens
=
prompt_lens
,
subquery_lens
=
subquery_lens
,
lora_index_mapping
=
lora_index_mapping
,
lora_prompt_mapping
=
lora_prompt_mapping
,
lora_requests
=
lora_requests
,
multi_modal_input
=
multi_modal_input
,
slot_mapping
=
slot_mapping
,
)
def
_prepare_decode
(
def
_prepare_decode
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
)
->
PrepareDecodeMetadata
:
List
[
int
],
Set
[
LoRARequest
]]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
...
@@ -366,6 +424,9 @@ class ModelRunner:
...
@@ -366,6 +424,9 @@ class ModelRunner:
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_requests
:
Set
[
LoRARequest
]
=
set
()
lora_requests
:
Set
[
LoRARequest
]
=
set
()
if
len
(
seq_group_metadata_list
)
==
0
:
return
PrepareDecodeMetadata
.
empty
()
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
not
seq_group_metadata
.
is_prompt
assert
not
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
token_chunk_size
==
1
assert
seq_group_metadata
.
token_chunk_size
==
1
...
@@ -424,15 +485,6 @@ class ModelRunner:
...
@@ -424,15 +485,6 @@ class ModelRunner:
lora_index_mapping
.
append
(
0
)
lora_index_mapping
.
append
(
0
)
batch_size
=
graph_batch_size
batch_size
=
graph_batch_size
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
context_lens
=
torch
.
tensor
(
context_lens
,
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
device
=
self
.
device
)
...
@@ -440,9 +492,9 @@ class ModelRunner:
...
@@ -440,9 +492,9 @@ class ModelRunner:
if
use_captured_graph
:
if
use_captured_graph
:
# When using cuda-graph all these tensors should be
# When using cuda-graph all these tensors should be
# padded.
# padded.
assert
context_lens
.
shape
[
0
]
==
input_tokens
.
shape
[
0
]
assert
context_lens
.
shape
[
0
]
==
len
(
input_tokens
)
assert
context_lens
.
shape
[
0
]
==
input_positions
.
shape
[
0
]
assert
context_lens
.
shape
[
0
]
==
len
(
input_positions
)
assert
context_lens
.
shape
[
0
]
==
slot_mapping
.
shape
[
0
]
assert
context_lens
.
shape
[
0
]
==
len
(
slot_mapping
)
# The shape of graph_block_tables is
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
# [max batch size, max context len // block size].
...
@@ -464,11 +516,8 @@ class ModelRunner:
...
@@ -464,11 +516,8 @@ class ModelRunner:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
is_prompt
=
False
,
slot_mapping
=
slot_mapping
,
prompt_lens
=
None
,
prompt_lens
=
None
,
prompt_lens_tensor
=
None
,
prompt_lens_tensor
=
None
,
num_prompt_tokens
=
0
,
num_generation_tokens
=
len
(
input_tokens
),
max_subquery_len
=
None
,
max_subquery_len
=
None
,
max_context_len
=
max_context_len
,
max_context_len
=
max_context_len
,
max_prompt_len
=
None
,
max_prompt_len
=
None
,
...
@@ -477,10 +526,16 @@ class ModelRunner:
...
@@ -477,10 +526,16 @@ class ModelRunner:
context_lens
=
context_lens
,
context_lens
=
context_lens
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
use_cuda_graph
=
use_captured_graph
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
return
PrepareDecodeMetadata
(
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
)
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
attn_metadata
=
attn_metadata
,
lora_index_mapping
=
lora_index_mapping
,
lora_prompt_mapping
=
lora_prompt_mapping
,
lora_requests
=
lora_requests
,
slot_mapping
=
slot_mapping
,
)
def
_prepare_sample
(
def
_prepare_sample
(
self
,
self
,
...
@@ -586,26 +641,66 @@ class ModelRunner:
...
@@ -586,26 +641,66 @@ class ModelRunner:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
,
Set
[
int
],
LoRAMapping
,
torch
.
Tensor
]:
Set
[
int
],
LoRAMapping
,
torch
.
Tensor
]:
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
# NOTE: We assume that all sequences in the group are all prompts or
prefill_reqs
=
[]
# all decodes.
decode_reqs
=
[]
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
for
seq_group_meta
in
seq_group_metadata_list
:
if
seq_group_meta
.
is_prompt
:
prefill_reqs
.
append
(
seq_group_meta
)
else
:
decode_reqs
.
append
(
seq_group_meta
)
# Prepare input tensors.
# Prepare input tensors.
if
is_prompt
:
(
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
,
input_tokens
,
subquery_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
input_positions
,
lora_requests
,
multi_modal_input
prefill_attn_metadata
,
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
prompt_lens
,
else
:
subquery_lens
,
(
input_tokens
,
input_positions
,
attn_metadata
,
lora_index_mapping
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_prompt_mapping
,
lora_requests
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
lora_requests
,
prompt_lens
=
[]
multi_modal_input
,
subquery_lens
=
None
slot_mapping
,
multi_modal_input
=
None
)
=
self
.
_prepare_prompt
(
prefill_reqs
)
(
decode_input_tokens
,
decode_input_positions
,
decode_attn_metadata
,
decode_lora_index_mapping
,
decode_lora_prompt_mapping
,
decode_lora_requests
,
decode_slot_mapping
,
)
=
self
.
_prepare_decode
(
decode_reqs
)
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
prompt_lens
,
subquery_lens
)
subquery_lens
)
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
:
assert
(
len
(
prefill_reqs
)
and
len
(
decode_reqs
))
==
0
num_prefills
=
len
(
prompt_lens
)
num_prefill_tokens
=
len
(
input_tokens
)
num_decode_tokens
=
len
(
decode_input_tokens
)
# Coalesce tensors. Note that attn_metadata is currently not
# coalesced for simplicity.
input_tokens
.
extend
(
decode_input_tokens
)
input_positions
.
extend
(
decode_input_positions
)
slot_mapping
.
extend
(
decode_slot_mapping
)
lora_index_mapping
.
extend
(
decode_lora_index_mapping
)
lora_prompt_mapping
.
extend
(
decode_lora_prompt_mapping
)
lora_requests
.
update
(
decode_lora_requests
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
if
self
.
lora_config
:
if
self
.
lora_config
:
lora_mapping
=
LoRAMapping
(
lora_mapping
=
LoRAMapping
(
lora_index_mapping
,
lora_index_mapping
,
...
@@ -615,6 +710,16 @@ class ModelRunner:
...
@@ -615,6 +710,16 @@ class ModelRunner:
lora_mapping
=
None
lora_mapping
=
None
# Broadcast the metadata.
# Broadcast the metadata.
# If batch contains both prefill and decode, it sends 2 broadcasts.
# If it only contains 1 type, it triggers a single broadcast.
if
(
prefill_attn_metadata
is
not
None
and
decode_attn_metadata
is
not
None
):
batch_type
=
BatchType
.
MIXED
elif
prefill_attn_metadata
is
not
None
:
batch_type
=
BatchType
.
PREFILL
else
:
batch_type
=
BatchType
.
DECODE
metadata_dict
=
{
metadata_dict
=
{
"input_tokens"
:
input_tokens
,
"input_tokens"
:
input_tokens
,
"input_positions"
:
input_positions
,
"input_positions"
:
input_positions
,
...
@@ -623,19 +728,49 @@ class ModelRunner:
...
@@ -623,19 +728,49 @@ class ModelRunner:
"lora_requests"
:
lora_requests
,
"lora_requests"
:
lora_requests
,
"lora_mapping"
:
lora_mapping
,
"lora_mapping"
:
lora_mapping
,
"multi_modal_input"
:
multi_modal_input
,
"multi_modal_input"
:
multi_modal_input
,
"num_prefill_tokens"
:
num_prefill_tokens
,
"num_decode_tokens"
:
num_decode_tokens
,
"slot_mapping"
:
slot_mapping
,
"num_prefills"
:
num_prefills
,
"batch_type"
:
batch_type
,
}
}
metadata_dict
.
update
(
attn_metadata
.
asdict_zerocopy
())
if
prefill_attn_metadata
is
not
None
:
metadata_dict
.
update
(
prefill_attn_metadata
.
asdict_zerocopy
())
else
:
metadata_dict
.
update
(
decode_attn_metadata
.
asdict_zerocopy
())
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
# Broadcast decode attn metadata for mixed batch type.
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
# We can potentially reduce the overhead by coelescing tensors.
if
batch_type
==
BatchType
.
MIXED
:
assert
decode_attn_metadata
is
not
None
metadata_dict
=
decode_attn_metadata
.
asdict_zerocopy
()
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
else
:
else
:
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
input_tokens
=
metadata_dict
.
pop
(
"input_tokens"
)
input_tokens
=
metadata_dict
.
pop
(
"input_tokens"
)
input_positions
=
metadata_dict
.
pop
(
"input_positions"
)
input_positions
=
metadata_dict
.
pop
(
"input_positions"
)
slot_mapping
=
metadata_dict
.
pop
(
"slot_mapping"
)
num_prefills
=
metadata_dict
.
pop
(
"num_prefills"
)
selected_token_indices
=
metadata_dict
.
pop
(
selected_token_indices
=
metadata_dict
.
pop
(
"selected_token_indices"
)
"selected_token_indices"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
multi_modal_input
=
metadata_dict
.
pop
(
"multi_modal_input"
)
multi_modal_input
=
metadata_dict
.
pop
(
"multi_modal_input"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
num_prefill_tokens
=
metadata_dict
.
pop
(
"num_prefill_tokens"
)
num_decode_tokens
=
metadata_dict
.
pop
(
"num_decode_tokens"
)
batch_type
=
metadata_dict
.
pop
(
"batch_type"
)
# Create an attention metadata.
prefill_attn_metadata
=
None
decode_attn_metadata
=
None
if
batch_type
==
BatchType
.
PREFILL
or
batch_type
==
BatchType
.
MIXED
:
prefill_attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
else
:
decode_attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
sampling_metadata
=
SamplingMetadata
(
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_groups
=
None
,
seq_data
=
None
,
seq_data
=
None
,
...
@@ -646,6 +781,23 @@ class ModelRunner:
...
@@ -646,6 +781,23 @@ class ModelRunner:
perform_sampling
=
False
,
perform_sampling
=
False
,
)
)
# if it is a mixed batch, decode attn_metadata is broadcasted
# separately.
if
batch_type
==
BatchType
.
MIXED
:
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
decode_attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
attn_metadata
=
AttentionMetadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
slot_mapping
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
prefill_metadata
=
prefill_attn_metadata
,
decode_metadata
=
decode_attn_metadata
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
return
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
lora_requests
,
lora_mapping
,
sampling_metadata
,
lora_requests
,
lora_mapping
,
multi_modal_input
)
multi_modal_input
)
...
@@ -663,8 +815,10 @@ class ModelRunner:
...
@@ -663,8 +815,10 @@ class ModelRunner:
if
self
.
lora_config
:
if
self
.
lora_config
:
self
.
set_active_loras
(
lora_requests
,
lora_mapping
)
self
.
set_active_loras
(
lora_requests
,
lora_mapping
)
# Execute the model.
# Currently cuda graph is only supported by the decode phase.
if
attn_metadata
.
use_cuda_graph
:
prefill_meta
=
attn_metadata
.
prefill_metadata
decode_meta
=
attn_metadata
.
decode_metadata
if
prefill_meta
is
None
and
decode_meta
.
use_cuda_graph
:
graph_batch_size
=
input_tokens
.
shape
[
0
]
graph_batch_size
=
input_tokens
.
shape
[
0
]
model_executable
=
self
.
graph_runners
[
graph_batch_size
]
model_executable
=
self
.
graph_runners
[
graph_batch_size
]
else
:
else
:
...
@@ -842,13 +996,10 @@ class ModelRunner:
...
@@ -842,13 +996,10 @@ class ModelRunner:
# memory usage of CUDA graph.
# memory usage of CUDA graph.
for
batch_size
in
reversed
(
batch_size_capture_list
):
for
batch_size
in
reversed
(
batch_size_capture_list
):
# Create dummy attn_metadata.
# Create dummy attn_metadata.
attn
_metadata
=
self
.
attn_backend
.
make_metadata
(
decode
_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
is_prompt
=
False
,
slot_mapping
=
slot_mapping
[:
batch_size
],
prompt_lens
=
None
,
prompt_lens
=
None
,
prompt_lens_tensor
=
None
,
prompt_lens_tensor
=
None
,
num_prompt_tokens
=
0
,
num_generation_tokens
=
batch_size
,
max_subquery_len
=
None
,
max_subquery_len
=
None
,
max_context_len
=
self
.
max_context_len_to_capture
,
max_context_len
=
self
.
max_context_len_to_capture
,
max_prompt_len
=
None
,
max_prompt_len
=
None
,
...
@@ -857,6 +1008,14 @@ class ModelRunner:
...
@@ -857,6 +1008,14 @@ class ModelRunner:
context_lens
=
context_lens
[:
batch_size
],
context_lens
=
context_lens
[:
batch_size
],
block_tables
=
block_tables
[:
batch_size
],
block_tables
=
block_tables
[:
batch_size
],
use_cuda_graph
=
True
,
use_cuda_graph
=
True
,
)
attn_metadata
=
AttentionMetadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
slot_mapping
=
slot_mapping
[:
batch_size
],
prefill_metadata
=
None
,
decode_metadata
=
decode_metadata
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
)
...
@@ -950,8 +1109,8 @@ class CUDAGraphRunner:
...
@@ -950,8 +1109,8 @@ class CUDAGraphRunner:
"positions"
:
positions
,
"positions"
:
positions
,
"kv_caches"
:
kv_caches
,
"kv_caches"
:
kv_caches
,
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
"context_lens"
:
attn_metadata
.
context_lens
,
"context_lens"
:
attn_metadata
.
decode_metadata
.
context_lens
,
"block_tables"
:
attn_metadata
.
block_tables
,
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
}
}
self
.
output_buffers
=
{
"hidden_states"
:
hidden_states
}
self
.
output_buffers
=
{
"hidden_states"
:
hidden_states
}
return
return
...
@@ -972,10 +1131,10 @@ class CUDAGraphRunner:
...
@@ -972,10 +1131,10 @@ class CUDAGraphRunner:
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
attn_metadata
.
slot_mapping
,
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
attn_metadata
.
slot_mapping
,
non_blocking
=
True
)
non_blocking
=
True
)
self
.
input_buffers
[
"context_lens"
].
copy_
(
attn_metadata
.
context_lens
,
self
.
input_buffers
[
"context_lens"
].
copy_
(
non_blocking
=
True
)
attn_metadata
.
decode_metadata
.
context_lens
,
non_blocking
=
True
)
self
.
input_buffers
[
"block_tables"
].
copy_
(
attn_metadata
.
block_tables
,
self
.
input_buffers
[
"block_tables"
].
copy_
(
non_blocking
=
True
)
attn_metadata
.
decode_metadata
.
block_tables
,
non_blocking
=
True
)
# Run the graph.
# Run the graph.
self
.
graph
.
replay
()
self
.
graph
.
replay
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment