Unverified Commit c9d3ecf0 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Merged multi-modal processor for Molmo (#12966)

parent fdcf64d3
...@@ -793,7 +793,7 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -793,7 +793,7 @@ See [this page](#generative-models) for more information on how to use generativ
- * `MolmoForCausalLM` - * `MolmoForCausalLM`
* Molmo * Molmo
* T + I * T + I
* `allenai/Molmo-7B-D-0924`, `allenai/Molmo-72B-0924`, etc. * `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
* ✅︎ * ✅︎
......
...@@ -27,7 +27,7 @@ from ...utils import check_logprobs_close ...@@ -27,7 +27,7 @@ from ...utils import check_logprobs_close
marks=[pytest.mark.core_model, pytest.mark.cpu_model], marks=[pytest.mark.core_model, pytest.mark.cpu_model],
), ),
pytest.param( pytest.param(
"THUDM/chatglm3-6b", # ChatGLM (text-only) "THUDM/chatglm3-6b", # chatglm (text-only)
), ),
pytest.param( pytest.param(
"meta-llama/Llama-3.2-1B-Instruct", # llama "meta-llama/Llama-3.2-1B-Instruct", # llama
......
...@@ -404,11 +404,10 @@ VLM_TEST_SETTINGS = { ...@@ -404,11 +404,10 @@ VLM_TEST_SETTINGS = {
"molmo": VLMTestInfo( "molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"], models=["allenai/Molmo-7B-D-0924"],
test_type=(VLMTestType.IMAGE), test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501 prompt_formatter=identity,
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
image_size_factors=[(),(1.0, 1.0, 1.0)], patch_hf_runner=model_utils.molmo_patch_hf_runner,
patch_hf_runner=model_utils.mlomo_patch_hf_runner,
postprocess_inputs=model_utils.molmo_post_processor, postprocess_inputs=model_utils.molmo_post_processor,
), ),
# Tests for phi3v currently live in another file because of a bug in # Tests for phi3v currently live in another file because of a bug in
......
...@@ -6,7 +6,7 @@ typically specific to a small subset of models. ...@@ -6,7 +6,7 @@ typically specific to a small subset of models.
import re import re
import types import types
from pathlib import PosixPath from pathlib import PosixPath
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from PIL.Image import Image from PIL.Image import Image
...@@ -17,9 +17,7 @@ from vllm.sequence import SampleLogprobs ...@@ -17,9 +17,7 @@ from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import patch_padding_side from vllm.transformers_utils.tokenizer import patch_padding_side
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from .....conftest import (HfRunner, ImageAsset, PromptAudioInput, from .....conftest import HfRunner, ImageAsset, _ImageAssets
PromptImageInput, PromptVideoInput, _ImageAssets)
from ....utils import TokensTextLogprobs
from .types import RunnerOutput from .types import RunnerOutput
...@@ -522,74 +520,7 @@ def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: ...@@ -522,74 +520,7 @@ def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model return hf_model
def _generate_greedy_logprobs_limit( def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
**kwargs: Any,
) -> List[TokensTextLogprobs]:
all_inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
# Process in batches for inference.
if len(all_inputs):
input_ids_lst = []
images_lst = []
images_input_idx_lst = []
imges_masks_lst = []
for inputs in all_inputs:
input_ids_lst.append(inputs["input_ids"])
images_lst.append(inputs["images"])
images_input_idx_lst.append(inputs["image_input_idx"])
imges_masks_lst.append(inputs["image_masks"])
batch_inputs = {}
batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0)
batch_inputs['images'] = torch.cat(images_lst, dim=0)
batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst,
dim=0)
batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0)
outputs = self.model.generate_from_batch(
batch=self.wrap_device(batch_inputs,
device=self.model.device.type),
generation_config=GenerationConfig(
max_new_tokens=max_tokens,
stop_strings="<|endoftext|>",
do_sample=False,
),
tokenizer=self.tokenizer,
output_hidden_states=True,
return_dict_in_generate=True,
)
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []
for index in range(len(all_inputs)):
(
seq_logprobs_lst,
output_len,
) = self._hidden_states_to_logprobs(outputs.hidden_states,
num_logprobs)
all_logprobs.append(seq_logprobs_lst)
seq_ids = outputs.sequences[index]
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
####### Molmo-specific HuggingFace runner patchers
def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Molmo.""" """Patches and returns an instance of the HfRunner to use for Molmo."""
hf_processor = hf_model.processor hf_processor = hf_model.processor
...@@ -598,10 +529,23 @@ def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: ...@@ -598,10 +529,23 @@ def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_model.processor = _processor hf_model.processor = _processor
setattr( # noqa: B010 def _generate(self, max_new_tokens=None, do_sample=None, **kwargs):
hf_model, batch = {
"generate_greedy_logprobs_limit", k: kwargs.pop(k)
types.MethodType(_generate_greedy_logprobs_limit, hf_model), for k in ("input_ids", "images", "image_input_idx", "image_masks")
if k in kwargs
}
return self.generate_from_batch(
batch,
generation_config=GenerationConfig(
max_new_tokens=max_new_tokens,
stop_strings="<|endoftext|>",
do_sample=do_sample,
),
**kwargs,
) )
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
return hf_model return hf_model
...@@ -168,6 +168,8 @@ def _test_processing_correctness( ...@@ -168,6 +168,8 @@ def _test_processing_correctness(
"mistral-community/pixtral-12b", "mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6", "openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6", "openbmb/MiniCPM-V-2_6",
"allenai/Molmo-7B-D-0924",
"allenai/Molmo-7B-O-0924",
"nvidia/NVLM-D-72B", "nvidia/NVLM-D-72B",
"Qwen/Qwen-VL-Chat", "Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2-VL-2B-Instruct",
......
...@@ -256,6 +256,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -256,6 +256,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6", "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6",
trust_remote_code=True), trust_remote_code=True),
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",
trust_remote_code=True), trust_remote_code=True),
......
This diff is collapsed.
...@@ -33,8 +33,7 @@ from dataclasses import dataclass, field ...@@ -33,8 +33,7 @@ from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps from functools import cache, lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, Iterator, List, Literal, Dict, Generator, Generic, Iterator, List, Literal,
NamedTuple, Optional, Tuple, Type, TypeVar, Union, NamedTuple, Optional, Tuple, Type, TypeVar, Union)
overload)
from uuid import uuid4 from uuid import uuid4
import cloudpickle import cloudpickle
...@@ -826,38 +825,6 @@ JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"], ...@@ -826,38 +825,6 @@ JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
"""A nested JSON structure where the leaves need not be JSON-serializable.""" """A nested JSON structure where the leaves need not be JSON-serializable."""
@overload
def json_map_leaves(
func: Callable[[T], U],
value: Dict[str, JSONTree[T]],
) -> Dict[str, JSONTree[U]]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: List[JSONTree[T]],
) -> List[JSONTree[U]]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: Tuple[JSONTree[T], ...],
) -> Tuple[JSONTree[U], ...]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: JSONTree[T],
) -> JSONTree[U]:
...
def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]: def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
if isinstance(value, dict): if isinstance(value, dict):
return {k: json_map_leaves(func, v) for k, v in value.items()} return {k: json_map_leaves(func, v) for k, v in value.items()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment