Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e551ca15
Unverified
Commit
e551ca15
authored
Sep 23, 2024
by
Isotr0py
Committed by
GitHub
Sep 23, 2024
Browse files
[Hardware][CPU] Refactor CPU model runner (#8729)
parent
9b8c8ba1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
193 additions
and
109 deletions
+193
-109
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+193
-109
No files found.
vllm/worker/cpu_model_runner.py
View file @
e551ca15
import
dataclasses
import
weakref
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
...
...
@@ -17,7 +19,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.utils
import
STR_NOT_IMPL_ENC_DEC_ERR_STRS
,
make_tensor_with_pad
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
...
...
@@ -32,16 +34,17 @@ _PAD_SLOT_ID = -1
@
dataclass
(
frozen
=
True
)
class
CPU
ModelInput
(
ModelRunnerInputBase
):
class
ModelInput
ForCPU
(
ModelRunnerInputBase
):
"""
Used by the CPUModelRunner.
Base class contains metadata needed for the base model forward pass on CPU
"""
input_tokens
:
Optional
[
torch
.
Tensor
]
=
None
input_positions
:
Optional
[
torch
.
Tensor
]
=
None
attn_metadata
:
Optional
[
"AttentionMetadata"
]
=
None
sampling_metadata
:
Optional
[
"SamplingMetadata"
]
=
None
multi_modal_kwargs
:
Optional
[
BatchedTensorInputs
]
=
None
virtual_engine
:
Optional
[
int
]
=
None
seq_lens
:
Optional
[
List
[
int
]]
=
None
query_lens
:
Optional
[
List
[
int
]]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
...
...
@@ -51,88 +54,96 @@ class CPUModelInput(ModelRunnerInputBase):
"multi_modal_kwargs"
:
self
.
multi_modal_kwargs
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_sampling_metadata_broadcastable_dict
(
tensor_dict
,
self
.
sampling_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
:
Type
[
"
CPU
ModelInput"
],
cls
:
Type
[
"ModelInput
ForCPU
"
],
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
)
->
"CPUModelInput"
:
tensor_dict
=
_init_sampling_metadata_from_tensor_dict
(
tensor_dict
)
)
->
"ModelInputForCPU"
:
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
class
CPUModelRunner
(
ModelRunnerBase
[
CPUModelInput
]):
@
dataclass
(
frozen
=
True
)
class
ModelInputForCPUWithSamplingMetadata
(
ModelInputForCPU
):
"""
Used by the ModelRunner.
"""
sampling_metadata
:
Optional
[
"SamplingMetadata"
]
=
None
def
__init__
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
is_driver_worker
:
bool
=
False
,
*
args
,
**
kwargs
,
):
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
# Currently, CPU worker doesn't support chunked prefill.
assert
self
.
scheduler_config
.
chunked_prefill_enabled
is
False
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
prompt_adapter_config
=
prompt_adapter_config
self
.
load_config
=
load_config
self
.
is_driver_worker
=
is_driver_worker
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_sampling_metadata_broadcastable_dict
(
tensor_dict
,
self
.
sampling_metadata
)
return
tensor_dict
self
.
device
=
self
.
device_config
.
device
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
,
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"ModelInputForCPUWithSamplingMetadata"
:
tensor_dict
=
_init_sampling_metadata_from_tensor_dict
(
tensor_dict
)
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
),
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
),
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
)
# Multi-modal data support
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
multi_modal_input_mapper
=
self
.
mm_registry
\
.
create_input_mapper
(
self
.
model_config
)
self
.
mm_registry
.
init_mm_limits_per_prompt
(
self
.
model_config
)
class
ModelInputForCPUBuilder
(
ModelRunnerInputBuilderBase
[
ModelInputForCPU
]):
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
def
__init__
(
self
,
runner
:
"CPUModelRunner"
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
super
().
__init__
()
self
.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
self
.
runner
=
runner
self
.
model_input_cls
=
self
.
runner
.
_model_input_cls
self
.
attn_backend
=
self
.
runner
.
attn_backend
self
.
sliding_window
=
self
.
runner
.
sliding_window
self
.
block_size
=
self
.
runner
.
block_size
self
.
device
=
self
.
runner
.
device
self
.
multi_modal_input_mapper
=
self
.
runner
.
multi_modal_input_mapper
if
self
.
model_config
.
is_encoder_decoder_model
:
raise
NotImplementedError
(
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_CPU'
])
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
self
.
seq_group_metadata_list
.
append
(
seq_group_metadata
)
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
load_config
=
self
.
load_config
,
device_config
=
self
.
device_config
,
lora_config
=
self
.
lora_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
,
cache_config
=
self
.
cache_config
)
def
build
(
self
)
->
ModelInputForCPU
:
multi_modal_kwargs
=
None
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt
=
self
.
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
multi_modal_kwargs
)
=
self
.
_prepare_prompt
(
self
.
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
attn_metadata
)
=
self
.
_prepare_decode
(
self
.
seq_group_metadata_list
)
seq_lens
=
[]
return
self
.
model_input_cls
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
attn_metadata
=
attn_metadata
,
multi_modal_kwargs
=
multi_modal_kwargs
,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens
=
seq_lens
,
query_lens
=
seq_lens
,
)
def
_prepare_prompt
(
self
,
...
...
@@ -165,8 +176,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
computed_len
,
seq_len
)))
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
:
if
(
mm_data
:
=
seq_group_metadata
.
multi_modal_data
):
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
...
...
@@ -302,56 +312,130 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
attn_metadata
,
)
class
CPUModelRunner
(
ModelRunnerBase
[
ModelInputForCPU
]):
_model_input_cls
:
Type
[
ModelInputForCPUWithSamplingMetadata
]
=
(
ModelInputForCPUWithSamplingMetadata
)
_builder_cls
:
Type
[
ModelInputForCPUBuilder
]
=
ModelInputForCPUBuilder
def
__init__
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
is_driver_worker
:
bool
=
False
,
*
args
,
**
kwargs
,
):
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
# Currently, CPU worker doesn't support chunked prefill.
assert
self
.
scheduler_config
.
chunked_prefill_enabled
is
False
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
prompt_adapter_config
=
prompt_adapter_config
self
.
load_config
=
load_config
self
.
is_driver_worker
=
is_driver_worker
self
.
device
=
self
.
device_config
.
device
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
),
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
),
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
)
# Multi-modal data support
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
multi_modal_input_mapper
=
self
.
mm_registry
\
.
create_input_mapper
(
self
.
model_config
)
self
.
mm_registry
.
init_mm_limits_per_prompt
(
self
.
model_config
)
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
if
self
.
model_config
.
is_encoder_decoder_model
:
raise
NotImplementedError
(
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_CPU'
])
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
load_config
=
self
.
load_config
,
device_config
=
self
.
device_config
,
lora_config
=
self
.
lora_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
,
cache_config
=
self
.
cache_config
)
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
],
)
->
CPU
ModelInput
:
return
CPU
ModelInput
.
from_broadcasted_tensor_dict
(
)
->
ModelInput
ForCPU
:
return
ModelInput
ForCPU
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
,
)
def
_prepare_model_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForCPUWithSamplingMetadata
:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
"""
builder
=
self
.
_builder_cls
(
weakref
.
proxy
(
self
),
finished_requests_ids
)
for
seq_group_metadata
in
seq_group_metadata_list
:
builder
.
add_seq_group
(
seq_group_metadata
)
return
builder
.
build
()
# type: ignore
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
CPUModelInput
:
multi_modal_kwargs
=
None
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
multi_modal_kwargs
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
attn_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
seq_lens
=
[]
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens
,
)
->
ModelInputForCPUWithSamplingMetadata
:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
"""
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
# Sampling metadata is only required for the final pp group
generators
=
self
.
get_generators
(
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
model_input
.
seq_lens
,
model_input
.
query_lens
,
self
.
device
,
pin_memory
=
False
,
generators
=
self
.
get_generators
(
finished_requests_ids
))
return
CPUModelInput
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
attn_metadata
=
attn_metadata
,
generators
=
generators
)
return
dataclasses
.
replace
(
model_input
,
sampling_metadata
=
sampling_metadata
,
multi_modal_kwargs
=
multi_modal_kwargs
,
)
virtual_engine
=
virtual_engine
)
@
torch
.
no_grad
()
def
execute_model
(
self
,
model_input
:
CPU
ModelInput
,
model_input
:
ModelInput
ForCPUWithSamplingMetadata
,
kv_caches
:
List
[
torch
.
Tensor
],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment