Unverified Commit 7e6f1238 authored by sangho.lee's avatar sangho.lee Committed by GitHub
Browse files

Add Molmo2 multimodal model support (#30997)


Signed-off-by: default avatarsanghol <sanghol@allenai.org>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 9312a6c0
......@@ -698,6 +698,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ |
| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ |
| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ |
| `Molmo2ForConditionalGeneration` | Molmo2 | T + I<sup>+</sup> / V | `allenai/Molmo2-4B`, `allenai/Molmo2-8B`, `allenai/Molmo2-O-7B` | ✅︎ | ✅︎ |
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ |
| `OpenCUAForConditionalGeneration` | OpenCUA-7B | T + I<sup>E+</sup> | `xlangai/OpenCUA-7B` | ✅︎ | ✅︎ |
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ |
......
......@@ -1227,6 +1227,36 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
)
# Molmo2
def run_molmo2(questions: list[str], modality: str) -> ModelRequestData:
model_name = "allenai/Molmo2-8B"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
dtype="bfloat16",
limit_mm_per_prompt={modality: 1},
max_num_batched_tokens=36864,
)
if modality == "image":
placeholder = "<|image|>"
elif modality == "video":
placeholder = "<|video|>"
else:
raise ValueError(f"Unsupported modality for molmo2: {modality}")
prompts = [
f"{placeholder}<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Nemontron_VL
def run_nemotron_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"
......@@ -1920,6 +1950,7 @@ model_example_map = {
"minimax_vl_01": run_minimax_vl_01,
"mistral3": run_mistral3,
"molmo": run_molmo,
"molmo2": run_molmo2,
"nemotron_vl": run_nemotron_vl,
"NVLM_D": run_nvlm_d,
"ovis": run_ovis,
......@@ -1949,6 +1980,7 @@ MODELS_NEED_VIDEO_METADATA = [
"glm4_1v",
"glm4_5v",
"glm4_5v_fp8",
"molmo2",
"qwen3_vl",
"qwen3_vl_moe",
]
......
......@@ -1301,6 +1301,43 @@ def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_molmo2(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "allenai/Molmo2-8B"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
dtype="bfloat16",
limit_mm_per_prompt={"image": len(image_urls)},
max_num_batched_tokens=36864,
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
model_example_map = {
"aria": load_aria,
"aya_vision": load_aya_vision,
......@@ -1323,6 +1360,7 @@ model_example_map = {
"llava-next": load_llava_next,
"llava-onevision": load_llava_onevision,
"mistral3": load_mistral3,
"molmo2": load_molmo2,
"NVLM_D": load_nvlm_d,
"ovis": load_ovis,
"ovis2_5": load_ovis2_5,
......
......@@ -123,6 +123,7 @@ MM_DATA_PATCHES = {
"glm4v": glm4_1v_patch_mm_data,
"glm4v_moe": glm4_1v_patch_mm_data,
"glmasr": glmasr_patch_mm_data,
"molmo2": qwen3_vl_patch_mm_data,
"qwen3_vl": qwen3_vl_patch_mm_data,
"qwen3_vl_moe": qwen3_vl_patch_mm_data,
}
......
......@@ -92,6 +92,11 @@ class _HfExamplesInfo:
length that is too large to fit into memory in CI.
"""
max_num_batched_tokens: int | None = None
"""
The maximum number of tokens to be processed in a single batch.
"""
revision: str | None = None
"""
The specific revision (commit hash, tag, or branch) to use for the model.
......@@ -817,6 +822,14 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"olmo": "allenai/Molmo-7B-O-0924"},
trust_remote_code=True,
),
"Molmo2ForConditionalGeneration": _HfExamplesInfo(
"allenai/Molmo2-8B",
extras={"olmo": "allenai/Molmo2-O-7B"},
min_transformers_version="4.51",
trust_remote_code=True,
# required by current PrefixLM implementation
max_num_batched_tokens=31872,
),
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", trust_remote_code=True),
"Llama_Nemotron_Nano_VL": _HfExamplesInfo(
"nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1",
......
......@@ -140,6 +140,7 @@ def can_initialize(
else None,
trust_remote_code=model_info.trust_remote_code,
max_model_len=model_info.max_model_len,
max_num_batched_tokens=model_info.max_num_batched_tokens,
# these tests seem to produce leftover memory
gpu_memory_utilization=0.80,
load_format="dummy",
......
......@@ -1127,6 +1127,7 @@ class ModelConfig:
"""Whether to use bidirectional attention for mm positions."""
MM_PREFIX_LM_MODELS = (
"gemma3",
"molmo2",
"paligemma",
)
if not hasattr(self.hf_config, "model_type"):
......
This diff is collapsed.
......@@ -384,6 +384,7 @@ _MULTIMODAL_MODELS = {
"Mistral3ForConditionalGeneration",
),
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
"Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
"Ovis": ("ovis", "Ovis"),
"Ovis2_5": ("ovis2_5", "Ovis2_5"),
......
......@@ -386,6 +386,21 @@ class PromptUpdateDetails(Generic[_S]):
return PromptUpdateDetails(full=seq, is_embed=is_embed)
@staticmethod
def select_token_ids(
seq: _S,
embed_token_ids: list[int],
) -> "PromptUpdateDetails[_S]":
def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor:
token_ids = _seq2tokens(tokenizer, full)
return torch.isin(
torch.tensor(token_ids),
torch.tensor(embed_token_ids),
)
return PromptUpdateDetails(full=seq, is_embed=is_embed)
PromptUpdateInfo: TypeAlias = PromptSeq | PromptUpdateDetails
"""
......
......@@ -6,7 +6,7 @@ from abc import abstractmethod
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
import numpy as np
import numpy.typing as npt
......@@ -439,6 +439,324 @@ class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
return frames, metadata
@VIDEO_LOADER_REGISTRY.register("molmo2")
class Molmo2VideoBackend(VideoLoader):
def get_cv2_video_api(self):
import cv2.videoio_registry as vr
api_pref = None
for backend in vr.getStreamBufferedBackends():
if not vr.hasBackend(backend):
continue
if not vr.isBackendBuiltIn(backend):
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
if abi < 1 or (abi == 1 and api < 2):
continue
api_pref = backend
break
return api_pref
@classmethod
def get_candidate_target_fps(
cls,
video_fps: float,
sampling_fps: float,
max_fps: float = 8.0,
) -> list[float]:
"""
Return the subset of `video_fps` factors that remain multiples
of `sampling_fps`.
Examples:
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
[2, 6]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
[1, 5]
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
[2]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
Traceback (most recent call last):
...
ValueError: sampling_fps=2 must divide video_fps=5 to produce
consistent frame steps.
"""
video_fps = int(video_fps)
sampling_fps = int(sampling_fps)
max_fps = int(max_fps)
if sampling_fps is None:
raise ValueError("sampling_fps must be provided")
if video_fps <= 0 or sampling_fps <= 0:
raise ValueError(
"video_fps and sampling_fps must be positive "
f"(got {video_fps}, {sampling_fps})"
)
if video_fps % sampling_fps != 0:
raise ValueError(
f"sampling_fps={sampling_fps} must divide video_fps={video_fps}."
)
candidates = []
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
if candidate > max_fps:
break
if video_fps % candidate == 0:
candidates.append(float(candidate))
return candidates
@classmethod
def get_target_fps(
cls,
video_fps: float,
max_frames: int,
total_frames: int,
frame_sample_mode: str,
candidate_target_fps: list[float],
) -> float | None:
"""
Get the target fps that best spans the videoand has the most frames sampled
"""
num_frames_sampled = 0
selected_target_fps = None
for target_fps in candidate_target_fps:
step_size = max(int(video_fps / target_fps), 1)
num_frames_sampled_at_fps = int(total_frames / step_size)
if num_frames_sampled == 0:
if (
"uniform" in frame_sample_mode
and num_frames_sampled_at_fps > max_frames
):
break
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
else:
# the candidate sampling fps increases so frame count can't decrease
assert num_frames_sampled <= num_frames_sampled_at_fps
if num_frames_sampled_at_fps > max_frames:
# choose the sampling fps that spans the video
continue
elif num_frames_sampled_at_fps > num_frames_sampled:
# both are less than max_frames; choose the one with higher
# density of frames sampled
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
return selected_target_fps
@classmethod
def get_frame_times_and_chosen_fps(
cls,
selected_target_fps: float | None,
total_frames: int,
max_frames: int,
video_fps: float,
) -> tuple[float | None, npt.NDArray]:
if selected_target_fps is None:
frame_indices = np.linspace(
0, total_frames, max_frames, endpoint=False, dtype=int
)
else:
step_size = max(int(video_fps / selected_target_fps), 1)
frame_indices = np.arange(0, total_frames, step_size)
if len(frame_indices) > max_frames:
frame_indices = frame_indices[:max_frames]
return selected_target_fps, frame_indices
@classmethod
def sample_times(
cls,
duration: float,
max_frames: int,
frame_sample_mode: str,
max_fps: int | None,
candidate_target_fps: list[float] | None = None,
**kwargs,
) -> npt.NDArray:
if frame_sample_mode == "fps":
assert candidate_target_fps is not None
# Try larger and larger FPSs until we hit one that can't span the video
sampling_fps = candidate_target_fps[0]
for candidate_fps in candidate_target_fps[1:]:
if max_frames / candidate_fps < duration:
break
sampling_fps = candidate_fps
times = np.arange(0, max_frames) / sampling_fps
times = times[times < duration]
return times
elif frame_sample_mode == "uniform_last_frame":
if max_fps is not None:
max_duration = (
max_frames - 1
) / max_fps # -1 to include the last frame
if max_duration < duration:
times = np.linspace(
0, duration, num=max_frames, endpoint=True, dtype=np.float64
)
else:
times = np.arange(0.0, stop=duration, step=1 / max_fps)
times = np.concatenate([times, [duration]], axis=0)
assert len(times) <= max_frames
else:
times = np.linspace(
0, duration, num=max_frames, endpoint=True, dtype=np.float64
)
return times
else:
raise NotImplementedError(frame_sample_mode)
@classmethod
def _sample_frames(
cls,
total_num_frames: int,
video_fps: float,
duration: float,
frame_sample_mode: str,
num_frames: int,
max_fps: int,
sampling_fps: int,
) -> npt.NDArray:
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
if total_num_frames <= 2:
indices = np.arange(total_num_frames).astype(int)
elif duration > (num_frames - 1) / max_fps: # -1 to include the last frame
# uniform fallback
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
else:
float_indices = np.arange(
0.0,
stop=total_num_frames - 1,
step=float(video_fps / max_fps),
)
if np.round(float_indices[-1]) != total_num_frames - 1:
float_indices = np.concatenate(
[float_indices, [total_num_frames - 1]], axis=0
)
indices = np.round(float_indices).astype(int)
assert indices[-1] < total_num_frames
assert len(float_indices) <= num_frames
elif frame_sample_mode == "uniform_last_frame":
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
elif frame_sample_mode == "fps":
candidate_target_fps = cls.get_candidate_target_fps(video_fps, sampling_fps)
selected_target_fps = cls.get_target_fps(
video_fps,
num_frames,
total_num_frames,
frame_sample_mode,
candidate_target_fps,
)
_, indices = cls.get_frame_times_and_chosen_fps(
selected_target_fps,
total_num_frames,
num_frames,
video_fps,
)
else:
raise NotImplementedError(frame_sample_mode)
return indices
@classmethod
def load_bytes_opencv(
cls,
data: bytes,
frame_sample_mode: str | None = None,
num_frames: int = -1,
max_fps: int = 2,
sampling_fps: int = 2,
**kwargs,
) -> tuple[npt.NDArray, dict[str, Any]]:
import cv2
backend = cls().get_cv2_video_api()
cap = cv2.VideoCapture(BytesIO(data), backend, [])
if not cap.isOpened():
raise ValueError("Could not open video stream")
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames_num / original_fps if original_fps > 0 else 0
if frame_sample_mode is None:
# Use transformers transformers.video_utils.VideoMetadata format
frame_idx = list(range(0, total_frames_num))
frame_idx_set = set(frame_idx)
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
cap, frame_idx_set, total_frames_num, max(frame_idx)
)
do_sample_frames = valid_num_frames == total_frames_num
metadata = {
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv",
"do_sample_frames": do_sample_frames,
}
if not do_sample_frames:
metadata["frames_indices"] = valid_frame_indices
return frames, metadata
frame_idx = cls._sample_frames(
total_frames_num,
original_fps,
duration,
frame_sample_mode,
num_frames,
max_fps,
sampling_fps,
).tolist()
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
cap,
set(frame_idx),
len(frame_idx),
total_frames_num - 1,
)
metadata = {
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv",
"frames_indices": valid_frame_indices,
"do_sample_frames": False,
}
return frames, metadata
@classmethod
def load_bytes(
cls,
data: bytes,
num_frames: int = -1,
**kwargs,
) -> tuple[npt.NDArray, dict[str, Any]]:
frame_sample_mode = cast(str | None, kwargs.pop("frame_sample_mode", None))
max_fps = cast(int, kwargs.pop("max_fps", 2))
sampling_fps = cast(int, kwargs.pop("sampling_fps", 2))
out = cls.load_bytes_opencv(
data,
frame_sample_mode,
num_frames,
max_fps,
sampling_fps,
**kwargs,
)
return out
class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
def __init__(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment