Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
66818e5b
Unverified
Commit
66818e5b
authored
Jan 22, 2025
by
youkaichao
Committed by
GitHub
Jan 22, 2025
Browse files
[core] separate builder init and builder prepare for each batch (#12253)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
222a9dc3
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
90 additions
and
47 deletions
+90
-47
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+6
-5
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+6
-5
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+8
-6
vllm/attention/backends/placeholder_attn.py
vllm/attention/backends/placeholder_attn.py
+5
-3
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+4
-1
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+7
-6
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+17
-7
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+24
-12
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+5
-0
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+8
-2
No files found.
vllm/attention/backends/abstract.py
View file @
66818e5b
...
...
@@ -65,11 +65,6 @@ class AttentionBackend(ABC):
def
get_builder_cls
()
->
Type
[
"AttentionMetadataBuilder"
]:
raise
NotImplementedError
@
classmethod
def
make_metadata_builder
(
cls
,
*
args
,
**
kwargs
)
->
"AttentionMetadataBuilder"
:
return
cls
.
get_builder_cls
()(
*
args
,
**
kwargs
)
@
staticmethod
@
abstractmethod
def
get_kv_cache_shape
(
...
...
@@ -214,6 +209,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
@
abstractmethod
def
__init__
(
self
,
input_builder
:
"ModelRunnerInputBuilderBase"
)
->
None
:
"""Create the builder, remember some configuration and parameters."""
raise
NotImplementedError
@
abstractmethod
def
prepare
(
self
)
->
None
:
"""Prepare for one batch."""
raise
NotImplementedError
@
abstractmethod
...
...
vllm/attention/backends/flash_attn.py
View file @
66818e5b
...
...
@@ -375,6 +375,12 @@ class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
def
prepare
(
self
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
...
...
@@ -388,11 +394,6 @@ class FlashAttentionMetadataBuilder(
self
.
num_decode_tokens
=
0
self
.
has_prefix_cache_hit
=
False
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
def
_add_seq_group
(
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
chunked_prefill_enabled
:
bool
,
prefix_cache_hit
:
bool
):
...
...
vllm/attention/backends/flashinfer.py
View file @
66818e5b
...
...
@@ -488,6 +488,14 @@ class FlashInferMetadata(AttentionMetadata):
class
FlashInferMetadataBuilder
(
AttentionMetadataBuilder
[
FlashInferMetadata
]):
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
def
prepare
(
self
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
...
...
@@ -500,12 +508,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
...
...
vllm/attention/backends/placeholder_attn.py
View file @
66818e5b
...
...
@@ -253,6 +253,11 @@ class PlaceholderAttentionMetadataBuilder(
AttentionMetadataBuilder
[
PlaceholderAttentionMetadata
]):
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
def
prepare
(
self
):
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
...
...
@@ -263,9 +268,6 @@ class PlaceholderAttentionMetadataBuilder(
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
def
_add_seq_group
(
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
chunked_prefill_enabled
:
bool
):
...
...
vllm/attention/backends/torch_sdpa.py
View file @
66818e5b
...
...
@@ -282,7 +282,10 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
def
__init__
(
self
,
input_builder
:
ModelInputForCPUBuilder
)
->
None
:
self
.
chunked_prefill
=
input_builder
.
chunked_prefill
self
.
input_data
=
input_builder
.
input_data
self
.
input_builder
=
input_builder
def
prepare
(
self
):
self
.
input_data
=
self
.
input_builder
.
input_data
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
)
->
TorchSDPAMetadata
:
...
...
vllm/attention/backends/utils.py
View file @
66818e5b
...
...
@@ -122,6 +122,13 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
_metadata_cls
:
Type
[
TAttentionMetadata
]
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
def
prepare
(
self
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
...
...
@@ -134,12 +141,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
def
_add_seq_group
(
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
chunked_prefill_enabled
:
bool
):
...
...
vllm/worker/cpu_model_runner.py
View file @
66818e5b
...
...
@@ -144,9 +144,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
runner
:
"CPUModelRunner"
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
super
().
__init__
()
self
.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
self
.
runner
=
runner
self
.
chunked_prefill
=
(
runner
.
scheduler_config
.
chunked_prefill_enabled
or
runner
.
cache_config
.
enable_prefix_caching
)
self
.
model_input_cls
=
self
.
runner
.
_model_input_cls
...
...
@@ -156,10 +154,17 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
self
.
device
=
self
.
runner
.
device
self
.
multi_modal_input_mapper
=
self
.
runner
.
multi_modal_input_mapper
self
.
enable_lora
=
self
.
runner
.
lora_config
is
not
None
if
self
.
runner
.
attn_backend
is
not
None
:
# spec decode (e.g. Medusa) does not have atten backend
attn_backend
=
self
.
runner
.
attn_backend
self
.
att_metadata_builder
=
attn_backend
.
get_builder_cls
()(
self
)
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
self
.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
self
.
input_data
=
ModelInputForCPUBuilder
.
ModelInputData
(
self
.
runner
.
model_config
.
uses_mrope
)
self
.
att_metadata_builder
=
self
.
runner
.
attn_backend
.
get_builder_cls
()(
self
)
self
.
att_metadata_builder
.
prepare
()
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
self
.
seq_group_metadata_list
.
append
(
seq_group_metadata
)
...
...
@@ -431,6 +436,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
"""
_model_input_cls
:
Type
[
TModelInputForCPU
]
_builder_cls
:
Type
[
ModelInputForCPUBuilder
]
builder
:
ModelInputForCPUBuilder
def
__init__
(
self
,
...
...
@@ -477,6 +483,10 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
# Set after load_model.
self
.
lora_manager
:
Optional
[
LRUCacheWorkerLoRAManager
]
=
None
if
hasattr
(
self
,
"_builder_cls"
):
# multi-step model runner does not have `_builder_cls`
self
.
builder
=
self
.
_builder_cls
(
weakref
.
proxy
(
self
))
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
...
...
@@ -522,10 +532,10 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
metadata for possible additional steps, e.g., sampling.
"""
builder
=
self
.
_
builder
_cls
(
weakref
.
proxy
(
self
),
finished_requests_ids
)
builder
.
set_seq_group_list
(
seq_group_metadata_list
)
self
.
builder
.
prepare
(
finished_requests_ids
)
self
.
builder
.
set_seq_group_list
(
seq_group_metadata_list
)
return
builder
.
build
()
# type: ignore
return
self
.
builder
.
build
()
# type: ignore
# sampler property will be used by spec_decode_worker
@
property
...
...
vllm/worker/model_runner.py
View file @
66818e5b
...
...
@@ -457,16 +457,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
enable_prompt_adapter
=
(
self
.
runner
.
prompt_adapter_config
is
not
None
)
self
.
multi_modal_input_mapper
=
self
.
runner
.
multi_modal_input_mapper
self
.
finished_requests_ids
=
finished_requests_ids
self
.
decode_only
=
True
# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self
.
inter_data_list
:
List
[
ModelInputForGPUBuilder
.
InterDataForSeqGroup
]
=
[]
# Attention metadata inputs.
self
.
attn_metadata_builder
=
self
.
attn_backend
.
make_metadata_builder
(
if
self
.
attn_backend
is
not
None
:
# spec decode (e.g. Medusa) does not have atten backend
self
.
attn_metadata_builder
=
self
.
attn_backend
.
get_builder_cls
()(
weakref
.
proxy
(
self
))
# Engine/Model configurations.
...
...
@@ -479,6 +475,17 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
block_aligned_sliding_window
=
\
self
.
sliding_window_blocks
*
self
.
block_size
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
self
.
finished_requests_ids
=
finished_requests_ids
# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self
.
inter_data_list
:
List
[
ModelInputForGPUBuilder
.
InterDataForSeqGroup
]
=
[]
self
.
attn_metadata_builder
.
prepare
()
def
_compute_lens
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""Compute context length, sequence length and tokens
...
...
@@ -993,6 +1000,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"""
_model_input_cls
:
Type
[
TModelInputForGPU
]
_builder_cls
:
Type
[
ModelInputForGPUBuilder
]
builder
:
ModelInputForGPUBuilder
def
__init__
(
self
,
...
...
@@ -1093,6 +1101,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
SamplingMetadataCache
()
\
if
self
.
parallel_config
.
pipeline_parallel_size
==
1
else
None
if
hasattr
(
self
,
"_builder_cls"
):
# multi-step model runner does not have `_builder_cls`
self
.
builder
=
self
.
_builder_cls
(
weakref
.
proxy
(
self
))
def
load_model
(
self
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
with
DeviceMemoryProfiler
()
as
m
:
...
...
@@ -1226,13 +1238,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
If cuda graph is required, this API automatically pads inputs.
"""
builder
=
self
.
_
builder
_cls
(
weakref
.
proxy
(
self
),
finished_requests_ids
)
self
.
builder
.
prepare
(
finished_requests_ids
)
for
seq_group_metadata
in
seq_group_metadata_list
:
builder
.
add_seq_group
(
seq_group_metadata
)
self
.
builder
.
add_seq_group
(
seq_group_metadata
)
builder
.
reset_cached_inter_data
()
self
.
builder
.
reset_cached_inter_data
()
return
builder
.
build
()
# type: ignore
return
self
.
builder
.
build
()
# type: ignore
@
contextmanager
def
set_in_profile_run
(
self
):
...
...
vllm/worker/model_runner_base.py
View file @
66818e5b
...
...
@@ -200,6 +200,11 @@ class ModelRunnerInputBuilderBase(ABC, Generic[T]):
"""A builder to create ModelRunnerInputBase objects.
"""
@
abstractmethod
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
add_seq_group
(
self
,
seq_group_metadata
):
"""TBA"""
...
...
vllm/worker/xpu_model_runner.py
View file @
66818e5b
...
...
@@ -113,7 +113,6 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
runner
:
"XPUModelRunner"
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
super
().
__init__
()
self
.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
self
.
runner
=
runner
self
.
model_input_cls
=
self
.
runner
.
_model_input_cls
self
.
attn_backend
=
self
.
runner
.
attn_backend
...
...
@@ -121,6 +120,10 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
self
.
block_size
=
self
.
runner
.
block_size
self
.
device
=
self
.
runner
.
device
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
self
.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
self
.
seq_group_metadata_list
.
append
(
seq_group_metadata
)
...
...
@@ -408,6 +411,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
SamplingMetadataCache
()
\
if
self
.
parallel_config
.
pipeline_parallel_size
==
1
else
None
self
.
builder
=
self
.
_builder_cls
(
weakref
.
proxy
(
self
))
def
load_model
(
self
)
->
None
:
with
DeviceMemoryProfiler
()
as
m
:
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
...
...
@@ -517,7 +522,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
metadata for possible additional steps, e.g., sampling.
"""
builder
=
self
.
_builder_cls
(
weakref
.
proxy
(
self
),
finished_requests_ids
)
builder
=
self
.
builder
builder
.
prepare
(
finished_requests_ids
)
for
seq_group_metadata
in
seq_group_metadata_list
:
builder
.
add_seq_group
(
seq_group_metadata
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment