Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e02ac556
Unverified
Commit
e02ac556
authored
Aug 09, 2024
by
Alexander Matveev
Committed by
GitHub
Aug 08, 2024
Browse files
[Performance] Optimize e2e overheads: Reduce python allocations (#7162)
parent
73388c07
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
550 additions
and
125 deletions
+550
-125
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+5
-1
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+10
-2
vllm/block.py
vllm/block.py
+45
-3
vllm/core/block_manager_v1.py
vllm/core/block_manager_v1.py
+15
-12
vllm/core/scheduler.py
vllm/core/scheduler.py
+127
-44
vllm/model_executor/__init__.py
vllm/model_executor/__init__.py
+3
-1
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+71
-10
vllm/outputs.py
vllm/outputs.py
+1
-1
vllm/sequence.py
vllm/sequence.py
+24
-4
vllm/utils.py
vllm/utils.py
+38
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+211
-47
No files found.
vllm/attention/backends/flash_attn.py
View file @
e02ac556
...
...
@@ -259,7 +259,11 @@ class FlashAttentionMetadataBuilder(
block_table
=
block_tables
[
seq_id
]
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
if
curr_sliding_window_block
==
0
:
block_table
=
block_tables
[
seq_id
]
else
:
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
# Compute slot mapping.
...
...
vllm/attention/backends/utils.py
View file @
e02ac556
...
...
@@ -68,13 +68,21 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
block_table
=
block_tables
[
seq_id
]
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
max
(
0
,
start_idx
-
context_len
))
for
i
in
range
(
max
(
start_idx
,
context_len
),
seq_len
):
def
add_slot
(
i
):
block_number
=
block_table
[
i
//
block_size
]
block_offset
=
i
%
block_size
slot
=
block_number
*
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
if
start_idx
==
0
and
(
seq_len
-
context_len
)
==
1
:
# Optimization for common-case of decoding next token
add_slot
(
seq_len
-
1
)
else
:
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
max
(
0
,
start_idx
-
context_len
))
for
i
in
range
(
max
(
start_idx
,
context_len
),
seq_len
):
add_slot
(
i
)
TAttentionMetadata
=
TypeVar
(
"TAttentionMetadata"
,
bound
=
'AttentionMetadata'
)
...
...
vllm/block.py
View file @
e02ac556
"""Token blocks."""
from
typing
import
List
from
typing
import
List
,
Optional
from
vllm.utils
import
Device
...
...
@@ -37,5 +37,47 @@ class PhysicalTokenBlock:
f
'computed=
{
self
.
computed
}
)'
)
# Mapping: logical block number -> physical block.
BlockTable
=
List
[
PhysicalTokenBlock
]
class
BlockTable
:
"""Holds a list of blocks with caching of their associated block_ids
"""
def
__init__
(
self
,
blocks
:
Optional
[
List
[
PhysicalTokenBlock
]]
=
None
):
self
.
_blocks
:
List
[
PhysicalTokenBlock
]
=
[]
self
.
_block_ids
:
List
[
int
]
=
[]
if
blocks
is
not
None
:
for
block
in
blocks
:
self
.
append
(
block
)
def
append
(
self
,
block
:
PhysicalTokenBlock
):
self
.
_blocks
.
append
(
block
)
self
.
_block_ids
.
append
(
block
.
block_number
)
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_blocks
)
def
__getitem__
(
self
,
key
):
return
self
.
_blocks
[
key
]
def
__setitem__
(
self
,
key
,
value
):
if
isinstance
(
key
,
slice
):
blocks
=
value
self
.
_blocks
[
key
]
=
blocks
self
.
_block_ids
[
key
]
=
[
b
.
block_number
for
b
in
blocks
]
else
:
block
=
value
self
.
_blocks
[
key
]
=
block
self
.
_block_ids
[
key
]
=
block
.
block_number
def
reset
(
self
):
self
.
_blocks
=
[]
self
.
_block_ids
=
[]
def
copy
(
self
)
->
"BlockTable"
:
return
BlockTable
(
self
.
_blocks
)
def
list
(
self
)
->
List
[
PhysicalTokenBlock
]:
return
self
.
_blocks
def
ids
(
self
)
->
List
[
int
]:
return
self
.
_block_ids
vllm/core/block_manager_v1.py
View file @
e02ac556
...
...
@@ -170,7 +170,7 @@ class UncachedBlockAllocator(BlockAllocatorBase):
self
.
num_blocks
=
num_blocks
# Initialize the free blocks.
self
.
free_blocks
:
BlockTable
=
[]
self
.
free_blocks
:
List
[
PhysicalTokenBlock
]
=
[]
for
i
in
range
(
num_blocks
):
block
=
PhysicalTokenBlock
(
device
=
device
,
block_number
=
i
,
...
...
@@ -256,6 +256,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
Device
.
CPU
,
block_size
,
num_cpu_blocks
)
# Mapping: seq_id -> BlockTable.
self
.
block_tables
:
Dict
[
int
,
BlockTable
]
=
{}
# Mapping: req_id -> BlockTable
# Note that each SequenceGroup has a unique
# request ID
...
...
@@ -299,7 +300,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks
=
seq
.
n_blocks
block_table
:
BlockTable
=
[]
block_table
:
BlockTable
=
BlockTable
()
for
logical_idx
in
range
(
num_prompt_blocks
):
if
(
self
.
block_sliding_window
is
not
None
and
logical_idx
>=
self
.
block_sliding_window
):
...
...
@@ -326,13 +327,17 @@ class BlockSpaceManagerV1(BlockSpaceManager):
#
# NOTE: Here we assume that all sequences in the group have the same
# decoder prompt.
seq
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)[
0
]
wait_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)
seq
=
wait_seqs
[
0
]
block_table
:
BlockTable
=
\
self
.
_allocate_sequence
(
seq
,
seq_group
.
num_seqs
(),
is_encoder_decoder
)
# Assign the self-attention block tables for each sequence.
if
len
(
wait_seqs
)
==
1
:
self
.
block_tables
[
wait_seqs
[
0
].
seq_id
]
=
block_table
else
:
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
copy
()
...
...
@@ -476,6 +481,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
return
src_block_table
=
self
.
block_tables
[
parent_seq
.
seq_id
]
self
.
block_tables
[
child_seq
.
seq_id
]
=
src_block_table
.
copy
()
# When using a sliding window, blocks will be eventually reused.
# In this case the block tables will contain repeated blocks.
# When forking, we must make sure that each block's `ref_count`
...
...
@@ -527,7 +533,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
dest_allocator
:
BlockAllocatorBase
,
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
])
->
BlockTable
:
new_block_table
=
[]
new_block_table
:
BlockTable
=
BlockTable
()
for
from_block
in
block_table
:
if
from_block
in
mapping
:
...
...
@@ -553,8 +559,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
self
.
block_tables
[
seq
.
seq_id
]
=
\
self
.
_swap_block_table
(
self
.
block_tables
[
seq
.
seq_id
],
self
.
cpu_allocator
,
self
.
gpu_allocator
,
self
.
cpu_allocator
,
self
.
gpu_allocator
,
mapping
)
if
seq_group
.
is_encoder_decoder
():
...
...
@@ -580,8 +585,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
self
.
block_tables
[
seq
.
seq_id
]
=
\
self
.
_swap_block_table
(
self
.
block_tables
[
seq
.
seq_id
],
self
.
gpu_allocator
,
self
.
cpu_allocator
,
self
.
gpu_allocator
,
self
.
cpu_allocator
,
mapping
)
if
seq_group
.
is_encoder_decoder
():
...
...
@@ -636,8 +640,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self
.
cross_block_tables
.
clear
()
def
get_block_table
(
self
,
seq
:
Sequence
)
->
List
[
int
]:
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
return
[
block
.
block_number
for
block
in
block_table
]
return
self
.
block_tables
[
seq
.
seq_id
].
ids
()
def
get_cross_block_table
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
int
]:
block_table
=
self
.
cross_block_tables
[
seq_group
.
request_id
]
...
...
vllm/core/scheduler.py
View file @
e02ac556
...
...
@@ -13,6 +13,7 @@ from vllm.lora.request import LoRARequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
from
vllm.utils
import
PyObjectCache
logger
=
init_logger
(
__name__
)
...
...
@@ -176,10 +177,10 @@ class SchedulerRunningOutputs:
enough memory, it can be preempted (for recompute) or swapped out.
"""
# Selected sequences that are running and in a decoding phase.
decode_seq_groups
:
List
[
SequenceGroup
]
decode_seq_groups
:
List
[
Scheduled
SequenceGroup
]
# Selected sequences that are running and in a prefill phase.
# I.e., it means the prefill has been chunked.
prefill_seq_groups
:
List
[
SequenceGroup
]
prefill_seq_groups
:
List
[
Scheduled
SequenceGroup
]
# The preempted sequences.
preempted
:
List
[
SequenceGroup
]
# Sequences that are swapped out.
...
...
@@ -191,6 +192,10 @@ class SchedulerRunningOutputs:
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
# Optimization for fast-access to seq_group lists
decode_seq_groups_list
:
List
[
SequenceGroup
]
prefill_seq_groups_list
:
List
[
SequenceGroup
]
@
classmethod
def
create_empty
(
cls
)
->
"SchedulerRunningOutputs"
:
return
SchedulerRunningOutputs
(
...
...
@@ -201,6 +206,8 @@ class SchedulerRunningOutputs:
blocks_to_swap_out
=
[],
blocks_to_copy
=
[],
num_lookahead_slots
=
0
,
decode_seq_groups_list
=
[],
prefill_seq_groups_list
=
[],
)
...
...
@@ -259,6 +266,30 @@ class SchedulerPrefillOutputs:
)
def
seq_group_metadata_builder
():
return
SequenceGroupMetadata
(
request_id
=
""
,
is_prompt
=
False
,
seq_data
=
{},
sampling_params
=
None
,
block_tables
=
{})
def
scheduler_running_outputs_builder
():
return
SchedulerRunningOutputs
(
decode_seq_groups
=
[],
prefill_seq_groups
=
[],
preempted
=
[],
swapped_out
=
[],
blocks_to_swap_out
=
[],
blocks_to_copy
=
[],
num_lookahead_slots
=
0
,
prefill_seq_groups_list
=
[],
decode_seq_groups_list
=
[])
def
scheduled_seq_group_builder
():
return
ScheduledSequenceGroup
(
seq_group
=
None
,
token_chunk_size
=
0
)
class
Scheduler
:
def
__init__
(
...
...
@@ -331,6 +362,14 @@ class Scheduler:
else
0
)
self
.
num_cumulative_preemption
:
int
=
0
# Used to cache python objects
self
.
_seq_group_metadata_cache
:
PyObjectCache
=
PyObjectCache
(
seq_group_metadata_builder
)
self
.
_scheduler_running_outputs_cache
:
PyObjectCache
=
PyObjectCache
(
scheduler_running_outputs_builder
)
self
.
_scheduled_seq_group_cache
:
PyObjectCache
=
PyObjectCache
(
scheduled_seq_group_builder
)
@
property
def
lora_enabled
(
self
)
->
bool
:
return
bool
(
self
.
lora_config
)
...
...
@@ -441,14 +480,30 @@ class Scheduler:
Returns:
SchedulerRunningOutputs.
"""
ret
:
SchedulerRunningOutputs
=
\
self
.
_scheduler_running_outputs_cache
.
get_object
()
ret
.
blocks_to_swap_out
.
clear
()
ret
.
blocks_to_copy
.
clear
()
ret
.
decode_seq_groups
.
clear
()
ret
.
prefill_seq_groups
.
clear
()
ret
.
preempted
.
clear
()
ret
.
swapped_out
.
clear
()
ret
.
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
)
ret
.
decode_seq_groups_list
.
clear
()
ret
.
prefill_seq_groups_list
.
clear
()
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
ret
.
blocks_to_swap_out
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
ret
.
blocks_to_copy
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
preempted
:
List
[
SequenceGroup
]
=
[]
swapped_out
:
List
[
SequenceGroup
]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
ret
.
decode_seq_groups
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
ret
.
prefill_seq_groups
preempted
:
List
[
SequenceGroup
]
=
ret
.
preempted
swapped_out
:
List
[
SequenceGroup
]
=
ret
.
swapped_out
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
...
...
@@ -497,15 +552,19 @@ class Scheduler:
else
:
self
.
_append_slots
(
seq_group
,
blocks_to_copy
)
is_prefill
=
seq_group
.
is_prefill
()
scheduled_seq_group
:
ScheduledSequenceGroup
=
\
self
.
_scheduled_seq_group_cache
.
get_object
()
scheduled_seq_group
.
seq_group
=
seq_group
if
is_prefill
:
prefill_seq_groups
.
append
(
ScheduledSequenceGroup
(
seq_group
=
seq_group
,
token_chunk_size
=
num_running_tokens
))
scheduled_seq_group
.
token_chunk_size
=
num_running_tokens
prefill_seq_groups
.
append
(
scheduled_seq_group
)
ret
.
prefill_seq_groups_list
.
append
(
seq_group
)
else
:
decode_seq_groups
.
append
(
ScheduledSequenceGroup
(
seq_group
=
seq_group
,
token_chunk_size
=
1
))
scheduled_seq_group
.
token_chunk_size
=
1
decode_seq_groups
.
append
(
scheduled_seq_group
)
ret
.
decode_seq_groups_list
.
append
(
seq_group
)
budget
.
add_num_batched_tokens
(
seq_group
.
request_id
,
num_running_tokens
)
# OPTIMIZATION: Note that get_max_num_running_seqs is
...
...
@@ -518,15 +577,10 @@ class Scheduler:
if
curr_loras
is
not
None
and
seq_group
.
lora_int_id
>
0
:
curr_loras
.
add
(
seq_group
.
lora_int_id
)
return
SchedulerRunningOutputs
(
decode_seq_groups
=
decode_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
preempted
=
preempted
,
swapped_out
=
swapped_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
))
self
.
_scheduler_running_outputs_cache
.
reset
()
self
.
_scheduled_seq_group_cache
.
reset
()
return
ret
def
_schedule_swapped
(
self
,
...
...
@@ -820,11 +874,15 @@ class Scheduler:
# Update waiting requests.
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
# Update new running requests.
if
len
(
prefills
.
seq_groups
)
>
0
:
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
self
.
running
.
extend
(
running_scheduled
.
decode_seq_groups_list
)
if
len
(
swapped_in
.
decode_seq_groups
)
>
0
:
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
swapped_in
.
decode_seq_groups
])
# Update swapped requests.
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
preempted
=
(
len
(
running_scheduled
.
preempted
)
+
...
...
@@ -834,18 +892,30 @@ class Scheduler:
# doesn't allow chunked prefills.
assert
len
(
running_scheduled
.
prefill_seq_groups
)
==
0
assert
len
(
swapped_in
.
prefill_seq_groups
)
==
0
# Merge lists
num_prefill_groups
=
len
(
prefills
.
seq_groups
)
if
num_prefill_groups
>
0
:
scheduled_seq_groups
=
prefills
.
seq_groups
scheduled_seq_groups
.
extend
(
running_scheduled
.
decode_seq_groups
)
else
:
scheduled_seq_groups
=
running_scheduled
.
decode_seq_groups
scheduled_seq_groups
.
extend
(
swapped_in
.
decode_seq_groups
)
blocks_to_copy
=
running_scheduled
.
blocks_to_copy
blocks_to_copy
.
extend
(
swapped_in
.
blocks_to_copy
)
ignored_seq_groups
=
prefills
.
ignored_seq_groups
ignored_seq_groups
.
extend
(
swapped_in
.
infeasible_seq_groups
)
return
SchedulerOutputs
(
scheduled_seq_groups
=
(
prefills
.
seq_groups
+
running_scheduled
.
decode_seq_groups
+
swapped_in
.
decode_seq_groups
),
num_prefill_groups
=
len
(
prefills
.
seq_groups
),
scheduled_seq_groups
=
scheduled_seq_groups
,
num_prefill_groups
=
num_prefill_groups
,
num_batched_tokens
=
budget
.
num_batched_tokens
,
blocks_to_swap_in
=
swapped_in
.
blocks_to_swap_in
,
blocks_to_swap_out
=
running_scheduled
.
blocks_to_swap_out
,
blocks_to_copy
=
running_scheduled
.
blocks_to_copy
+
swapped_in
.
blocks_to_copy
,
ignored_seq_groups
=
prefills
.
ignored_seq_groups
+
swapped_in
.
infeasible_seq_groups
,
blocks_to_copy
=
blocks_to_copy
,
ignored_seq_groups
=
ignored_seq_groups
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
running_queue_size
=
len
(
self
.
running
),
preempted
=
preempted
,
...
...
@@ -963,6 +1033,9 @@ class Scheduler:
scheduler_outputs
=
self
.
_schedule
()
now
=
time
.
time
()
if
not
self
.
cache_config
.
enable_prefix_caching
:
common_computed_block_nums
=
[]
# Create input data structures.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
for
i
,
scheduled_seq_group
in
enumerate
(
...
...
@@ -971,10 +1044,15 @@ class Scheduler:
token_chunk_size
=
scheduled_seq_group
.
token_chunk_size
seq_group
.
maybe_set_first_scheduled_time
(
now
)
seq_group_metadata
=
self
.
_seq_group_metadata_cache
.
get_object
()
seq_group_metadata
.
seq_data
.
clear
()
seq_group_metadata
.
block_tables
.
clear
()
# seq_id -> SequenceData
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
seq_data
:
Dict
[
int
,
SequenceData
]
=
seq_group_metadata
.
seq_data
# seq_id -> physical block numbers
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
seq_group_metadata
.
block_tables
if
seq_group
.
is_encoder_decoder
():
# Encoder associated with SequenceGroup
...
...
@@ -993,6 +1071,7 @@ class Scheduler:
block_tables
[
seq_id
]
=
self
.
block_manager
.
get_block_table
(
seq
)
self
.
block_manager
.
access_all_blocks_in_seq
(
seq
,
now
)
if
self
.
cache_config
.
enable_prefix_caching
:
common_computed_block_nums
=
(
self
.
block_manager
.
get_common_computed_block_ids
(
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)))
...
...
@@ -1014,7 +1093,8 @@ class Scheduler:
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
is_prompt
=
seq_group
.
is_prefill
()
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
.
__init__
(
request_id
=
seq_group
.
request_id
,
is_prompt
=
is_prompt
,
seq_data
=
seq_data
,
...
...
@@ -1045,6 +1125,8 @@ class Scheduler:
self
.
block_manager
.
mark_blocks_as_computed
(
scheduled_seq_group
.
seq_group
)
self
.
_seq_group_metadata_cache
.
reset
()
return
seq_group_metadata_list
,
scheduler_outputs
def
fork_seq
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
...
...
@@ -1093,6 +1175,7 @@ class Scheduler:
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
cows
=
self
.
block_manager
.
append_slots
(
seq
,
num_lookahead_slots
)
if
len
(
cows
)
>
0
:
blocks_to_copy
.
extend
(
cows
)
def
_preempt
(
...
...
vllm/model_executor/__init__.py
View file @
e02ac556
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
PackedvLLMParameter
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
SamplingMetadataCache
)
from
vllm.model_executor.utils
import
set_random_seed
__all__
=
[
"SamplingMetadata"
,
"SamplingMetadataCache"
,
"set_random_seed"
,
"BasevLLMParameter"
,
"PackedvLLMParameter"
,
...
...
vllm/model_executor/sampling_metadata.py
View file @
e02ac556
...
...
@@ -8,8 +8,9 @@ import torch
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
,
SequenceGroupMetadata
from
vllm.triton_utils.sample
import
get_num_triton_sampler_splits
from
vllm.utils
import
(
async_tensor_h2d
,
is_pin_memory_available
,
make_tensor_with_pad
,
maybe_expand_dim
)
from
vllm.utils
import
(
PyObjectCache
,
async_tensor_h2d
,
is_pin_memory_available
,
make_tensor_with_pad
,
maybe_expand_dim
)
_SAMPLING_EPS
=
1e-5
_SEED_0_REPLACEMENT
=
3403598558
...
...
@@ -62,6 +63,39 @@ class SequenceGroupToSample:
assert
self
.
query_len
is
not
None
def
gen_seq_group_to_sample_builder
(
num_seqs
:
int
):
return
lambda
:
SequenceGroupToSample
(
seq_ids
=
[
0
]
*
num_seqs
,
sampling_params
=
None
,
seq_data
=
None
,
# type: ignore
seq_len
=
0
,
query_len
=
0
,
generator
=
None
,
is_prompt
=
True
,
prompt_logprob_indices
=
[],
sample_indices
=
[])
class
SamplingMetadataCache
:
"""Used to cache SamplingMetadata objects between scheduler iterations
"""
def
__init__
(
self
):
self
.
_seq_group_to_sample_cache
:
Dict
[
int
,
PyObjectCache
]
=
{}
def
get_cached_seq_group_to_sample
(
self
,
num_seqs
):
if
num_seqs
not
in
self
.
_seq_group_to_sample_cache
:
self
.
_seq_group_to_sample_cache
[
num_seqs
]
=
PyObjectCache
(
gen_seq_group_to_sample_builder
(
num_seqs
))
obj
=
self
.
_seq_group_to_sample_cache
[
num_seqs
].
get_object
()
return
obj
def
reset
(
self
):
for
cache
in
self
.
_seq_group_to_sample_cache
.
values
():
cache
.
reset
()
class
SamplingMetadata
:
"""Metadata for input sequences. Used in sampler.
...
...
@@ -121,6 +155,7 @@ class SamplingMetadata:
device
:
str
,
pin_memory
:
bool
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
cache
:
Optional
[
SamplingMetadataCache
]
=
None
,
)
->
"SamplingMetadata"
:
(
seq_groups
,
...
...
@@ -128,7 +163,7 @@ class SamplingMetadata:
categorized_sample_indices
,
num_prompts
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
device
,
generators
)
device
,
generators
,
cache
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
target_device
=
device
,
...
...
@@ -164,6 +199,7 @@ def _prepare_seq_groups(
query_lens
:
Optional
[
List
[
int
]],
device
:
str
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
cache
:
Optional
[
SamplingMetadataCache
]
=
None
,
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]],
int
]:
"""Prepare sequence groups and indices for sampling.
...
...
@@ -210,15 +246,27 @@ def _prepare_seq_groups(
num_prompts
=
0
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
seq_ids
=
seq_group_metadata
.
seq_data
.
keys
()
if
cache
is
not
None
:
sample_obj
=
cache
.
get_cached_seq_group_to_sample
(
len
(
seq_ids
))
for
j
,
seq_id
in
enumerate
(
seq_ids
):
sample_obj
.
seq_ids
[
j
]
=
seq_id
sample_obj
.
prompt_logprob_indices
.
clear
()
sample_obj
.
sample_indices
.
clear
()
sampling_params
=
seq_group_metadata
.
sampling_params
is_prompt
=
seq_group_metadata
.
is_prompt
generator
:
Optional
[
torch
.
Generator
]
=
None
# If the current seq group is in decode stage, it is None.
seq_len
:
Optional
[
int
]
=
None
query_len
:
Optional
[
int
]
=
None
prompt_logprob_indices
:
List
[
int
]
=
[]
sample_indices
:
List
[
int
]
=
[]
prompt_logprob_indices
:
List
[
int
]
=
\
sample_obj
.
prompt_logprob_indices
if
cache
is
not
None
else
[]
sample_indices
:
List
[
int
]
=
\
sample_obj
.
sample_indices
if
cache
is
not
None
else
[]
do_sample
=
seq_group_metadata
.
do_sample
if
seq_group_metadata
.
is_prompt
:
...
...
@@ -290,9 +338,16 @@ def _prepare_seq_groups(
logit_idx
+=
sample_len
sample_idx
+=
sample_len
seq_groups
.
append
(
SequenceGroupToSample
(
seq_ids
=
seq_ids
,
if
cache
is
not
None
:
sample_obj
.
sampling_params
=
sampling_params
sample_obj
.
seq_data
=
seq_group_metadata
.
seq_data
sample_obj
.
seq_len
=
seq_len
sample_obj
.
query_len
=
query_len
sample_obj
.
generator
=
generator
sample_obj
.
is_prompt
=
is_prompt
else
:
sample_obj
=
SequenceGroupToSample
(
seq_ids
=
list
(
seq_ids
),
sampling_params
=
sampling_params
,
seq_data
=
seq_group_metadata
.
seq_data
,
seq_len
=
seq_len
,
...
...
@@ -300,7 +355,13 @@ def _prepare_seq_groups(
generator
=
generator
,
is_prompt
=
is_prompt
,
prompt_logprob_indices
=
list
(
prompt_logprob_indices
),
sample_indices
=
list
(
sample_indices
)))
sample_indices
=
list
(
sample_indices
))
seq_groups
.
append
(
sample_obj
)
if
cache
is
not
None
:
cache
.
reset
()
return
(
seq_groups
,
selected_token_indices
,
categorized_sample_indices
,
num_prompts
)
...
...
vllm/outputs.py
View file @
e02ac556
...
...
@@ -139,7 +139,7 @@ class RequestOutput:
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
get_output_text_to_return
(
text_buffer_length
),
seq
.
get
_output_token_ids
(),
seq
.
data
.
_output_token_ids
,
# type: ignore
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
seq
.
output_logprobs
if
include_logprobs
else
None
,
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
...
...
vllm/sequence.py
View file @
e02ac556
"""Sequence and its related classes."""
import
copy
import
enum
import
math
from
abc
import
ABC
,
abstractmethod
from
array
import
array
from
collections
import
defaultdict
...
...
@@ -330,7 +329,7 @@ class Sequence:
@
property
def
n_blocks
(
self
)
->
int
:
return
math
.
ceil
(
self
.
get_len
()
/
self
.
block_size
)
return
(
self
.
get_len
()
+
self
.
block_size
-
1
)
//
self
.
block_size
@
property
def
prompt
(
self
)
->
Optional
[
str
]:
...
...
@@ -514,7 +513,9 @@ class SequenceGroup:
)
->
None
:
self
.
request_id
=
request_id
self
.
seqs
=
seqs
self
.
is_single_seq
=
len
(
seqs
)
==
1
self
.
seqs_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
self
.
sampling_params
=
sampling_params
self
.
metrics
=
RequestMetrics
(
arrival_time
=
arrival_time
,
last_token_time
=
arrival_time
,
...
...
@@ -635,6 +636,10 @@ class SequenceGroup:
)
->
List
[
Sequence
]:
if
status
is
None
:
return
self
.
seqs
if
self
.
is_single_seq
:
return
self
.
seqs
if
self
.
seqs
[
0
].
status
==
status
else
[]
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
status
==
status
]
def
is_encoder_decoder
(
self
)
->
bool
:
...
...
@@ -644,6 +649,9 @@ class SequenceGroup:
return
self
.
encoder_seq
def
get_unfinished_seqs
(
self
)
->
List
[
Sequence
]:
if
self
.
is_single_seq
:
return
self
.
seqs
if
not
self
.
seqs
[
0
].
is_finished
()
else
[]
return
[
seq
for
seq
in
self
.
seqs
if
not
seq
.
is_finished
()]
def
get_finished_seqs
(
self
)
->
List
[
Sequence
]:
...
...
@@ -668,12 +676,21 @@ class SequenceGroup:
if
status
is
None
:
return
len
(
self
.
seqs
)
if
self
.
is_single_seq
:
return
1
if
self
.
seqs
[
0
].
status
==
status
else
0
return
len
(
self
.
get_seqs
(
status
))
def
num_unfinished_seqs
(
self
)
->
int
:
if
self
.
is_single_seq
:
return
1
if
not
self
.
seqs
[
0
].
is_finished
()
else
0
return
len
(
self
.
get_unfinished_seqs
())
def
num_finished_seqs
(
self
)
->
int
:
if
self
.
is_single_seq
:
return
1
if
self
.
seqs
[
0
].
is_finished
()
else
0
return
len
(
self
.
get_finished_seqs
())
def
find
(
self
,
seq_id
:
int
)
->
Sequence
:
...
...
@@ -686,12 +703,14 @@ class SequenceGroup:
raise
ValueError
(
f
"Sequence
{
seq
.
seq_id
}
already exists."
)
self
.
seqs_dict
[
seq
.
seq_id
]
=
seq
self
.
seqs
.
append
(
seq
)
self
.
is_single_seq
=
len
(
self
.
seqs
)
==
1
def
remove
(
self
,
seq_id
:
int
)
->
None
:
seq
=
self
.
seqs_dict
.
pop
(
seq_id
,
None
)
if
seq
is
None
:
raise
ValueError
(
f
"Sequence
{
seq_id
}
not found."
)
self
.
seqs
.
remove
(
seq
)
self
.
is_single_seq
=
len
(
self
.
seqs
)
==
1
def
is_finished
(
self
)
->
bool
:
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
seqs
)
...
...
@@ -775,9 +794,10 @@ class SequenceGroupMetadata:
# TODO: We should maintain this states out of the sequence group.
self
.
num_speculative_tokens
=
None
if
self
.
_token_chunk_size
is
None
:
if
seq_data
is
not
None
and
self
.
_token_chunk_size
is
None
:
if
is_prompt
:
self
.
_token_chunk_size
=
list
(
seq_data
.
values
())[
0
].
get_len
()
self
.
_token_chunk_size
=
next
(
iter
(
seq_data
.
values
())).
get_len
()
else
:
self
.
_token_chunk_size
=
1
...
...
vllm/utils.py
View file @
e02ac556
...
...
@@ -261,6 +261,44 @@ class LRUCache(Generic[T]):
self
.
cache
.
clear
()
class
PyObjectCache
:
"""Used to cache python objects to avoid object allocations
across scheduler iterations.
"""
def
__init__
(
self
,
obj_builder
):
self
.
_obj_builder
=
obj_builder
self
.
_index
=
0
self
.
_obj_cache
=
[]
for
_
in
range
(
128
):
self
.
_obj_cache
.
append
(
self
.
_obj_builder
())
def
_grow_cache
(
self
):
# Double the size of the cache
num_objs
=
len
(
self
.
_obj_cache
)
for
_
in
range
(
num_objs
):
self
.
_obj_cache
.
append
(
self
.
_obj_builder
())
def
get_object
(
self
):
"""Returns a pre-allocated cached object. If there is not enough
objects, then the cache size will double.
"""
if
self
.
_index
>=
len
(
self
.
_obj_cache
):
self
.
_grow_cache
()
assert
self
.
_index
<
len
(
self
.
_obj_cache
)
obj
=
self
.
_obj_cache
[
self
.
_index
]
self
.
_index
+=
1
return
obj
def
reset
(
self
):
"""Makes all cached-objects available for the next scheduler iteration.
"""
self
.
_index
=
0
def
is_hip
()
->
bool
:
return
torch
.
version
.
hip
is
not
None
...
...
vllm/worker/model_runner.py
View file @
e02ac556
import
dataclasses
import
gc
import
itertools
import
time
import
warnings
import
weakref
...
...
@@ -35,7 +36,7 @@ from vllm.logger import init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
,
SamplingMetadataCache
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.models.interfaces
import
(
supports_lora
,
...
...
@@ -50,8 +51,8 @@ from vllm.prompt_adapter.worker_manager import (
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
async_tensor_h2d
,
flatten_2d_lists
,
get_kv_cache_torch_dtype
,
is_hip
,
from
vllm.utils
import
(
CudaMemoryProfiler
,
PyObjectCache
,
async_tensor_h2d
,
flatten_2d_lists
,
get_kv_cache_torch_dtype
,
is_hip
,
is_pin_memory_available
)
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
...
...
@@ -178,6 +179,20 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
class
InterDataForSeqGroup
:
"""Intermediate data for the current sequence group."""
def
simple_reinit
(
self
):
self
.
input_tokens
[
0
].
clear
()
# type: ignore
self
.
input_positions
[
0
].
clear
()
# type: ignore
self
.
seq_lens
[
0
]
=
0
# type: ignore
self
.
orig_seq_lens
[
0
]
=
0
# type: ignore
self
.
query_lens
[
0
]
=
0
# type: ignore
self
.
context_lens
[
0
]
=
0
# type: ignore
self
.
curr_sliding_window_blocks
[
0
]
=
0
# type: ignore
self
.
lora_index_mapping
.
clear
()
# type: ignore
self
.
lora_prompt_mapping
.
clear
()
# type: ignore
self
.
lora_requests
.
clear
()
# type: ignore
self
.
prompt_adapter_index_mapping
.
clear
()
# type: ignore
self
.
prompt_adapter_prompt_mapping
.
clear
()
# type: ignore
def
__init__
(
self
,
*
,
...
...
@@ -220,34 +235,120 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit
:
bool
=
False
,
reinit
:
bool
=
False
,
reinit_use_defaults
:
bool
=
False
,
):
self
.
request_id
=
request_id
if
reinit
:
assert
len
(
self
.
seq_ids
)
==
len
(
seq_ids
)
# type: ignore
for
i
,
seq_id
in
enumerate
(
seq_ids
):
self
.
seq_ids
[
i
]
=
seq_id
# type: ignore
else
:
self
.
seq_ids
=
seq_ids
self
.
request_id
=
request_id
self
.
is_prompt
=
is_prompt
self
.
block_tables
=
block_tables
self
.
computed_block_nums
=
computed_block_nums
self
.
n_seqs
=
n_seqs
if
reinit
:
if
len
(
self
.
seq_ids
)
==
1
and
reinit_use_defaults
:
self
.
simple_reinit
()
else
:
if
input_tokens
:
self
.
input_tokens
=
input_tokens
else
:
for
seq_id
in
range
(
len
(
self
.
seq_ids
)):
self
.
input_tokens
[
seq_id
].
clear
()
if
input_positions
:
self
.
input_positions
=
input_positions
else
:
for
seq_id
in
range
(
len
(
self
.
seq_ids
)):
self
.
input_positions
[
seq_id
].
clear
()
if
seq_lens
:
self
.
seq_lens
=
seq_lens
else
:
for
seq_id
in
range
(
len
(
self
.
seq_ids
)):
self
.
seq_lens
[
seq_id
]
=
0
if
orig_seq_lens
:
self
.
orig_seq_lens
=
orig_seq_lens
else
:
for
seq_id
in
range
(
len
(
self
.
seq_ids
)):
self
.
orig_seq_lens
[
seq_id
]
=
0
if
query_lens
:
self
.
query_lens
=
query_lens
else
:
for
seq_id
in
range
(
len
(
self
.
seq_ids
)):
self
.
query_lens
[
seq_id
]
=
0
if
context_lens
:
self
.
context_lens
=
context_lens
else
:
for
seq_id
in
range
(
len
(
self
.
seq_ids
)):
self
.
context_lens
[
seq_id
]
=
0
if
curr_sliding_window_blocks
:
self
.
curr_sliding_window_blocks
=
\
curr_sliding_window_blocks
else
:
for
seq_id
in
range
(
len
(
self
.
seq_ids
)):
self
.
curr_sliding_window_blocks
[
seq_id
]
=
0
if
lora_index_mapping
:
self
.
lora_index_mapping
=
lora_index_mapping
else
:
self
.
lora_index_mapping
.
clear
()
if
lora_prompt_mapping
:
self
.
lora_prompt_mapping
=
lora_prompt_mapping
else
:
self
.
lora_prompt_mapping
.
clear
()
if
lora_requests
:
self
.
lora_requests
=
lora_requests
else
:
self
.
lora_requests
.
clear
()
if
prompt_adapter_index_mapping
:
self
.
prompt_adapter_index_mapping
=
\
prompt_adapter_index_mapping
else
:
self
.
prompt_adapter_index_mapping
.
clear
()
if
prompt_adapter_prompt_mapping
:
self
.
prompt_adapter_prompt_mapping
=
\
prompt_adapter_prompt_mapping
else
:
self
.
prompt_adapter_prompt_mapping
.
clear
()
else
:
self
.
input_tokens
=
input_tokens
or
[]
self
.
input_positions
=
input_positions
or
[]
self
.
seq_lens
=
seq_lens
or
[]
self
.
orig_seq_lens
=
orig_seq_lens
or
[]
self
.
query_lens
=
query_lens
or
[]
self
.
context_lens
=
context_lens
or
[]
self
.
curr_sliding_window_blocks
=
curr_sliding_window_blocks
or
[]
self
.
curr_sliding_window_blocks
=
\
curr_sliding_window_blocks
or
[]
self
.
lora_index_mapping
=
lora_index_mapping
or
[]
self
.
lora_prompt_mapping
=
lora_prompt_mapping
or
[]
self
.
lora_requests
=
lora_requests
or
set
()
self
.
prompt_adapter_index_mapping
=
(
prompt_adapter_index_mapping
or
[])
self
.
prompt_adapter_prompt_mapping
=
(
prompt_adapter_prompt_mapping
or
[])
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
prompt_adapter_index_mapping
=
(
prompt_adapter_index_mapping
or
[])
self
.
prompt_adapter_prompt_mapping
=
(
prompt_adapter_prompt_mapping
or
[])
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
multi_modal_inputs
=
multi_modal_inputs
self
.
prefix_cache_hit
=
prefix_cache_hit
if
not
reinit
:
self
.
__post_init__
()
def
__post_init__
(
self
):
...
...
@@ -261,8 +362,36 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
context_lens
=
[
0
]
*
self
.
n_seqs
self
.
curr_sliding_window_blocks
=
[
0
]
*
self
.
n_seqs
self
.
lora_index_mapping
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
lora_prompt_mapping
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
lora_index_mapping
=
[]
self
.
lora_prompt_mapping
=
[]
def
gen_inter_data_builder
(
self
,
num_seqs
:
int
):
return
lambda
:
ModelInputForGPUBuilder
.
InterDataForSeqGroup
(
request_id
=
""
,
seq_ids
=
[
0
]
*
num_seqs
,
is_prompt
=
True
,
block_tables
=
None
,
computed_block_nums
=
[])
def
init_cached_inter_data
(
self
,
*
args
,
**
kwargs
):
assert
len
(
args
)
==
0
assert
"seq_ids"
in
kwargs
seq_ids
=
kwargs
[
"seq_ids"
]
num_seqs
=
len
(
seq_ids
)
# The inter-data cache is per model_runner
inter_data_cache
=
self
.
runner
.
inter_data_cache
if
num_seqs
not
in
inter_data_cache
:
inter_data_cache
[
num_seqs
]
=
PyObjectCache
(
self
.
gen_inter_data_builder
(
num_seqs
))
obj
=
inter_data_cache
[
num_seqs
].
get_object
()
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
def
reset_cached_inter_data
(
self
):
for
cache
in
self
.
runner
.
inter_data_cache
.
values
():
cache
.
reset
()
def
__init__
(
self
,
runner
:
"GPUModelRunnerBase"
,
...
...
@@ -337,17 +466,29 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Compute tokens.
if
inter_data
.
is_prompt
:
tokens
=
seq_data
.
get_token_ids
()[
context_len
:
seq_len
]
tokens
=
seq_data
.
get_token_ids
()
if
context_len
!=
0
or
seq_len
<
len
(
tokens
):
tokens
=
tokens
[
context_len
:
seq_len
]
else
:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens
=
[
seq_data
.
get_last_token_id
()
]
tokens
=
seq_data
.
get_last_token_id
()
inter_data
.
seq_lens
[
seq_idx
]
=
seq_len
inter_data
.
orig_seq_lens
[
seq_idx
]
=
seq_len
inter_data
.
context_lens
[
seq_idx
]
=
context_len
inter_data
.
input_tokens
[
seq_idx
]
=
tokens
inter_data
.
input_positions
[
seq_idx
]
=
list
(
range
(
context_len
,
seq_len
))
if
isinstance
(
tokens
,
list
):
inter_data
.
input_tokens
[
seq_idx
].
extend
(
tokens
)
else
:
inter_data
.
input_tokens
[
seq_idx
].
append
(
tokens
)
if
(
seq_len
-
context_len
)
==
1
:
inter_data
.
input_positions
[
seq_idx
].
append
(
seq_len
-
1
)
else
:
inter_data
.
input_positions
[
seq_idx
].
extend
(
range
(
context_len
,
seq_len
))
inter_data
.
query_lens
[
seq_idx
]
=
seq_len
-
context_len
if
inter_data
.
is_prompt
else
1
...
...
@@ -471,7 +612,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""Add a sequence group to the builder."""
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
()
)
seq_ids
=
seq_group_metadata
.
seq_data
.
keys
()
n_seqs
=
len
(
seq_ids
)
is_prompt
=
seq_group_metadata
.
is_prompt
...
...
@@ -479,12 +620,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
assert
n_seqs
==
1
self
.
decode_only
=
False
inter_data
=
self
.
I
nter
D
ata
ForSeqGroup
(
inter_data
=
self
.
init_cached_i
nter
_d
ata
(
request_id
=
seq_group_metadata
.
request_id
,
seq_ids
=
seq_ids
,
is_prompt
=
is_prompt
,
block_tables
=
seq_group_metadata
.
block_tables
,
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
)
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
,
reinit
=
True
,
reinit_use_defaults
=
True
)
self
.
inter_data_list
.
append
(
inter_data
)
for
seq_idx
in
range
(
n_seqs
):
...
...
@@ -504,18 +648,21 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
create on-device tensors.
"""
# Combine and flatten intermediate data.
input_tokens
=
flatten_2d_lists
([
flatten_2d_lists
(
inter_data
.
input_tokens
)
for
inter_data
in
self
.
inter_data_list
])
input_tokens
=
[]
for
inter_data
in
self
.
inter_data_list
:
for
cur_input_tokens
in
inter_data
.
input_tokens
:
input_tokens
.
extend
(
cur_input_tokens
)
if
not
input_tokens
:
# This may happen when all prefill requests hit
# prefix caching and there is no decode request.
return
self
.
model_input_cls
()
input_positions
=
flatten_2d_lists
([
flatten_2d_lists
(
inter_data
.
input_positions
)
for
inter_data
in
self
.
inter_data_list
])
input_positions
=
[]
for
inter_data
in
self
.
inter_data_list
:
for
cur_input_positions
in
inter_data
.
input_positions
:
input_positions
.
extend
(
cur_input_positions
)
seq_lens
=
[]
max_decode_seq_len
=
0
for
inter_data
in
self
.
inter_data_list
:
...
...
@@ -523,8 +670,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if
not
inter_data
.
is_prompt
:
max_decode_seq_len
=
max
(
max_decode_seq_len
,
max
(
inter_data
.
seq_lens
))
query_lens
=
flatten_2d_lists
(
[
inter_data
.
query_lens
for
inter_data
in
self
.
inter_data_list
])
query_lens
=
[]
for
inter_data
in
self
.
inter_data_list
:
query_lens
.
extend
(
inter_data
.
query_lens
)
# Mapping from request IDs to sequence IDs. Used for Jamba models
# that manages the cache by itself.
request_ids_to_seq_ids
=
{
...
...
@@ -547,8 +696,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
batch_size
=
graph_batch_size
# Tokens and positions.
input_tokens
.
extend
([
0
]
*
cuda_graph_pad_size
)
input_positions
.
extend
([
0
]
*
cuda_graph_pad_size
)
if
cuda_graph_pad_size
:
input_tokens
.
extend
(
itertools
.
repeat
(
0
,
cuda_graph_pad_size
))
input_positions
.
extend
(
itertools
.
repeat
(
0
,
cuda_graph_pad_size
))
assert
self
.
runner
.
device
is
not
None
input_tokens_tensor
=
async_tensor_h2d
(
input_tokens
,
torch
.
long
,
self
.
runner
.
device
,
...
...
@@ -558,7 +708,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
runner
.
pin_memory
)
# Sequence and query lengths.
seq_lens
.
extend
([
1
]
*
cuda_graph_pad_size
)
if
cuda_graph_pad_size
:
seq_lens
.
extend
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
# Attention metadata.
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
...
...
@@ -574,11 +725,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
flatten_2d_lists
(
inter_data
.
lora_index_mapping
)
for
inter_data
in
self
.
inter_data_list
])
lora_index_mapping
.
extend
([
0
]
*
cuda_graph_pad_size
)
if
cuda_graph_pad_size
:
lora_index_mapping
.
extend
(
itertools
.
repeat
(
0
,
cuda_graph_pad_size
))
lora_prompt_mapping
=
flatten_2d_lists
([
flatten_2d_lists
(
inter_data
.
lora_prompt_mapping
)
for
inter_data
in
self
.
inter_data_list
])
lora_mapping
=
LoRAMapping
(
**
dict
(
index_mapping
=
lora_index_mapping
,
prompt_mapping
=
lora_prompt_mapping
,
...
...
@@ -595,7 +749,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data
.
prompt_adapter_index_mapping
for
inter_data
in
self
.
inter_data_list
])
prompt_adapter_index_mapping
.
extend
([
0
]
*
cuda_graph_pad_size
)
if
cuda_graph_pad_size
:
prompt_adapter_index_mapping
.
extend
(
itertools
.
repeat
(
0
,
cuda_graph_pad_size
))
prompt_adapter_prompt_mapping
=
flatten_2d_lists
([
inter_data
.
prompt_adapter_prompt_mapping
for
inter_data
in
self
.
inter_data_list
...
...
@@ -717,6 +873,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
set_cpu_offload_max_bytes
(
int
(
self
.
cache_config
.
cpu_offload_gb
*
1024
**
3
))
# Used to cache python objects
self
.
inter_data_cache
:
Dict
[
int
,
PyObjectCache
]
=
{}
self
.
sampling_metadata_cache
:
SamplingMetadataCache
=
\
SamplingMetadataCache
()
def
load_model
(
self
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
with
CudaMemoryProfiler
()
as
m
:
...
...
@@ -843,6 +1004,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
builder
=
self
.
_builder_cls
(
weakref
.
proxy
(
self
),
finished_requests_ids
)
for
seq_group_metadata
in
seq_group_metadata_list
:
builder
.
add_seq_group
(
seq_group_metadata
)
builder
.
reset_cached_inter_data
()
return
builder
.
build
()
# type: ignore
@
torch
.
inference_mode
()
...
...
@@ -1276,7 +1440,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
model_input
.
seq_lens
,
model_input
.
query_lens
,
self
.
device
,
self
.
pin_memory
,
generators
)
generators
,
self
.
sampling_metadata_cache
)
else
:
sampling_metadata
=
None
is_prompt
=
(
seq_group_metadata_list
[
0
].
is_prompt
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment