Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3f674a49
Unverified
Commit
3f674a49
authored
Aug 15, 2024
by
Cyrus Leung
Committed by
GitHub
Aug 14, 2024
Browse files
[VLM][Core] Support profiling with multiple multi-modal inputs per prompt (#7126)
parent
70b746ef
Changes
38
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
292 additions
and
162 deletions
+292
-162
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+14
-9
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+5
-3
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+16
-11
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+2
-5
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+8
-3
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+10
-6
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+10
-4
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+11
-10
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+8
-5
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+8
-4
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+5
-3
vllm/multimodal/base.py
vllm/multimodal/base.py
+33
-15
vllm/multimodal/image.py
vllm/multimodal/image.py
+6
-3
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+101
-16
vllm/utils.py
vllm/utils.py
+0
-28
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+16
-3
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+20
-17
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+19
-17
No files found.
vllm/model_executor/models/chameleon.py
View file @
3f674a49
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
TypedDict
)
Tuple
,
TypedDict
)
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -19,8 +19,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -19,8 +19,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -61,6 +60,7 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
...
@@ -61,6 +60,7 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
def
dummy_seq_data_for_chameleon
(
def
dummy_seq_data_for_chameleon
(
seq_len
:
int
,
seq_len
:
int
,
num_images
:
int
,
*
,
*
,
image_token_id
:
int
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
...
@@ -70,12 +70,14 @@ def dummy_seq_data_for_chameleon(
...
@@ -70,12 +70,14 @@ def dummy_seq_data_for_chameleon(
else
:
else
:
image_feature_size
=
image_feature_size_override
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
token_ids
=
[
image_token_id
]
*
image_feature_size
*
num_images
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_chameleon
(
def
dummy_image_for_chameleon
(
num_images
:
int
,
*
,
image_width_override
:
Optional
[
int
]
=
None
,
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
):
):
...
@@ -87,17 +89,20 @@ def dummy_image_for_chameleon(
...
@@ -87,17 +89,20 @@ def dummy_image_for_chameleon(
height
=
image_height_override
height
=
image_height_override
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
}
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
dummy_data_for_chameleon
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_chameleon
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_images
=
mm_counts
[
"image"
]
seq_data
=
dummy_seq_data_for_chameleon
(
seq_data
=
dummy_seq_data_for_chameleon
(
seq_len
,
seq_len
,
num_images
,
image_token_id
=
CHAMELEON_IMAGE_TOKEN_ID
,
image_token_id
=
CHAMELEON_IMAGE_TOKEN_ID
,
)
)
mm_data
=
dummy_image_for_chameleon
()
mm_data
=
dummy_image_for_chameleon
(
num_images
)
return
seq_data
,
mm_data
return
seq_data
,
mm_data
...
...
vllm/model_executor/models/clip.py
View file @
3f674a49
...
@@ -43,6 +43,7 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
...
@@ -43,6 +43,7 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
def
dummy_seq_data_for_clip
(
def
dummy_seq_data_for_clip
(
hf_config
:
CLIPVisionConfig
,
hf_config
:
CLIPVisionConfig
,
seq_len
:
int
,
seq_len
:
int
,
num_images
:
int
,
*
,
*
,
image_token_id
:
int
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
...
@@ -52,13 +53,14 @@ def dummy_seq_data_for_clip(
...
@@ -52,13 +53,14 @@ def dummy_seq_data_for_clip(
else
:
else
:
image_feature_size
=
image_feature_size_override
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
token_ids
=
[
image_token_id
]
*
image_feature_size
*
num_images
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_clip
(
def
dummy_image_for_clip
(
hf_config
:
CLIPVisionConfig
,
hf_config
:
CLIPVisionConfig
,
num_images
:
int
,
*
,
*
,
image_width_override
:
Optional
[
int
]
=
None
,
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
...
@@ -70,7 +72,7 @@ def dummy_image_for_clip(
...
@@ -70,7 +72,7 @@ def dummy_image_for_clip(
height
=
image_height_override
height
=
image_height_override
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
}
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
input_processor_for_clip
(
def
input_processor_for_clip
(
...
...
vllm/model_executor/models/fuyu.py
View file @
3f674a49
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
# limitations under the License.
# limitations under the License.
""" PyTorch Fuyu model."""
""" PyTorch Fuyu model."""
import
math
import
math
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
from
typing
import
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -29,8 +29,7 @@ from vllm.config import CacheConfig, MultiModalConfig
...
@@ -29,8 +29,7 @@ from vllm.config import CacheConfig, MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.persimmon
import
PersimmonForCausalLM
from
vllm.model_executor.models.persimmon
import
PersimmonForCausalLM
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -94,27 +93,33 @@ def get_max_fuyu_image_tokens(ctx: InputContext):
...
@@ -94,27 +93,33 @@ def get_max_fuyu_image_tokens(ctx: InputContext):
return
(
ncol
+
1
)
*
nrow
return
(
ncol
+
1
)
*
nrow
def
dummy_seq_data_for_fuyu
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_seq_data_for_fuyu
(
ctx
:
InputContext
,
seq_len
:
int
,
num_images
:
int
):
ncol
,
nrow
=
get_max_fuyu_image_feature_size
()
ncol
,
nrow
=
get_max_fuyu_image_feature_size
()
image_feature_size
=
get_max_fuyu_image_tokens
(
ctx
)
image_feature_size
=
get_max_fuyu_image_tokens
(
ctx
)
token_ids
=
([
_IMAGE_TOKEN_ID
]
*
ncol
+
[
_NEWLINE_TOKEN_ID
])
*
nrow
image_token_ids
=
([
_IMAGE_TOKEN_ID
]
*
ncol
+
[
_NEWLINE_TOKEN_ID
])
*
nrow
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
token_ids
=
image_token_ids
*
num_images
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_fuyu
(
def
dummy_image_for_fuyu
(
num_images
:
int
,
*
,
image_width
:
int
,
image_width
:
int
,
image_height
:
int
,
image_height
:
int
,
):
):
image
=
Image
.
new
(
"RGB"
,
(
image_width
,
image_height
),
color
=
0
)
image
=
Image
.
new
(
"RGB"
,
(
image_width
,
image_height
),
color
=
0
)
return
{
"image"
:
image
}
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
dummy_data_for_fuyu
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_fuyu
(
ctx
:
InputContext
,
seq_len
:
int
,
seq_data
=
dummy_seq_data_for_fuyu
(
ctx
,
seq_len
)
mm_counts
:
Mapping
[
str
,
int
]):
mm_data
=
dummy_image_for_fuyu
(
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
num_images
=
mm_counts
[
"image"
]
MAX_IMAGE_FEATURE_SIZE_HEIGHT
)
seq_data
=
dummy_seq_data_for_fuyu
(
ctx
,
seq_len
,
num_images
)
mm_data
=
dummy_image_for_fuyu
(
num_images
,
image_width
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_height
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
)
return
seq_data
,
mm_data
return
seq_data
,
mm_data
...
...
vllm/model_executor/models/interfaces.py
View file @
3f674a49
...
@@ -11,14 +11,11 @@ logger = init_logger(__name__)
...
@@ -11,14 +11,11 @@ logger = init_logger(__name__)
@
runtime_checkable
@
runtime_checkable
class
SupportsMultiModal
(
Protocol
):
class
SupportsMultiModal
(
Protocol
):
"""
"""The interface required for all multi-modal models."""
The interface required for all multimodal (vision or audio) language
models.
"""
supports_multimodal
:
ClassVar
[
Literal
[
True
]]
=
True
supports_multimodal
:
ClassVar
[
Literal
[
True
]]
=
True
"""
"""
A flag that indicates this model supports multimodal inputs.
A flag that indicates this model supports multi
-
modal inputs.
Note:
Note:
There is no need to redefine this flag if this class is in the
There is no need to redefine this flag if this class is in the
...
...
vllm/model_executor/models/internvl.py
View file @
3f674a49
...
@@ -5,7 +5,8 @@
...
@@ -5,7 +5,8 @@
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
# --------------------------------------------------------
import
itertools
import
itertools
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -230,7 +231,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -230,7 +231,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
def
input_mapper_for_internvl
(
ctx
:
InputContext
,
data
:
object
):
def
input_mapper_for_internvl
(
ctx
:
InputContext
,
data
:
object
):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
()
use_thumbnail
=
hf_config
.
use_thumbnail
use_thumbnail
=
hf_config
.
use_thumbnail
min_num
=
hf_config
.
min_dynamic_patch
min_num
=
hf_config
.
min_dynamic_patch
...
@@ -256,7 +257,9 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
...
@@ -256,7 +257,9 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
})
})
def
dummy_data_for_internvl
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_internvl
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_max_internvl_image_tokens
(
ctx
)
image_feature_size
=
get_max_internvl_image_tokens
(
ctx
)
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
...
@@ -268,6 +271,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
...
@@ -268,6 +271,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
seq_data
=
dummy_seq_data_for_clip
(
seq_data
=
dummy_seq_data_for_clip
(
vision_config
,
vision_config
,
seq_len
,
seq_len
,
num_images
,
image_token_id
=
tokenizer
.
encode
(
IMG_CONTEXT
,
image_token_id
=
tokenizer
.
encode
(
IMG_CONTEXT
,
add_special_tokens
=
False
)[
0
],
add_special_tokens
=
False
)[
0
],
image_feature_size_override
=
image_feature_size
,
image_feature_size_override
=
image_feature_size
,
...
@@ -281,6 +285,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
...
@@ -281,6 +285,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
mm_data
=
dummy_image_for_clip
(
mm_data
=
dummy_image_for_clip
(
vision_config
,
vision_config
,
num_images
,
image_width_override
=
max_image_width
,
image_width_override
=
max_image_width
,
image_height_override
=
max_image_height
,
image_height_override
=
max_image_height
,
)
)
...
...
vllm/model_executor/models/llava.py
View file @
3f674a49
import
itertools
import
itertools
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -9,8 +10,7 @@ from vllm.attention import AttentionMetadata
...
@@ -9,8 +10,7 @@ from vllm.attention import AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
@@ -88,9 +88,11 @@ def get_max_llava_image_tokens(ctx: InputContext):
...
@@ -88,9 +88,11 @@ def get_max_llava_image_tokens(ctx: InputContext):
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
def
dummy_data_for_llava
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_llava
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
hf_config
=
ctx
.
get_hf_config
(
LlavaConfig
)
hf_config
=
ctx
.
get_hf_config
(
LlavaConfig
)
vision_config
=
hf_config
.
vision_config
vision_config
=
hf_config
.
vision_config
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_max_llava_image_tokens
(
ctx
)
image_feature_size
=
get_max_llava_image_tokens
(
ctx
)
...
@@ -98,21 +100,23 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int):
...
@@ -98,21 +100,23 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int):
seq_data
=
dummy_seq_data_for_clip
(
seq_data
=
dummy_seq_data_for_clip
(
vision_config
,
vision_config
,
seq_len
,
seq_len
,
num_images
,
image_token_id
=
hf_config
.
image_token_index
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
image_feature_size_override
=
image_feature_size
,
)
)
mm_data
=
dummy_image_for_clip
(
vision_config
)
mm_data
=
dummy_image_for_clip
(
vision_config
,
num_images
)
return
seq_data
,
mm_data
return
seq_data
,
mm_data
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
seq_data
=
dummy_seq_data_for_siglip
(
seq_data
=
dummy_seq_data_for_siglip
(
vision_config
,
vision_config
,
seq_len
,
seq_len
,
num_images
,
image_token_id
=
hf_config
.
image_token_index
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
image_feature_size_override
=
image_feature_size
,
)
)
mm_data
=
dummy_image_for_siglip
(
vision_config
)
mm_data
=
dummy_image_for_siglip
(
vision_config
,
num_images
)
return
seq_data
,
mm_data
return
seq_data
,
mm_data
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
...
...
vllm/model_executor/models/llava_next.py
View file @
3f674a49
import
itertools
import
itertools
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -13,8 +14,7 @@ from vllm.attention import AttentionMetadata
...
@@ -13,8 +14,7 @@ from vllm.attention import AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
@@ -158,9 +158,11 @@ def get_max_llava_next_image_tokens(ctx: InputContext):
...
@@ -158,9 +158,11 @@ def get_max_llava_next_image_tokens(ctx: InputContext):
)
)
def
dummy_data_for_llava_next
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_llava_next
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
hf_config
=
ctx
.
get_hf_config
(
LlavaNextConfig
)
hf_config
=
ctx
.
get_hf_config
(
LlavaNextConfig
)
vision_config
=
hf_config
.
vision_config
vision_config
=
hf_config
.
vision_config
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_max_llava_next_image_tokens
(
ctx
)
image_feature_size
=
get_max_llava_next_image_tokens
(
ctx
)
...
@@ -168,12 +170,14 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
...
@@ -168,12 +170,14 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
seq_data
=
dummy_seq_data_for_clip
(
seq_data
=
dummy_seq_data_for_clip
(
vision_config
,
vision_config
,
seq_len
,
seq_len
,
num_images
,
image_token_id
=
hf_config
.
image_token_index
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
image_feature_size_override
=
image_feature_size
,
)
)
mm_data
=
dummy_image_for_clip
(
mm_data
=
dummy_image_for_clip
(
vision_config
,
vision_config
,
num_images
,
image_width_override
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_width_override
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_height_override
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
image_height_override
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
)
)
...
@@ -183,12 +187,14 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
...
@@ -183,12 +187,14 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
seq_data
=
dummy_seq_data_for_siglip
(
seq_data
=
dummy_seq_data_for_siglip
(
vision_config
,
vision_config
,
seq_len
,
seq_len
,
num_images
,
image_token_id
=
hf_config
.
image_token_index
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
image_feature_size_override
=
image_feature_size
,
)
)
mm_data
=
dummy_image_for_siglip
(
mm_data
=
dummy_image_for_siglip
(
vision_config
,
vision_config
,
num_images
,
image_width_override
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_width_override
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_height_override
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
image_height_override
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
)
)
...
...
vllm/model_executor/models/minicpmv.py
View file @
3f674a49
...
@@ -24,8 +24,8 @@
...
@@ -24,8 +24,8 @@
import
math
import
math
import
re
import
re
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Optional
,
Tuple
,
TypedDict
,
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
Union
)
TypedDict
,
Union
)
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -42,8 +42,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
...
@@ -42,8 +42,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
...
@@ -408,22 +407,24 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
...
@@ -408,22 +407,24 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
return
getattr
(
hf_config
,
"query_num"
,
64
)
return
getattr
(
hf_config
,
"query_num"
,
64
)
def
dummy_seq_data_for_minicpmv
(
seq_len
:
int
):
def
dummy_seq_data_for_minicpmv
(
seq_len
:
int
,
num_images
:
int
):
token_ids
=
[
0
]
*
seq_len
token_ids
=
[
0
]
*
seq_len
return
SequenceData
(
token_ids
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_minicpmv
(
hf_config
:
PretrainedConfig
):
def
dummy_image_for_minicpmv
(
hf_config
:
PretrainedConfig
,
num_images
:
int
):
width
=
height
=
hf_config
.
image_size
width
=
height
=
hf_config
.
image_size
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
}
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
dummy_data_for_minicpmv
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_minicpmv
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
hf_config
=
ctx
.
get_hf_config
()
hf_config
=
ctx
.
get_hf_config
()
num_images
=
mm_counts
[
"image"
]
seq_data
=
dummy_seq_data_for_minicpmv
(
seq_len
)
seq_data
=
dummy_seq_data_for_minicpmv
(
seq_len
,
num_images
)
mm_data
=
dummy_image_for_minicpmv
(
hf_config
)
mm_data
=
dummy_image_for_minicpmv
(
hf_config
,
num_images
)
return
seq_data
,
mm_data
return
seq_data
,
mm_data
...
...
vllm/model_executor/models/paligemma.py
View file @
3f674a49
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -9,8 +10,7 @@ from vllm.config import CacheConfig, MultiModalConfig
...
@@ -9,8 +10,7 @@ from vllm.config import CacheConfig, MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.gemma
import
GemmaModel
from
vllm.model_executor.models.gemma
import
GemmaModel
...
@@ -57,17 +57,20 @@ def get_max_paligemma_image_tokens(ctx: InputContext):
...
@@ -57,17 +57,20 @@ def get_max_paligemma_image_tokens(ctx: InputContext):
return
get_max_siglip_image_tokens
(
vision_config
)
return
get_max_siglip_image_tokens
(
vision_config
)
def
dummy_data_for_paligemma
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_paligemma
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
hf_config
=
ctx
.
get_hf_config
(
PaliGemmaConfig
)
hf_config
=
ctx
.
get_hf_config
(
PaliGemmaConfig
)
vision_config
=
hf_config
.
vision_config
vision_config
=
hf_config
.
vision_config
num_images
=
mm_counts
[
"image"
]
seq_data
=
dummy_seq_data_for_siglip
(
seq_data
=
dummy_seq_data_for_siglip
(
vision_config
,
vision_config
,
seq_len
,
seq_len
,
num_images
,
image_token_id
=
hf_config
.
image_token_index
,
image_token_id
=
hf_config
.
image_token_index
,
)
)
mm_data
=
dummy_image_for_siglip
(
vision_config
)
mm_data
=
dummy_image_for_siglip
(
vision_config
,
num_images
)
return
seq_data
,
mm_data
return
seq_data
,
mm_data
...
...
vllm/model_executor/models/phi3v.py
View file @
3f674a49
...
@@ -15,7 +15,8 @@
...
@@ -15,7 +15,8 @@
# limitations under the License.
# limitations under the License.
import
re
import
re
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -28,8 +29,7 @@ from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
...
@@ -28,8 +29,7 @@ from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -347,18 +347,22 @@ def get_max_phi3v_image_tokens(ctx: InputContext):
...
@@ -347,18 +347,22 @@ def get_max_phi3v_image_tokens(ctx: InputContext):
)
)
def
dummy_data_for_phi3v
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_phi3v
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_max_phi3v_image_tokens
(
ctx
)
image_feature_size
=
get_max_phi3v_image_tokens
(
ctx
)
seq_data
=
dummy_seq_data_for_clip
(
seq_data
=
dummy_seq_data_for_clip
(
CLIP_VIT_LARGE_PATCH14_336_CONFIG
,
CLIP_VIT_LARGE_PATCH14_336_CONFIG
,
seq_len
,
seq_len
,
num_images
,
image_token_id
=
_IMAGE_TOKEN_ID
,
image_token_id
=
_IMAGE_TOKEN_ID
,
image_feature_size_override
=
image_feature_size
,
image_feature_size_override
=
image_feature_size
,
)
)
mm_data
=
dummy_image_for_clip
(
mm_data
=
dummy_image_for_clip
(
CLIP_VIT_LARGE_PATCH14_336_CONFIG
,
CLIP_VIT_LARGE_PATCH14_336_CONFIG
,
num_images
,
image_width_override
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_width_override
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_height_override
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
image_height_override
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
)
)
...
...
vllm/model_executor/models/siglip.py
View file @
3f674a49
...
@@ -52,6 +52,7 @@ def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
...
@@ -52,6 +52,7 @@ def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
def
dummy_seq_data_for_siglip
(
def
dummy_seq_data_for_siglip
(
hf_config
:
SiglipVisionConfig
,
hf_config
:
SiglipVisionConfig
,
seq_len
:
int
,
seq_len
:
int
,
num_images
:
int
,
*
,
*
,
image_token_id
:
int
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
...
@@ -61,13 +62,14 @@ def dummy_seq_data_for_siglip(
...
@@ -61,13 +62,14 @@ def dummy_seq_data_for_siglip(
else
:
else
:
image_feature_size
=
image_feature_size_override
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
token_ids
=
[
image_token_id
]
*
image_feature_size
*
num_images
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_siglip
(
def
dummy_image_for_siglip
(
hf_config
:
SiglipVisionConfig
,
hf_config
:
SiglipVisionConfig
,
num_images
:
int
,
*
,
*
,
image_width_override
:
Optional
[
int
]
=
None
,
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
...
@@ -79,7 +81,7 @@ def dummy_image_for_siglip(
...
@@ -79,7 +81,7 @@ def dummy_image_for_siglip(
height
=
image_height_override
height
=
image_height_override
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
}
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
input_processor_for_siglip
(
def
input_processor_for_siglip
(
...
...
vllm/multimodal/base.py
View file @
3f674a49
import
sys
import
sys
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
UserDict
,
defaultdict
from
collections
import
UserDict
,
defaultdict
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Callable
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Tuple
,
Type
,
TypedDict
,
TypeVar
,
Union
,
cast
from
typing
import
Tuple
,
Type
,
TypedDict
,
TypeVar
,
Union
,
cast
,
final
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -116,17 +116,30 @@ class MultiModalInputs(_MultiModalInputsBase):
...
@@ -116,17 +116,30 @@ class MultiModalInputs(_MultiModalInputsBase):
batched_inputs
)
batched_inputs
)
_T
=
TypeVar
(
"_T"
)
MultiModalData
:
TypeAlias
=
Union
[
_T
,
List
[
_T
]]
"""
Either a single data instance, or a list of data instances.
The number of data instances allowed per modality is restricted by
`--limit-mm-per-prompt`.
"""
@
final
class
MultiModalDataBuiltins
(
TypedDict
,
total
=
False
):
class
MultiModalDataBuiltins
(
TypedDict
,
total
=
False
):
"""Modality types that are predefined by vLLM."""
"""Modality types that are predefined by vLLM."""
image
:
Image
.
Image
image
:
MultiModalData
[
Image
.
Image
]
"""The input image."""
"""The input image
(s)
."""
audio
:
Tuple
[
np
.
ndarray
,
Union
[
int
,
float
]]
audio
:
MultiModalData
[
Tuple
[
np
.
ndarray
,
Union
[
int
,
float
]]
]
"""The input audio
and its
sampling rate."""
"""The input audio
item(s) and corresponding
sampling rate
(s)
."""
MultiModalDataDict
=
Union
[
MultiModalDataBuiltins
,
Dict
[
str
,
Any
]]
MultiModalDataDict
=
Union
[
MultiModalDataBuiltins
,
Mapping
[
str
,
MultiModalData
[
object
]]]
"""
"""
A dictionary containing an item for each modality type to input.
A dictionary containing an item for each modality type to input.
...
@@ -137,7 +150,8 @@ Note:
...
@@ -137,7 +150,8 @@ Note:
Read more on that :ref:`here <adding_multimodal_plugin>`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
"""
MultiModalInputMapper
=
Callable
[[
InputContext
,
object
],
MultiModalInputs
]
MultiModalInputMapper
=
Callable
[[
InputContext
,
MultiModalData
[
object
]],
MultiModalInputs
]
"""
"""
Return a dictionary to be passed as keyword arguments to
Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
...
@@ -181,8 +195,11 @@ class MultiModalPlugin(ABC):
...
@@ -181,8 +195,11 @@ class MultiModalPlugin(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
_default_input_mapper
(
self
,
ctx
:
InputContext
,
def
_default_input_mapper
(
data
:
object
)
->
MultiModalInputs
:
self
,
ctx
:
InputContext
,
data
:
MultiModalData
[
object
],
)
->
MultiModalInputs
:
"""
"""
Return a dictionary to be passed as keyword arguments to
Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to
:meth:`~torch.nn.Module.forward`. This is similar in concept to
...
@@ -225,7 +242,7 @@ class MultiModalPlugin(ABC):
...
@@ -225,7 +242,7 @@ class MultiModalPlugin(ABC):
return
wrapper
return
wrapper
def
map_input
(
self
,
model_config
:
ModelConfig
,
def
map_input
(
self
,
model_config
:
ModelConfig
,
data
:
object
)
->
MultiModalInputs
:
data
:
MultiModalData
[
object
]
)
->
MultiModalInputs
:
"""
"""
Transform the data into a dictionary of model inputs using the
Transform the data into a dictionary of model inputs using the
input mapper registered for that model.
input mapper registered for that model.
...
@@ -254,8 +271,8 @@ class MultiModalPlugin(ABC):
...
@@ -254,8 +271,8 @@ class MultiModalPlugin(ABC):
@
abstractmethod
@
abstractmethod
def
_default_max_multimodal_tokens
(
self
,
ctx
:
InputContext
)
->
int
:
def
_default_max_multimodal_tokens
(
self
,
ctx
:
InputContext
)
->
int
:
"""
"""
Calculate the maximum number of
multimodal tokens input to the languag
e
Calculate the maximum number of
tokens, corresponding to a singl
e
model. This does not include tokens that correspond to the input text
.
instance of multimodal data, that are passed to the language model
.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -269,8 +286,9 @@ class MultiModalPlugin(ABC):
...
@@ -269,8 +286,9 @@ class MultiModalPlugin(ABC):
max_mm_tokens
:
Optional
[
MultiModalTokensCalc
]
=
None
,
max_mm_tokens
:
Optional
[
MultiModalTokensCalc
]
=
None
,
):
):
"""
"""
Register the maximum number of multi-modal tokens input to the
Register the maximum number of tokens, corresponding to a single
language model for a model class.
instance of multimodal data, that are passed to the language model
for a model class.
If `None` is provided, then the default calculation is used instead.
If `None` is provided, then the default calculation is used instead.
...
...
vllm/multimodal/image.py
View file @
3f674a49
...
@@ -11,7 +11,7 @@ from vllm.transformers_utils.image_processor import get_image_processor
...
@@ -11,7 +11,7 @@ from vllm.transformers_utils.image_processor import get_image_processor
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
get_tokenizer
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
from
.base
import
MultiModalInputs
,
MultiModalPlugin
from
.base
import
MultiModalData
,
MultiModalInputs
,
MultiModalPlugin
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -110,8 +110,11 @@ class ImagePlugin(MultiModalPlugin):
...
@@ -110,8 +110,11 @@ class ImagePlugin(MultiModalPlugin):
model_config
.
model
,
model_config
.
model
,
trust_remote_code
=
model_config
.
trust_remote_code
)
trust_remote_code
=
model_config
.
trust_remote_code
)
def
_default_input_mapper
(
self
,
ctx
:
InputContext
,
def
_default_input_mapper
(
data
:
object
)
->
MultiModalInputs
:
self
,
ctx
:
InputContext
,
data
:
MultiModalData
[
object
],
)
->
MultiModalInputs
:
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
# PIL image
# PIL image
...
...
vllm/multimodal/registry.py
View file @
3f674a49
import
functools
import
functools
from
typing
import
Dict
,
Optional
,
Sequence
from
collections
import
UserDict
from
typing
import
Dict
,
Mapping
,
Optional
,
Sequence
import
torch
from
vllm.config
import
ModelConfig
,
MultiModalConfig
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
.audio
import
AudioPlugin
from
.audio
import
AudioPlugin
from
.base
import
(
MultiModalDataDict
,
MultiModalInputMapper
,
MultiModalInputs
,
from
.base
import
(
MultiModalDataDict
,
MultiModalInputMapper
,
MultiModalInputs
,
MultiModalPlugin
,
MultiModalTokensCalc
)
MultiModalPlugin
,
MultiModalTokensCalc
,
NestedTensors
)
from
.image
import
ImagePlugin
from
.image
import
ImagePlugin
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
_MultiModalLimits
(
UserDict
):
"""
Wraps `_limits_by_model` for a more informative error message
when attempting to access a model that does not exist.
"""
def
__getitem__
(
self
,
key
:
ModelConfig
)
->
Dict
[
str
,
int
]:
try
:
return
super
().
__getitem__
(
key
)
except
KeyError
as
exc
:
msg
=
(
f
"Cannot find `mm_limits` for model=
{
key
.
model
}
. Did you "
"forget to call `init_mm_limits_per_prompt`?"
)
raise
KeyError
(
msg
)
from
exc
class
MultiModalRegistry
:
class
MultiModalRegistry
:
"""
"""
A registry that dispatches data processing to the
A registry that dispatches data processing to the
...
@@ -28,6 +42,11 @@ class MultiModalRegistry:
...
@@ -28,6 +42,11 @@ class MultiModalRegistry:
plugins
:
Sequence
[
MultiModalPlugin
]
=
DEFAULT_PLUGINS
)
->
None
:
plugins
:
Sequence
[
MultiModalPlugin
]
=
DEFAULT_PLUGINS
)
->
None
:
self
.
_plugins
=
{
p
.
get_data_key
():
p
for
p
in
plugins
}
self
.
_plugins
=
{
p
.
get_data_key
():
p
for
p
in
plugins
}
# This is used for non-multimodal models
self
.
_disabled_limits_per_plugin
=
{
k
:
0
for
k
in
self
.
_plugins
}
self
.
_limits_by_model
=
_MultiModalLimits
()
def
register_plugin
(
self
,
plugin
:
MultiModalPlugin
)
->
None
:
def
register_plugin
(
self
,
plugin
:
MultiModalPlugin
)
->
None
:
"""
"""
Register a multi-modal plugin so it can be recognized by vLLM.
Register a multi-modal plugin so it can be recognized by vLLM.
...
@@ -86,13 +105,24 @@ class MultiModalRegistry:
...
@@ -86,13 +105,24 @@ class MultiModalRegistry:
via the input mapper registered for that model.
via the input mapper registered for that model.
See :meth:`MultiModalPlugin.map_input` for more details.
See :meth:`MultiModalPlugin.map_input` for more details.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
"""
merged_dict
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
merged_dict
:
Dict
[
str
,
Nested
Tensor
s
]
=
{}
for
data_key
,
data_value
in
data
.
items
():
for
data_key
,
data_value
in
data
.
items
():
input_dict
=
self
.
_get_plugin
(
data_key
)
\
plugin
=
self
.
_get_plugin
(
data_key
)
.
map_input
(
model_config
,
data_value
)
num_items
=
len
(
data_value
)
if
isinstance
(
data_value
,
list
)
else
1
max_items
=
self
.
_limits_by_model
[
model_config
][
data_key
]
if
num_items
>
max_items
:
raise
ValueError
(
f
"You set
{
data_key
}
=
{
max_items
}
(or defaulted to 1) in "
f
"`--limit-mm-per-prompt`, but found
{
num_items
}
items "
"in the same prompt."
)
input_dict
=
plugin
.
map_input
(
model_config
,
data_value
)
for
input_key
,
input_tensor
in
input_dict
.
items
():
for
input_key
,
input_tensor
in
input_dict
.
items
():
if
input_key
in
merged_dict
:
if
input_key
in
merged_dict
:
raise
ValueError
(
f
"The input mappers (keys=
{
set
(
data
)
}
) "
raise
ValueError
(
f
"The input mappers (keys=
{
set
(
data
)
}
) "
...
@@ -115,8 +145,9 @@ class MultiModalRegistry:
...
@@ -115,8 +145,9 @@ class MultiModalRegistry:
max_mm_tokens
:
Optional
[
MultiModalTokensCalc
]
=
None
,
max_mm_tokens
:
Optional
[
MultiModalTokensCalc
]
=
None
,
):
):
"""
"""
Register the maximum number of tokens, belonging to a
Register the maximum number of tokens, corresponding to a single
specific modality, input to the language model for a model class.
instance of multimodal data belonging to a specific modality, that are
passed to the language model for a model class.
"""
"""
return
self
.
_get_plugin
(
data_type_key
)
\
return
self
.
_get_plugin
(
data_type_key
)
\
.
register_max_multimodal_tokens
(
max_mm_tokens
)
.
register_max_multimodal_tokens
(
max_mm_tokens
)
...
@@ -126,8 +157,8 @@ class MultiModalRegistry:
...
@@ -126,8 +157,8 @@ class MultiModalRegistry:
max_mm_tokens
:
Optional
[
MultiModalTokensCalc
]
=
None
,
max_mm_tokens
:
Optional
[
MultiModalTokensCalc
]
=
None
,
):
):
"""
"""
Register the maximum number of image tokens
Register the maximum number of image tokens
, corresponding to a single
i
nput
to the language model for a model class.
i
mage, that are passed
to the language model for a model class.
"""
"""
return
self
.
register_max_multimodal_tokens
(
"image"
,
max_mm_tokens
)
return
self
.
register_max_multimodal_tokens
(
"image"
,
max_mm_tokens
)
...
@@ -137,7 +168,61 @@ class MultiModalRegistry:
...
@@ -137,7 +168,61 @@ class MultiModalRegistry:
for profiling the memory usage of a model.
for profiling the memory usage of a model.
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
limits_per_plugin
=
self
.
_limits_by_model
[
model_config
]
return
sum
((
limits_per_plugin
[
key
]
*
plugin
.
get_max_multimodal_tokens
(
model_config
))
for
key
,
plugin
in
self
.
_plugins
.
items
())
def
init_mm_limits_per_prompt
(
self
,
model_config
:
ModelConfig
,
multimodal_config
:
Optional
[
MultiModalConfig
],
)
->
None
:
"""
Initialize the maximum number of multi-modal input instances for each
modality that are allowed per prompt for a model class.
"""
if
model_config
in
self
.
_limits_by_model
:
logger
.
warning
(
"`mm_limits` has already been set for model=%s, and will "
"be overwritten by the new values."
,
model_config
.
model
)
if
multimodal_config
is
None
:
limits_per_plugin
=
self
.
_disabled_limits_per_plugin
else
:
config_limits_per_plugin
=
multimodal_config
.
limit_per_prompt
extra_keys
=
config_limits_per_plugin
.
keys
()
-
self
.
_plugins
.
keys
()
if
extra_keys
:
logger
.
warning
(
"Detected extra keys in `--limit-mm-per-prompt` which "
"are not registered as multi-modal plugins: %s. "
"They will be ignored."
,
extra_keys
)
# NOTE: Currently the default is set to 1 for each plugin
# TODO: Automatically determine the limits based on budget
# once more models support multi-image inputs
limits_per_plugin
=
{
key
:
config_limits_per_plugin
.
get
(
key
,
1
)
for
key
in
self
.
_plugins
}
self
.
_limits_by_model
[
model_config
]
=
limits_per_plugin
def
get_mm_limits_per_prompt
(
self
,
model_config
:
ModelConfig
,
)
->
Mapping
[
str
,
int
]:
"""
Get the maximum number of multi-modal input instances for each modality
that are allowed per prompt for a model class.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
"""
return
sum
(
return
self
.
_limits_by_model
[
model_config
]
plugin
.
get_max_multimodal_tokens
(
model_config
)
for
plugin
in
self
.
_plugins
.
values
())
vllm/utils.py
View file @
3f674a49
...
@@ -13,7 +13,6 @@ import threading
...
@@ -13,7 +13,6 @@ import threading
import
uuid
import
uuid
import
warnings
import
warnings
from
asyncio
import
FIRST_COMPLETED
,
ensure_future
from
asyncio
import
FIRST_COMPLETED
,
ensure_future
from
collections
import
defaultdict
from
functools
import
lru_cache
,
partial
,
wraps
from
functools
import
lru_cache
,
partial
,
wraps
from
platform
import
uname
from
platform
import
uname
from
typing
import
(
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
Dict
,
Generic
,
from
typing
import
(
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
Dict
,
Generic
,
...
@@ -760,16 +759,6 @@ class CudaMemoryProfiler:
...
@@ -760,16 +759,6 @@ class CudaMemoryProfiler:
gc
.
collect
()
gc
.
collect
()
def
str_to_int_tuple
(
s
:
str
)
->
Tuple
[
int
,
...]:
"""Convert a string to a tuple of integers."""
try
:
return
tuple
(
map
(
int
,
s
.
split
(
","
)))
except
ValueError
as
e
:
raise
ValueError
(
"String must be a series of integers separated by commas "
f
"(e.g., 1, 2, 3). Given input:
{
s
}
"
)
from
e
def
make_ndarray_with_pad
(
def
make_ndarray_with_pad
(
x
:
List
[
List
[
T
]],
x
:
List
[
List
[
T
]],
pad
:
T
,
pad
:
T
,
...
@@ -863,23 +852,6 @@ def is_list_of(
...
@@ -863,23 +852,6 @@ def is_list_of(
assert_never
(
check
)
assert_never
(
check
)
def
merge_dicts
(
dict1
:
Dict
[
K
,
List
[
T
]],
dict2
:
Dict
[
K
,
List
[
T
]])
->
Dict
[
K
,
List
[
T
]]:
"""Merge 2 dicts that have key -> List of items.
When a key conflicts, the values in dict1 is prioritized.
"""
merged_dict
:
Dict
[
K
,
List
[
T
]]
=
defaultdict
(
list
)
for
key
,
value
in
dict1
.
items
():
merged_dict
[
key
].
extend
(
value
)
for
key
,
value
in
dict2
.
items
():
merged_dict
[
key
].
extend
(
value
)
return
dict
(
merged_dict
)
JSONTree
=
Union
[
Dict
[
str
,
"JSONTree[T]"
],
List
[
"JSONTree[T]"
],
JSONTree
=
Union
[
Dict
[
str
,
"JSONTree[T]"
],
List
[
"JSONTree[T]"
],
Tuple
[
"JSONTree[T]"
,
...],
T
]
Tuple
[
"JSONTree[T]"
,
...],
T
]
"""A nested JSON structure where the leaves need not be JSON-serializable."""
"""A nested JSON structure where the leaves need not be JSON-serializable."""
...
...
vllm/worker/enc_dec_model_runner.py
View file @
3f674a49
...
@@ -12,9 +12,10 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
...
@@ -12,9 +12,10 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ObservabilityConfig
,
ModelConfig
,
MultiModalConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
SamplerOutput
,
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
SamplerOutput
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
...
@@ -83,6 +84,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -83,6 +84,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
,
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
,
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
,
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
):
):
'''
'''
EncoderDecoderModelRunner constructor.
EncoderDecoderModelRunner constructor.
...
@@ -271,6 +274,16 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -271,6 +274,16 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
model_config
=
self
.
model_config
model_config
=
self
.
model_config
mm_config
=
self
.
multimodal_config
input_registry
=
self
.
input_registry
mm_registry
=
self
.
mm_registry
mm_registry
.
init_mm_limits_per_prompt
(
model_config
,
mm_config
)
max_mm_tokens
=
mm_registry
.
get_max_multimodal_tokens
(
model_config
)
if
max_mm_tokens
>
0
:
raise
NotImplementedError
(
"Multi-modal encoder-decoder models are not supported yet"
)
batch_size
=
0
batch_size
=
0
for
group_id
in
range
(
max_num_seqs
):
for
group_id
in
range
(
max_num_seqs
):
...
@@ -278,8 +291,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -278,8 +291,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
batch_size
+=
seq_len
batch_size
+=
seq_len
seq_data
,
_
=
INPUT_REGISTRY
\
seq_data
,
_
=
input_registry
\
.
dummy_data_for_profiling
(
model_config
,
seq_len
)
.
dummy_data_for_profiling
(
model_config
,
seq_len
,
mm_registry
)
# Having more tokens is over-conservative but otherwise fine
# Having more tokens is over-conservative but otherwise fine
assert
len
(
seq_data
.
prompt_token_ids
)
>=
seq_len
,
(
assert
len
(
seq_data
.
prompt_token_ids
)
>=
seq_len
,
(
...
...
vllm/worker/model_runner.py
View file @
3f674a49
...
@@ -31,7 +31,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...
@@ -31,7 +31,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.distributed
import
get_pp_group
from
vllm.distributed
import
get_pp_group
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -43,7 +43,7 @@ from vllm.model_executor.models.interfaces import (supports_lora,
...
@@ -43,7 +43,7 @@ from vllm.model_executor.models.interfaces import (supports_lora,
supports_multimodal
)
supports_multimodal
)
from
vllm.model_executor.models.utils
import
set_cpu_offload_max_bytes
from
vllm.model_executor.models.utils
import
set_cpu_offload_max_bytes
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalInputs
)
MultiModalInputs
,
MultiModalRegistry
)
from
vllm.prompt_adapter.layers
import
PromptAdapterMapping
from
vllm.prompt_adapter.layers
import
PromptAdapterMapping
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.worker_manager
import
(
from
vllm.prompt_adapter.worker_manager
import
(
...
@@ -807,6 +807,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -807,6 +807,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
,
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
,
return_hidden_states
:
bool
=
False
,
return_hidden_states
:
bool
=
False
,
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
,
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
):
):
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
...
@@ -860,8 +862,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -860,8 +862,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
)
if
num_attn_heads
else
None
)
if
num_attn_heads
else
None
# Multi-modal data support
# Multi-modal data support
self
.
multi_modal_input_mapper
=
MULTIMODAL_REGISTRY
\
self
.
input_registry
=
input_registry
.
create_input_mapper
(
self
.
model_config
)
self
.
mm_registry
=
mm_registry
self
.
multi_modal_input_mapper
=
mm_registry
\
.
create_input_mapper
(
model_config
)
# Lazy initialization
# Lazy initialization
self
.
model
:
nn
.
Module
# Set after load_model
self
.
model
:
nn
.
Module
# Set after load_model
...
@@ -902,7 +906,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -902,7 +906,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
assert
supports_lora
(
self
.
model
),
"Model does not support LoRA"
assert
supports_lora
(
self
.
model
),
"Model does not support LoRA"
assert
not
supports_multimodal
(
assert
not
supports_multimodal
(
self
.
model
self
.
model
),
"To be tested:
m
ultimodal
language
model with LoRA settings."
),
"To be tested:
M
ulti
-
modal model with LoRA settings."
self
.
lora_manager
=
LRUCacheWorkerLoRAManager
(
self
.
lora_manager
=
LRUCacheWorkerLoRAManager
(
self
.
scheduler_config
.
max_num_seqs
,
self
.
scheduler_config
.
max_num_seqs
,
...
@@ -1046,17 +1050,21 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1046,17 +1050,21 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# Profile memory usage with max_num_sequences sequences and the total
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
# number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
# Additional GPU memory may be needed for
vision
encoding, which
needs
# Additional GPU memory may be needed for
multi-modal
encoding, which
# to be accounted for when calculating the GPU blocks for
#
needs
to be accounted for when calculating the GPU blocks for
# vLLM blocker manager.
# vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
# of images processed.
model_config
=
self
.
model_config
model_config
=
self
.
model_config
mm_config
=
self
.
multimodal_config
if
supports_multimodal
(
self
.
model
):
input_registry
=
self
.
input_registry
max_mm_tokens
=
MULTIMODAL_REGISTRY
\
mm_registry
=
self
.
mm_registry
.
get_max_multimodal_tokens
(
model_config
)
mm_registry
.
init_mm_limits_per_prompt
(
model_config
,
mm_config
)
max_mm_tokens
=
mm_registry
.
get_max_multimodal_tokens
(
model_config
)
if
max_mm_tokens
>
0
:
max_num_seqs_orig
=
max_num_seqs
max_num_seqs_orig
=
max_num_seqs
max_num_seqs
=
min
(
max_num_seqs
,
max_num_seqs
=
min
(
max_num_seqs
,
max_num_batched_tokens
//
max_mm_tokens
)
max_num_batched_tokens
//
max_mm_tokens
)
...
@@ -1074,13 +1082,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1074,13 +1082,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
batch_size
+=
seq_len
batch_size
+=
seq_len
seq_data
,
dummy_multi_modal_data
=
INPUT_REGISTRY
\
seq_data
,
dummy_multi_modal_data
=
input_registry
\
.
dummy_data_for_profiling
(
model_config
,
seq_len
)
.
dummy_data_for_profiling
(
model_config
,
seq_len
,
mm_registry
)
# Having more tokens is over-conservative but otherwise fine
assert
len
(
seq_data
.
prompt_token_ids
)
>=
seq_len
,
(
f
"Expected at least
{
seq_len
}
dummy tokens for profiling, "
f
"but got:
{
len
(
seq_data
.
prompt_token_ids
)
}
"
)
seq
=
SequenceGroupMetadata
(
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
request_id
=
str
(
group_id
),
...
...
vllm/worker/xpu_model_runner.py
View file @
3f674a49
...
@@ -9,12 +9,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...
@@ -9,12 +9,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models.interfaces
import
supports_multimodal
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalInputs
)
MultiModalInputs
,
MultiModalRegistry
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
...
@@ -89,6 +88,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
...
@@ -89,6 +88,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
is_driver_worker
:
bool
=
False
,
is_driver_worker
:
bool
=
False
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -120,8 +121,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
...
@@ -120,8 +121,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
)
)
# Multi-modal data support
# Multi-modal data support
self
.
multi_modal_input_mapper
=
MULTIMODAL_REGISTRY
\
self
.
input_registry
=
input_registry
.
create_input_mapper
(
self
.
model_config
)
self
.
mm_registry
=
mm_registry
self
.
multi_modal_input_mapper
=
mm_registry
\
.
create_input_mapper
(
model_config
)
# Lazy initialization.
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
self
.
model
:
nn
.
Module
# Set after init_Model
...
@@ -157,17 +160,21 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
...
@@ -157,17 +160,21 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# Profile memory usage with max_num_sequences sequences and the total
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
# number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
# Additional GPU memory may be needed for
vision
encoding, which
needs
# Additional GPU memory may be needed for
multi-modal
encoding, which
# to be accounted for when calculating the GPU blocks for
#
needs
to be accounted for when calculating the GPU blocks for
# vLLM blocker manager.
# vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
# of images processed.
model_config
=
self
.
model_config
model_config
=
self
.
model_config
mm_config
=
self
.
multimodal_config
if
supports_multimodal
(
self
.
model
):
input_registry
=
self
.
input_registry
max_mm_tokens
=
MULTIMODAL_REGISTRY
\
mm_registry
=
self
.
mm_registry
.
get_max_multimodal_tokens
(
model_config
)
mm_registry
.
init_mm_limits_per_prompt
(
model_config
,
mm_config
)
max_mm_tokens
=
mm_registry
.
get_max_multimodal_tokens
(
model_config
)
if
max_mm_tokens
>
0
:
max_num_seqs_orig
=
max_num_seqs
max_num_seqs_orig
=
max_num_seqs
max_num_seqs
=
min
(
max_num_seqs
,
max_num_seqs
=
min
(
max_num_seqs
,
max_num_batched_tokens
//
max_mm_tokens
)
max_num_batched_tokens
//
max_mm_tokens
)
...
@@ -183,13 +190,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
...
@@ -183,13 +190,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
seq_data
,
dummy_multi_modal_data
=
INPUT_REGISTRY
\
seq_data
,
dummy_multi_modal_data
=
input_registry
\
.
dummy_data_for_profiling
(
model_config
,
seq_len
)
.
dummy_data_for_profiling
(
model_config
,
seq_len
,
mm_registry
)
# Having more tokens is over-conservative but otherwise fine
assert
len
(
seq_data
.
prompt_token_ids
)
>=
seq_len
,
(
f
"Expected at least
{
seq_len
}
dummy tokens for profiling, "
f
"but got:
{
len
(
seq_data
.
prompt_token_ids
)
}
"
)
seq
=
SequenceGroupMetadata
(
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
request_id
=
str
(
group_id
),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment