Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
296cdf8a
Unverified
Commit
296cdf8a
authored
Apr 22, 2024
by
Isotr0py
Committed by
GitHub
Apr 22, 2024
Browse files
[Misc] Add vision language model support to CPU backend (#3968)
parent
747b1a71
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
32 deletions
+53
-32
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+1
-0
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+37
-23
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+15
-9
No files found.
vllm/executor/cpu_executor.py
View file @
296cdf8a
...
@@ -45,6 +45,7 @@ class CPUExecutor(ExecutorBase):
...
@@ -45,6 +45,7 @@ class CPUExecutor(ExecutorBase):
rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
distributed_init_method
=
distributed_init_method
,
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
is_driver_worker
=
True
,
is_driver_worker
=
True
,
)
)
...
...
vllm/worker/cpu_model_runner.py
View file @
296cdf8a
...
@@ -5,7 +5,7 @@ from torch import nn
...
@@ -5,7 +5,7 @@ from torch import nn
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
...
@@ -29,6 +29,7 @@ class CPUModelRunner:
...
@@ -29,6 +29,7 @@ class CPUModelRunner:
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
load_config
:
LoadConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
is_driver_worker
:
bool
=
False
,
*
args
,
*
args
,
...
@@ -38,6 +39,7 @@ class CPUModelRunner:
...
@@ -38,6 +39,7 @@ class CPUModelRunner:
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
vision_language_config
=
vision_language_config
self
.
load_config
=
load_config
self
.
load_config
=
load_config
self
.
is_driver_worker
=
is_driver_worker
self
.
is_driver_worker
=
is_driver_worker
...
@@ -59,10 +61,11 @@ class CPUModelRunner:
...
@@ -59,10 +61,11 @@ class CPUModelRunner:
self
.
block_size
:
int
# Set after initial profiling.
self
.
block_size
:
int
# Set after initial profiling.
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
load_config
=
self
.
load_config
,
load_config
=
self
.
load_config
,
device_config
=
self
.
device_config
,
device_config
=
self
.
device_config
,
vision_language_config
=
None
,
vision_language_config
=
self
.
vision_language_config
,
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
parallel_config
=
self
.
parallel_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
)
scheduler_config
=
self
.
scheduler_config
)
...
@@ -76,6 +79,7 @@ class CPUModelRunner:
...
@@ -76,6 +79,7 @@ class CPUModelRunner:
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
is_prompt
...
@@ -96,6 +100,10 @@ class CPUModelRunner:
...
@@ -96,6 +100,10 @@ class CPUModelRunner:
# is always the first token in the sequence.
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
computed_len
,
prompt_len
)))
input_positions
.
extend
(
list
(
range
(
computed_len
,
prompt_len
)))
if
seq_group_metadata
.
multi_modal_data
:
multi_modal_input_list
.
append
(
seq_group_metadata
.
multi_modal_data
.
data
)
# Compute the slot mapping.
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
...
@@ -118,6 +126,15 @@ class CPUModelRunner:
...
@@ -118,6 +126,15 @@ class CPUModelRunner:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
slot_mapping
.
append
(
slot
)
if
multi_modal_input_list
:
assert
self
.
vision_language_config
,
(
"Multi-modal inputs are only supported by "
"vision language models."
)
multi_modal_input
=
torch
.
cat
(
multi_modal_input_list
,
dim
=
0
).
to
(
self
.
device
)
else
:
multi_modal_input
=
None
num_prompt_tokens
=
len
(
input_tokens
)
num_prompt_tokens
=
len
(
input_tokens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
input_tokens
=
torch
.
tensor
(
input_tokens
,
...
@@ -144,12 +161,8 @@ class CPUModelRunner:
...
@@ -144,12 +161,8 @@ class CPUModelRunner:
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
)
return
(
return
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
,
input_tokens
,
multi_modal_input
)
input_positions
,
attn_metadata
,
prompt_lens
,
)
def
_prepare_decode
(
def
_prepare_decode
(
self
,
self
,
...
@@ -336,14 +349,16 @@ class CPUModelRunner:
...
@@ -336,14 +349,16 @@ class CPUModelRunner:
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
]:
SamplingMetadata
]:
multi_modal_input
=
None
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
# NOTE: We assume that all sequences in the group are all prompts or
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
# all decodes.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
# Prepare input tensors.
if
is_prompt
:
if
is_prompt
:
(
input_tokens
,
input_positions
,
attn_metadata
,
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
,
prompt_lens
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
multi_modal_input
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
else
:
(
input_tokens
,
input_positions
,
(
input_tokens
,
input_positions
,
attn_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
attn_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
...
@@ -376,12 +391,8 @@ class CPUModelRunner:
...
@@ -376,12 +391,8 @@ class CPUModelRunner:
perform_sampling
=
False
,
perform_sampling
=
False
,
)
)
return
(
return
(
input_tokens
,
input_positions
,
attn_metadata
,
input_tokens
,
sampling_metadata
,
multi_modal_input
)
input_positions
,
attn_metadata
,
sampling_metadata
,
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
...
@@ -389,7 +400,8 @@ class CPUModelRunner:
...
@@ -389,7 +400,8 @@ class CPUModelRunner:
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
multi_modal_input
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
model_executable
=
self
.
model
model_executable
=
self
.
model
...
@@ -399,6 +411,8 @@ class CPUModelRunner:
...
@@ -399,6 +411,8 @@ class CPUModelRunner:
"kv_caches"
:
kv_caches
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
attn_metadata
,
"attn_metadata"
:
attn_metadata
,
}
}
if
self
.
vision_language_config
:
execute_model_kwargs
.
update
({
"image_input"
:
multi_modal_input
})
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
...
...
vllm/worker/cpu_worker.py
View file @
296cdf8a
...
@@ -6,7 +6,8 @@ import torch.distributed
...
@@ -6,7 +6,8 @@ import torch.distributed
from
vllm.attention
import
get_attn_backend
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
(
broadcast_tensor_dict
,
from
vllm.distributed
import
(
broadcast_tensor_dict
,
ensure_model_parallel_initialized
,
ensure_model_parallel_initialized
,
init_distributed_environment
)
init_distributed_environment
)
...
@@ -122,6 +123,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
...
@@ -122,6 +123,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
distributed_init_method
:
str
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
vision_language_config
:
Optional
[
VisionLanguageConfig
]
=
None
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
is_driver_worker
:
bool
=
False
,
)
->
None
:
)
->
None
:
...
@@ -135,19 +137,23 @@ class CPUWorker(LoraNotSupportedWorkerBase):
...
@@ -135,19 +137,23 @@ class CPUWorker(LoraNotSupportedWorkerBase):
self
.
rank
=
rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
distributed_init_method
=
distributed_init_method
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
vision_language_config
=
vision_language_config
self
.
is_driver_worker
=
is_driver_worker
self
.
is_driver_worker
=
is_driver_worker
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
if
self
.
model_config
.
trust_remote_code
:
if
self
.
model_config
.
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
init_cached_hf_modules
()
self
.
model_runner
=
CPUModelRunner
(
model_config
,
self
.
model_runner
=
CPUModelRunner
(
model_config
,
parallel_config
,
parallel_config
,
scheduler_config
,
scheduler_config
,
device_config
,
device_config
,
load_config
=
self
.
load_config
,
load_config
=
self
.
load_config
,
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
is_driver_worker
)
is_driver_worker
=
is_driver_worker
)
# Uninitialized cache engine. Will be initialized by
# Uninitialized cache engine. Will be initialized by
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment