Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9e169a4c
Unverified
Commit
9e169a4c
authored
Jul 25, 2024
by
Alphi
Committed by
GitHub
Jul 24, 2024
Browse files
[Model] Adding support for MiniCPM-V (#4087)
parent
5689e256
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
942 additions
and
18 deletions
+942
-18
docs/source/dev/multimodal/multimodal_index.rst
docs/source/dev/multimodal/multimodal_index.rst
+2
-0
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+4
-0
examples/minicpmv_example.py
examples/minicpmv_example.py
+53
-0
tests/conftest.py
tests/conftest.py
+6
-5
tests/models/test_minicpmv.py
tests/models/test_minicpmv.py
+163
-0
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+3
-1
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+2
-1
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+682
-0
vllm/multimodal/__init__.py
vllm/multimodal/__init__.py
+2
-1
vllm/multimodal/base.py
vllm/multimodal/base.py
+24
-10
No files found.
docs/source/dev/multimodal/multimodal_index.rst
View file @
9e169a4c
...
...
@@ -40,6 +40,8 @@ Registry
Base Classes
------------
.. autodata:: vllm.multimodal.NestedTensors
.. autodata:: vllm.multimodal.BatchedTensors
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
...
...
docs/source/models/supported_models.rst
View file @
9e169a4c
...
...
@@ -206,6 +206,10 @@ Vision Language Models
- Phi-3-Vision
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
-
* - :code:`MiniCPM-V`
- MiniCPM-V
- :code:`openbmb/MiniCPM-V-2`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
-
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>`
...
...
examples/minicpmv_example.py
0 → 100644
View file @
9e169a4c
from
transformers
import
AutoTokenizer
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
# 2.0
# MODEL_NAME = "HwwwH/MiniCPM-V-2"
# 2.5
MODEL_NAME
=
"openbmb/MiniCPM-Llama3-V-2_5"
image
=
ImageAsset
(
"stop_sign"
).
pil_image
.
convert
(
"RGB"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_NAME
,
trust_remote_code
=
True
)
llm
=
LLM
(
model
=
MODEL_NAME
,
gpu_memory_utilization
=
1
,
trust_remote_code
=
True
,
max_model_len
=
4096
)
messages
=
[{
'role'
:
'user'
,
'content'
:
'(<image>./</image>)
\n
'
+
"What's the content of the image?"
}]
prompt
=
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
# 2.0
# stop_token_ids = [tokenizer.eos_id]
# 2.5
stop_token_ids
=
[
tokenizer
.
eos_id
,
tokenizer
.
eot_id
]
sampling_params
=
SamplingParams
(
stop_token_ids
=
stop_token_ids
,
# temperature=0.7,
# top_p=0.8,
# top_k=100,
# seed=3472,
max_tokens
=
1024
,
# min_tokens=150,
temperature
=
0
,
use_beam_search
=
True
,
# length_penalty=1.2,
best_of
=
3
)
outputs
=
llm
.
generate
({
"prompt"
:
prompt
,
"multi_modal_data"
:
{
"image"
:
image
}
},
sampling_params
=
sampling_params
)
print
(
outputs
[
0
].
outputs
[
0
].
text
)
tests/conftest.py
View file @
9e169a4c
...
...
@@ -11,7 +11,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
PIL
import
Image
from
transformers
import
(
AutoModelForCausalLM
,
AutoModelForVision2Seq
,
AutoTokenizer
,
BatchEncoding
)
AutoTokenizer
,
BatchEncoding
,
BatchFeature
)
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
...
...
@@ -133,7 +133,7 @@ def image_assets() -> _ImageAssets:
return
IMAGE_ASSETS
_T
=
TypeVar
(
"_T"
,
nn
.
Module
,
torch
.
Tensor
,
BatchEncoding
)
_T
=
TypeVar
(
"_T"
,
nn
.
Module
,
torch
.
Tensor
,
BatchEncoding
,
BatchFeature
)
class
HfRunner
:
...
...
@@ -339,7 +339,6 @@ class HfRunner:
processor_kwargs
[
"images"
]
=
images
[
i
]
inputs
=
self
.
processor
(
**
processor_kwargs
)
input_ids
=
inputs
.
input_ids
output
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
...
...
@@ -381,7 +380,7 @@ class HfRunner:
all_logprobs
.
append
(
seq_logprobs_lst
)
seq_ids
=
output
.
sequences
[
0
]
output_len
=
seq_ids
.
shape
[
0
]
-
input_ids
.
shape
[
1
]
output_len
=
len
(
seq_logprobs_lst
)
output_ids
=
seq_ids
[
-
output_len
:]
all_output_ids
.
append
(
output_ids
.
tolist
())
all_output_strs
.
append
(
self
.
tokenizer
.
decode
(
output_ids
))
...
...
@@ -514,10 +513,12 @@ class VllmRunner:
max_tokens
:
int
,
num_logprobs
:
int
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
greedy_logprobs_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
,
logprobs
=
num_logprobs
)
logprobs
=
num_logprobs
,
stop_token_ids
=
stop_token_ids
)
outputs
=
self
.
generate_w_logprobs
(
prompts
,
greedy_logprobs_params
,
images
=
images
)
...
...
tests/models/test_minicpmv.py
0 → 100644
View file @
9e169a4c
from
collections
import
UserDict
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
pytest
import
torch
import
torch.types
from
transformers
import
BatchFeature
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.sequence
import
SampleLogprobs
from
..conftest
import
IMAGE_ASSETS
,
HfRunner
,
VllmRunner
,
_ImageAssets
from
.utils
import
check_logprobs_close
pytestmark
=
pytest
.
mark
.
vlm
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
"stop_sign"
:
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
"
\
"(<image>./</image>)
\n
What's the content of the image?<|eot_id|>"
\
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
,
# noqa: E501
"cherry_blossom"
:
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
"
\
"(<image>./</image>)
\n
What is the season?<|eot_id|>"
\
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
})
models
=
[
"openbmb/MiniCPM-Llama3-V-2_5"
]
def
trunc_hf_output
(
hf_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]):
output_ids
,
output_str
,
out_logprobs
=
hf_output
if
output_str
.
endswith
(
"<|eot_id|>"
):
output_str
=
output_str
.
split
(
"<|eot_id|>"
)[
0
]
return
output_ids
,
output_str
,
out_logprobs
target_dtype
=
"half"
def
run_test
(
hf_runner
:
Type
[
HfRunner
],
vllm_runner
:
Type
[
VllmRunner
],
image_assets
:
_ImageAssets
,
model
:
str
,
*
,
size_factors
:
List
[
float
],
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
inputs_per_image
=
[(
[
prompt
for
_
in
size_factors
],
[
rescale_image_size
(
image
,
factor
)
for
factor
in
size_factors
],
)
for
image
,
prompt
in
zip
(
images
,
HF_IMAGE_PROMPTS
)]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
max_model_len
=
4096
,
max_num_seqs
=
1
,
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
)
as
vllm_model
:
tokenizer
=
vllm_model
.
model
.
get_tokenizer
()
stop_token_ids
=
[
tokenizer
.
eos_id
,
tokenizer
.
eot_id
]
vllm_outputs_per_image
=
[
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
vllm_images
,
stop_token_ids
=
stop_token_ids
)
for
prompts
,
vllm_images
in
inputs_per_image
]
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
,
torch
.
no_grad
():
class
NestedInputs
(
UserDict
):
def
__init__
(
self
,
model_inputs
:
BatchFeature
):
super
().
__init__
({
"model_inputs"
:
model_inputs
})
self
.
model_inputs
=
model_inputs
def
to
(
self
,
device
:
torch
.
types
.
Device
):
return
NestedInputs
(
self
.
model_inputs
.
to
(
device
))
hf_processor
=
hf_model
.
processor
hf_model
.
processor
=
lambda
**
kw
:
NestedInputs
(
hf_processor
(
**
kw
)
# type: ignore
)
hf_outputs_per_image
=
[
hf_model
.
generate_greedy_logprobs_limit
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
hf_images
,
tokenizer
=
tokenizer
)
for
prompts
,
hf_images
in
inputs_per_image
]
for
hf_outputs
,
vllm_outputs
in
zip
(
hf_outputs_per_image
,
vllm_outputs_per_image
):
check_logprobs_close
(
outputs_0_lst
=
[
trunc_hf_output
(
hf_output
)
for
hf_output
in
hf_outputs
],
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"size_factors"
,
[
# No image
[],
# Single-scale
[
1.0
],
# Single-scale, batched
[
1.0
,
1.0
,
1.0
],
# Multi-scale
[
0.25
,
0.5
,
1.0
],
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
target_dtype
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
size_factors
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
run_test
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
size_factors
=
size_factors
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
vllm/model_executor/models/__init__.py
View file @
9e169a4c
...
...
@@ -50,6 +50,7 @@ _GENERATION_MODELS = {
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MiniCPMForCausalLM"
:
(
"minicpm"
,
"MiniCPMForCausalLM"
),
"MiniCPMV"
:
(
"minicpmv"
,
"MiniCPMV"
),
"OlmoForCausalLM"
:
(
"olmo"
,
"OlmoForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
...
...
vllm/model_executor/models/llama.py
View file @
9e169a4c
...
...
@@ -418,9 +418,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
input_embeds
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
input_embeds
)
return
model_output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/minicpm.py
View file @
9e169a4c
...
...
@@ -463,10 +463,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
input_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
input_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/minicpmv.py
0 → 100644
View file @
9e169a4c
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM-V-2 model compatible with HuggingFace weights."""
import
math
import
re
from
functools
import
partial
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
PIL
import
Image
from
torch
import
nn
from
torch.nn.init
import
trunc_normal_
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.models.idefics2.modeling_idefics2
import
(
Idefics2VisionTransformer
)
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsVision
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.minicpm
import
MiniCPMForCausalLM
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
(
cached_get_image_processor
,
cached_get_tokenizer
)
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
"language_model.model"
:
"language_model"
,
}
def
get_abs_pos
(
abs_pos
,
tgt_size
):
# abs_pos: L, C
# tgt_size: (H, W)
# return: M, C
src_size
=
int
(
math
.
sqrt
(
abs_pos
.
size
(
0
)))
# tgt_size = int(math.sqrt(tgt_size))
dtype
=
abs_pos
.
dtype
return
F
.
interpolate
(
abs_pos
.
float
().
reshape
(
1
,
src_size
,
src_size
,
-
1
).
permute
(
0
,
3
,
1
,
2
),
size
=
(
tgt_size
[
0
],
tgt_size
[
1
]),
mode
=
"bicubic"
,
align_corners
=
False
,
).
permute
(
0
,
2
,
3
,
1
).
flatten
(
0
,
2
).
to
(
dtype
=
dtype
)
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def
get_2d_sincos_pos_embed
(
embed_dim
,
grid_size
,
cls_token
=
False
,
version
=
2.0
):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if
isinstance
(
grid_size
,
int
):
grid_h_size
,
grid_w_size
=
grid_size
,
grid_size
else
:
grid_h_size
,
grid_w_size
=
grid_size
[
0
],
grid_size
[
1
]
grid_h
=
np
.
arange
(
grid_h_size
,
dtype
=
np
.
float32
)
grid_w
=
np
.
arange
(
grid_w_size
,
dtype
=
np
.
float32
)
grid
=
np
.
meshgrid
(
grid_w
,
grid_h
)
# here w goes first
grid
=
np
.
stack
(
grid
,
axis
=
0
)
if
version
==
2.0
:
grid
=
grid
.
reshape
([
2
,
1
,
grid_h_size
,
grid_w_size
])
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
,
version
)
if
cls_token
:
pos_embed
=
np
.
concatenate
([
np
.
zeros
([
1
,
embed_dim
]),
pos_embed
],
axis
=
0
)
else
:
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
,
version
)
return
pos_embed
def
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
,
version
=
2.0
):
assert
embed_dim
%
2
==
0
# use half of dimensions to encode grid_h
emb_h
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
0
],
version
)
# (H*W, D/2) or (H, W, D/2)
emb_w
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
1
],
version
)
# (H*W, D/2) or (H, W, D/2)
if
version
==
2.0
:
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=
1
)
# (H*W, D)
else
:
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=-
1
)
# (H, W, D)
return
emb
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
,
pos
,
version
=
2.0
):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) / (H, W)
out: (M, D) / (H, W, D)
"""
assert
embed_dim
%
2
==
0
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float32
)
omega
/=
embed_dim
/
2.
omega
=
1.
/
10000
**
omega
# (D/2,)
if
version
==
2.0
:
pos
=
pos
.
reshape
(
-
1
)
# (M,)
out
=
np
.
einsum
(
'm,d->md'
,
pos
,
omega
)
# (M, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (M, D/2)
emb_cos
=
np
.
cos
(
out
)
# (M, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
# (M, D)
else
:
out
=
np
.
einsum
(
'hw,d->hwd'
,
pos
,
omega
)
# (H, W, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (H, W, D/2)
emb_cos
=
np
.
cos
(
out
)
# (H, W, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=-
1
)
# (H, W, D)
return
emb
class
Resampler
(
nn
.
Module
):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
default_norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
def
__init__
(
self
,
num_queries
,
grid_size
,
embed_dim
,
num_heads
,
kv_dim
=
None
,
norm_layer
=
default_norm_layer
,
adaptive
=
False
,
max_size
=
(
70
,
70
),
version
=
2.0
):
super
().
__init__
()
self
.
version
=
version
if
self
.
version
==
2.0
:
self
.
num_queries
=
grid_size
**
2
else
:
self
.
num_queries
=
num_queries
self
.
max_size
=
max_size
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
adaptive
=
adaptive
self
.
query
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_queries
,
embed_dim
))
trunc_normal_
(
self
.
query
,
std
=
.
02
)
if
kv_dim
is
not
None
and
kv_dim
!=
embed_dim
:
self
.
kv_proj
=
nn
.
Linear
(
kv_dim
,
embed_dim
,
bias
=
False
)
else
:
self
.
kv_proj
=
nn
.
Identity
()
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
,
num_heads
)
self
.
ln_q
=
norm_layer
(
embed_dim
)
self
.
ln_kv
=
norm_layer
(
embed_dim
)
self
.
ln_post
=
norm_layer
(
embed_dim
)
self
.
proj
=
nn
.
Parameter
(
(
embed_dim
**-
0.5
)
*
torch
.
randn
(
embed_dim
,
embed_dim
))
if
self
.
version
==
2.0
:
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
from_numpy
(
get_2d_sincos_pos_embed
(
embed_dim
,
grid_size
,
version
=
self
.
version
)).
float
()).
requires_grad_
(
False
)
else
:
self
.
_set_2d_pos_cache
(
self
.
max_size
)
self
.
apply
(
self
.
_init_weights
)
def
_set_2d_pos_cache
(
self
,
max_size
,
device
=
'cpu'
):
pos_embed
=
torch
.
from_numpy
(
get_2d_sincos_pos_embed
(
self
.
embed_dim
,
max_size
,
version
=
self
.
version
)).
float
().
to
(
device
)
self
.
register_buffer
(
"pos_embed"
,
pos_embed
,
persistent
=
False
)
def
_adjust_pos_cache
(
self
,
tgt_sizes
,
device
):
max_h
=
torch
.
max
(
tgt_sizes
[:,
0
])
max_w
=
torch
.
max
(
tgt_sizes
[:,
1
])
if
max_h
>
self
.
max_size
[
0
]
or
max_w
>
self
.
max_size
[
1
]:
self
.
max_size
=
[
max
(
max_h
,
self
.
max_size
[
0
]),
max
(
max_w
,
self
.
max_size
[
1
])
]
self
.
_set_2d_pos_cache
(
self
.
max_size
,
device
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward_2_5
(
self
,
x
,
tgt_sizes
=
None
):
assert
x
.
shape
[
0
]
==
tgt_sizes
.
shape
[
0
]
bs
=
x
.
shape
[
0
]
device
=
x
.
device
dtype
=
x
.
dtype
patch_len
=
tgt_sizes
[:,
0
]
*
tgt_sizes
[:,
1
]
self
.
_adjust_pos_cache
(
tgt_sizes
,
device
=
device
)
max_patch_len
=
torch
.
max
(
patch_len
)
key_padding_mask
=
torch
.
zeros
((
bs
,
max_patch_len
),
dtype
=
torch
.
bool
,
device
=
device
)
pos_embed
=
[]
for
i
in
range
(
bs
):
tgt_h
,
tgt_w
=
tgt_sizes
[
i
]
pos_embed
.
append
(
self
.
pos_embed
[:
tgt_h
,
:
tgt_w
,
:].
reshape
(
(
tgt_h
*
tgt_w
,
-
1
)).
to
(
dtype
))
# patches * D
key_padding_mask
[
i
,
patch_len
[
i
]:]
=
True
pos_embed
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
pos_embed
,
batch_first
=
True
,
padding_value
=
0.0
).
permute
(
1
,
0
,
2
)
# BLD => L * B * D
x
=
self
.
kv_proj
(
x
)
# B * L * D
x
=
self
.
ln_kv
(
x
).
permute
(
1
,
0
,
2
)
# L * B * D
q
=
self
.
ln_q
(
self
.
query
)
# Q * D
out
=
self
.
attn
(
self
.
_repeat
(
q
,
bs
),
# Q * B * D
x
+
pos_embed
,
# L * B * D + L * B * D
x
,
key_padding_mask
=
key_padding_mask
)[
0
]
# out: Q * B * D
x
=
out
.
permute
(
1
,
0
,
2
)
# B * Q * D
x
=
self
.
ln_post
(
x
)
x
=
x
@
self
.
proj
return
x
def
forward_2
(
self
,
x
,
tgt_sizes
=
None
,
attn_mask
=
None
):
if
self
.
adaptive
:
pos_embed
=
torch
.
Tensor
(
get_2d_sincos_pos_embed
(
self
.
embed_dim
,
tgt_sizes
)).
float
().
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
else
:
pos_embed
=
get_abs_pos
(
self
.
pos_embed
,
tgt_sizes
)
x
=
self
.
kv_proj
(
x
)
x
=
self
.
ln_kv
(
x
).
permute
(
1
,
0
,
2
)
N
=
x
.
shape
[
1
]
q
=
self
.
ln_q
(
self
.
query
)
out
=
self
.
attn
(
self
.
_repeat
(
q
,
N
)
+
self
.
pos_embed
.
unsqueeze
(
1
),
x
+
pos_embed
.
unsqueeze
(
1
),
x
,
attn_mask
=
attn_mask
)[
0
]
x
=
out
.
permute
(
1
,
0
,
2
)
x
=
self
.
ln_post
(
x
)
x
=
x
@
self
.
proj
return
x
def
forward
(
self
,
x
,
tgt_sizes
=
None
,
attn_mask
=
None
):
if
self
.
version
==
2.0
:
return
self
.
forward_2
(
x
,
tgt_sizes
=
tgt_sizes
,
attn_mask
=
attn_mask
)
else
:
return
self
.
forward_2_5
(
x
,
tgt_sizes
=
tgt_sizes
)
def
_repeat
(
self
,
query
,
N
:
int
):
return
query
.
unsqueeze
(
1
).
repeat
(
1
,
N
,
1
)
def
get_max_minicpmv_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
return
getattr
(
hf_config
,
"query_num"
,
64
)
def
dummy_seq_data_for_minicpmv
(
seq_len
:
int
):
token_ids
=
[
0
]
*
seq_len
return
SequenceData
(
token_ids
)
def
dummy_image_for_minicpmv
(
hf_config
):
width
=
height
=
hf_config
.
image_size
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
}
def
dummy_data_for_minicpmv
(
ctx
:
InputContext
,
seq_len
:
int
):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
# image_feature_size = get_max_minicpmv_image_tokens(ctx)
seq_data
=
dummy_seq_data_for_minicpmv
(
seq_len
)
mm_data
=
dummy_image_for_minicpmv
(
hf_config
)
return
seq_data
,
mm_data
def
input_processor_for_minicpmv
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
llm_inputs
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
prompt
=
llm_inputs
.
get
(
"prompt"
)
if
prompt
is
None
:
token_ids
=
llm_inputs
.
get
(
"prompt_token_ids"
)
prompt
=
tokenizer
.
decode
(
token_ids
)
image_processor
=
cached_get_image_processor
(
model_config
.
tokenizer
)
pattern
=
"(<image>./</image>)"
image
=
multi_modal_data
[
"image"
]
image_tags
=
re
.
findall
(
pattern
,
prompt
)
assert
len
(
image_tags
)
<=
1
text_chunks
=
prompt
.
split
(
pattern
)
new_prompt
=
text_chunks
[
0
]
\
+
image_processor
.
get_slice_image_placeholder
(
image
.
size
)
\
+
text_chunks
[
1
]
new_token_ids
=
tokenizer
.
encode
(
new_prompt
)
llm_inputs
=
LLMInputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
)
return
llm_inputs
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_minicpmv_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_minicpmv
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_minicpmv
)
class
MiniCPMV
(
nn
.
Module
,
SupportsVision
):
def
__init__
(
self
,
config
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
version
=
float
(
self
.
config
.
version
)
self
.
llm
=
self
.
init_llm
(
config
,
cache_config
,
quant_config
)
self
.
vpm
=
self
.
init_vision_module
()
param_dtype
=
torch
.
get_default_dtype
()
self
.
vpm
.
to
(
dtype
=
param_dtype
)
self
.
vision_dim
=
self
.
vpm
.
embed_dim
if
self
.
version
==
2.0
\
else
self
.
vpm
.
embeddings
.
embed_dim
self
.
embed_dim
=
self
.
llm
.
config
.
hidden_size
self
.
resampler
=
self
.
init_resampler
(
self
.
embed_dim
,
self
.
vision_dim
)
self
.
resampler
.
to
(
device
=
"cuda"
,
dtype
=
param_dtype
)
self
.
sampler
=
Sampler
()
def
init_llm
(
self
,
config
,
cache_config
,
quant_config
):
if
self
.
version
==
2.0
:
return
MiniCPMForCausalLM
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
else
:
return
LlamaForCausalLM
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
init_vision_module
(
self
):
if
self
.
version
==
2.0
:
try
:
import
timm
except
ImportError
:
raise
ImportError
(
'Please install timm==0.9.10'
)
from
ImportError
default_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
torch
.
float16
)
model
=
timm
.
create_model
(
'vit_so400m_patch14_siglip_384.webli'
,
pretrained
=
False
,
num_classes
=
0
,
dynamic_img_size
=
True
,
dynamic_img_pad
=
True
)
torch
.
set_default_dtype
(
default_dtype
)
if
isinstance
(
model
,
timm
.
models
.
VisionTransformer
)
and
model
.
attn_pool
is
not
None
:
model
.
attn_pool
=
torch
.
nn
.
Identity
()
if
self
.
config
.
drop_vision_last_layer
:
model
.
blocks
=
model
.
blocks
[:
-
1
]
else
:
model
=
Idefics2VisionTransformer
(
self
.
config
.
vision_config
)
if
self
.
config
.
drop_vision_last_layer
:
model
.
encoder
.
layers
=
model
.
encoder
.
layers
[:
-
1
]
return
model
def
init_resampler
(
self
,
embed_dim
,
vision_dim
):
default_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
torch
.
float16
)
if
self
.
version
==
2.0
:
resampler
=
Resampler
(
grid_size
=
int
(
math
.
sqrt
(
self
.
config
.
query_num
)),
num_queries
=
None
,
embed_dim
=
embed_dim
,
num_heads
=
embed_dim
//
128
,
kv_dim
=
vision_dim
,
adaptive
=
True
,
version
=
self
.
version
)
else
:
resampler
=
Resampler
(
num_queries
=
self
.
config
.
query_num
,
grid_size
=
None
,
embed_dim
=
embed_dim
,
num_heads
=
embed_dim
//
128
,
kv_dim
=
vision_dim
,
adaptive
=
True
,
version
=
self
.
version
)
torch
.
set_default_dtype
(
default_dtype
)
return
resampler
def
get_vision_embedding
(
self
,
pixel_values
,
patch_attn_mask
=
None
,
tgt_sizes
=
None
,
version
=
2.0
):
if
version
==
2.0
:
res
=
[]
dtype
=
self
.
vpm
.
pos_embed
.
data
.
dtype
for
pixel_value
in
pixel_values
:
# V2.0 start
H
,
W
=
pixel_value
[
0
].
shape
[
-
2
:]
tgt_size
=
(
math
.
ceil
(
H
/
self
.
vpm
.
patch_embed
.
patch_size
[
0
]),
math
.
ceil
(
W
/
self
.
vpm
.
patch_embed
.
patch_size
[
0
]))
# V2.0 end
vision_embedding
=
self
.
vpm
.
forward_features
(
pixel_value
.
unsqueeze
(
0
).
type
(
dtype
))
if
hasattr
(
self
.
vpm
,
'num_prefix_tokens'
)
and
self
.
vpm
.
num_prefix_tokens
>
0
:
vision_embedding
=
vision_embedding
[:,
self
.
vpm
.
num_prefix_tokens
:]
res
.
append
(
self
.
resampler
(
vision_embedding
,
tgt_size
))
return
torch
.
vstack
(
res
)
else
:
vision_embedding
=
self
.
vpm
(
pixel_values
.
type
(
dtype
),
patch_attention_mask
=
patch_attn_mask
).
last_hidden_state
vision_embedding
=
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
def
get_image_bounds
(
self
,
input_ids
):
tokenizer
=
cached_get_tokenizer
(
self
.
config
.
_name_or_path
,
trust_remote_code
=
True
)
im_start_token_id
=
tokenizer
.
im_start_id
im_end_token_id
=
tokenizer
.
im_end_id
image_start_tokens
=
torch
.
where
(
input_ids
==
im_start_token_id
)[
0
]
image_start_tokens
+=
1
image_end_tokens
=
torch
.
where
(
input_ids
==
im_end_token_id
)[
0
]
valid_image_nums
=
min
(
len
(
image_start_tokens
),
len
(
image_end_tokens
))
if
valid_image_nums
==
0
:
return
[]
image_bound
=
torch
.
hstack
([
image_start_tokens
[:
valid_image_nums
].
unsqueeze
(
-
1
),
image_end_tokens
[:
valid_image_nums
].
unsqueeze
(
-
1
),
])
return
image_bound
def
get_vision_hidden_states
(
self
,
data
):
if
"vision_hidden_states"
not
in
data
:
pixel_values
=
data
[
"pixel_values"
]
tgt_sizes
=
data
[
"tgt_sizes"
]
vision_hidden_states
=
[]
if
self
.
version
==
2.0
:
if
pixel_values
is
not
None
and
len
(
pixel_values
)
>
0
:
vision_hidden_states
=
self
.
get_vision_embedding
(
pixel_values
)
else
:
vision_hidden_states
=
torch
.
tensor
([]).
to
(
data
[
"input_ids"
].
device
)
else
:
device
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
device
dtype
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
dtype
all_pixel_values
=
[
i
.
flatten
(
end_dim
=
1
).
permute
(
1
,
0
)
for
i
in
pixel_values
]
if
all_pixel_values
:
tgt_sizes
=
torch
.
vstack
(
tgt_sizes
).
type
(
torch
.
int32
)
max_patches
=
torch
.
max
(
tgt_sizes
[:,
0
]
*
tgt_sizes
[:,
1
])
all_pixel_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
all_pixel_values
,
batch_first
=
True
,
padding_value
=
0.0
)
B
,
L
,
_
=
all_pixel_values
.
shape
all_pixel_values
=
all_pixel_values
.
permute
(
0
,
2
,
1
).
reshape
(
B
,
3
,
-
1
,
L
)
patch_attn_mask
=
torch
.
zeros
((
B
,
1
,
max_patches
),
dtype
=
torch
.
bool
,
device
=
device
)
for
i
in
range
(
B
):
patch_attn_mask
[
i
,
:
tgt_sizes
[
i
][
0
]
*
tgt_sizes
[
i
][
1
]]
=
True
vision_embedding
=
self
.
vpm
(
all_pixel_values
.
type
(
dtype
),
patch_attention_mask
=
patch_attn_mask
).
last_hidden_state
vision_hidden_states
=
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
else
:
# no image
dummy_feature
=
[]
vision_hidden_states
=
dummy_feature
else
:
vision_hidden_states
=
data
[
"vision_hidden_states"
]
return
vision_hidden_states
def
get_embedding
(
self
,
data
):
input_ids
=
data
[
"input_ids"
]
vision_hidden_states
=
self
.
get_vision_hidden_states
(
data
)
if
vision_hidden_states
is
not
None
and
len
(
vision_hidden_states
)
>
0
:
image_bounds
=
self
.
get_image_bounds
(
input_ids
)
else
:
image_bounds
=
[]
if
hasattr
(
self
.
llm
.
config
,
'scale_emb'
):
vlm_embedding
=
self
.
llm
.
model
.
embed_tokens
(
input_ids
)
*
self
.
llm
.
config
.
scale_emb
else
:
vlm_embedding
=
self
.
llm
.
model
.
embed_tokens
(
input_ids
)
vision_hidden_states
=
[
i
.
type
(
vlm_embedding
.
dtype
)
if
isinstance
(
i
,
torch
.
Tensor
)
else
i
for
i
in
vision_hidden_states
]
if
len
(
vision_hidden_states
)
>
0
and
len
(
image_bounds
)
>
0
:
vision_hidden_states
=
torch
.
cat
(
vision_hidden_states
,
dim
=
0
)
image_indices
=
torch
.
stack
([
torch
.
arange
(
r
[
0
],
r
[
1
],
dtype
=
torch
.
long
)
for
r
in
image_bounds
]).
to
(
vlm_embedding
.
device
)
vlm_embedding
.
scatter_
(
0
,
image_indices
.
view
(
-
1
,
1
).
repeat
(
1
,
vlm_embedding
.
shape
[
-
1
]),
vision_hidden_states
.
view
(
-
1
,
vision_hidden_states
.
shape
[
-
1
]))
return
vlm_embedding
,
vision_hidden_states
def
process_multimodal_inputs
(
self
,
inputs
):
pixel_values
=
[]
tgt_sizes
=
[]
for
b
in
range
(
len
(
inputs
[
"pixel_values"
])):
pixel_values
+=
inputs
[
"pixel_values"
][
b
]
tgt_sizes
+=
inputs
[
"tgt_sizes"
][
b
]
return
{
"pixel_values"
:
pixel_values
,
"input_ids"
:
inputs
[
"input_ids"
],
"tgt_sizes"
:
tgt_sizes
}
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
,
):
inputs
=
{
"pixel_values"
:
kwargs
.
pop
(
"pixel_values"
,
[]),
"input_ids"
:
input_ids
,
"tgt_sizes"
:
kwargs
.
pop
(
"tgt_sizes"
,
None
),
}
inputs
=
self
.
process_multimodal_inputs
(
inputs
)
vlm_embeddings
,
vision_hidden_states
=
self
.
get_embedding
(
inputs
)
output
=
self
.
llm
(
input_ids
=
None
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
input_embeds
=
vlm_embeddings
)
return
output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
return
self
.
llm
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
llm
.
sample
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
# for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
# if key_to_modify in name:
# name = name.replace(key_to_modify, new_key)
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
use_default_weight_loading
=
False
if
"vpm"
in
name
or
'resampler'
in
name
:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading
=
True
else
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
use_default_weight_loading
=
True
if
use_default_weight_loading
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/multimodal/__init__.py
View file @
9e169a4c
from
.base
import
(
BatchedTensors
,
MultiModalDataBuiltins
,
MultiModalDataDict
,
MultiModalInputs
,
MultiModalPlugin
)
MultiModalInputs
,
MultiModalPlugin
,
NestedTensors
)
from
.registry
import
MultiModalRegistry
MULTIMODAL_REGISTRY
=
MultiModalRegistry
()
...
...
@@ -17,6 +17,7 @@ __all__ = [
"MultiModalDataDict"
,
"MultiModalInputs"
,
"MultiModalPlugin"
,
"NestedTensors"
,
"MULTIMODAL_REGISTRY"
,
"MultiModalRegistry"
,
]
vllm/multimodal/base.py
View file @
9e169a4c
...
...
@@ -2,7 +2,7 @@ import sys
from
abc
import
ABC
,
abstractmethod
from
collections
import
UserDict
,
defaultdict
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
,
TypedDict
,
TypeVar
,
Union
)
TypeVar
,
Union
,
cast
)
import
torch
import
torch.types
...
...
@@ -15,10 +15,17 @@ from vllm.logger import init_logger
logger
=
init_logger
(
__name__
)
BatchedTensors
=
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
NestedTensors
=
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
]
"""
Use a list instead of a tensor if the dimensions of each element do not match.
Currently only supports up to singly nested list of tensors.
"""
BatchedTensors
=
Union
[
List
[
NestedTensors
],
NestedTensors
]
"""
If each input tensor in the batch has the same size, this is a single batched
tensor; otherwise, this is a list of tensors with one element per batch.
tensor; otherwise, this is a list of :class:`NestedTensors` with one element
per item in the batch.
"""
if
sys
.
version_info
<
(
3
,
9
):
...
...
@@ -27,7 +34,7 @@ if sys.version_info < (3, 9):
pass
else
:
class
_MultiModalInputsBase
(
UserDict
[
str
,
torch
.
Tensor
]):
class
_MultiModalInputsBase
(
UserDict
[
str
,
Nested
Tensor
s
]):
pass
...
...
@@ -39,19 +46,26 @@ class MultiModalInputs(_MultiModalInputsBase):
@
staticmethod
def
try_concat
(
tensors
:
List
[
torch
.
Tensor
],
tensors
:
List
[
Nested
Tensor
s
],
*
,
device
:
torch
.
types
.
Device
,
)
->
BatchedTensors
:
unbatched_shape
=
tensors
[
0
].
shape
[
1
:]
# may be list rather than tensors
if
isinstance
(
tensors
[
0
],
list
):
return
[[
t
.
to
(
device
=
device
)
for
t
in
tensor
[
0
]]
for
tensor
in
tensors
]
tensors_
=
cast
(
List
[
torch
.
Tensor
],
tensors
)
unbatched_shape
=
tensors_
[
0
].
shape
[
1
:]
for
tensor
in
tensors
:
for
tensor
in
tensors
_
:
if
tensor
.
shape
[
1
:]
!=
unbatched_shape
:
return
[
tensor
.
squeeze
(
0
).
to
(
device
=
device
)
for
tensor
in
tensors
tensor
.
squeeze
(
0
).
to
(
device
=
device
)
for
tensor
in
tensors
_
]
return
torch
.
cat
(
tensors
,
dim
=
0
).
to
(
device
=
device
)
return
torch
.
cat
(
tensors
_
,
dim
=
0
).
to
(
device
=
device
)
@
staticmethod
def
batch
(
...
...
@@ -64,7 +78,7 @@ class MultiModalInputs(_MultiModalInputsBase):
keys
=
inputs_list
[
0
].
keys
()
item_lists
:
Dict
[
str
,
List
[
torch
.
Tensor
]]
=
defaultdict
(
list
)
item_lists
:
Dict
[
str
,
List
[
Nested
Tensor
s
]]
=
defaultdict
(
list
)
for
inputs
in
inputs_list
:
if
inputs
.
keys
()
!=
keys
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment