Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
64172a97
Unverified
Commit
64172a97
authored
Mar 25, 2024
by
xwjiang2010
Committed by
GitHub
Mar 25, 2024
Browse files
[Feature] Add vision language model support. (#3042)
parent
f408d05c
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
407 additions
and
31 deletions
+407
-31
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+2
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+9
-2
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+246
-0
vllm/sequence.py
vllm/sequence.py
+25
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+14
-0
vllm/utils.py
vllm/utils.py
+10
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+85
-21
vllm/worker/worker.py
vllm/worker/worker.py
+16
-8
No files found.
vllm/model_executor/models/__init__.py
View file @
64172a97
...
...
@@ -29,6 +29,8 @@ _MODELS = {
"InternLM2ForCausalLM"
:
(
"internlm2"
,
"InternLM2ForCausalLM"
),
"JAISLMHeadModel"
:
(
"jais"
,
"JAISLMHeadModel"
),
"LlamaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"LlavaForConditionalGeneration"
:
(
"llava"
,
"LlavaForConditionalGeneration"
),
# For decapoda-research/llama-*
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"MistralForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
...
...
vllm/model_executor/models/llama.py
View file @
64172a97
...
...
@@ -250,14 +250,21 @@ class LlamaModel(nn.Module):
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
Optional
[
torch
.
Tensor
]
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
...
...
vllm/model_executor/models/llava.py
0 → 100644
View file @
64172a97
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VisionLanguageConfig
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
"language_model.model"
:
"language_model"
,
}
# TODO(xwjiang): Run benchmark and decide if TP.
class
LlavaMultiModalProjector
(
nn
.
Module
):
def
__init__
(
self
,
vision_hidden_size
:
int
,
text_hidden_size
:
int
,
projector_hidden_act
:
str
):
super
().
__init__
()
self
.
linear_1
=
nn
.
Linear
(
vision_hidden_size
,
text_hidden_size
,
bias
=
True
)
self
.
act
=
get_act_fn
(
projector_hidden_act
)
self
.
linear_2
=
nn
.
Linear
(
text_hidden_size
,
text_hidden_size
,
bias
=
True
)
def
forward
(
self
,
image_features
):
hidden_states
=
self
.
linear_1
(
image_features
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
return
hidden_states
def
_merge_vision_embeddings
(
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
vision_embeddings
:
torch
.
Tensor
,
image_token_id
:
int
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask
=
(
input_ids
==
image_token_id
)
inputs_embeds
[
mask
]
=
vision_embeddings
.
view
(
-
1
,
vision_embeddings
.
shape
[
-
1
])
class
LlavaForConditionalGeneration
(
nn
.
Module
):
def
__init__
(
self
,
config
:
"LlavaConfig"
,
vision_language_config
:
VisionLanguageConfig
,
linear_method
:
Optional
[
"LinearMethodBase"
]
=
None
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
vision_language_config
=
vision_language_config
assert
self
.
vision_language_config
,
(
"Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments."
)
if
self
.
vision_language_config
.
image_input_type
==
(
VisionLanguageConfig
.
ImageInputType
.
PIXEL_VALUES
):
self
.
vision_tower
=
CLIPVisionModel
(
config
.
vision_config
)
else
:
self
.
vision_tower
=
None
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
linear_method
=
linear_method
self
.
language_model
=
LlamaModel
(
config
.
text_config
,
linear_method
)
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
text_config
.
hidden_size
,
org_num_embeddings
=
self
.
language_model
.
org_vocab_size
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
attn_metadata
:
AttentionMetadata
,
image_input
:
Optional
[
torch
.
Tensor
]
=
None
)
->
SamplerOutput
:
# noqa: E501
"""Run forward pass for Llava 1.5.
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
"<image>
\n
USER: What's the content of the image?
\n
ASSISTANT:".
Tokenizer outputs:
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
The to-be-inserted image has a size of 576 (24 * 24) along the context
length dimension.
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
9047, 13566, 29901].
There will be 576 `32000` in the `input_ids`.
(32000 is the token id for `<image>`.)
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
The model takes two types of image inputs:
PIXEL_VALUES and IMAGE_FEATURES.
The following shows how each maps to huggingface implementation.
PIXEL_VALUES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
IMAGE_FEATURES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
before going through the multi modal projector.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
image_input: A batch of image inputs.
For PIXEL_VALUES, expecting [1, 3, 336, 336].
For IMAGE_FEATURES, expecting [1, 576, 1024].
"""
if
image_input
is
not
None
:
if
list
(
image_input
.
shape
[
1
:])
!=
list
(
self
.
vision_language_config
.
image_input_shape
[
1
:]):
raise
ValueError
(
f
"The expected image tensor shape is batch dimension "
f
"plus "
f
"
{
self
.
vision_language_config
.
image_input_shape
[
1
:]
}
."
f
" You supplied
{
image_input
.
shape
}
. "
f
"If you are using vLLM's entrypoint, make sure your "
f
"supplied image input is consistent with "
f
"image_input_shape in engine args."
)
if
self
.
vision_tower
is
not
None
:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs
=
self
.
vision_tower
(
image_input
,
output_hidden_states
=
True
)
image_features
=
image_outputs
.
hidden_states
[
self
.
config
.
vision_feature_layer
]
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if
self
.
config
.
vision_feature_select_strategy
==
"default"
:
image_features
=
image_features
[:,
1
:]
elif
self
.
config
.
vision_feature_select_strategy
==
"full"
:
image_features
=
image_features
else
:
raise
ValueError
(
f
"Unexpected select feature strategy: "
f
"
{
self
.
config
.
vision_feature_select_strategy
}
"
)
else
:
image_features
=
image_input
vision_embeddings
=
self
.
multi_modal_projector
(
image_features
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
_merge_vision_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
vision_language_config
.
image_token_id
)
input_ids
=
None
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
# only doing this for language model part for now.
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
continue
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
name
=
name
.
replace
(
key_to_modify
,
new_key
)
use_default_weight_loading
=
False
if
"vision"
in
name
:
if
self
.
vision_tower
is
not
None
:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading
=
True
else
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
use_default_weight_loading
=
True
if
use_default_weight_loading
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/sequence.py
View file @
64172a97
...
...
@@ -303,6 +303,25 @@ class SequenceGroupState:
generator
:
Optional
=
None
class
MultiModalData
:
"""Multi modal request.
Args:
type: The data type.
data: The actual data.
The required shape and semantic meaning of it depends on the vision
language config of the hosted model.
See `VisionLanguageConfig` in `config.py`.
"""
class
Type
(
enum
.
Enum
):
IMAGE
=
enum
.
auto
()
def
__init__
(
self
,
type
:
Type
,
data
:
"torch.Tensor"
):
self
.
type
=
type
self
.
data
=
data
class
SequenceGroup
:
"""A group of sequences that are generated from the same prompt.
...
...
@@ -312,6 +331,7 @@ class SequenceGroup:
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
lora_request: LoRA request.
multi_modal_data: Multi modal data associated with the request.
"""
def
__init__
(
...
...
@@ -321,6 +341,7 @@ class SequenceGroup:
sampling_params
:
SamplingParams
,
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
None
:
self
.
request_id
=
request_id
self
.
seqs_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
...
...
@@ -333,6 +354,7 @@ class SequenceGroup:
self
.
lora_request
=
lora_request
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
state
=
SequenceGroupState
()
self
.
multi_modal_data
=
multi_modal_data
@
property
def
prompt
(
self
)
->
str
:
...
...
@@ -450,6 +472,7 @@ class SequenceGroupMetadata:
numbers)
state: Internal state tied to this sequence group.
lora_request: LoRA request.
multi_modal_data: Multi modal data.
"""
def
__init__
(
...
...
@@ -462,6 +485,7 @@ class SequenceGroupMetadata:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
state
:
Optional
[
SequenceGroupState
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
None
:
self
.
request_id
=
request_id
self
.
is_prompt
=
is_prompt
...
...
@@ -470,6 +494,7 @@ class SequenceGroupMetadata:
self
.
block_tables
=
block_tables
self
.
lora_request
=
lora_request
self
.
computed_block_nums
=
computed_block_nums
self
.
multi_modal_data
=
multi_modal_data
self
.
state
=
SequenceGroupState
()
if
state
is
None
else
state
@
property
...
...
vllm/transformers_utils/config.py
View file @
64172a97
...
...
@@ -40,3 +40,17 @@ def get_config(model: str,
revision
=
revision
,
code_revision
=
code_revision
)
return
config
def
get_hf_text_config
(
config
:
PretrainedConfig
):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
if
hasattr
(
config
,
"text_config"
):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert
hasattr
(
config
.
text_config
,
"num_attention_heads"
)
return
config
.
text_config
else
:
return
config
vllm/utils.py
View file @
64172a97
...
...
@@ -377,6 +377,16 @@ class CudaMemoryProfiler:
gc
.
collect
()
def
str_to_int_tuple
(
s
:
str
)
->
Tuple
[
int
]:
"""Convert a string to a tuple of integers."""
try
:
return
tuple
(
map
(
int
,
s
.
split
(
","
)))
except
ValueError
as
e
:
raise
ValueError
(
"String must be a series of integers separated by commas "
f
"(e.g., 1, 2, 3). Given input:
{
s
}
"
)
from
e
def
pad_to_max_length
(
x
:
List
[
int
],
max_len
:
int
,
pad
:
int
)
->
List
[
int
]:
assert
len
(
x
)
<=
max_len
return
x
+
[
pad
]
*
(
max_len
-
len
(
x
))
...
...
vllm/worker/model_runner.py
View file @
64172a97
...
...
@@ -8,7 +8,7 @@ import torch.nn as nn
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
(
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
...
...
@@ -21,7 +21,8 @@ from vllm.model_executor.parallel_utils.communication_op import (
from
vllm.model_executor.parallel_utils.parallel_state
import
(
with_cupy_nccl_for_all_reduce
)
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
async_tensor_h2d
,
is_pin_memory_available
,
make_tensor_with_pad
,
maybe_expand_dim
)
...
...
@@ -49,6 +50,7 @@ class ModelRunner:
lora_config
:
Optional
[
LoRAConfig
],
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
vision_language_config
:
Optional
[
VisionLanguageConfig
]
=
None
,
):
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
...
...
@@ -83,17 +85,20 @@ class ModelRunner:
self
.
graph_block_tables
=
None
# Set after initial profiling.
self
.
pin_memory
=
is_pin_memory_available
()
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
vision_language_config
=
vision_language_config
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
dtype
if
model_config
is
not
None
else
None
)
def
load_model
(
self
)
->
None
:
with
CudaMemoryProfiler
()
as
m
:
self
.
model
=
get_model
(
self
.
model_config
,
self
.
device_config
,
lora_config
=
self
.
lora_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
)
self
.
model
=
get_model
(
self
.
model_config
,
self
.
device_config
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
)
self
.
model_memory_usage
=
m
.
consumed_memory
logger
.
info
(
f
"Loading model weights took "
...
...
@@ -130,7 +135,8 @@ class ModelRunner:
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
],
Set
[
LoRARequest
]]:
List
[
int
],
List
[
int
],
List
[
int
],
Set
[
LoRARequest
],
torch
.
Tensor
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
...
...
@@ -143,6 +149,7 @@ class ModelRunner:
context_lens
:
List
[
int
]
=
[]
subquery_lens
:
List
[
int
]
=
[]
prefix_block_tables
:
List
[
List
[
int
]]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
...
...
@@ -188,6 +195,10 @@ class ModelRunner:
(
prompt_len
-
computed_len
if
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
if
seq_group_metadata
.
multi_modal_data
:
multi_modal_input_list
.
append
(
seq_group_metadata
.
multi_modal_data
.
data
)
if
seq_group_metadata
.
block_tables
is
None
:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
...
...
@@ -236,6 +247,16 @@ class ModelRunner:
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
if
multi_modal_input_list
:
assert
self
.
vision_language_config
,
(
"Multi-modal inputs are only supported by "
"vision language models."
)
multi_modal_input
=
torch
.
cat
(
multi_modal_input_list
,
dim
=
0
).
to
(
self
.
device
)
else
:
multi_modal_input
=
None
# Prepare prefix block tables
max_prompt_block_table_len
=
max
(
len
(
t
)
for
t
in
prefix_block_tables
)
block_tables
=
make_tensor_with_pad
(
...
...
@@ -291,7 +312,7 @@ class ModelRunner:
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
,
subquery_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
)
lora_requests
,
multi_modal_input
)
def
_prepare_decode
(
self
,
...
...
@@ -525,7 +546,7 @@ class ModelRunner:
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
,
Set
[
int
],
LoRAMapping
]:
Set
[
int
],
LoRAMapping
,
torch
.
Tensor
]:
if
self
.
is_driver_worker
:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
...
...
@@ -534,13 +555,15 @@ class ModelRunner:
if
is_prompt
:
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
,
subquery_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
lora_requests
,
multi_modal_input
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
attn_metadata
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt_lens
=
[]
subquery_lens
=
None
multi_modal_input
=
None
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
)
...
...
@@ -561,6 +584,7 @@ class ModelRunner:
sampling_metadata
.
selected_token_indices
,
"lora_requests"
:
lora_requests
,
"lora_mapping"
:
lora_mapping
,
"multi_modal_input"
:
multi_modal_input
,
}
metadata_dict
.
update
(
attn_metadata
.
asdict_zerocopy
())
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
...
...
@@ -572,6 +596,7 @@ class ModelRunner:
"selected_token_indices"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
multi_modal_input
=
metadata_dict
.
pop
(
"multi_modal_input"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
...
...
@@ -584,7 +609,8 @@ class ModelRunner:
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
lora_requests
,
lora_mapping
)
sampling_metadata
,
lora_requests
,
lora_mapping
,
multi_modal_input
)
@
torch
.
inference_mode
()
def
execute_model
(
...
...
@@ -593,8 +619,8 @@ class ModelRunner:
kv_caches
:
List
[
torch
.
Tensor
],
)
->
Optional
[
SamplerOutput
]:
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
lora_requests
,
lora_mapping
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
lora_requests
,
lora_mapping
,
multi_modal_input
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
if
self
.
lora_config
:
self
.
set_active_loras
(
lora_requests
,
lora_mapping
)
...
...
@@ -605,12 +631,15 @@ class ModelRunner:
model_executable
=
self
.
graph_runners
[
graph_batch_size
]
else
:
model_executable
=
self
.
model
hidden_states
=
model_executable
(
input_ids
=
input_tokens
,
positions
=
input_positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
)
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
"positions"
:
input_positions
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
attn_metadata
,
}
if
self
.
vision_language_config
:
execute_model_kwargs
.
update
({
"image_input"
:
multi_modal_input
})
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
...
...
@@ -658,10 +687,22 @@ class ModelRunner:
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
# Additional GPU memory may be needed for vision encoding, which needs
# to be accounted for when calculating the GPU blocks for
# vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
if
self
.
vision_language_config
:
max_num_seqs
=
min
(
max_num_seqs
,
int
(
max_num_batched_tokens
/
self
.
vision_language_config
.
image_feature_size
))
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
seq_data
=
SequenceData
([
0
]
*
seq_len
)
seq_data
,
fake_multi_modal_input
=
_prepare_fake_inputs
(
seq_len
,
self
.
vision_language_config
)
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
is_prompt
=
True
,
...
...
@@ -670,6 +711,7 @@ class ModelRunner:
block_tables
=
None
,
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
multi_modal_data
=
fake_multi_modal_input
,
)
seqs
.
append
(
seq
)
...
...
@@ -831,6 +873,7 @@ class CUDAGraphRunner:
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
memory_pool
,
**
kwargs
,
)
->
None
:
assert
self
.
graph
is
None
# Run the model once without capturing the graph.
...
...
@@ -842,6 +885,7 @@ class CUDAGraphRunner:
positions
,
kv_caches
,
attn_metadata
,
**
kwargs
,
)
torch
.
cuda
.
synchronize
()
...
...
@@ -856,6 +900,7 @@ class CUDAGraphRunner:
positions
,
kv_caches
,
attn_metadata
,
**
kwargs
,
)
torch
.
cuda
.
synchronize
()
...
...
@@ -877,6 +922,7 @@ class CUDAGraphRunner:
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
**
kwargs
,
)
->
torch
.
Tensor
:
# KV caches are fixed tensors, so we don't need to copy them.
del
kv_caches
...
...
@@ -922,3 +968,21 @@ def _get_graph_batch_size(batch_size: int) -> int:
else
:
return
((
batch_size
+
_BATCH_SIZE_ALIGNMENT
-
1
)
//
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
def
_prepare_fake_inputs
(
seq_len
:
int
,
vision_language_config
:
Optional
[
VisionLanguageConfig
]):
"""Prepare fake inputs for profile run."""
if
vision_language_config
:
prompt_tokens
=
[
vision_language_config
.
image_token_id
]
*
vision_language_config
.
image_feature_size
+
[
0
]
*
(
seq_len
-
vision_language_config
.
image_feature_size
)
fake_image_input
=
MultiModalData
(
type
=
MultiModalData
.
Type
.
IMAGE
,
data
=
torch
.
zeros
(
vision_language_config
.
image_input_shape
,
dtype
=
torch
.
float16
))
else
:
prompt_tokens
=
[
0
]
*
seq_len
fake_image_input
=
None
return
SequenceData
(
prompt_tokens
),
fake_image_input
vllm/worker/worker.py
View file @
64172a97
...
...
@@ -7,7 +7,7 @@ import torch
import
torch.distributed
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.parallel_utils
import
cupy_utils
...
...
@@ -39,6 +39,7 @@ class Worker:
rank
:
int
,
distributed_init_method
:
str
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
vision_language_config
:
Optional
[
VisionLanguageConfig
]
=
None
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
)
->
None
:
...
...
@@ -54,13 +55,20 @@ class Worker:
if
self
.
is_driver_worker
:
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
self
.
model_runner
=
ModelRunner
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
is_driver_worker
)
self
.
vision_language_config
=
vision_language_config
if
self
.
vision_language_config
:
assert
not
self
.
lora_config
,
(
"To be tested: vision language model with LoRA settings."
)
self
.
model_runner
=
ModelRunner
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
is_driver_worker
,
vision_language_config
=
vision_language_config
)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self
.
cache_config
=
None
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment