Unverified Commit c9d3ecf0 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Merged multi-modal processor for Molmo (#12966)

parent fdcf64d3
...@@ -793,7 +793,7 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -793,7 +793,7 @@ See [this page](#generative-models) for more information on how to use generativ
- * `MolmoForCausalLM` - * `MolmoForCausalLM`
* Molmo * Molmo
* T + I * T + I
* `allenai/Molmo-7B-D-0924`, `allenai/Molmo-72B-0924`, etc. * `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
* ✅︎ * ✅︎
......
...@@ -27,7 +27,7 @@ from ...utils import check_logprobs_close ...@@ -27,7 +27,7 @@ from ...utils import check_logprobs_close
marks=[pytest.mark.core_model, pytest.mark.cpu_model], marks=[pytest.mark.core_model, pytest.mark.cpu_model],
), ),
pytest.param( pytest.param(
"THUDM/chatglm3-6b", # ChatGLM (text-only) "THUDM/chatglm3-6b", # chatglm (text-only)
), ),
pytest.param( pytest.param(
"meta-llama/Llama-3.2-1B-Instruct", # llama "meta-llama/Llama-3.2-1B-Instruct", # llama
......
...@@ -404,11 +404,10 @@ VLM_TEST_SETTINGS = { ...@@ -404,11 +404,10 @@ VLM_TEST_SETTINGS = {
"molmo": VLMTestInfo( "molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"], models=["allenai/Molmo-7B-D-0924"],
test_type=(VLMTestType.IMAGE), test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501 prompt_formatter=identity,
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
image_size_factors=[(),(1.0, 1.0, 1.0)], patch_hf_runner=model_utils.molmo_patch_hf_runner,
patch_hf_runner=model_utils.mlomo_patch_hf_runner,
postprocess_inputs=model_utils.molmo_post_processor, postprocess_inputs=model_utils.molmo_post_processor,
), ),
# Tests for phi3v currently live in another file because of a bug in # Tests for phi3v currently live in another file because of a bug in
......
...@@ -6,7 +6,7 @@ typically specific to a small subset of models. ...@@ -6,7 +6,7 @@ typically specific to a small subset of models.
import re import re
import types import types
from pathlib import PosixPath from pathlib import PosixPath
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from PIL.Image import Image from PIL.Image import Image
...@@ -17,9 +17,7 @@ from vllm.sequence import SampleLogprobs ...@@ -17,9 +17,7 @@ from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import patch_padding_side from vllm.transformers_utils.tokenizer import patch_padding_side
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from .....conftest import (HfRunner, ImageAsset, PromptAudioInput, from .....conftest import HfRunner, ImageAsset, _ImageAssets
PromptImageInput, PromptVideoInput, _ImageAssets)
from ....utils import TokensTextLogprobs
from .types import RunnerOutput from .types import RunnerOutput
...@@ -522,74 +520,7 @@ def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: ...@@ -522,74 +520,7 @@ def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model return hf_model
def _generate_greedy_logprobs_limit( def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
**kwargs: Any,
) -> List[TokensTextLogprobs]:
all_inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
# Process in batches for inference.
if len(all_inputs):
input_ids_lst = []
images_lst = []
images_input_idx_lst = []
imges_masks_lst = []
for inputs in all_inputs:
input_ids_lst.append(inputs["input_ids"])
images_lst.append(inputs["images"])
images_input_idx_lst.append(inputs["image_input_idx"])
imges_masks_lst.append(inputs["image_masks"])
batch_inputs = {}
batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0)
batch_inputs['images'] = torch.cat(images_lst, dim=0)
batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst,
dim=0)
batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0)
outputs = self.model.generate_from_batch(
batch=self.wrap_device(batch_inputs,
device=self.model.device.type),
generation_config=GenerationConfig(
max_new_tokens=max_tokens,
stop_strings="<|endoftext|>",
do_sample=False,
),
tokenizer=self.tokenizer,
output_hidden_states=True,
return_dict_in_generate=True,
)
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []
for index in range(len(all_inputs)):
(
seq_logprobs_lst,
output_len,
) = self._hidden_states_to_logprobs(outputs.hidden_states,
num_logprobs)
all_logprobs.append(seq_logprobs_lst)
seq_ids = outputs.sequences[index]
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
####### Molmo-specific HuggingFace runner patchers
def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Molmo.""" """Patches and returns an instance of the HfRunner to use for Molmo."""
hf_processor = hf_model.processor hf_processor = hf_model.processor
...@@ -598,10 +529,23 @@ def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: ...@@ -598,10 +529,23 @@ def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_model.processor = _processor hf_model.processor = _processor
setattr( # noqa: B010 def _generate(self, max_new_tokens=None, do_sample=None, **kwargs):
hf_model, batch = {
"generate_greedy_logprobs_limit", k: kwargs.pop(k)
types.MethodType(_generate_greedy_logprobs_limit, hf_model), for k in ("input_ids", "images", "image_input_idx", "image_masks")
if k in kwargs
}
return self.generate_from_batch(
batch,
generation_config=GenerationConfig(
max_new_tokens=max_new_tokens,
stop_strings="<|endoftext|>",
do_sample=do_sample,
),
**kwargs,
) )
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
return hf_model return hf_model
...@@ -168,6 +168,8 @@ def _test_processing_correctness( ...@@ -168,6 +168,8 @@ def _test_processing_correctness(
"mistral-community/pixtral-12b", "mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6", "openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6", "openbmb/MiniCPM-V-2_6",
"allenai/Molmo-7B-D-0924",
"allenai/Molmo-7B-O-0924",
"nvidia/NVLM-D-72B", "nvidia/NVLM-D-72B",
"Qwen/Qwen-VL-Chat", "Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2-VL-2B-Instruct",
......
...@@ -256,6 +256,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -256,6 +256,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6", "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6",
trust_remote_code=True), trust_remote_code=True),
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",
trust_remote_code=True), trust_remote_code=True),
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
import re
from array import array
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache, partial from functools import cached_property, partial
from typing import Iterable, List, Mapping, Optional, Set, Tuple, TypedDict from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union, cast)
import numpy as np
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from PIL import Image from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
from torch import nn TensorType)
from torch.nn import functional as F from transformers.image_utils import ImageInput
from transformers import PretrainedConfig from transformers.tokenization_utils_base import TextInput
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
...@@ -22,8 +24,6 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, ...@@ -22,8 +24,6 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather) tensor_model_parallel_all_gather)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU, from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
SiluAndMul) SiluAndMul)
...@@ -40,15 +40,21 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -40,15 +40,21 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.utils import cached_get_tokenizer NestedTensors)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
SequenceData) MultiModalDataItems)
from vllm.transformers_utils.processor import get_processor from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, json_map_leaves
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
...@@ -56,38 +62,39 @@ from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, ...@@ -56,38 +62,39 @@ from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
VIT_LAYERS = [-2, -9] VIT_LAYERS = [-2, -9]
NUM_PREFIX_TOKENS = 1 NUM_PREFIX_TOKENS = 1
ADDITIONAL_VOCAB_SIZE = 128 ADDITIONAL_VOCAB_SIZE = 128
DEFAULT_IMAGE_PATCH_TOKEN_ID = 152066 IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN_ID = 152067 IM_COL_TOKEN = "<im_col>"
DEFAULT_IM_END_TOKEN_ID = 152064 IM_START_TOKEN = "<im_start>"
DEFAULT_IM_COL_TOKEN_ID = 152065 IM_END_TOKEN = "<im_end>"
POOLING_SIZE = 2
class MolmoImageInputs(TypedDict): class MolmoImageInputs(TypedDict):
images: torch.Tensor images: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: """Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
`(batch_size, num_crops, num_patch, patch_dim)`
""" image_masks: Optional[Union[torch.Tensor, List[torch.Tensor]]]
"""Shape: `(batch_size, num_crops, num_patch)`"""
image_input_idx: torch.Tensor feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""Shape:
`(batch_size, num_crops, num_patch)`
""" """
A boolean mask indicating which image features correspond
to patch tokens.
seq_len: torch.Tensor Shape: `(batch_size, num_crops, num_patch)`
"""Shape:
`(batch_size, )`
""" """
image_masks: Optional[torch.Tensor] embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""Shape:
`(batch_size, num_crops, num_patch)`
""" """
A boolean mask indicating which image embeddings correspond
to patch tokens.
image_start_end: Tuple[int, int] Shape: `(batch_size, num_embeds)`
"""Starting and ending index of placeholder
tokens
""" """
num_crops: torch.Tensor
"""Shape: `(batch_size, num_images)`"""
@dataclass @dataclass
class VisionBackboneConfig: class VisionBackboneConfig:
...@@ -335,7 +342,7 @@ class VisionTransformer(nn.Module): ...@@ -335,7 +342,7 @@ class VisionTransformer(nn.Module):
def forward(self, def forward(self,
x: torch.Tensor, x: torch.Tensor,
patch_num: int = None) -> List[torch.Tensor]: patch_num: Optional[int] = None) -> List[torch.Tensor]:
""" """
: param x: (batch_size, num_patch, n_pixels) : param x: (batch_size, num_patch, n_pixels)
""" """
...@@ -465,7 +472,7 @@ class MolmoAttention(nn.Module): ...@@ -465,7 +472,7 @@ class MolmoAttention(nn.Module):
return output return output
class LanuageModelMLP(nn.Module): class LanguageModelMLP(nn.Module):
"""Molmo's LLM mlp.""" """Molmo's LLM mlp."""
def __init__(self, def __init__(self,
...@@ -559,7 +566,7 @@ class MolmoDecoderLayer(nn.Module): ...@@ -559,7 +566,7 @@ class MolmoDecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn") prefix=f"{prefix}.self_attn")
# MLP block. # MLP block.
self.mlp = LanuageModelMLP(config, quant_config=quant_config) self.mlp = LanguageModelMLP(config, quant_config=quant_config)
# LayerNorm # LayerNorm
assert config.layer_norm_type == "rms" assert config.layer_norm_type == "rms"
...@@ -638,8 +645,8 @@ class MolmoVisionBackbone(nn.Module): ...@@ -638,8 +645,8 @@ class MolmoVisionBackbone(nn.Module):
self.vit_layers = VIT_LAYERS self.vit_layers = VIT_LAYERS
self.image_num_patch = vision_config.image_num_patch self.image_num_patch = vision_config.image_num_patch
self.llm_patches_per_crop = ( self.llm_patches_per_crop = (
(self.image_num_patch[0] + 1) // 2, (self.image_num_patch[0] + 1) // POOLING_SIZE,
(self.image_num_patch[1] + 1) // 2, (self.image_num_patch[1] + 1) // POOLING_SIZE,
) )
self.image_vit = VisionTransformer(vision_config, self.image_vit = VisionTransformer(vision_config,
quant_config=quant_config) quant_config=quant_config)
...@@ -723,19 +730,19 @@ class MolmoVisionBackbone(nn.Module): ...@@ -723,19 +730,19 @@ class MolmoVisionBackbone(nn.Module):
image_features = image_features.reshape( image_features = image_features.reshape(
(batch_size, num_image) + self.image_num_patch + (-1, ), ) (batch_size, num_image) + self.image_num_patch + (-1, ), )
if self.image_num_patch[0] % 2 == 1: if (missing_w := self.image_num_patch[0] % POOLING_SIZE):
# Pad so we can still pool 2x2 patches # Padding for image pooling (see below)
image_features = F.pad( image_features = F.pad(
image_features, image_features,
(0, 0, 0, 1, 0, 1, 0, 0, 0, 0), (0, 0, 0, missing_w, 0, missing_w, 0, 0, 0, 0),
) )
# image pooling # image pooling
image_features = rearrange( image_features = rearrange(
image_features, image_features,
'b n (h dh) (w dw) c -> (b n h w) (dh dw) c', 'b n (h dh) (w dw) c -> (b n h w) (dh dw) c',
dh=2, dh=POOLING_SIZE,
dw=2, dw=POOLING_SIZE,
) )
query = image_features.mean(-2, keepdim=True) query = image_features.mean(-2, keepdim=True)
...@@ -888,249 +895,513 @@ class MolmoModel(nn.Module): ...@@ -888,249 +895,513 @@ class MolmoModel(nn.Module):
return loaded_params return loaded_params
cached_get_processor = lru_cache(get_processor) def _lowest_multiple(x: int, k: int) -> int:
return (x // k) * k
def get_num_patches(num_tiles: int, crop_patches: int, left_margin: int, def get_num_patches(
right_margin: int, pooling_size: int) -> int: num_tiles: int,
*,
crop_patches: int,
left_margin: int,
right_margin: int,
pooling_size: int,
) -> int:
if num_tiles == 1:
return _lowest_multiple(crop_patches + pooling_size - 1, pooling_size)
crop_window_patches = crop_patches - (left_margin + right_margin) crop_window_patches = crop_patches - (left_margin + right_margin)
if num_tiles > 1:
left_crop_window_patches = (crop_window_patches + left_margin +
pooling_size -
1) // pooling_size * pooling_size
middle_crop_window_patches = (crop_window_patches + pooling_size -
1) // pooling_size * pooling_size
right_crop_window_patches = (crop_window_patches + right_margin +
pooling_size -
1) // pooling_size * pooling_size
return left_crop_window_patches + (
num_tiles -
2) * middle_crop_window_patches + right_crop_window_patches
else:
single_crop_window_patches = (crop_patches + pooling_size -
1) // pooling_size * pooling_size
return single_crop_window_patches
def get_tokens(tiling_h: int, tiling_w: int, crop_patches: int,
left_margin: int, right_margin: int, pooling_size: int) -> int:
h = get_num_patches(tiling_h, crop_patches, left_margin, right_margin,
pooling_size)
w = get_num_patches(tiling_w, crop_patches, left_margin, right_margin,
pooling_size)
per_row = w // pooling_size + 1
joint = per_row * (h // pooling_size) + 2
image_token_length = (crop_patches + pooling_size - 1) // pooling_size
resize = (image_token_length + 1) * image_token_length + 2
return resize + joint
left_num = _lowest_multiple(
crop_window_patches + left_margin + pooling_size - 1,
pooling_size,
)
middle_num = _lowest_multiple(
crop_window_patches + pooling_size - 1,
pooling_size,
)
right_num = _lowest_multiple(
crop_window_patches + right_margin + pooling_size - 1,
pooling_size,
)
def get_max_tokens(max_crops: int, crop_patches: int, left_margin: int, return left_num + (num_tiles - 2) * middle_num + right_num
right_margin: int, pooling_size: int) -> int:
tilings = []
for i in range(1, max_crops + 1): def get_patches_grid_size(
for j in range(1, max_crops + 1): *,
if i * j <= max_crops: tiling_h: int,
tilings.append((i, j)) tiling_w: int,
tokens = [ crop_patches: int,
get_tokens(tilings[i][0], tilings[i][1], crop_patches, left_margin, left_margin: int,
right_margin, pooling_size) for i in range(len(tilings)) right_margin: int,
] pooling_size: int,
return max(tokens) ) -> tuple[int, int]:
nrows = get_num_patches(
tiling_h,
def get_max_molmo_image_tokens(ctx: InputContext) -> int: crop_patches=crop_patches,
processor = cached_get_processor( left_margin=left_margin,
ctx.model_config.model, right_margin=right_margin,
trust_remote_code=ctx.model_config.trust_remote_code, pooling_size=pooling_size,
revision=ctx.model_config.code_revision) )
image_processor = processor.image_processor ncols = get_num_patches(
max_llm_image_tokens = get_max_tokens( tiling_w,
image_processor.max_crops, crop_patches=crop_patches,
image_processor.base_image_input_size[0] // left_margin=left_margin,
image_processor.image_patch_size, right_margin=right_margin,
image_processor.overlap_margins[0], pooling_size=pooling_size,
image_processor.overlap_margins[1],
2,
) )
return max_llm_image_tokens
return nrows, ncols
def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]:
tilings = [(i, j) for i in range(1, max_num + 1)
for j in range(1, max_num + 1) if i * j <= max_num]
return sorted(tilings, key=lambda x: x[0] * x[1])
# NOTE: preprocessing for the image data has been included in the
# 'input_processor_for_molmo' function def select_tiling(
def image_input_mapper_for_molmo( *,
ctx: InputContext, height: int,
data: object, width: int,
patch_size: int,
max_num_patches: int,
): ):
if isinstance(data, list): tilings = get_candidate_tilings(max_num_patches)
assert len(data) == 1, "Molmo supports only one image per prompt." candidate_tilings = np.array(tilings, dtype=np.int32)
data = data[0] candidate_resolutions = candidate_tilings * patch_size
original_size = np.array([height, width], dtype=np.float32)
required_scale_d = candidate_resolutions.astype(np.float32) / original_size
required_scale = required_scale_d.min(axis=-1, keepdims=True)
return MultiModalKwargs(data) if (required_scale < 1).all():
ix = required_scale.argmax()
else:
ix = np.where(required_scale < 1.0, 10e9, required_scale).argmin()
return candidate_tilings[ix]
def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
processor = cached_get_processor(
ctx.model_config.model,
trust_remote_code=ctx.model_config.trust_remote_code,
revision=ctx.model_config.code_revision)
image_processor = processor.image_processor
base_image_input_d = image_processor.image_patch_size class MolmoProcessorWrapper:
left_margin, right_margin = image_processor.overlap_margins """
Wraps :class:`MolmoProcessor` so that it can be called directly.
The original definition can be found here:
https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py
"""
def __init__(self, processor: ProcessorMixin):
super().__init__()
self.processor = processor
@cached_property
def vocab(self) -> dict[str, int]:
return self.processor.tokenizer.vocab # type: ignore
@cached_property
def max_crops(self) -> int:
image_processor = self.processor.image_processor # type: ignore
max_crops = image_processor.max_crops max_crops = image_processor.max_crops
assert isinstance(max_crops, int)
return max_crops
@cached_property
def base_image_input_size(self) -> tuple[int, int]:
image_processor = self.processor.image_processor # type: ignore
base_image_input_size = image_processor.base_image_input_size
if isinstance(base_image_input_size, int):
return base_image_input_size, base_image_input_size
return tuple(base_image_input_size)
@cached_property
def image_patch_size(self) -> int:
image_processor = self.processor.image_processor # type: ignore
image_patch_size = image_processor.image_patch_size
assert isinstance(image_patch_size, int)
return image_patch_size
@cached_property
def overlap_margins(self) -> tuple[int, int]:
image_processor = self.processor.image_processor # type: ignore
left_margin, right_margin = image_processor.overlap_margins
assert isinstance(left_margin, int)
assert isinstance(right_margin, int)
# Assume: prompt_token_ids always starts with bos_token_id followed image tokens # noqa: E501 return left_margin, right_margin
max_llm_image_tokens = get_max_molmo_image_tokens(ctx)
if seq_len - max_llm_image_tokens - 1 < 0: @cached_property
raise RuntimeError( def image_token_length_w(self) -> int:
f"Molmo cannot process {max_crops} crops in a prompt, " image_processor = self.processor.image_processor # type: ignore
"please increase max_model_len or reduce number of crops")
image_token_length_w = image_processor.image_token_length_w
assert isinstance(image_token_length_w, int)
return image_token_length_w
@cached_property
def image_token_length_h(self) -> int:
image_processor = self.processor.image_processor # type: ignore
image_token_length_h = image_processor.image_token_length_h
assert isinstance(image_token_length_h, int)
return image_token_length_h
@property
def message_format(self) -> Optional[str]:
return "role"
@property
def always_start_with_space(self) -> bool:
return True
@cached_property
def image_patch_id(self) -> int:
return self.vocab[IMAGE_PATCH_TOKEN]
@cached_property
def im_col_id(self) -> int:
return self.vocab[IM_COL_TOKEN]
@cached_property
def im_start_id(self) -> int:
return self.vocab[IM_START_TOKEN]
@cached_property
def im_end_id(self) -> int:
return self.vocab[IM_END_TOKEN]
@property
def pooling_size(self) -> int:
return POOLING_SIZE
def select_tiling(
self,
*,
image_width: int,
image_height: int,
) -> tuple[int, int]:
max_crops = self.max_crops
left_margin, right_margin = self.overlap_margins
base_image_input_size = self.base_image_input_size
base_image_input_d = self.image_patch_size
# The vertical image has the maximum number of image tokens due to column tokens. # noqa: E501
tiling = (max_crops, 1)
total_margin_pixels = base_image_input_d * (right_margin + left_margin) total_margin_pixels = base_image_input_d * (right_margin + left_margin)
crop_patches = image_processor.base_image_input_size[ crop_patches = base_image_input_size[0] // base_image_input_d
0] // base_image_input_d
crop_window_patches = crop_patches - (right_margin + left_margin) crop_window_patches = crop_patches - (right_margin + left_margin)
crop_window_size = crop_window_patches * base_image_input_d crop_window_size = crop_window_patches * base_image_input_d
tiling_h, tiling_w = select_tiling(
height=image_height - total_margin_pixels,
width=image_width - total_margin_pixels,
patch_size=crop_window_size,
max_num_patches=max_crops,
)
h = crop_window_size * tiling[0] + total_margin_pixels return tiling_w, tiling_h
w = crop_window_size * tiling[1] + total_margin_pixels
dummy_image = Image.new("RGB", (w, h), color="red") def get_patches_grid_size(
self,
*,
image_width: int,
image_height: int,
) -> tuple[int, int]:
left_margin, right_margin = self.overlap_margins
base_image_input_size = self.base_image_input_size
base_image_input_d = self.image_patch_size
pooling_size = self.pooling_size
crop_patches = base_image_input_size[0] // base_image_input_d
tiling_w, tiling_h = self.select_tiling(
image_height=image_height,
image_width=image_width,
)
out = processor.process("dummy prompt", dummy_image) nrows, ncols = get_patches_grid_size(
tiling_h=tiling_h,
tiling_w=tiling_w,
crop_patches=crop_patches,
left_margin=left_margin,
right_margin=right_margin,
pooling_size=pooling_size,
)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, return ncols, nrows
out["input_ids"][:1 + max_llm_image_tokens])
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, def __call__(
[0]) * (seq_len - max_llm_image_tokens - 1) self,
dummy_seqdata = SequenceData(token_ids) text: Optional[Union[TextInput, list[TextInput]]] = None,
dummy_imgdata = { images: Optional[Union[ImageInput, list[ImageInput]]] = None,
"images": out["images"], return_tensors: Optional[Union[str, TensorType]] = None,
"image_input_idx": out["image_input_idx"], **kwargs,
} ) -> BatchFeature:
if "image_masks" in out: outputs = self.processor.process( # type: ignore
dummy_imgdata["image_masks"] = out["image_masks"] text, images, **kwargs)
dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long)
size = 0 if images is None:
offset = -1 images = []
for i in range(len(token_ids)): if not isinstance(images, list):
if token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID, images = [images]
DEFAULT_IM_START_TOKEN_ID, DEFAULT_IM_END_TOKEN_ID,
DEFAULT_IM_COL_TOKEN_ID): input_ids: torch.Tensor = outputs.pop("input_ids")
if offset < 0: outputs["input_ids"] = input_ids.unsqueeze(0)
offset = i
size += 1 image_input_idx = outputs.pop("image_input_idx", None)
dummy_imgdata["image_start_end"] = (offset, offset + size) if image_input_idx is not None:
return DummyData(seq_data=dummy_seqdata, input_is_patch = input_ids == self.image_patch_id
multi_modal_data={"image": dummy_imgdata}, image_input_idx_flat: torch.Tensor = image_input_idx.view(-1)
multi_modal_placeholders={ image_valid_flat = image_input_idx_flat >= 0
feat_is_patch_flat = image_valid_flat.clone()
feat_is_patch_flat[image_valid_flat] = (
input_is_patch[image_input_idx_flat[image_valid_flat]])
feat_is_patch = feat_is_patch_flat.view(*image_input_idx.shape)
input_is_embed = torch.isin(
input_ids,
torch.tensor([
self.image_patch_id,
self.im_col_id,
self.im_start_id,
self.im_end_id,
]),
)
embed_ids = input_ids[input_is_embed]
embed_is_patch = embed_ids == self.image_patch_id
assert embed_is_patch.sum() == feat_is_patch.sum()
tilings = [
self.select_tiling(
image_width=image.size[0],
image_height=image.size[1],
) for image in images
]
# For each image: tiling_h * tiling_w + extra
num_crops = torch.tensor(tilings).prod(-1) + 1
assert num_crops.sum() == len(feat_is_patch)
outputs["feat_is_patch"] = feat_is_patch
outputs["embed_is_patch"] = embed_is_patch
outputs["num_crops"] = num_crops
outputs["img_patch_id"] = self.image_patch_id
return BatchFeature(outputs, tensor_type=return_tensors)
class MolmoProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self) -> MolmoProcessorWrapper:
processor = self.ctx.get_hf_processor()
return MolmoProcessorWrapper(processor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[MolmoProcessorWrapper],
) -> int:
if processor is None:
processor = self.get_hf_processor()
ncols, nrows = processor.get_patches_grid_size(
image_width=image_width,
image_height=image_height,
)
pooling_size = processor.pooling_size
base_image_input_size = processor.base_image_input_size
base_image_input_d = processor.image_patch_size
crop_patches = base_image_input_size[0] // base_image_input_d
per_row = ncols // pooling_size + 1
joint = per_row * (nrows // pooling_size) + 2
image_token_length = (crop_patches + pooling_size - 1) // pooling_size
resize = (image_token_length + 1) * image_token_length + 2
return resize + joint
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
tilings = get_candidate_tilings(processor.max_crops)
base_h, base_w = processor.base_image_input_size
largest_feature_size, largest_feature_pinpoint = 0, None
for wr, hr in tilings:
width, height = base_w * wr, base_h * hr
feat_size = self.get_num_image_tokens(
image_width=width,
image_height=height,
processor=processor,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_pinpoint
class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
"image": "image":
[PlaceholderRange(offset=offset, length=size)] self._get_dummy_images(width=target_width,
}) height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
def pad_images(
max_total_crops: int, class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
images: torch.Tensor,
image_input_idx: torch.Tensor, def _apply_hf_processor_tokens_only(
image_masks: Optional[torch.Tensor] = None, self,
): prompt_tokens: list[int],
n = max_total_crops - images.shape[0] ) -> list[int]:
images = F.pad(images, (0, 0, 0, 0, 0, n), value=-1) processor = self.info.get_hf_processor()
image_input_idx = F.pad(image_input_idx, (0, 0, 0, n), value=-1)
if image_masks is not None: # Apply the chat template to the tokens
image_masks = F.pad(image_masks, (0, 0, 0, n), value=-1) tokens = processor.processor.get_tokens_input( # type: ignore
return images, image_input_idx, image_masks self.info.get_tokenizer().decode(prompt_tokens),
message_format=processor.message_format,
always_start_with_space=processor.always_start_with_space,
def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
prompt = inputs.get("prompt")
multi_modal_data = inputs.get("multi_modal_data")
image = None if multi_modal_data is None else multi_modal_data.get("image")
model_config = ctx.model_config
processor = cached_get_processor(
ctx.model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=ctx.model_config.code_revision)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
# NOTE: message formatting for raw text prompt is only applied for
# offline inference; for online serving, the prompt is always in
# instruction format and tokenized.
if prompt is not None and re.match(r"^User:[\s\S]*?(Assistant:)*$",
prompt):
out = processor.process(prompt, image, message_format="none")
elif prompt is not None:
out = processor.process(prompt, image)
else:
out = processor.process(None, image, tokens=inputs["prompt_token_ids"])
# If there is no image, return directly.
if image is None:
new_prompt_token_ids = out["input_ids"].tolist()
prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(new_prompt_token_ids)
return token_inputs(
prompt_token_ids=new_prompt_token_ids,
prompt=prompt,
) )
image_processor = processor.image_processor processed_data = self.info.ctx.call_hf_processor(
max_total_crops = 1 + image_processor.max_crops processor, # type: ignore
images, image_input_idx, image_masks = pad_images( dict(tokens=tokens),
max_total_crops,
out["images"],
out["image_input_idx"],
out.get("image_masks"),
) )
image_data = dict( prompt_ids, = processed_data.pop("input_ids").tolist()
images=images,
image_input_idx=image_input_idx, return prompt_ids
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_crops = hf_inputs.get("num_crops", torch.empty(0))
num_images = len(num_crops)
return dict(
images=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
image_masks=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
embed_is_patch=MultiModalFieldConfig.shared("image", num_images),
num_crops=MultiModalFieldConfig.batched("image"),
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
) )
if image_masks is not None:
image_data["image_masks"] = image_masks def _get_prompt_replacements(
self,
new_prompt_token_ids = out["input_ids"].tolist() mm_items: MultiModalDataItems,
image_data["seq_len"] = torch.tensor(len(new_prompt_token_ids), hf_processor_mm_kwargs: Mapping[str, object],
dtype=torch.long) out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
multi_modal_data = dict(image=image_data) processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
size = 0 tokenizer = self.info.get_tokenizer()
offset = -1
for i in range(len(new_prompt_token_ids)): image_token_length_w = processor.image_token_length_w
if new_prompt_token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID, image_token_length_h = processor.image_token_length_h
DEFAULT_IM_START_TOKEN_ID, pooling_size = processor.pooling_size
DEFAULT_IM_END_TOKEN_ID,
DEFAULT_IM_COL_TOKEN_ID): user_str = "User:"
if offset < 0: if processor.always_start_with_space:
offset = i user_str = " " + user_str
size += 1
image_data["image_start_end"] = (offset, offset + size) user_tokens = tokenizer.encode(user_str, add_special_tokens=False)
prompt = inputs.get("prompt")
if prompt is None: img_patch_id = processor.image_patch_id
prompt = tokenizer.decode(new_prompt_token_ids) img_col_id = processor.im_col_id
return token_inputs( img_start_id = processor.im_start_id
prompt_token_ids=new_prompt_token_ids, img_end_id = processor.im_end_id
prompt=prompt,
multi_modal_data=multi_modal_data, extra_row = [img_patch_id] * image_token_length_w + [img_col_id]
multi_modal_placeholders={ extra_joint = ([img_start_id] + extra_row * image_token_length_h +
"image": [PlaceholderRange(offset=offset, length=size)] [img_end_id])
},
def get_replacement_molmo(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = processor.get_patches_grid_size(
image_width=image_size.width,
image_height=image_size.height,
) )
joint_row = ([img_patch_id] * ((ncols + 1) // pooling_size) +
[img_col_id])
joint = ([img_start_id] + joint_row *
((nrows + 1) // pooling_size) + [img_end_id])
@MULTIMODAL_REGISTRY.register_image_input_mapper(image_input_mapper_for_molmo) image_tokens = extra_joint + joint
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) return PromptReplacementDetails(
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) full=image_tokens + user_tokens,
features=image_tokens,
)
return [
PromptReplacement(
modality="image",
target=user_str,
replacement=get_replacement_molmo,
)
]
@MULTIMODAL_REGISTRY.register_processor(MolmoMultiModalProcessor,
info=MolmoProcessingInfo,
dummy_inputs=MolmoDummyInputsBuilder)
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA): SupportsLoRA):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
...@@ -1202,6 +1473,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -1202,6 +1473,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
quant_config) quant_config)
self.model = MolmoModel(vllm_config=vllm_config, self.model = MolmoModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.img_patch_id = None
if self.config.weight_tying: if self.config.weight_tying:
self.lm_head = self.model.transformer.wte self.lm_head = self.model.transformer.wte
...@@ -1224,33 +1496,69 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -1224,33 +1496,69 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
**kwargs: object, **kwargs: object,
) -> Optional[MolmoImageInputs]: ) -> Optional[MolmoImageInputs]:
images = kwargs.pop("images", None) images = kwargs.pop("images", None)
image_masks = kwargs.pop("image_masks", None)
image_start_end = kwargs.pop("image_start_end", None)
if images is None: if images is None:
return None return None
image_input_idx = kwargs.pop("image_input_idx", None) if not isinstance(images, (torch.Tensor, list)):
seq_len = kwargs.pop("seq_len", None) raise ValueError("Incorrect type of images. "
if image_input_idx is None: f"Got type: {type(images)}")
raise ValueError("image_input_idx is required for Molmo model.")
if seq_len is None: image_masks = kwargs.pop("image_masks", None)
raise ValueError("seq_len is required for Molmo model.") if not (image_masks is None or isinstance(image_masks,
if not isinstance(seq_len, torch.Tensor): (torch.Tensor, list))):
seq_len = torch.tensor(seq_len) raise ValueError("Incorrect type of image_masks. "
f"Got type: {type(image_masks)}")
feat_is_patch = kwargs.pop("feat_is_patch", None)
if not isinstance(feat_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")
embed_is_patch = kwargs.pop("embed_is_patch", None)
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_crops = kwargs.pop("num_crops", None)
if not isinstance(num_crops, torch.Tensor):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
img_patch_id = kwargs.pop("img_patch_id", None)
if not isinstance(img_patch_id, torch.Tensor):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
self.img_patch_id = img_patch_id.flatten().unique().item()
return MolmoImageInputs( return MolmoImageInputs(
images=images, images=images,
image_input_idx=image_input_idx,
seq_len=seq_len,
image_masks=image_masks, image_masks=image_masks,
image_start_end=image_start_end, feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
) )
def _process_image_input( def _process_image_input(
self, self,
image_input: MolmoImageInputs, image_input: MolmoImageInputs,
) -> torch.Tensor: ) -> Union[torch.Tensor, List[torch.Tensor]]:
if isinstance(image_input["images"], list):
# Call the vision backbone on the whole batch at once
images_flat = flatten_bn(image_input["images"], concat=True)
image_masks_flat = (None if (image_masks :=
image_input["image_masks"]) is None
else flatten_bn(image_masks, concat=True))
image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0),
image_masks=(None if image_masks_flat is None else
image_masks_flat.unsqueeze(0)),
).squeeze(0)
# Reconstruct the batch dimension
image_features = image_features_flat.split(
image_input["num_crops"].sum(-1).tolist())
else:
image_features = self.vision_backbone( image_features = self.vision_backbone(
images=image_input["images"], images=image_input["images"],
image_masks=image_input["image_masks"], image_masks=image_input["image_masks"],
...@@ -1258,51 +1566,73 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -1258,51 +1566,73 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return image_features return image_features
def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
num_crops: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
) -> list[torch.Tensor]:
"""
Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Note:
The original code only considers patch tokens as feature
tokens, but our processor considers all image-related tokens
as feature tokens because the feature tokens need to be
consecutive in `input_ids`.
Example:
A simplified example for one item in the batch:
.. code-block::
Embedding tokens (from HF processor):
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
embed_is_patch (from HF processor):
[ False True True False True True False False ]
Encoder outputs (from model):
[ p1 p2 0 p3 p4 0 ]
feat_is_patch (from HF processor):
[ True True False True True False ]
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
num_crops_per_image = num_crops.tolist()
feats_per_image = features.split(num_crops_per_image)
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
_, _, embed_dim = features.shape
(num_embeds, ) = embed_is_patch.shape
embeds_in_batch = list[torch.Tensor]()
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
embeds[embed_is_patch] = feats[f_is_patch]
embeds_in_batch.append(embeds)
return embeds_in_batch
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) image_features = self._process_image_input(image_input)
image_input_idx = image_input["image_input_idx"]
seq_len = image_input["seq_len"] return [
batch_size, num_image, num_patch = image_features.shape[:3] self._get_mm_embeds(*args) for args in zip(
assert image_input_idx.shape == (batch_size, num_image, num_patch) image_features,
image_input["feat_is_patch"],
# insert the image feature into the embedding. image_input["num_crops"],
image_features = image_features.view(batch_size, num_image * num_patch, image_input["embed_is_patch"],
-1) )
image_input_idx = image_input_idx.view(batch_size, ]
num_image * num_patch)
valid = image_input_idx >= 0
image_features = image_features * valid[:, :, None].to(
image_features.dtype)
image_features = image_features.view(
batch_size * num_image * num_patch, -1).contiguous()
image_input_idx = image_input_idx * valid.to(image_input_idx.dtype)
offset = torch.cat([seq_len.new_zeros(1),
seq_len.cumsum(dim=0)[:-1]],
dim=0)[:, None]
image_input_idx = image_input_idx + offset.to(image_input_idx.dtype)
image_input_idx = image_input_idx.flatten()[:, None]
mat = image_input_idx == torch.arange(
seq_len.sum().item(), device=image_features.device)[None, :]
mat = mat.to(image_features.dtype)
# Note: In this original implementation from AI2, the final
# vision_embeddings will be always be the same length
# of input embeddings.
vision_embeddings = torch.einsum('nd,nm->md', image_features, mat)
# Split by the sizes of the input sequences. For each full embedding,
# extract the actual vision embeddings to be merged.
vision_embeddings = list(vision_embeddings.split(seq_len.tolist()))
for i in range(len(vision_embeddings)):
start, end = image_input['image_start_end'][i]
vision_embeddings[i] = vision_embeddings[i][start:end]
return vision_embeddings
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -1311,11 +1641,20 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -1311,11 +1641,20 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids) inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
assert self.img_patch_id is not None
# Extract the patch tokens scattered in _get_mm_embeds
patch_embeddings = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [ input_ids,
DEFAULT_IMAGE_PATCH_TOKEN_ID, DEFAULT_IM_START_TOKEN_ID, inputs_embeds,
DEFAULT_IM_END_TOKEN_ID, DEFAULT_IM_COL_TOKEN_ID cast(NestedTensors, patch_embeddings),
]) self.img_patch_id,
)
return inputs_embeds return inputs_embeds
def forward( def forward(
......
...@@ -33,8 +33,7 @@ from dataclasses import dataclass, field ...@@ -33,8 +33,7 @@ from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps from functools import cache, lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, Iterator, List, Literal, Dict, Generator, Generic, Iterator, List, Literal,
NamedTuple, Optional, Tuple, Type, TypeVar, Union, NamedTuple, Optional, Tuple, Type, TypeVar, Union)
overload)
from uuid import uuid4 from uuid import uuid4
import cloudpickle import cloudpickle
...@@ -826,38 +825,6 @@ JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"], ...@@ -826,38 +825,6 @@ JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
"""A nested JSON structure where the leaves need not be JSON-serializable.""" """A nested JSON structure where the leaves need not be JSON-serializable."""
@overload
def json_map_leaves(
func: Callable[[T], U],
value: Dict[str, JSONTree[T]],
) -> Dict[str, JSONTree[U]]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: List[JSONTree[T]],
) -> List[JSONTree[U]]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: Tuple[JSONTree[T], ...],
) -> Tuple[JSONTree[U], ...]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: JSONTree[T],
) -> JSONTree[U]:
...
def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]: def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
if isinstance(value, dict): if isinstance(value, dict):
return {k: json_map_leaves(func, v) for k, v in value.items()} return {k: json_map_leaves(func, v) for k, v in value.items()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment