Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5448f676
Unverified
Commit
5448f676
authored
Jul 24, 2024
by
Antoni Baum
Committed by
GitHub
Jul 24, 2024
Browse files
[Core] Tweaks to model runner/input builder developer APIs (#6712)
parent
0e63494c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
109 additions
and
64 deletions
+109
-64
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+19
-16
vllm/worker/embedding_model_runner.py
vllm/worker/embedding_model_runner.py
+3
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+87
-47
No files found.
vllm/attention/backends/flashinfer.py
View file @
5448f676
...
...
@@ -297,23 +297,26 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if
is_profile_run
:
return
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound
=
seq_len
//
self
.
block_size
+
1
\
if
seq_len
%
self
.
block_size
!=
0
\
else
seq_len
//
self
.
block_size
block_table
=
block_tables
[
seq_id
]
self
.
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
self
.
paged_kv_indptr
.
append
(
self
.
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
last_page_len
=
seq_len
%
self
.
block_size
if
last_page_len
==
0
:
last_page_len
=
self
.
block_size
self
.
paged_kv_last_page_len
.
append
(
last_page_len
)
self
.
_update_paged_kv_tensors
(
block_table
,
seq_len
)
def
_update_paged_kv_tensors
(
self
,
block_table
:
List
[
int
],
seq_len
:
int
):
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound
=
seq_len
//
self
.
block_size
+
1
\
if
seq_len
%
self
.
block_size
!=
0
\
else
seq_len
//
self
.
block_size
self
.
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
self
.
paged_kv_indptr
.
append
(
self
.
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
last_page_len
=
seq_len
%
self
.
block_size
if
last_page_len
==
0
:
last_page_len
=
self
.
block_size
self
.
paged_kv_last_page_len
.
append
(
last_page_len
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
...
...
vllm/worker/embedding_model_runner.py
View file @
5448f676
...
...
@@ -11,7 +11,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from
vllm.pooling_params
import
PoolingParams
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.worker.model_runner
import
GPUModelRunnerBase
,
ModelInputForGPU
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPU
,
ModelInputForGPUBuilder
)
logger
=
init_logger
(
__name__
)
...
...
@@ -28,6 +29,7 @@ class EmbeddingModelRunner(
GPUModelRunnerBase
[
ModelInputForGPUWithPoolingMetadata
]):
_model_input_cls
:
Type
[
ModelInputForGPUWithPoolingMetadata
]
=
(
ModelInputForGPUWithPoolingMetadata
)
_builder_cls
:
Type
[
ModelInputForGPUBuilder
]
=
ModelInputForGPUBuilder
def
__init__
(
self
,
...
...
vllm/worker/model_runner.py
View file @
5448f676
...
...
@@ -3,7 +3,7 @@ import gc
import
time
import
warnings
import
weakref
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
,
Union
)
...
...
@@ -171,48 +171,83 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
class
ModelInputForGPUBuilder
(
ModelRunnerInputBuilderBase
[
ModelInputForGPU
]):
"""Build ModelInputForGPU from SequenceGroupMetadata."""
@
dataclass
# Note: ideally we would be using a dataclass(kw_only=True)
# here, so that this can be subclassed easily,
# but kw_only is not supported in python<3.10.
class
InterDataForSeqGroup
:
"""Intermediate data for the current sequence group."""
# From sequence group metadata.
request_id
:
str
seq_ids
:
List
[
int
]
is_prompt
:
bool
block_tables
:
Optional
[
Dict
[
int
,
List
[
int
]]]
computed_block_nums
:
List
[
int
]
n_seqs
:
int
=
0
# Input tokens and positions.
input_tokens
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
input_positions
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
# The sequence length (may be capped to the sliding window).
seq_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The query length.
query_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The number of tokens that are already computed.
context_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The current sliding window block.
curr_sliding_window_blocks
:
List
[
int
]
=
field
(
default_factory
=
list
)
# LoRA inputs.
lora_index_mapping
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
lora_prompt_mapping
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
lora_requests
:
Set
[
LoRARequest
]
=
field
(
default_factory
=
set
)
# Prompt adapter inputs.
prompt_adapter_index_mapping
:
List
[
int
]
=
field
(
default_factory
=
list
)
prompt_adapter_prompt_mapping
:
List
[
int
]
=
field
(
default_factory
=
list
)
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
# Multi-modal inputs.
multi_modal_inputs
:
Optional
[
MultiModalInputs
]
=
None
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit
:
bool
=
False
def
__init__
(
self
,
*
,
# From sequence group metadata.
request_id
:
str
,
seq_ids
:
List
[
int
],
is_prompt
:
bool
,
block_tables
:
Optional
[
Dict
[
int
,
List
[
int
]]],
computed_block_nums
:
List
[
int
],
n_seqs
:
int
=
0
,
# Input tokens and positions.
input_tokens
:
Optional
[
List
[
List
[
int
]]]
=
None
,
input_positions
:
Optional
[
List
[
List
[
int
]]]
=
None
,
# The sequence length (may be capped to the sliding window).
seq_lens
:
Optional
[
List
[
int
]]
=
None
,
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens
:
Optional
[
List
[
int
]]
=
None
,
# The query length.
query_lens
:
Optional
[
List
[
int
]]
=
None
,
# The number of tokens that are already computed.
context_lens
:
Optional
[
List
[
int
]]
=
None
,
# The current sliding window block.
curr_sliding_window_blocks
:
Optional
[
List
[
int
]]
=
None
,
# LoRA inputs.
lora_index_mapping
:
Optional
[
List
[
List
[
int
]]]
=
None
,
lora_prompt_mapping
:
Optional
[
List
[
List
[
int
]]]
=
None
,
lora_requests
:
Optional
[
Set
[
LoRARequest
]]
=
None
,
# Prompt adapter inputs.
prompt_adapter_index_mapping
:
Optional
[
List
[
int
]]
=
None
,
prompt_adapter_prompt_mapping
:
Optional
[
List
[
int
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
# Multi-modal inputs.
multi_modal_inputs
:
Optional
[
MultiModalInputs
]
=
None
,
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit
:
bool
=
False
,
):
self
.
request_id
=
request_id
self
.
seq_ids
=
seq_ids
self
.
is_prompt
=
is_prompt
self
.
block_tables
=
block_tables
self
.
computed_block_nums
=
computed_block_nums
self
.
n_seqs
=
n_seqs
self
.
input_tokens
=
input_tokens
or
[]
self
.
input_positions
=
input_positions
or
[]
self
.
seq_lens
=
seq_lens
or
[]
self
.
orig_seq_lens
=
orig_seq_lens
or
[]
self
.
query_lens
=
query_lens
or
[]
self
.
context_lens
=
context_lens
or
[]
self
.
curr_sliding_window_blocks
=
curr_sliding_window_blocks
or
[]
self
.
lora_index_mapping
=
lora_index_mapping
or
[]
self
.
lora_prompt_mapping
=
lora_prompt_mapping
or
[]
self
.
lora_requests
=
lora_requests
or
set
()
self
.
prompt_adapter_index_mapping
=
(
prompt_adapter_index_mapping
or
[])
self
.
prompt_adapter_prompt_mapping
=
(
prompt_adapter_prompt_mapping
or
[])
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
multi_modal_inputs
=
multi_modal_inputs
self
.
prefix_cache_hit
=
prefix_cache_hit
self
.
__post_init__
()
def
__post_init__
(
self
):
self
.
n_seqs
=
len
(
self
.
seq_ids
)
...
...
@@ -457,6 +492,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for
per_seq_group_fn
in
self
.
per_seq_group_compute_fns
:
per_seq_group_fn
(
inter_data
,
seq_group_metadata
)
def
_use_captured_graph
(
self
,
batch_size
:
int
,
max_decode_seq_len
:
int
)
->
bool
:
return
(
self
.
decode_only
and
not
self
.
runner
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_decode_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
)
def
build
(
self
)
->
ModelInputForGPU
:
"""Finalize the builder intermediate data and
create on-device tensors.
...
...
@@ -491,10 +532,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
}
batch_size
=
len
(
input_tokens
)
use_captured_graph
=
(
self
.
decode_only
and
not
self
.
runner
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_decode_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
)
use_captured_graph
=
self
.
_use_captured_graph
(
batch_size
,
max_decode_seq_len
)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
...
...
@@ -592,6 +631,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
Helper class for shared methods between GPU model runners.
"""
_model_input_cls
:
Type
[
TModelInputForGPU
]
_builder_cls
:
Type
[
ModelInputForGPUBuilder
]
def
__init__
(
self
,
...
...
@@ -794,8 +834,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
If cuda graph is required, this API automatically pads inputs.
"""
builder
=
ModelInputForGPUBuilder
(
weakref
.
proxy
(
self
),
finished_requests_ids
)
builder
=
self
.
_builder_cls
(
weakref
.
proxy
(
self
),
finished_requests_ids
)
for
seq_group_metadata
in
seq_group_metadata_list
:
builder
.
add_seq_group
(
seq_group_metadata
)
return
builder
.
build
()
# type: ignore
...
...
@@ -1191,6 +1230,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"""
_model_input_cls
:
Type
[
ModelInputForGPUWithSamplingMetadata
]
=
(
ModelInputForGPUWithSamplingMetadata
)
_builder_cls
:
Type
[
ModelInputForGPUBuilder
]
=
ModelInputForGPUBuilder
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment