Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c9d3ecf0
Unverified
Commit
c9d3ecf0
authored
Feb 13, 2025
by
Cyrus Leung
Committed by
GitHub
Feb 13, 2025
Browse files
[VLM] Merged multi-modal processor for Molmo (#12966)
parent
fdcf64d3
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
750 additions
and
498 deletions
+750
-498
docs/source/models/supported_models.md
docs/source/models/supported_models.md
+1
-1
tests/models/decoder_only/language/test_models.py
tests/models/decoder_only/language/test_models.py
+1
-1
tests/models/decoder_only/vision_language/test_models.py
tests/models/decoder_only/vision_language/test_models.py
+2
-3
tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
...els/decoder_only/vision_language/vlm_utils/model_utils.py
+21
-77
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+2
-0
tests/models/registry.py
tests/models/registry.py
+1
-0
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+681
-342
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+40
-40
vllm/utils.py
vllm/utils.py
+1
-34
No files found.
docs/source/models/supported_models.md
View file @
c9d3ecf0
...
@@ -793,7 +793,7 @@ See [this page](#generative-models) for more information on how to use generativ
...
@@ -793,7 +793,7 @@ See [this page](#generative-models) for more information on how to use generativ
-
*
`MolmoForCausalLM`
-
*
`MolmoForCausalLM`
*
Molmo
*
Molmo
*
T + I
*
T + I
*
`allenai/Molmo-7B-D-0924`
,
`allenai/Molmo-7
2
B-0924`
, etc.
*
`allenai/Molmo-7B-D-0924`
,
`allenai/Molmo-7B
-O
-0924`
, etc.
*
✅︎
*
✅︎
*
✅︎
*
✅︎
*
✅︎
*
✅︎
...
...
tests/models/decoder_only/language/test_models.py
View file @
c9d3ecf0
...
@@ -27,7 +27,7 @@ from ...utils import check_logprobs_close
...
@@ -27,7 +27,7 @@ from ...utils import check_logprobs_close
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
],
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
],
),
),
pytest
.
param
(
pytest
.
param
(
"THUDM/chatglm3-6b"
,
#
C
hat
GLM
(text-only)
"THUDM/chatglm3-6b"
,
#
c
hat
glm
(text-only)
),
),
pytest
.
param
(
pytest
.
param
(
"meta-llama/Llama-3.2-1B-Instruct"
,
# llama
"meta-llama/Llama-3.2-1B-Instruct"
,
# llama
...
...
tests/models/decoder_only/vision_language/test_models.py
View file @
c9d3ecf0
...
@@ -404,11 +404,10 @@ VLM_TEST_SETTINGS = {
...
@@ -404,11 +404,10 @@ VLM_TEST_SETTINGS = {
"molmo"
:
VLMTestInfo
(
"molmo"
:
VLMTestInfo
(
models
=
[
"allenai/Molmo-7B-D-0924"
],
models
=
[
"allenai/Molmo-7B-D-0924"
],
test_type
=
(
VLMTestType
.
IMAGE
),
test_type
=
(
VLMTestType
.
IMAGE
),
prompt_formatter
=
lambda
img_prompt
:
"User: "
+
img_prompt
+
" Assistant:"
,
# noqa: E501
prompt_formatter
=
identity
,
max_model_len
=
4096
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
max_num_seqs
=
2
,
image_size_factors
=
[(),(
1.0
,
1.0
,
1.0
)],
patch_hf_runner
=
model_utils
.
molmo_patch_hf_runner
,
patch_hf_runner
=
model_utils
.
mlomo_patch_hf_runner
,
postprocess_inputs
=
model_utils
.
molmo_post_processor
,
postprocess_inputs
=
model_utils
.
molmo_post_processor
,
),
),
# Tests for phi3v currently live in another file because of a bug in
# Tests for phi3v currently live in another file because of a bug in
...
...
tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
View file @
c9d3ecf0
...
@@ -6,7 +6,7 @@ typically specific to a small subset of models.
...
@@ -6,7 +6,7 @@ typically specific to a small subset of models.
import
re
import
re
import
types
import
types
from
pathlib
import
PosixPath
from
pathlib
import
PosixPath
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
PIL.Image
import
Image
from
PIL.Image
import
Image
...
@@ -17,9 +17,7 @@ from vllm.sequence import SampleLogprobs
...
@@ -17,9 +17,7 @@ from vllm.sequence import SampleLogprobs
from
vllm.transformers_utils.tokenizer
import
patch_padding_side
from
vllm.transformers_utils.tokenizer
import
patch_padding_side
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
.....conftest
import
(
HfRunner
,
ImageAsset
,
PromptAudioInput
,
from
.....conftest
import
HfRunner
,
ImageAsset
,
_ImageAssets
PromptImageInput
,
PromptVideoInput
,
_ImageAssets
)
from
....utils
import
TokensTextLogprobs
from
.types
import
RunnerOutput
from
.types
import
RunnerOutput
...
@@ -522,74 +520,7 @@ def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
...
@@ -522,74 +520,7 @@ def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return
hf_model
return
hf_model
def
_generate_greedy_logprobs_limit
(
def
molmo_patch_hf_runner
(
hf_model
:
HfRunner
)
->
HfRunner
:
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
num_logprobs
:
int
,
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
TokensTextLogprobs
]:
all_inputs
=
self
.
get_inputs
(
prompts
,
images
=
images
,
videos
=
videos
,
audios
=
audios
)
# Process in batches for inference.
if
len
(
all_inputs
):
input_ids_lst
=
[]
images_lst
=
[]
images_input_idx_lst
=
[]
imges_masks_lst
=
[]
for
inputs
in
all_inputs
:
input_ids_lst
.
append
(
inputs
[
"input_ids"
])
images_lst
.
append
(
inputs
[
"images"
])
images_input_idx_lst
.
append
(
inputs
[
"image_input_idx"
])
imges_masks_lst
.
append
(
inputs
[
"image_masks"
])
batch_inputs
=
{}
batch_inputs
[
'input_ids'
]
=
torch
.
cat
(
input_ids_lst
,
dim
=
0
)
batch_inputs
[
'images'
]
=
torch
.
cat
(
images_lst
,
dim
=
0
)
batch_inputs
[
'image_input_idx'
]
=
torch
.
cat
(
images_input_idx_lst
,
dim
=
0
)
batch_inputs
[
'image_masks'
]
=
torch
.
cat
(
imges_masks_lst
,
dim
=
0
)
outputs
=
self
.
model
.
generate_from_batch
(
batch
=
self
.
wrap_device
(
batch_inputs
,
device
=
self
.
model
.
device
.
type
),
generation_config
=
GenerationConfig
(
max_new_tokens
=
max_tokens
,
stop_strings
=
"<|endoftext|>"
,
do_sample
=
False
,
),
tokenizer
=
self
.
tokenizer
,
output_hidden_states
=
True
,
return_dict_in_generate
=
True
,
)
all_logprobs
:
List
[
List
[
Dict
[
int
,
float
]]]
=
[]
all_output_ids
:
List
[
List
[
int
]]
=
[]
all_output_strs
:
List
[
str
]
=
[]
for
index
in
range
(
len
(
all_inputs
)):
(
seq_logprobs_lst
,
output_len
,
)
=
self
.
_hidden_states_to_logprobs
(
outputs
.
hidden_states
,
num_logprobs
)
all_logprobs
.
append
(
seq_logprobs_lst
)
seq_ids
=
outputs
.
sequences
[
index
]
output_ids
=
seq_ids
[
-
output_len
:]
all_output_ids
.
append
(
output_ids
.
tolist
())
all_output_strs
.
append
(
self
.
tokenizer
.
decode
(
output_ids
))
outputs
=
zip
(
all_output_ids
,
all_output_strs
,
all_logprobs
)
return
[(
output_ids
,
output_str
,
output_logprobs
)
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
####### Molmo-specific HuggingFace runner patchers
def
mlomo_patch_hf_runner
(
hf_model
:
HfRunner
)
->
HfRunner
:
"""Patches and returns an instance of the HfRunner to use for Molmo."""
"""Patches and returns an instance of the HfRunner to use for Molmo."""
hf_processor
=
hf_model
.
processor
hf_processor
=
hf_model
.
processor
...
@@ -598,10 +529,23 @@ def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
...
@@ -598,10 +529,23 @@ def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_model
.
processor
=
_processor
hf_model
.
processor
=
_processor
setattr
(
# noqa: B010
def
_generate
(
self
,
max_new_tokens
=
None
,
do_sample
=
None
,
**
kwargs
):
hf_model
,
batch
=
{
"generate_greedy_logprobs_limit"
,
k
:
kwargs
.
pop
(
k
)
types
.
MethodType
(
_generate_greedy_logprobs_limit
,
hf_model
),
for
k
in
(
"input_ids"
,
"images"
,
"image_input_idx"
,
"image_masks"
)
if
k
in
kwargs
}
return
self
.
generate_from_batch
(
batch
,
generation_config
=
GenerationConfig
(
max_new_tokens
=
max_new_tokens
,
stop_strings
=
"<|endoftext|>"
,
do_sample
=
do_sample
,
),
**
kwargs
,
)
)
hf_model
.
model
.
generate
=
types
.
MethodType
(
_generate
,
hf_model
.
model
)
return
hf_model
return
hf_model
tests/models/multimodal/processing/test_common.py
View file @
c9d3ecf0
...
@@ -168,6 +168,8 @@ def _test_processing_correctness(
...
@@ -168,6 +168,8 @@ def _test_processing_correctness(
"mistral-community/pixtral-12b"
,
"mistral-community/pixtral-12b"
,
"openbmb/MiniCPM-o-2_6"
,
"openbmb/MiniCPM-o-2_6"
,
"openbmb/MiniCPM-V-2_6"
,
"openbmb/MiniCPM-V-2_6"
,
"allenai/Molmo-7B-D-0924"
,
"allenai/Molmo-7B-O-0924"
,
"nvidia/NVLM-D-72B"
,
"nvidia/NVLM-D-72B"
,
"Qwen/Qwen-VL-Chat"
,
"Qwen/Qwen-VL-Chat"
,
"Qwen/Qwen2-VL-2B-Instruct"
,
"Qwen/Qwen2-VL-2B-Instruct"
,
...
...
tests/models/registry.py
View file @
c9d3ecf0
...
@@ -256,6 +256,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -256,6 +256,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"MiniCPMV"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-V-2_6"
,
"MiniCPMV"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-V-2_6"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"MolmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/Molmo-7B-D-0924"
,
"MolmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/Molmo-7B-D-0924"
,
extras
=
{
"olmo"
:
"allenai/Molmo-7B-O-0924"
},
# noqa: E501
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"NVLM_D"
:
_HfExamplesInfo
(
"nvidia/NVLM-D-72B"
,
"NVLM_D"
:
_HfExamplesInfo
(
"nvidia/NVLM-D-72B"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
...
...
vllm/model_executor/models/molmo.py
View file @
c9d3ecf0
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
math
import
math
import
re
from
array
import
array
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
lru_cache
,
partial
from
functools
import
cached_property
,
partial
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
from
typing
import
(
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
,
cast
)
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
PIL
import
Image
from
transformers
import
(
BatchFeature
,
PretrainedConfig
,
ProcessorMixin
,
from
torch
import
nn
TensorType
)
from
t
orch.nn
import
functional
as
F
from
t
ransformers.image_utils
import
ImageInput
from
transformers
import
PretrainedConfig
from
transformers
.tokenization_utils_base
import
TextInput
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.attention.layer
import
MultiHeadAttention
...
@@ -22,8 +24,6 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
...
@@ -22,8 +24,6 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
)
tensor_model_parallel_all_gather
)
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
,
token_inputs
)
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.activation
import
(
MulAndSilu
,
QuickGELU
,
from
vllm.model_executor.layers.activation
import
(
MulAndSilu
,
QuickGELU
,
SiluAndMul
)
SiluAndMul
)
...
@@ -40,15 +40,21 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -40,15 +40,21 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
NestedTensors
,
PlaceholderRange
from
vllm.multimodal.inputs
import
(
MultiModalFieldConfig
,
MultiModalKwargs
,
from
vllm.multimodal.utils
import
cached_get_tokenizer
NestedTensors
)
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
SequenceData
)
MultiModalDataItems
)
from
vllm.transformers_utils.processor
import
get_processor
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
PromptReplacementDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
JSONTree
,
json_map_leaves
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
...
@@ -56,38 +62,39 @@ from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
...
@@ -56,38 +62,39 @@ from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
VIT_LAYERS
=
[
-
2
,
-
9
]
VIT_LAYERS
=
[
-
2
,
-
9
]
NUM_PREFIX_TOKENS
=
1
NUM_PREFIX_TOKENS
=
1
ADDITIONAL_VOCAB_SIZE
=
128
ADDITIONAL_VOCAB_SIZE
=
128
DEFAULT_IMAGE_PATCH_TOKEN_ID
=
152066
IMAGE_PATCH_TOKEN
=
"<im_patch>"
DEFAULT_IM_START_TOKEN_ID
=
152067
IM_COL_TOKEN
=
"<im_col>"
DEFAULT_IM_END_TOKEN_ID
=
152064
IM_START_TOKEN
=
"<im_start>"
DEFAULT_IM_COL_TOKEN_ID
=
152065
IM_END_TOKEN
=
"<im_end>"
POOLING_SIZE
=
2
class
MolmoImageInputs
(
TypedDict
):
class
MolmoImageInputs
(
TypedDict
):
images
:
torch
.
Tensor
images
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
"""Shape:
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
`(batch_size, num_crops, num_patch, patch_dim)`
"""
image_masks
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]
"""Shape: `(batch_size, num_crops, num_patch)`"""
image_input_idx
:
torch
.
Tensor
feat_is_patch
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
"""Shape:
`(batch_size, num_crops, num_patch)`
"""
"""
A boolean mask indicating which image features correspond
to patch tokens.
seq_len
:
torch
.
Tensor
Shape: `(batch_size, num_crops, num_patch)`
"""Shape:
`(batch_size, )`
"""
"""
image_masks
:
Optional
[
torch
.
Tensor
]
embed_is_patch
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
"""Shape:
`(batch_size, num_crops, num_patch)`
"""
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
image_start_end
:
Tuple
[
int
,
int
]
Shape: `(batch_size, num_embeds)`
"""Starting and ending index of placeholder
tokens
"""
"""
num_crops
:
torch
.
Tensor
"""Shape: `(batch_size, num_images)`"""
@
dataclass
@
dataclass
class
VisionBackboneConfig
:
class
VisionBackboneConfig
:
...
@@ -335,7 +342,7 @@ class VisionTransformer(nn.Module):
...
@@ -335,7 +342,7 @@ class VisionTransformer(nn.Module):
def
forward
(
self
,
def
forward
(
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
patch_num
:
int
=
None
)
->
List
[
torch
.
Tensor
]:
patch_num
:
Optional
[
int
]
=
None
)
->
List
[
torch
.
Tensor
]:
"""
"""
: param x: (batch_size, num_patch, n_pixels)
: param x: (batch_size, num_patch, n_pixels)
"""
"""
...
@@ -465,7 +472,7 @@ class MolmoAttention(nn.Module):
...
@@ -465,7 +472,7 @@ class MolmoAttention(nn.Module):
return
output
return
output
class
LanuageModelMLP
(
nn
.
Module
):
class
Lan
g
uageModelMLP
(
nn
.
Module
):
"""Molmo's LLM mlp."""
"""Molmo's LLM mlp."""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -559,7 +566,7 @@ class MolmoDecoderLayer(nn.Module):
...
@@ -559,7 +566,7 @@ class MolmoDecoderLayer(nn.Module):
prefix
=
f
"
{
prefix
}
.self_attn"
)
prefix
=
f
"
{
prefix
}
.self_attn"
)
# MLP block.
# MLP block.
self
.
mlp
=
LanuageModelMLP
(
config
,
quant_config
=
quant_config
)
self
.
mlp
=
Lan
g
uageModelMLP
(
config
,
quant_config
=
quant_config
)
# LayerNorm
# LayerNorm
assert
config
.
layer_norm_type
==
"rms"
assert
config
.
layer_norm_type
==
"rms"
...
@@ -638,8 +645,8 @@ class MolmoVisionBackbone(nn.Module):
...
@@ -638,8 +645,8 @@ class MolmoVisionBackbone(nn.Module):
self
.
vit_layers
=
VIT_LAYERS
self
.
vit_layers
=
VIT_LAYERS
self
.
image_num_patch
=
vision_config
.
image_num_patch
self
.
image_num_patch
=
vision_config
.
image_num_patch
self
.
llm_patches_per_crop
=
(
self
.
llm_patches_per_crop
=
(
(
self
.
image_num_patch
[
0
]
+
1
)
//
2
,
(
self
.
image_num_patch
[
0
]
+
1
)
//
POOLING_SIZE
,
(
self
.
image_num_patch
[
1
]
+
1
)
//
2
,
(
self
.
image_num_patch
[
1
]
+
1
)
//
POOLING_SIZE
,
)
)
self
.
image_vit
=
VisionTransformer
(
vision_config
,
self
.
image_vit
=
VisionTransformer
(
vision_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
...
@@ -723,19 +730,19 @@ class MolmoVisionBackbone(nn.Module):
...
@@ -723,19 +730,19 @@ class MolmoVisionBackbone(nn.Module):
image_features
=
image_features
.
reshape
(
image_features
=
image_features
.
reshape
(
(
batch_size
,
num_image
)
+
self
.
image_num_patch
+
(
-
1
,
),
)
(
batch_size
,
num_image
)
+
self
.
image_num_patch
+
(
-
1
,
),
)
if
self
.
image_num_patch
[
0
]
%
2
==
1
:
if
(
missing_w
:
=
self
.
image_num_patch
[
0
]
%
POOLING_SIZE
)
:
# Pad
so we can still pool 2x2 patches
# Pad
ding for image pooling (see below)
image_features
=
F
.
pad
(
image_features
=
F
.
pad
(
image_features
,
image_features
,
(
0
,
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
),
(
0
,
0
,
0
,
missing_w
,
0
,
missing_w
,
0
,
0
,
0
,
0
),
)
)
# image pooling
# image pooling
image_features
=
rearrange
(
image_features
=
rearrange
(
image_features
,
image_features
,
'b n (h dh) (w dw) c -> (b n h w) (dh dw) c'
,
'b n (h dh) (w dw) c -> (b n h w) (dh dw) c'
,
dh
=
2
,
dh
=
POOLING_SIZE
,
dw
=
2
,
dw
=
POOLING_SIZE
,
)
)
query
=
image_features
.
mean
(
-
2
,
keepdim
=
True
)
query
=
image_features
.
mean
(
-
2
,
keepdim
=
True
)
...
@@ -888,249 +895,513 @@ class MolmoModel(nn.Module):
...
@@ -888,249 +895,513 @@ class MolmoModel(nn.Module):
return
loaded_params
return
loaded_params
cached_get_processor
=
lru_cache
(
get_processor
)
def
_lowest_multiple
(
x
:
int
,
k
:
int
)
->
int
:
return
(
x
//
k
)
*
k
def
get_num_patches
(
num_tiles
:
int
,
crop_patches
:
int
,
left_margin
:
int
,
def
get_num_patches
(
right_margin
:
int
,
pooling_size
:
int
)
->
int
:
num_tiles
:
int
,
*
,
crop_patches
:
int
,
left_margin
:
int
,
right_margin
:
int
,
pooling_size
:
int
,
)
->
int
:
if
num_tiles
==
1
:
return
_lowest_multiple
(
crop_patches
+
pooling_size
-
1
,
pooling_size
)
crop_window_patches
=
crop_patches
-
(
left_margin
+
right_margin
)
crop_window_patches
=
crop_patches
-
(
left_margin
+
right_margin
)
if
num_tiles
>
1
:
left_crop_window_patches
=
(
crop_window_patches
+
left_margin
+
pooling_size
-
1
)
//
pooling_size
*
pooling_size
middle_crop_window_patches
=
(
crop_window_patches
+
pooling_size
-
1
)
//
pooling_size
*
pooling_size
right_crop_window_patches
=
(
crop_window_patches
+
right_margin
+
pooling_size
-
1
)
//
pooling_size
*
pooling_size
return
left_crop_window_patches
+
(
num_tiles
-
2
)
*
middle_crop_window_patches
+
right_crop_window_patches
else
:
single_crop_window_patches
=
(
crop_patches
+
pooling_size
-
1
)
//
pooling_size
*
pooling_size
return
single_crop_window_patches
def
get_tokens
(
tiling_h
:
int
,
tiling_w
:
int
,
crop_patches
:
int
,
left_margin
:
int
,
right_margin
:
int
,
pooling_size
:
int
)
->
int
:
h
=
get_num_patches
(
tiling_h
,
crop_patches
,
left_margin
,
right_margin
,
pooling_size
)
w
=
get_num_patches
(
tiling_w
,
crop_patches
,
left_margin
,
right_margin
,
pooling_size
)
per_row
=
w
//
pooling_size
+
1
joint
=
per_row
*
(
h
//
pooling_size
)
+
2
image_token_length
=
(
crop_patches
+
pooling_size
-
1
)
//
pooling_size
resize
=
(
image_token_length
+
1
)
*
image_token_length
+
2
return
resize
+
joint
left_num
=
_lowest_multiple
(
crop_window_patches
+
left_margin
+
pooling_size
-
1
,
pooling_size
,
)
middle_num
=
_lowest_multiple
(
crop_window_patches
+
pooling_size
-
1
,
pooling_size
,
)
right_num
=
_lowest_multiple
(
crop_window_patches
+
right_margin
+
pooling_size
-
1
,
pooling_size
,
)
def
get_max_tokens
(
max_crops
:
int
,
crop_patches
:
int
,
left_margin
:
int
,
return
left_num
+
(
num_tiles
-
2
)
*
middle_num
+
right_num
right_margin
:
int
,
pooling_size
:
int
)
->
int
:
tilings
=
[]
for
i
in
range
(
1
,
max_crops
+
1
):
def
get_patches_grid_size
(
for
j
in
range
(
1
,
max_crops
+
1
):
*
,
if
i
*
j
<=
max_crops
:
tiling_h
:
int
,
tilings
.
append
((
i
,
j
))
tiling_w
:
int
,
tokens
=
[
crop_patches
:
int
,
get_tokens
(
tilings
[
i
][
0
],
tilings
[
i
][
1
],
crop_patches
,
left_margin
,
left_margin
:
int
,
right_margin
,
pooling_size
)
for
i
in
range
(
len
(
tilings
))
right_margin
:
int
,
]
pooling_size
:
int
,
return
max
(
tokens
)
)
->
tuple
[
int
,
int
]:
nrows
=
get_num_patches
(
tiling_h
,
def
get_max_molmo_image_tokens
(
ctx
:
InputContext
)
->
int
:
crop_patches
=
crop_patches
,
processor
=
cached_get_processor
(
left_margin
=
left_margin
,
ctx
.
model_config
.
model
,
right_margin
=
right_margin
,
trust_remote_code
=
ctx
.
model_config
.
trust_remote_code
,
pooling_size
=
pooling_size
,
revision
=
ctx
.
model_config
.
code_revision
)
)
image_processor
=
processor
.
image_processor
ncols
=
get_num_patches
(
max_llm_image_tokens
=
get_max_tokens
(
tiling_w
,
image_processor
.
max_crops
,
crop_patches
=
crop_patches
,
image_processor
.
base_image_input_size
[
0
]
//
left_margin
=
left_margin
,
image_processor
.
image_patch_size
,
right_margin
=
right_margin
,
image_processor
.
overlap_margins
[
0
],
pooling_size
=
pooling_size
,
image_processor
.
overlap_margins
[
1
],
2
,
)
)
return
max_llm_image_tokens
return
nrows
,
ncols
def
get_candidate_tilings
(
max_num
:
int
)
->
list
[
tuple
[
int
,
int
]]:
tilings
=
[(
i
,
j
)
for
i
in
range
(
1
,
max_num
+
1
)
for
j
in
range
(
1
,
max_num
+
1
)
if
i
*
j
<=
max_num
]
return
sorted
(
tilings
,
key
=
lambda
x
:
x
[
0
]
*
x
[
1
])
# NOTE: preprocessing for the image data has been included in the
# 'input_processor_for_molmo' function
def
select_tiling
(
def
image_input_mapper_for_molmo
(
*
,
ctx
:
InputContext
,
height
:
int
,
data
:
object
,
width
:
int
,
patch_size
:
int
,
max_num_patches
:
int
,
):
):
if
isinstance
(
data
,
list
):
tilings
=
get_candidate_tilings
(
max_num_patches
)
assert
len
(
data
)
==
1
,
"Molmo supports only one image per prompt."
candidate_tilings
=
np
.
array
(
tilings
,
dtype
=
np
.
int32
)
data
=
data
[
0
]
candidate_resolutions
=
candidate_tilings
*
patch_size
original_size
=
np
.
array
([
height
,
width
],
dtype
=
np
.
float32
)
required_scale_d
=
candidate_resolutions
.
astype
(
np
.
float32
)
/
original_size
required_scale
=
required_scale_d
.
min
(
axis
=-
1
,
keepdims
=
True
)
return
MultiModalKwargs
(
data
)
if
(
required_scale
<
1
).
all
():
ix
=
required_scale
.
argmax
()
else
:
ix
=
np
.
where
(
required_scale
<
1.0
,
10e9
,
required_scale
).
argmin
()
return
candidate_tilings
[
ix
]
def
dummy_data_for_molmo
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
processor
=
cached_get_processor
(
ctx
.
model_config
.
model
,
trust_remote_code
=
ctx
.
model_config
.
trust_remote_code
,
revision
=
ctx
.
model_config
.
code_revision
)
image_processor
=
processor
.
image_processor
base_image_input_d
=
image_processor
.
image_patch_size
class
MolmoProcessorWrapper
:
left_margin
,
right_margin
=
image_processor
.
overlap_margins
"""
Wraps :class:`MolmoProcessor` so that it can be called directly.
The original definition can be found here:
https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py
"""
def
__init__
(
self
,
processor
:
ProcessorMixin
):
super
().
__init__
()
self
.
processor
=
processor
@
cached_property
def
vocab
(
self
)
->
dict
[
str
,
int
]:
return
self
.
processor
.
tokenizer
.
vocab
# type: ignore
@
cached_property
def
max_crops
(
self
)
->
int
:
image_processor
=
self
.
processor
.
image_processor
# type: ignore
max_crops
=
image_processor
.
max_crops
max_crops
=
image_processor
.
max_crops
assert
isinstance
(
max_crops
,
int
)
return
max_crops
@
cached_property
def
base_image_input_size
(
self
)
->
tuple
[
int
,
int
]:
image_processor
=
self
.
processor
.
image_processor
# type: ignore
base_image_input_size
=
image_processor
.
base_image_input_size
if
isinstance
(
base_image_input_size
,
int
):
return
base_image_input_size
,
base_image_input_size
return
tuple
(
base_image_input_size
)
@
cached_property
def
image_patch_size
(
self
)
->
int
:
image_processor
=
self
.
processor
.
image_processor
# type: ignore
image_patch_size
=
image_processor
.
image_patch_size
assert
isinstance
(
image_patch_size
,
int
)
return
image_patch_size
@
cached_property
def
overlap_margins
(
self
)
->
tuple
[
int
,
int
]:
image_processor
=
self
.
processor
.
image_processor
# type: ignore
left_margin
,
right_margin
=
image_processor
.
overlap_margins
assert
isinstance
(
left_margin
,
int
)
assert
isinstance
(
right_margin
,
int
)
# Assume: prompt_token_ids always starts with bos_token_id followed image tokens # noqa: E501
return
left_margin
,
right_margin
max_llm_image_tokens
=
get_max_molmo_image_tokens
(
ctx
)
if
seq_len
-
max_llm_image_tokens
-
1
<
0
:
@
cached_property
raise
RuntimeError
(
def
image_token_length_w
(
self
)
->
int
:
f
"Molmo cannot process
{
max_crops
}
crops in a prompt, "
image_processor
=
self
.
processor
.
image_processor
# type: ignore
"please increase max_model_len or reduce number of crops"
)
image_token_length_w
=
image_processor
.
image_token_length_w
assert
isinstance
(
image_token_length_w
,
int
)
return
image_token_length_w
@
cached_property
def
image_token_length_h
(
self
)
->
int
:
image_processor
=
self
.
processor
.
image_processor
# type: ignore
image_token_length_h
=
image_processor
.
image_token_length_h
assert
isinstance
(
image_token_length_h
,
int
)
return
image_token_length_h
@
property
def
message_format
(
self
)
->
Optional
[
str
]:
return
"role"
@
property
def
always_start_with_space
(
self
)
->
bool
:
return
True
@
cached_property
def
image_patch_id
(
self
)
->
int
:
return
self
.
vocab
[
IMAGE_PATCH_TOKEN
]
@
cached_property
def
im_col_id
(
self
)
->
int
:
return
self
.
vocab
[
IM_COL_TOKEN
]
@
cached_property
def
im_start_id
(
self
)
->
int
:
return
self
.
vocab
[
IM_START_TOKEN
]
@
cached_property
def
im_end_id
(
self
)
->
int
:
return
self
.
vocab
[
IM_END_TOKEN
]
@
property
def
pooling_size
(
self
)
->
int
:
return
POOLING_SIZE
def
select_tiling
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
tuple
[
int
,
int
]:
max_crops
=
self
.
max_crops
left_margin
,
right_margin
=
self
.
overlap_margins
base_image_input_size
=
self
.
base_image_input_size
base_image_input_d
=
self
.
image_patch_size
# The vertical image has the maximum number of image tokens due to column tokens. # noqa: E501
tiling
=
(
max_crops
,
1
)
total_margin_pixels
=
base_image_input_d
*
(
right_margin
+
left_margin
)
total_margin_pixels
=
base_image_input_d
*
(
right_margin
+
left_margin
)
crop_patches
=
image_processor
.
base_image_input_size
[
crop_patches
=
base_image_input_size
[
0
]
//
base_image_input_d
0
]
//
base_image_input_d
crop_window_patches
=
crop_patches
-
(
right_margin
+
left_margin
)
crop_window_patches
=
crop_patches
-
(
right_margin
+
left_margin
)
crop_window_size
=
crop_window_patches
*
base_image_input_d
crop_window_size
=
crop_window_patches
*
base_image_input_d
tiling_h
,
tiling_w
=
select_tiling
(
height
=
image_height
-
total_margin_pixels
,
width
=
image_width
-
total_margin_pixels
,
patch_size
=
crop_window_size
,
max_num_patches
=
max_crops
,
)
h
=
crop_window_size
*
tiling
[
0
]
+
total_margin_pixels
return
tiling_w
,
tiling_h
w
=
crop_window_size
*
tiling
[
1
]
+
total_margin_pixels
dummy_image
=
Image
.
new
(
"RGB"
,
(
w
,
h
),
color
=
"red"
)
def
get_patches_grid_size
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
tuple
[
int
,
int
]:
left_margin
,
right_margin
=
self
.
overlap_margins
base_image_input_size
=
self
.
base_image_input_size
base_image_input_d
=
self
.
image_patch_size
pooling_size
=
self
.
pooling_size
crop_patches
=
base_image_input_size
[
0
]
//
base_image_input_d
tiling_w
,
tiling_h
=
self
.
select_tiling
(
image_height
=
image_height
,
image_width
=
image_width
,
)
out
=
processor
.
process
(
"dummy prompt"
,
dummy_image
)
nrows
,
ncols
=
get_patches_grid_size
(
tiling_h
=
tiling_h
,
tiling_w
=
tiling_w
,
crop_patches
=
crop_patches
,
left_margin
=
left_margin
,
right_margin
=
right_margin
,
pooling_size
=
pooling_size
,
)
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
return
ncols
,
nrows
out
[
"input_ids"
][:
1
+
max_llm_image_tokens
])
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
def
__call__
(
[
0
])
*
(
seq_len
-
max_llm_image_tokens
-
1
)
self
,
dummy_seqdata
=
SequenceData
(
token_ids
)
text
:
Optional
[
Union
[
TextInput
,
list
[
TextInput
]]]
=
None
,
dummy_imgdata
=
{
images
:
Optional
[
Union
[
ImageInput
,
list
[
ImageInput
]]]
=
None
,
"images"
:
out
[
"images"
],
return_tensors
:
Optional
[
Union
[
str
,
TensorType
]]
=
None
,
"image_input_idx"
:
out
[
"image_input_idx"
],
**
kwargs
,
}
)
->
BatchFeature
:
if
"image_masks"
in
out
:
outputs
=
self
.
processor
.
process
(
# type: ignore
dummy_imgdata
[
"image_masks"
]
=
out
[
"image_masks"
]
text
,
images
,
**
kwargs
)
dummy_imgdata
[
"seq_len"
]
=
torch
.
tensor
(
seq_len
,
dtype
=
torch
.
long
)
size
=
0
if
images
is
None
:
offset
=
-
1
images
=
[]
for
i
in
range
(
len
(
token_ids
)):
if
not
isinstance
(
images
,
list
):
if
token_ids
[
i
]
in
(
DEFAULT_IMAGE_PATCH_TOKEN_ID
,
images
=
[
images
]
DEFAULT_IM_START_TOKEN_ID
,
DEFAULT_IM_END_TOKEN_ID
,
DEFAULT_IM_COL_TOKEN_ID
):
input_ids
:
torch
.
Tensor
=
outputs
.
pop
(
"input_ids"
)
if
offset
<
0
:
outputs
[
"input_ids"
]
=
input_ids
.
unsqueeze
(
0
)
offset
=
i
size
+=
1
image_input_idx
=
outputs
.
pop
(
"image_input_idx"
,
None
)
dummy_imgdata
[
"image_start_end"
]
=
(
offset
,
offset
+
size
)
if
image_input_idx
is
not
None
:
return
DummyData
(
seq_data
=
dummy_seqdata
,
input_is_patch
=
input_ids
==
self
.
image_patch_id
multi_modal_data
=
{
"image"
:
dummy_imgdata
},
image_input_idx_flat
:
torch
.
Tensor
=
image_input_idx
.
view
(
-
1
)
multi_modal_placeholders
=
{
image_valid_flat
=
image_input_idx_flat
>=
0
feat_is_patch_flat
=
image_valid_flat
.
clone
()
feat_is_patch_flat
[
image_valid_flat
]
=
(
input_is_patch
[
image_input_idx_flat
[
image_valid_flat
]])
feat_is_patch
=
feat_is_patch_flat
.
view
(
*
image_input_idx
.
shape
)
input_is_embed
=
torch
.
isin
(
input_ids
,
torch
.
tensor
([
self
.
image_patch_id
,
self
.
im_col_id
,
self
.
im_start_id
,
self
.
im_end_id
,
]),
)
embed_ids
=
input_ids
[
input_is_embed
]
embed_is_patch
=
embed_ids
==
self
.
image_patch_id
assert
embed_is_patch
.
sum
()
==
feat_is_patch
.
sum
()
tilings
=
[
self
.
select_tiling
(
image_width
=
image
.
size
[
0
],
image_height
=
image
.
size
[
1
],
)
for
image
in
images
]
# For each image: tiling_h * tiling_w + extra
num_crops
=
torch
.
tensor
(
tilings
).
prod
(
-
1
)
+
1
assert
num_crops
.
sum
()
==
len
(
feat_is_patch
)
outputs
[
"feat_is_patch"
]
=
feat_is_patch
outputs
[
"embed_is_patch"
]
=
embed_is_patch
outputs
[
"num_crops"
]
=
num_crops
outputs
[
"img_patch_id"
]
=
self
.
image_patch_id
return
BatchFeature
(
outputs
,
tensor_type
=
return_tensors
)
class
MolmoProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_processor
(
self
)
->
MolmoProcessorWrapper
:
processor
=
self
.
ctx
.
get_hf_processor
()
return
MolmoProcessorWrapper
(
processor
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Optional
[
MolmoProcessorWrapper
],
)
->
int
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
ncols
,
nrows
=
processor
.
get_patches_grid_size
(
image_width
=
image_width
,
image_height
=
image_height
,
)
pooling_size
=
processor
.
pooling_size
base_image_input_size
=
processor
.
base_image_input_size
base_image_input_d
=
processor
.
image_patch_size
crop_patches
=
base_image_input_size
[
0
]
//
base_image_input_d
per_row
=
ncols
//
pooling_size
+
1
joint
=
per_row
*
(
nrows
//
pooling_size
)
+
2
image_token_length
=
(
crop_patches
+
pooling_size
-
1
)
//
pooling_size
resize
=
(
image_token_length
+
1
)
*
image_token_length
+
2
return
resize
+
joint
def
get_max_image_tokens
(
self
)
->
int
:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
processor
=
None
,
)
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
processor
=
self
.
get_hf_processor
()
tilings
=
get_candidate_tilings
(
processor
.
max_crops
)
base_h
,
base_w
=
processor
.
base_image_input_size
largest_feature_size
,
largest_feature_pinpoint
=
0
,
None
for
wr
,
hr
in
tilings
:
width
,
height
=
base_w
*
wr
,
base_h
*
hr
feat_size
=
self
.
get_num_image_tokens
(
image_width
=
width
,
image_height
=
height
,
processor
=
processor
,
)
if
feat_size
>
largest_feature_size
:
largest_feature_size
=
feat_size
largest_feature_pinpoint
=
ImageSize
(
width
=
width
,
height
=
height
)
if
largest_feature_size
==
0
or
largest_feature_pinpoint
is
None
:
raise
ValueError
(
"Cannot have a largest feature size of 0!"
)
return
largest_feature_pinpoint
class
MolmoDummyInputsBuilder
(
BaseDummyInputsBuilder
[
MolmoProcessingInfo
]):
def
get_dummy_processor_inputs
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
num_images
=
mm_counts
.
get
(
"image"
,
0
)
mm_data
=
{
"image"
:
"image"
:
[
PlaceholderRange
(
offset
=
offset
,
length
=
size
)]
self
.
_get_dummy_images
(
width
=
target_width
,
})
height
=
target_height
,
num_images
=
num_images
)
}
return
ProcessorInputs
(
prompt_text
=
""
,
mm_data
=
mm_data
,
)
def
pad_images
(
max_total_crops
:
int
,
class
MolmoMultiModalProcessor
(
BaseMultiModalProcessor
[
MolmoProcessingInfo
]):
images
:
torch
.
Tensor
,
image_input_idx
:
torch
.
Tensor
,
def
_apply_hf_processor_tokens_only
(
image_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
self
,
):
prompt_tokens
:
list
[
int
],
n
=
max_total_crops
-
images
.
shape
[
0
]
)
->
list
[
int
]:
images
=
F
.
pad
(
images
,
(
0
,
0
,
0
,
0
,
0
,
n
),
value
=-
1
)
processor
=
self
.
info
.
get_hf_processor
()
image_input_idx
=
F
.
pad
(
image_input_idx
,
(
0
,
0
,
0
,
n
),
value
=-
1
)
if
image_masks
is
not
None
:
# Apply the chat template to the tokens
image_masks
=
F
.
pad
(
image_masks
,
(
0
,
0
,
0
,
n
),
value
=-
1
)
tokens
=
processor
.
processor
.
get_tokens_input
(
# type: ignore
return
images
,
image_input_idx
,
image_masks
self
.
info
.
get_tokenizer
().
decode
(
prompt_tokens
),
message_format
=
processor
.
message_format
,
always_start_with_space
=
processor
.
always_start_with_space
,
def
input_processor_for_molmo
(
ctx
:
InputContext
,
inputs
:
DecoderOnlyInputs
):
prompt
=
inputs
.
get
(
"prompt"
)
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
image
=
None
if
multi_modal_data
is
None
else
multi_modal_data
.
get
(
"image"
)
model_config
=
ctx
.
model_config
processor
=
cached_get_processor
(
ctx
.
model_config
.
model
,
trust_remote_code
=
model_config
.
trust_remote_code
,
revision
=
ctx
.
model_config
.
code_revision
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
# NOTE: message formatting for raw text prompt is only applied for
# offline inference; for online serving, the prompt is always in
# instruction format and tokenized.
if
prompt
is
not
None
and
re
.
match
(
r
"^User:[\s\S]*?(Assistant:)*$"
,
prompt
):
out
=
processor
.
process
(
prompt
,
image
,
message_format
=
"none"
)
elif
prompt
is
not
None
:
out
=
processor
.
process
(
prompt
,
image
)
else
:
out
=
processor
.
process
(
None
,
image
,
tokens
=
inputs
[
"prompt_token_ids"
])
# If there is no image, return directly.
if
image
is
None
:
new_prompt_token_ids
=
out
[
"input_ids"
].
tolist
()
prompt
=
inputs
.
get
(
"prompt"
)
if
prompt
is
None
:
prompt
=
tokenizer
.
decode
(
new_prompt_token_ids
)
return
token_inputs
(
prompt_token_ids
=
new_prompt_token_ids
,
prompt
=
prompt
,
)
)
image_processor
=
processor
.
image_processor
processed_data
=
self
.
info
.
ctx
.
call_hf_processor
(
max_total_crops
=
1
+
image_processor
.
max_crops
processor
,
# type: ignore
images
,
image_input_idx
,
image_masks
=
pad_images
(
dict
(
tokens
=
tokens
),
max_total_crops
,
out
[
"images"
],
out
[
"image_input_idx"
],
out
.
get
(
"image_masks"
),
)
)
image_data
=
dict
(
prompt_ids
,
=
processed_data
.
pop
(
"input_ids"
).
tolist
()
images
=
images
,
image_input_idx
=
image_input_idx
,
return
prompt_ids
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
num_crops
=
hf_inputs
.
get
(
"num_crops"
,
torch
.
empty
(
0
))
num_images
=
len
(
num_crops
)
return
dict
(
images
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_crops
),
image_masks
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_crops
),
feat_is_patch
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_crops
),
embed_is_patch
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
num_crops
=
MultiModalFieldConfig
.
batched
(
"image"
),
img_patch_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
)
)
if
image_masks
is
not
None
:
image_data
[
"image_masks"
]
=
image_masks
def
_get_prompt_replacements
(
self
,
new_prompt_token_ids
=
out
[
"input_ids"
].
tolist
()
mm_items
:
MultiModalDataItems
,
image_data
[
"seq_len"
]
=
torch
.
tensor
(
len
(
new_prompt_token_ids
),
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
dtype
=
torch
.
long
)
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
multi_modal_data
=
dict
(
image
=
image_data
)
processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
size
=
0
tokenizer
=
self
.
info
.
get_tokenizer
()
offset
=
-
1
for
i
in
range
(
len
(
new_prompt_token_ids
)):
image_token_length_w
=
processor
.
image_token_length_w
if
new_prompt_token_ids
[
i
]
in
(
DEFAULT_IMAGE_PATCH_TOKEN_ID
,
image_token_length_h
=
processor
.
image_token_length_h
DEFAULT_IM_START_TOKEN_ID
,
pooling_size
=
processor
.
pooling_size
DEFAULT_IM_END_TOKEN_ID
,
DEFAULT_IM_COL_TOKEN_ID
):
user_str
=
"User:"
if
offset
<
0
:
if
processor
.
always_start_with_space
:
offset
=
i
user_str
=
" "
+
user_str
size
+=
1
image_data
[
"image_start_end"
]
=
(
offset
,
offset
+
size
)
user_tokens
=
tokenizer
.
encode
(
user_str
,
add_special_tokens
=
False
)
prompt
=
inputs
.
get
(
"prompt"
)
if
prompt
is
None
:
img_patch_id
=
processor
.
image_patch_id
prompt
=
tokenizer
.
decode
(
new_prompt_token_ids
)
img_col_id
=
processor
.
im_col_id
return
token_inputs
(
img_start_id
=
processor
.
im_start_id
prompt_token_ids
=
new_prompt_token_ids
,
img_end_id
=
processor
.
im_end_id
prompt
=
prompt
,
multi_modal_data
=
multi_modal_data
,
extra_row
=
[
img_patch_id
]
*
image_token_length_w
+
[
img_col_id
]
multi_modal_placeholders
=
{
extra_joint
=
([
img_start_id
]
+
extra_row
*
image_token_length_h
+
"image"
:
[
PlaceholderRange
(
offset
=
offset
,
length
=
size
)]
[
img_end_id
])
},
def
get_replacement_molmo
(
item_idx
:
int
):
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
image_size
=
images
.
get_image_size
(
item_idx
)
ncols
,
nrows
=
processor
.
get_patches_grid_size
(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
)
)
joint_row
=
([
img_patch_id
]
*
((
ncols
+
1
)
//
pooling_size
)
+
[
img_col_id
])
joint
=
([
img_start_id
]
+
joint_row
*
((
nrows
+
1
)
//
pooling_size
)
+
[
img_end_id
])
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
image_input_mapper_for_molmo
)
image_tokens
=
extra_joint
+
joint
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_molmo_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_molmo
)
return
PromptReplacementDetails
(
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_molmo
)
full
=
image_tokens
+
user_tokens
,
features
=
image_tokens
,
)
return
[
PromptReplacement
(
modality
=
"image"
,
target
=
user_str
,
replacement
=
get_replacement_molmo
,
)
]
@
MULTIMODAL_REGISTRY
.
register_processor
(
MolmoMultiModalProcessor
,
info
=
MolmoProcessingInfo
,
dummy_inputs
=
MolmoDummyInputsBuilder
)
class
MolmoForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
class
MolmoForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsLoRA
):
SupportsLoRA
):
hf_to_vllm_mapper
=
WeightsMapper
(
hf_to_vllm_mapper
=
WeightsMapper
(
...
@@ -1202,6 +1473,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -1202,6 +1473,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
quant_config
)
quant_config
)
self
.
model
=
MolmoModel
(
vllm_config
=
vllm_config
,
self
.
model
=
MolmoModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
img_patch_id
=
None
if
self
.
config
.
weight_tying
:
if
self
.
config
.
weight_tying
:
self
.
lm_head
=
self
.
model
.
transformer
.
wte
self
.
lm_head
=
self
.
model
.
transformer
.
wte
...
@@ -1224,33 +1496,69 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -1224,33 +1496,69 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
**
kwargs
:
object
,
**
kwargs
:
object
,
)
->
Optional
[
MolmoImageInputs
]:
)
->
Optional
[
MolmoImageInputs
]:
images
=
kwargs
.
pop
(
"images"
,
None
)
images
=
kwargs
.
pop
(
"images"
,
None
)
image_masks
=
kwargs
.
pop
(
"image_masks"
,
None
)
image_start_end
=
kwargs
.
pop
(
"image_start_end"
,
None
)
if
images
is
None
:
if
images
is
None
:
return
None
return
None
image_input_idx
=
kwargs
.
pop
(
"image_input_idx"
,
None
)
if
not
isinstance
(
images
,
(
torch
.
Tensor
,
list
)):
seq_len
=
kwargs
.
pop
(
"seq_len"
,
None
)
raise
ValueError
(
"Incorrect type of images. "
if
image_input_idx
is
None
:
f
"Got type:
{
type
(
images
)
}
"
)
raise
ValueError
(
"image_input_idx is required for Molmo model."
)
if
seq_len
is
None
:
image_masks
=
kwargs
.
pop
(
"image_masks"
,
None
)
raise
ValueError
(
"seq_len is required for Molmo model."
)
if
not
(
image_masks
is
None
or
isinstance
(
image_masks
,
if
not
isinstance
(
seq_len
,
torch
.
Tensor
):
(
torch
.
Tensor
,
list
))):
seq_len
=
torch
.
tensor
(
seq_len
)
raise
ValueError
(
"Incorrect type of image_masks. "
f
"Got type:
{
type
(
image_masks
)
}
"
)
feat_is_patch
=
kwargs
.
pop
(
"feat_is_patch"
,
None
)
if
not
isinstance
(
feat_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of feat_is_patch. "
f
"Got type:
{
type
(
feat_is_patch
)
}
"
)
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
,
None
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
num_crops
=
kwargs
.
pop
(
"num_crops"
,
None
)
if
not
isinstance
(
num_crops
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of num_crops. "
f
"Got type:
{
type
(
num_crops
)
}
"
)
img_patch_id
=
kwargs
.
pop
(
"img_patch_id"
,
None
)
if
not
isinstance
(
img_patch_id
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of num_crops. "
f
"Got type:
{
type
(
num_crops
)
}
"
)
self
.
img_patch_id
=
img_patch_id
.
flatten
().
unique
().
item
()
return
MolmoImageInputs
(
return
MolmoImageInputs
(
images
=
images
,
images
=
images
,
image_input_idx
=
image_input_idx
,
seq_len
=
seq_len
,
image_masks
=
image_masks
,
image_masks
=
image_masks
,
image_start_end
=
image_start_end
,
feat_is_patch
=
feat_is_patch
,
embed_is_patch
=
embed_is_patch
,
num_crops
=
num_crops
,
)
)
def
_process_image_input
(
def
_process_image_input
(
self
,
self
,
image_input
:
MolmoImageInputs
,
image_input
:
MolmoImageInputs
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
if
isinstance
(
image_input
[
"images"
],
list
):
# Call the vision backbone on the whole batch at once
images_flat
=
flatten_bn
(
image_input
[
"images"
],
concat
=
True
)
image_masks_flat
=
(
None
if
(
image_masks
:
=
image_input
[
"image_masks"
])
is
None
else
flatten_bn
(
image_masks
,
concat
=
True
))
image_features_flat
=
self
.
vision_backbone
(
images
=
images_flat
.
unsqueeze
(
0
),
image_masks
=
(
None
if
image_masks_flat
is
None
else
image_masks_flat
.
unsqueeze
(
0
)),
).
squeeze
(
0
)
# Reconstruct the batch dimension
image_features
=
image_features_flat
.
split
(
image_input
[
"num_crops"
].
sum
(
-
1
).
tolist
())
else
:
image_features
=
self
.
vision_backbone
(
image_features
=
self
.
vision_backbone
(
images
=
image_input
[
"images"
],
images
=
image_input
[
"images"
],
image_masks
=
image_input
[
"image_masks"
],
image_masks
=
image_input
[
"image_masks"
],
...
@@ -1258,51 +1566,73 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -1258,51 +1566,73 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return
image_features
return
image_features
def
_get_mm_embeds
(
self
,
features
:
torch
.
Tensor
,
# Shape: (num_crop, num_patch, d)
feat_is_patch
:
torch
.
Tensor
,
# Shape: (num_crop, num_patch)
num_crops
:
torch
.
Tensor
,
# Shape: (num_images,)
embed_is_patch
:
torch
.
Tensor
,
# Shape: (num_embeds,)
)
->
list
[
torch
.
Tensor
]:
"""
Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Note:
The original code only considers patch tokens as feature
tokens, but our processor considers all image-related tokens
as feature tokens because the feature tokens need to be
consecutive in `input_ids`.
Example:
A simplified example for one item in the batch:
.. code-block::
Embedding tokens (from HF processor):
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
embed_is_patch (from HF processor):
[ False True True False True True False False ]
Encoder outputs (from model):
[ p1 p2 0 p3 p4 0 ]
feat_is_patch (from HF processor):
[ True True False True True False ]
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
num_crops_per_image
=
num_crops
.
tolist
()
feats_per_image
=
features
.
split
(
num_crops_per_image
)
f_is_patch_per_image
=
feat_is_patch
.
split
(
num_crops_per_image
)
_
,
_
,
embed_dim
=
features
.
shape
(
num_embeds
,
)
=
embed_is_patch
.
shape
embeds_in_batch
=
list
[
torch
.
Tensor
]()
for
feats
,
f_is_patch
in
zip
(
feats_per_image
,
f_is_patch_per_image
):
embeds
=
feats
.
new_full
((
num_embeds
,
embed_dim
),
torch
.
nan
)
embeds
[
embed_is_patch
]
=
feats
[
f_is_patch
]
embeds_in_batch
.
append
(
embeds
)
return
embeds_in_batch
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
image_features
=
self
.
_process_image_input
(
image_input
)
image_features
=
self
.
_process_image_input
(
image_input
)
image_input_idx
=
image_input
[
"image_input_idx"
]
seq_len
=
image_input
[
"seq_len"
]
return
[
batch_size
,
num_image
,
num_patch
=
image_features
.
shape
[:
3
]
self
.
_get_mm_embeds
(
*
args
)
for
args
in
zip
(
assert
image_input_idx
.
shape
==
(
batch_size
,
num_image
,
num_patch
)
image_features
,
image_input
[
"feat_is_patch"
],
# insert the image feature into the embedding.
image_input
[
"num_crops"
],
image_features
=
image_features
.
view
(
batch_size
,
num_image
*
num_patch
,
image_input
[
"embed_is_patch"
],
-
1
)
)
image_input_idx
=
image_input_idx
.
view
(
batch_size
,
]
num_image
*
num_patch
)
valid
=
image_input_idx
>=
0
image_features
=
image_features
*
valid
[:,
:,
None
].
to
(
image_features
.
dtype
)
image_features
=
image_features
.
view
(
batch_size
*
num_image
*
num_patch
,
-
1
).
contiguous
()
image_input_idx
=
image_input_idx
*
valid
.
to
(
image_input_idx
.
dtype
)
offset
=
torch
.
cat
([
seq_len
.
new_zeros
(
1
),
seq_len
.
cumsum
(
dim
=
0
)[:
-
1
]],
dim
=
0
)[:,
None
]
image_input_idx
=
image_input_idx
+
offset
.
to
(
image_input_idx
.
dtype
)
image_input_idx
=
image_input_idx
.
flatten
()[:,
None
]
mat
=
image_input_idx
==
torch
.
arange
(
seq_len
.
sum
().
item
(),
device
=
image_features
.
device
)[
None
,
:]
mat
=
mat
.
to
(
image_features
.
dtype
)
# Note: In this original implementation from AI2, the final
# vision_embeddings will be always be the same length
# of input embeddings.
vision_embeddings
=
torch
.
einsum
(
'nd,nm->md'
,
image_features
,
mat
)
# Split by the sizes of the input sequences. For each full embedding,
# extract the actual vision embeddings to be merged.
vision_embeddings
=
list
(
vision_embeddings
.
split
(
seq_len
.
tolist
()))
for
i
in
range
(
len
(
vision_embeddings
)):
start
,
end
=
image_input
[
'image_start_end'
][
i
]
vision_embeddings
[
i
]
=
vision_embeddings
[
i
][
start
:
end
]
return
vision_embeddings
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -1311,11 +1641,20 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -1311,11 +1641,20 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
:
assert
self
.
img_patch_id
is
not
None
# Extract the patch tokens scattered in _get_mm_embeds
patch_embeddings
=
json_map_leaves
(
lambda
x
:
x
[
~
x
.
isnan
()].
view
(
-
1
,
*
x
.
shape
[
1
:]),
cast
(
JSONTree
[
torch
.
Tensor
],
multimodal_embeddings
),
)
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
[
input_ids
,
DEFAULT_IMAGE_PATCH_TOKEN_ID
,
DEFAULT_IM_START_TOKEN_ID
,
inputs_embeds
,
DEFAULT_IM_END_TOKEN_ID
,
DEFAULT_IM_COL_TOKEN_ID
cast
(
NestedTensors
,
patch_embeddings
),
])
self
.
img_patch_id
,
)
return
inputs_embeds
return
inputs_embeds
def
forward
(
def
forward
(
...
...
vllm/multimodal/inputs.py
View file @
c9d3ecf0
vllm/utils.py
View file @
c9d3ecf0
...
@@ -33,8 +33,7 @@ from dataclasses import dataclass, field
...
@@ -33,8 +33,7 @@ from dataclasses import dataclass, field
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
Dict
,
Generator
,
Generic
,
Iterator
,
List
,
Literal
,
Dict
,
Generator
,
Generic
,
Iterator
,
List
,
Literal
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
)
overload
)
from
uuid
import
uuid4
from
uuid
import
uuid4
import
cloudpickle
import
cloudpickle
...
@@ -826,38 +825,6 @@ JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
...
@@ -826,38 +825,6 @@ JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
"""A nested JSON structure where the leaves need not be JSON-serializable."""
"""A nested JSON structure where the leaves need not be JSON-serializable."""
@
overload
def
json_map_leaves
(
func
:
Callable
[[
T
],
U
],
value
:
Dict
[
str
,
JSONTree
[
T
]],
)
->
Dict
[
str
,
JSONTree
[
U
]]:
...
@
overload
def
json_map_leaves
(
func
:
Callable
[[
T
],
U
],
value
:
List
[
JSONTree
[
T
]],
)
->
List
[
JSONTree
[
U
]]:
...
@
overload
def
json_map_leaves
(
func
:
Callable
[[
T
],
U
],
value
:
Tuple
[
JSONTree
[
T
],
...],
)
->
Tuple
[
JSONTree
[
U
],
...]:
...
@
overload
def
json_map_leaves
(
func
:
Callable
[[
T
],
U
],
value
:
JSONTree
[
T
],
)
->
JSONTree
[
U
]:
...
def
json_map_leaves
(
func
:
Callable
[[
T
],
U
],
value
:
JSONTree
[
T
])
->
JSONTree
[
U
]:
def
json_map_leaves
(
func
:
Callable
[[
T
],
U
],
value
:
JSONTree
[
T
])
->
JSONTree
[
U
]:
if
isinstance
(
value
,
dict
):
if
isinstance
(
value
,
dict
):
return
{
k
:
json_map_leaves
(
func
,
v
)
for
k
,
v
in
value
.
items
()}
return
{
k
:
json_map_leaves
(
func
,
v
)
for
k
,
v
in
value
.
items
()}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment