Unverified Commit c9d3ecf0 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Merged multi-modal processor for Molmo (#12966)

parent fdcf64d3
......@@ -793,7 +793,7 @@ See [this page](#generative-models) for more information on how to use generativ
- * `MolmoForCausalLM`
* Molmo
* T + I
* `allenai/Molmo-7B-D-0924`, `allenai/Molmo-72B-0924`, etc.
* `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc.
* ✅︎
* ✅︎
* ✅︎
......
......@@ -27,7 +27,7 @@ from ...utils import check_logprobs_close
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
pytest.param(
"THUDM/chatglm3-6b", # ChatGLM (text-only)
"THUDM/chatglm3-6b", # chatglm (text-only)
),
pytest.param(
"meta-llama/Llama-3.2-1B-Instruct", # llama
......
......@@ -404,11 +404,10 @@ VLM_TEST_SETTINGS = {
"molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"],
test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501
prompt_formatter=identity,
max_model_len=4096,
max_num_seqs=2,
image_size_factors=[(),(1.0, 1.0, 1.0)],
patch_hf_runner=model_utils.mlomo_patch_hf_runner,
patch_hf_runner=model_utils.molmo_patch_hf_runner,
postprocess_inputs=model_utils.molmo_post_processor,
),
# Tests for phi3v currently live in another file because of a bug in
......
......@@ -6,7 +6,7 @@ typically specific to a small subset of models.
import re
import types
from pathlib import PosixPath
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
from PIL.Image import Image
......@@ -17,9 +17,7 @@ from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import patch_padding_side
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from .....conftest import (HfRunner, ImageAsset, PromptAudioInput,
PromptImageInput, PromptVideoInput, _ImageAssets)
from ....utils import TokensTextLogprobs
from .....conftest import HfRunner, ImageAsset, _ImageAssets
from .types import RunnerOutput
......@@ -522,74 +520,7 @@ def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model
def _generate_greedy_logprobs_limit(
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
**kwargs: Any,
) -> List[TokensTextLogprobs]:
all_inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
# Process in batches for inference.
if len(all_inputs):
input_ids_lst = []
images_lst = []
images_input_idx_lst = []
imges_masks_lst = []
for inputs in all_inputs:
input_ids_lst.append(inputs["input_ids"])
images_lst.append(inputs["images"])
images_input_idx_lst.append(inputs["image_input_idx"])
imges_masks_lst.append(inputs["image_masks"])
batch_inputs = {}
batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0)
batch_inputs['images'] = torch.cat(images_lst, dim=0)
batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst,
dim=0)
batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0)
outputs = self.model.generate_from_batch(
batch=self.wrap_device(batch_inputs,
device=self.model.device.type),
generation_config=GenerationConfig(
max_new_tokens=max_tokens,
stop_strings="<|endoftext|>",
do_sample=False,
),
tokenizer=self.tokenizer,
output_hidden_states=True,
return_dict_in_generate=True,
)
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []
for index in range(len(all_inputs)):
(
seq_logprobs_lst,
output_len,
) = self._hidden_states_to_logprobs(outputs.hidden_states,
num_logprobs)
all_logprobs.append(seq_logprobs_lst)
seq_ids = outputs.sequences[index]
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
####### Molmo-specific HuggingFace runner patchers
def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Molmo."""
hf_processor = hf_model.processor
......@@ -598,10 +529,23 @@ def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_model.processor = _processor
setattr( # noqa: B010
hf_model,
"generate_greedy_logprobs_limit",
types.MethodType(_generate_greedy_logprobs_limit, hf_model),
)
def _generate(self, max_new_tokens=None, do_sample=None, **kwargs):
batch = {
k: kwargs.pop(k)
for k in ("input_ids", "images", "image_input_idx", "image_masks")
if k in kwargs
}
return self.generate_from_batch(
batch,
generation_config=GenerationConfig(
max_new_tokens=max_new_tokens,
stop_strings="<|endoftext|>",
do_sample=do_sample,
),
**kwargs,
)
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
return hf_model
......@@ -168,6 +168,8 @@ def _test_processing_correctness(
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6",
"allenai/Molmo-7B-D-0924",
"allenai/Molmo-7B-O-0924",
"nvidia/NVLM-D-72B",
"Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct",
......
......@@ -256,6 +256,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6",
trust_remote_code=True),
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501
trust_remote_code=True),
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",
trust_remote_code=True),
......
This diff is collapsed.
......@@ -353,17 +353,17 @@ class MultiModalFieldConfig:
Example:
.. code-block::
.. code-block::
Input:
Data: [[AAAA]
[BBBB]
[CCCC]]
Input:
Data: [[AAAA]
[BBBB]
[CCCC]]
Output:
Element 1: [AAAA]
Element 2: [BBBB]
Element 3: [CCCC]
Output:
Element 1: [AAAA]
Element 2: [BBBB]
Element 3: [CCCC]
"""
return MultiModalFieldConfig(
field=MultiModalBatchedField(),
......@@ -384,18 +384,18 @@ class MultiModalFieldConfig:
Example:
.. code-block::
Given:
slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
.. code-block::
Given:
slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
Input:
Data: [AAABBBBCC]
Input:
Data: [AAABBBBCC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
"""
return MultiModalFieldConfig(
field=MultiModalFlatField(slices=slices),
......@@ -416,18 +416,18 @@ class MultiModalFieldConfig:
Example:
.. code-block::
Given:
size_per_item: [3, 4, 2]
.. code-block::
Given:
size_per_item: [3, 4, 2]
Input:
Data: [AAABBBBCC]
Input:
Data: [AAABBBBCC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
See also:
:func:`MultiModalFieldConfig.flat`
......@@ -456,19 +456,19 @@ class MultiModalFieldConfig:
Example:
.. code-block::
Given:
batch_size: 4
.. code-block::
Given:
batch_size: 4
Input:
Data: [XYZ]
Input:
Data: [XYZ]
Output:
Element 1: [XYZ]
Element 2: [XYZ]
Element 3: [XYZ]
Element 4: [XYZ]
Output:
Element 1: [XYZ]
Element 2: [XYZ]
Element 3: [XYZ]
Element 4: [XYZ]
"""
return MultiModalFieldConfig(
field=MultiModalSharedField(batch_size),
......
......@@ -33,8 +33,7 @@ from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, Iterator, List, Literal,
NamedTuple, Optional, Tuple, Type, TypeVar, Union,
overload)
NamedTuple, Optional, Tuple, Type, TypeVar, Union)
from uuid import uuid4
import cloudpickle
......@@ -826,38 +825,6 @@ JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
"""A nested JSON structure where the leaves need not be JSON-serializable."""
@overload
def json_map_leaves(
func: Callable[[T], U],
value: Dict[str, JSONTree[T]],
) -> Dict[str, JSONTree[U]]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: List[JSONTree[T]],
) -> List[JSONTree[U]]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: Tuple[JSONTree[T], ...],
) -> Tuple[JSONTree[U], ...]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: JSONTree[T],
) -> JSONTree[U]:
...
def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
if isinstance(value, dict):
return {k: json_map_leaves(func, v) for k, v in value.items()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment