Unverified Commit 7e6f1238 authored by sangho.lee's avatar sangho.lee Committed by GitHub
Browse files

Add Molmo2 multimodal model support (#30997)


Signed-off-by: default avatarsanghol <sanghol@allenai.org>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 9312a6c0
...@@ -698,6 +698,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen ...@@ -698,6 +698,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | | `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ |
| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | | `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ |
| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | | `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ |
| `Molmo2ForConditionalGeneration` | Molmo2 | T + I<sup>+</sup> / V | `allenai/Molmo2-4B`, `allenai/Molmo2-8B`, `allenai/Molmo2-O-7B` | ✅︎ | ✅︎ |
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | | `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ |
| `OpenCUAForConditionalGeneration` | OpenCUA-7B | T + I<sup>E+</sup> | `xlangai/OpenCUA-7B` | ✅︎ | ✅︎ | | `OpenCUAForConditionalGeneration` | OpenCUA-7B | T + I<sup>E+</sup> | `xlangai/OpenCUA-7B` | ✅︎ | ✅︎ |
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | | `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ |
......
...@@ -1227,6 +1227,36 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData: ...@@ -1227,6 +1227,36 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
) )
# Molmo2
def run_molmo2(questions: list[str], modality: str) -> ModelRequestData:
model_name = "allenai/Molmo2-8B"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
dtype="bfloat16",
limit_mm_per_prompt={modality: 1},
max_num_batched_tokens=36864,
)
if modality == "image":
placeholder = "<|image|>"
elif modality == "video":
placeholder = "<|video|>"
else:
raise ValueError(f"Unsupported modality for molmo2: {modality}")
prompts = [
f"{placeholder}<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Nemontron_VL # Nemontron_VL
def run_nemotron_vl(questions: list[str], modality: str) -> ModelRequestData: def run_nemotron_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1" model_name = "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"
...@@ -1920,6 +1950,7 @@ model_example_map = { ...@@ -1920,6 +1950,7 @@ model_example_map = {
"minimax_vl_01": run_minimax_vl_01, "minimax_vl_01": run_minimax_vl_01,
"mistral3": run_mistral3, "mistral3": run_mistral3,
"molmo": run_molmo, "molmo": run_molmo,
"molmo2": run_molmo2,
"nemotron_vl": run_nemotron_vl, "nemotron_vl": run_nemotron_vl,
"NVLM_D": run_nvlm_d, "NVLM_D": run_nvlm_d,
"ovis": run_ovis, "ovis": run_ovis,
...@@ -1949,6 +1980,7 @@ MODELS_NEED_VIDEO_METADATA = [ ...@@ -1949,6 +1980,7 @@ MODELS_NEED_VIDEO_METADATA = [
"glm4_1v", "glm4_1v",
"glm4_5v", "glm4_5v",
"glm4_5v_fp8", "glm4_5v_fp8",
"molmo2",
"qwen3_vl", "qwen3_vl",
"qwen3_vl_moe", "qwen3_vl_moe",
] ]
......
...@@ -1301,6 +1301,43 @@ def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -1301,6 +1301,43 @@ def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData:
) )
def load_molmo2(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "allenai/Molmo2-8B"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
dtype="bfloat16",
limit_mm_per_prompt={"image": len(image_urls)},
max_num_batched_tokens=36864,
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
model_example_map = { model_example_map = {
"aria": load_aria, "aria": load_aria,
"aya_vision": load_aya_vision, "aya_vision": load_aya_vision,
...@@ -1323,6 +1360,7 @@ model_example_map = { ...@@ -1323,6 +1360,7 @@ model_example_map = {
"llava-next": load_llava_next, "llava-next": load_llava_next,
"llava-onevision": load_llava_onevision, "llava-onevision": load_llava_onevision,
"mistral3": load_mistral3, "mistral3": load_mistral3,
"molmo2": load_molmo2,
"NVLM_D": load_nvlm_d, "NVLM_D": load_nvlm_d,
"ovis": load_ovis, "ovis": load_ovis,
"ovis2_5": load_ovis2_5, "ovis2_5": load_ovis2_5,
......
...@@ -123,6 +123,7 @@ MM_DATA_PATCHES = { ...@@ -123,6 +123,7 @@ MM_DATA_PATCHES = {
"glm4v": glm4_1v_patch_mm_data, "glm4v": glm4_1v_patch_mm_data,
"glm4v_moe": glm4_1v_patch_mm_data, "glm4v_moe": glm4_1v_patch_mm_data,
"glmasr": glmasr_patch_mm_data, "glmasr": glmasr_patch_mm_data,
"molmo2": qwen3_vl_patch_mm_data,
"qwen3_vl": qwen3_vl_patch_mm_data, "qwen3_vl": qwen3_vl_patch_mm_data,
"qwen3_vl_moe": qwen3_vl_patch_mm_data, "qwen3_vl_moe": qwen3_vl_patch_mm_data,
} }
......
...@@ -92,6 +92,11 @@ class _HfExamplesInfo: ...@@ -92,6 +92,11 @@ class _HfExamplesInfo:
length that is too large to fit into memory in CI. length that is too large to fit into memory in CI.
""" """
max_num_batched_tokens: int | None = None
"""
The maximum number of tokens to be processed in a single batch.
"""
revision: str | None = None revision: str | None = None
""" """
The specific revision (commit hash, tag, or branch) to use for the model. The specific revision (commit hash, tag, or branch) to use for the model.
...@@ -817,6 +822,14 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -817,6 +822,14 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"olmo": "allenai/Molmo-7B-O-0924"}, extras={"olmo": "allenai/Molmo-7B-O-0924"},
trust_remote_code=True, trust_remote_code=True,
), ),
"Molmo2ForConditionalGeneration": _HfExamplesInfo(
"allenai/Molmo2-8B",
extras={"olmo": "allenai/Molmo2-O-7B"},
min_transformers_version="4.51",
trust_remote_code=True,
# required by current PrefixLM implementation
max_num_batched_tokens=31872,
),
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", trust_remote_code=True), "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", trust_remote_code=True),
"Llama_Nemotron_Nano_VL": _HfExamplesInfo( "Llama_Nemotron_Nano_VL": _HfExamplesInfo(
"nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1",
......
...@@ -140,6 +140,7 @@ def can_initialize( ...@@ -140,6 +140,7 @@ def can_initialize(
else None, else None,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
max_model_len=model_info.max_model_len, max_model_len=model_info.max_model_len,
max_num_batched_tokens=model_info.max_num_batched_tokens,
# these tests seem to produce leftover memory # these tests seem to produce leftover memory
gpu_memory_utilization=0.80, gpu_memory_utilization=0.80,
load_format="dummy", load_format="dummy",
......
...@@ -1127,6 +1127,7 @@ class ModelConfig: ...@@ -1127,6 +1127,7 @@ class ModelConfig:
"""Whether to use bidirectional attention for mm positions.""" """Whether to use bidirectional attention for mm positions."""
MM_PREFIX_LM_MODELS = ( MM_PREFIX_LM_MODELS = (
"gemma3", "gemma3",
"molmo2",
"paligemma", "paligemma",
) )
if not hasattr(self.hf_config, "model_type"): if not hasattr(self.hf_config, "model_type"):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields
from functools import cached_property, partial
from itertools import islice
from typing import Annotated, Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import ImageOps
from PIL.Image import Image
from transformers import (
BatchFeature,
PretrainedConfig,
ProcessorMixin,
TensorType,
)
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from transformers.video_utils import VideoInput, VideoMetadata
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import MulAndSilu, SiluAndMul, get_act_fn
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
VideoItem,
)
from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import round_down
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
SupportsQuant,
)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
_merge_multimodal_embeddings,
extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
logger = init_logger(__name__)
# Special tokens. These should be present in any tokenizer we use
# because the preprocessor relies on them.
IMAGE_PROMPT = "<|image|>"
VIDEO_PROMPT = "<|video|>"
_MAX_VIDEO_FPS = 8
class Molmo2ImageInputs(TensorSchema):
"""
Dimensions:
- nc: The total number of crops (dynamic)
- np: The total number of patches per crop
- cps: Number of channels * patch_size * patch_size
- npp: Number of pooled patches (dynamic)
- pp: pooling_size * pooling_size
- ni: Number of images
- nt: Number of image tokens (dynamic)
"""
pixel_values: Annotated[torch.Tensor, TensorShape("nc", "np", "cps")]
token_pooling: Annotated[torch.Tensor, TensorShape("npp", "pp")]
"""
An index tensor that maps image features to their corresponding
patch tokens before pooling.
"""
num_pooled_patches: Annotated[torch.Tensor, TensorShape("ni")]
image_tokens: Annotated[torch.BoolTensor, TensorShape("nt")]
num_image_tokens: Annotated[torch.Tensor, TensorShape("ni")]
class Molmo2VideoInputs(TensorSchema):
"""
Dimensions:
- nc: The total number of frames (dynamic)
- np: The total number of patches per frame
- cps: Number of channels * patch_size * patch_size
- npp: Number of pooled patches (dynamic)
- pp: pooling_size * pooling_size
- nv: Number of videos
- nt: Number of video tokens (dynamic)
"""
pixel_values_videos: Annotated[torch.Tensor, TensorShape("nc", "np", "cps")]
token_pooling: Annotated[torch.Tensor, TensorShape("npp", "pp")]
"""
An index tensor that maps image features to their corresponding
patch tokens before pooling.
"""
num_pooled_patches: Annotated[torch.Tensor, TensorShape("nv")]
video_tokens: Annotated[torch.BoolTensor, TensorShape("nt")]
num_video_tokens: Annotated[torch.Tensor, TensorShape("nv")]
@dataclass
class VitConfig:
"""Config for a vision transformer"""
hidden_size: int = 1152
intermediate_size: int = 4304
num_hidden_layers: int = 27
num_attention_heads: int = 16
num_key_value_heads: int = 16
head_dim: int = 72
hidden_act: str = "gelu_pytorch_tanh"
layer_norm_eps: float = 1e-6
image_default_input_size: tuple[int, int] = (378, 378)
image_patch_size: int = 14
image_num_pos: int = 577
def __post_init__(self):
self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment]
@property
def image_num_patch(self):
h, w = self.image_default_input_size
return h // self.image_patch_size, w // self.image_patch_size
@dataclass
class AdapterConfig:
"""Config for a vit-llm adapter"""
vit_layers: tuple[int, int] = (-3, -9)
pooling_attention_mask: bool = False
hidden_size: int = 1152
num_attention_heads: int = 16
num_key_value_heads: int = 16
head_dim: int = 72
hidden_act: str = "silu"
intermediate_size: int = 18944
text_hidden_size: int = 3584
@dataclass
class TextConfig:
"""Configuration for a text model transformer"""
hidden_size: int = 3584
"""
The hidden size of the model.
"""
num_attention_heads: int = 28
"""
The number of self-attention heads.
"""
num_key_value_heads: int = 4
"""
The number of heads to use for keys and values.
"""
head_dim: int = 128
"""
The head dimensionality for the attention mechanism.
"""
vocab_size: int = 152064
"""Vocabulary size of the model."""
additional_vocab_size: int = 128
"""Number of additional tokens to have the input embeddings for"""
qkv_bias: bool = True
"""
Do QKV projection a bias
"""
num_hidden_layers: int = 48
"""
The number of layers/blocks.
"""
intermediate_size: int = 18944
"""
The hidden size for the MLP.
"""
hidden_act: str = "silu"
"""
The activation function to use within the MLP layers.
"""
max_position_embeddings: int = 4096
"""
Max positional embeddings to use in RoPE cache
"""
rope_theta: float = 1000000.0
"""
RoPE theta parameter.
"""
use_qk_norm: bool = False
"""
Apply layer norm to the keys and queries within the attention mechanism.
This can help stabilize training.
"""
qk_norm_type: str = "olmo"
"""
The type of layer norm to use for the keys and queries.
Can be "olmo" or "qwen3".
"""
layer_norm_eps: float = 1e-6
"""
epsilon for layer norms
"""
norm_after: bool = False
"""Apply layer norm before and after the attention and MLP blocks."""
rope_scaling_layers: tuple[int, ...] | None = None
"""
RoPE scaling layers.
"""
class ViTMLP(nn.Module):
"""MLP used in Vision Transformer."""
def __init__(
self,
dim: int,
hidden_dim: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
) -> None:
super().__init__()
self.w1 = ColumnParallelLinear(
dim,
hidden_dim,
bias=True,
quant_config=quant_config,
)
# Activation function.
self.act = get_act_fn(hidden_act)
self.w2 = RowParallelLinear(
hidden_dim,
dim,
bias=True,
quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.w1(x)
x = self.act(x)
x, _ = self.w2(x)
return x
class ViTMultiHeadDotProductAttention(nn.Module):
"""Multi-head attention used in Vision Transformer."""
def __init__(
self,
hidden_size: int,
num_heads: int,
num_key_value_heads: int,
head_dim: int,
use_bias: bool = True,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.total_num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert self.hidden_size % self.total_num_heads == 0
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.head_dim = head_dim
assert self.head_dim == self.hidden_size // self.total_num_heads
self.total_num_kv_heads = num_key_value_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.merged_qkv = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=use_bias,
quant_config=quant_config,
)
self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=use_bias,
quant_config=quant_config,
)
self.scale = self.head_dim**-0.5
self.attn = MMEncoderAttention(
self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
prefix=f"{prefix}.attn",
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
qkv, _ = self.merged_qkv(inputs)
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
output = self.attn(xq, xk, xv)
output, _ = self.wo(output)
return output
class Molmo2VisionBlock(nn.Module):
"""Residual attention block used in Vision Transformer."""
def __init__(
self,
config: VitConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.attention = ViTMultiHeadDotProductAttention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
head_dim=config.head_dim,
quant_config=quant_config,
prefix=f"{prefix}.attention",
)
self.feed_forward = ViTMLP(
dim=config.hidden_size,
hidden_dim=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.attention_norm = nn.LayerNorm(
config.hidden_size,
eps=config.layer_norm_eps,
)
self.ffn_norm = nn.LayerNorm(
config.hidden_size,
eps=config.layer_norm_eps,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attention(self.attention_norm(x))
x = x + self.feed_forward(self.ffn_norm(x))
return x
class Molmo2VisionBlockCollection(nn.Module):
"""Collection of residual attention blocks used in Vision Transformer."""
def __init__(
self,
config: VitConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.resblocks = nn.ModuleList(
[
Molmo2VisionBlock(
config,
quant_config,
prefix=f"{prefix}.resblocks.{layer_idx}",
)
for layer_idx in range(config.num_hidden_layers)
]
)
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
hidden_states = []
for r in self.resblocks:
x = r(x)
hidden_states.append(x)
return hidden_states
class Molmo2VisionTransformer(nn.Module):
"""Vision Transformer used in Vision Backbone."""
def __init__(
self,
config: VitConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
scale = config.hidden_size**-0.5
self.num_prefix_tokens: int = 0 # no class embeddings
self.patch_num = config.image_num_patch
self.positional_embedding = nn.Parameter(
torch.randn(config.image_num_pos, config.hidden_size) * scale,
)
image_patch_size = config.image_patch_size
self.patch_embedding = nn.Linear(
image_patch_size * image_patch_size * 3,
config.hidden_size,
bias=True,
)
self.transformer = Molmo2VisionBlockCollection(
config,
quant_config,
prefix=f"{prefix}.transformer",
)
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
pos_emb = self.positional_embedding
pos_emb = pos_emb.reshape(
(
int(math.sqrt(pos_emb.shape[0])),
int(math.sqrt(pos_emb.shape[0])),
pos_emb.shape[1],
)
)
(patch_num_0, patch_num_1) = patch_num
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
# from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
pos_emb = F.interpolate(
pos_emb,
size=(patch_num_0, patch_num_1),
mode="bicubic",
align_corners=False,
antialias=True,
)
pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
x = x + pos_emb[None, :, :].to(x.dtype)
return x
def forward(
self,
x: torch.Tensor,
patch_num: int | None = None,
) -> list[torch.Tensor]:
"""
: param x: (batch_size, num_patch, n_pixels)
"""
if patch_num is None:
patch_num = self.patch_num
x = self.patch_embedding(x)
x = self.add_pos_emb(x, patch_num)
hidden_states = self.transformer(x)
return hidden_states
class ImagePoolingAttention(nn.Module):
"""Multi-head attention used for image pooling"""
def __init__(
self,
input_dim: int,
hidden_size: int,
num_heads: int,
num_key_value_heads: int,
head_dim: int,
use_bias: bool = True,
use_pytorch_sdpa: bool = False,
quant_config: QuantizationConfig | None = None,
) -> None:
super().__init__()
self.input_dim = input_dim
self.hidden_size = hidden_size
self.total_num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert self.hidden_size % self.total_num_heads == 0
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.head_dim = head_dim
assert self.head_dim == self.hidden_size // self.total_num_heads
self.total_num_kv_heads = num_key_value_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.kv_size = self.num_kv_heads * self.head_dim
self.q_proj = ColumnParallelLinear(
self.input_dim,
self.total_num_heads * self.head_dim,
bias=use_bias,
quant_config=quant_config,
)
self.merged_kv = MergedColumnParallelLinear(
self.input_dim,
[self.total_num_kv_heads * self.head_dim] * 2,
bias=use_bias,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=use_bias,
quant_config=quant_config,
)
self.scale = self.head_dim**-0.5
self.use_pytorch_sdpa = use_pytorch_sdpa
if use_pytorch_sdpa:
self.attn = None
else:
self.attn = MMEncoderAttention(
self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
)
def forward_sdpa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
bsz, q_len, _ = query.size()
kv_len = key.size(1)
query = query.view(bsz, q_len, self.num_heads, self.head_dim)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
if self.num_heads != self.num_kv_heads:
key = torch.repeat_interleave(
key,
self.num_heads // self.num_kv_heads,
dim=2,
)
value = torch.repeat_interleave(
value,
self.num_heads // self.num_kv_heads,
dim=2,
)
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask,
is_causal=False,
).transpose(1, 2)
return out.reshape(bsz, q_len, -1)
def forward(
self,
inputs_q: torch.Tensor,
inputs_kv: torch.Tensor,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
xq, _ = self.q_proj(inputs_q)
kv, _ = self.merged_kv(inputs_kv)
xk, xv = kv.split([self.kv_size, self.kv_size], dim=-1)
if self.use_pytorch_sdpa:
output = self.forward_sdpa(xq, xk, xv, attn_mask)
else:
output = self.attn(xq, xk, xv)
output, _ = self.o_proj(output)
return output
class ImageProjectorMLP(nn.Module):
"""MLP used for the image projector"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
) -> None:
super().__init__()
self.merged_linear = MergedColumnParallelLinear(
input_dim,
[hidden_dim] * 2,
bias=False,
quant_config=quant_config,
)
# Activation function.
assert hidden_act == "silu"
self.act_fn = SiluAndMul()
# Feed-forward output projection.
self.down_proj = RowParallelLinear(
hidden_dim,
output_dim,
bias=False,
quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.merged_linear(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class Molmo2VisionBackbone(nn.Module, SupportsQuant):
packed_modules_mapping = {
"merged_qkv": ["wq", "wk", "wv"], # vision backbone
"merged_kv": ["k_proj", "v_proj"], # image_pooling_2d
"merged_linear": ["gate_proj", "up_proj"],
}
def __init__(
self,
vit_config: VitConfig,
adapter_config: AdapterConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.vit_config = vit_config
self.adapter_config = adapter_config
self.vit_layers = []
for layer in adapter_config.vit_layers:
if layer >= 0:
self.vit_layers.append(layer)
else:
self.vit_layers.append(layer + vit_config.num_hidden_layers)
last_layer_needed = max(self.vit_layers) + 1
if last_layer_needed < vit_config.num_hidden_layers:
vit_config.num_hidden_layers = last_layer_needed
self.image_vit = Molmo2VisionTransformer(
vit_config,
quant_config,
prefix=f"{prefix}.image_vit",
)
self.num_prefix_tokens: int = self.image_vit.num_prefix_tokens
pool_dim = vit_config.hidden_size * len(adapter_config.vit_layers)
self.image_pooling_2d = ImagePoolingAttention(
input_dim=pool_dim,
hidden_size=adapter_config.hidden_size,
num_heads=adapter_config.num_attention_heads,
num_key_value_heads=adapter_config.num_key_value_heads,
head_dim=adapter_config.head_dim,
use_pytorch_sdpa=adapter_config.pooling_attention_mask,
quant_config=quant_config,
)
self.image_projector = ImageProjectorMLP(
input_dim=adapter_config.hidden_size,
hidden_dim=adapter_config.intermediate_size,
output_dim=adapter_config.text_hidden_size,
hidden_act=adapter_config.hidden_act,
quant_config=quant_config,
)
@property
def dtype(self) -> torch.dtype:
return self.image_vit.patch_embedding.weight.dtype
@property
def device(self) -> torch.device:
return self.image_vit.patch_embedding.weight.device
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
"""
: param images: (batch_size, num_crops, num_patch, n_pixels)
"""
B, T, N, D = images.shape
images = images.view(B * T, N, D)
image_features = self.image_vit(images)
features = []
for layer in self.vit_layers:
features.append(image_features[layer])
image_features = torch.cat(features, dim=-1)
if self.num_prefix_tokens > 0:
image_features = image_features[:, 1:]
image_features = image_features.view(B, T, N, -1)
return image_features
def forward(
self,
images: torch.Tensor,
token_pooling: torch.Tensor,
) -> torch.Tensor:
# image_features shape:
# (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
batch_size, num_image = images.shape[:2]
images = images.to(device=self.device, dtype=self.dtype)
image_features = self.encode_image(images)
dim = image_features.shape[-1]
valid = token_pooling >= 0
valid_token = torch.any(valid, -1)
# Use `token_pooling` to arange the features for image pooling
batch_idx = torch.arange(
token_pooling.shape[0],
dtype=torch.long,
device=token_pooling.device,
)
batch_idx = torch.tile(
batch_idx.view(batch_size, 1, 1),
[1, token_pooling.shape[1], token_pooling.shape[2]],
)
# Now [batch, num_features, num_pooled_patches, dim]
to_pool = image_features.reshape(batch_size, -1, dim)[
batch_idx, torch.clip(token_pooling, 0)
]
to_pool = to_pool * valid.to(self.dtype)[:, :, :, None]
to_pool = to_pool.reshape([-1, token_pooling.shape[-1], dim])
if self.adapter_config.pooling_attention_mask:
attn_mask = valid.reshape([-1, 1, 1, valid.shape[-1]])
denom = valid.view(-1, to_pool.shape[-2]).float().sum(-1)
denom = torch.where(denom == 0, 1, denom)
query = to_pool.sum(-2, keepdim=True) / denom[:, None, None].to(
to_pool.dtype
)
else:
attn_mask = None
query = to_pool.mean(-2, keepdim=True)
pooled_features = self.image_pooling_2d(query, to_pool, attn_mask=attn_mask)
pooled_features = pooled_features.reshape(
[batch_size, -1, pooled_features.shape[-1]]
)
# MLP layer to map the feature.
pooled_features = self.image_projector(pooled_features)
return pooled_features.view(-1, pooled_features.shape[-1])[
valid_token.flatten()
]
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("merged_qkv", "wq", "q"),
("merged_qkv", "wk", "k"),
("merged_qkv", "wv", "v"),
("merged_kv", "k_proj", 0),
("merged_kv", "v_proj", 1),
("merged_linear", "gate_proj", 0),
("merged_linear", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Molmo2Attention(nn.Module):
"""Molmo2's LLM Attention."""
def __init__(
self,
config: TextConfig,
rope_parameters: dict[str, Any],
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.hidden_size % self.total_num_heads == 0
assert self.total_num_heads % self.tp_size == 0
self.num_heads = self.total_num_heads // self.tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= self.tp_size:
assert self.total_num_kv_heads % self.tp_size == 0
else:
assert self.tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
self.head_dim = config.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
# Attention input projection. Projects x -> (q, k, v)
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.qkv_bias,
quant_config=quant_config,
)
self.tp_rank: int | None = None
self.k_norm: nn.Module | None = None
self.q_norm: nn.Module | None = None
self.qk_norm_type: str | None = None
if config.use_qk_norm:
k_norm_size = (
self.head_dim
if config.qk_norm_type == "qwen3"
else self.total_num_kv_heads * self.head_dim
)
self.tp_rank = get_tensor_model_parallel_rank()
self.k_norm = RMSNorm(k_norm_size, eps=config.layer_norm_eps)
q_norm_size = (
self.head_dim
if config.qk_norm_type == "qwen3"
else self.total_num_heads * self.head_dim
)
self.q_norm = RMSNorm(q_norm_size, eps=config.layer_norm_eps)
self.qk_norm_type = config.qk_norm_type
# Rotary embeddings. Rope scaling is only applied on full attention layers.
layer_idx = extract_layer_index(prefix)
if (
config.rope_scaling_layers is not None
and layer_idx not in config.rope_scaling_layers
):
rope_theta = rope_parameters["rope_theta"]
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
self.rotary_emb = get_rope(
self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=rope_parameters,
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
# Attention output projection.
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
)
def _apply_qk_norm(
self,
q: torch.Tensor,
k: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm(q)
k = self.k_norm(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
**kwargs: object,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if (
self.q_norm is not None
and self.k_norm is not None
and self.qk_norm_type == "olmo"
):
q, k = self._apply_qk_norm(q, k)
elif self.q_norm is not None and self.k_norm is not None:
q_by_head = q.view(
*q.shape[:-1],
q.shape[-1] // self.head_dim,
self.head_dim,
)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(
*k.shape[:-1],
k.shape[-1] // self.head_dim,
self.head_dim,
)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class LanguageModelMLP(nn.Module):
"""Molmo2's LLM mlp."""
def __init__(
self,
input_dim: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
) -> None:
super().__init__()
self.up_gate_proj = MergedColumnParallelLinear(
input_dim,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
)
# Activation function.
assert hidden_act == "silu"
self.act_fn = MulAndSilu()
# Feed-forward output projection.
self.down_proj = RowParallelLinear(
intermediate_size,
input_dim,
bias=False,
quant_config=quant_config,
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
up_gate, _ = self.up_gate_proj(x)
x = self.act_fn(up_gate)
x, _ = self.down_proj(x)
return x
class Molmo2DecoderLayer(nn.Module):
def __init__(
self,
config: TextConfig,
rope_parameters: dict[str, Any],
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
# Attention block.
self.self_attn = Molmo2Attention(
config,
rope_parameters,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn",
)
# MLP block.
self.mlp = LanguageModelMLP(
config.hidden_size,
config.intermediate_size,
config.hidden_act,
quant_config,
)
# LayerNorm
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size,
eps=config.layer_norm_eps,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
**kwargs: object,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
**kwargs,
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class Molmo2DecoderNormAfterLayer(Molmo2DecoderLayer):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
**kwargs: object,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
# Self Attention
residual = hidden_states
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
**kwargs,
)
hidden_states = self.input_layernorm(hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = hidden_states + residual
residual = None
return hidden_states, residual
@support_torch_compile
class Molmo2TextModel(nn.Module, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
if hasattr(config, "text_config"):
hf_text_config = config.text_config
else:
hf_text_config = config.llm_config
kwargs = {}
for field in fields(TextConfig):
kwargs[field.name] = getattr(hf_text_config, field.name)
text_config = TextConfig(**kwargs)
self.embedding_size = text_config.vocab_size
self.embedding_size += text_config.additional_vocab_size or 0
self.embed_tokens = VocabParallelEmbedding(
self.embedding_size,
text_config.hidden_size,
quant_config=quant_config,
)
decoder_layer = (
Molmo2DecoderNormAfterLayer
if text_config.norm_after
else Molmo2DecoderLayer
)
self.start_layer, self.end_layer, self.layers = make_layers(
text_config.num_hidden_layers,
lambda prefix: decoder_layer(
text_config,
hf_text_config.rope_parameters,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(text_config.hidden_size, eps=text_config.layer_norm_eps)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"],
text_config.hidden_size,
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
# Apply blocks one-by-one.
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
residual,
**kwargs,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
if residual is not None:
hidden_states, _ = self.norm(hidden_states, residual)
else:
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def get_patches_grid_size(
*,
image_h: int,
image_w: int,
patch_size: int,
pool_h: int,
pool_w: int,
) -> tuple[int, int]:
patch_h = image_h // patch_size
patch_w = image_w // patch_size
h_pad = round_down(patch_h + pool_h - 1, pool_h) - patch_h
w_pad = round_down(patch_w + pool_w - 1, pool_w) - patch_w
nrows = (patch_h + h_pad) // pool_h
ncols = (patch_w + w_pad) // pool_w
return nrows, ncols
def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]:
tilings = [
(i, j)
for i in range(1, max_num + 1)
for j in range(1, max_num + 1)
if i * j <= max_num
]
return sorted(tilings, key=lambda x: (x[0] * x[1], x[0]))
def select_tiling(
*,
height: int,
width: int,
patch_size: int,
max_num_patches: int,
):
tilings = get_candidate_tilings(max_num_patches)
candidate_tilings = np.array(tilings, dtype=np.int32)
candidate_resolutions = candidate_tilings * patch_size
original_size = np.array([height, width], dtype=np.float32)
required_scale_d = candidate_resolutions.astype(np.float32) / original_size
required_scale = required_scale_d.min(axis=-1, keepdims=True)
if (required_scale < 1).all():
ix = required_scale.argmax()
else:
ix = np.where(required_scale < 1.0, 10e9, required_scale).argmin()
return candidate_tilings[ix]
def get_image_size(image: ImageInput) -> ImageSize:
if isinstance(image, Image):
return ImageSize(*image.size)
elif isinstance(image, (np.ndarray, torch.Tensor)):
assert image.ndim == 3
h, w, c = image.shape
assert c in [1, 3]
return ImageSize(w, h)
else:
raise ValueError(f"Unknown image type: {type(image)}")
def exif_tranpose(
images: ImageInput | None,
) -> ImageInput | None:
if images is None:
return None
if images is not None and isinstance(images, (list, tuple)):
images = [
exif_tranpose(img) if isinstance(img, Image) else img for img in images
]
elif images is not None and isinstance(images, Image):
images = ImageOps.exif_transpose(images)
return images
def build_flat_image_bool_length(
image_grids: torch.LongTensor,
image_patch_id: int,
low_res_image_start_id: int,
image_start_id: int,
image_col_id: int,
image_end_id: int,
) -> tuple[torch.LongTensor, torch.LongTensor]:
device = image_grids.device
B = image_grids.shape[0]
resized_h = image_grids[:, 0]
resized_w = image_grids[:, 1]
h = image_grids[:, 2]
w = image_grids[:, 3]
lengths = resized_h * resized_w + h * (w + 1) + 4 # [B]
total_len = int(lengths.sum().item())
flat = torch.empty(total_len, dtype=torch.long, device=device)
offset = 0
for i in range(B):
resized_h_i, resized_w_i, h_i, w_i = image_grids[i].tolist()
L_i = int(lengths[i].item())
num_low_res_patches = resized_h_i * resized_w_i
idx = offset
flat[idx] = low_res_image_start_id
idx += 1
if num_low_res_patches > 0:
flat[idx : idx + num_low_res_patches] = image_patch_id
idx += num_low_res_patches
flat[idx] = image_end_id
idx += 1
flat[idx] = image_start_id
idx += 1
block_len = w_i + 1
if block_len > 0 and h_i > 0:
line = torch.empty(block_len, dtype=torch.long, device=device)
if w_i > 0:
line[:w_i] = image_patch_id
line[w_i] = image_col_id
block = line.repeat(h_i)
flat[idx : idx + h_i * block_len] = block
idx += h_i * block_len
flat[idx] = image_end_id
idx += 1
assert idx - offset == L_i
offset += L_i
return flat, lengths
def build_flat_video_bool_length(
video_grids: torch.LongTensor,
image_patch_id: int,
frame_start_id: int,
frame_end_id: int,
) -> tuple[torch.LongTensor, torch.LongTensor]:
device = video_grids.device
B = video_grids.shape[0]
t = video_grids[:, 0]
resized_h = video_grids[:, 1]
resized_w = video_grids[:, 2]
P = resized_h * resized_w
block_len = P + 2
lengths = t * block_len
total_len = int(lengths.sum().item())
flat = torch.empty(total_len, dtype=torch.long, device=device)
offset = 0
for i in range(B):
ti = int(t[i].item())
Pi = int(P[i].item())
Li = int(lengths[i].item())
block = torch.empty(Pi + 2, dtype=torch.long, device=device)
block[0] = frame_start_id
if Pi > 0:
block[1 : 1 + Pi] = image_patch_id
block[-1] = frame_end_id
seq = block.repeat(ti)
flat[offset : offset + Li] = seq
offset += Li
return flat, lengths
class Molmo2ProcessorWrapper:
"""
Wraps :class:`Molmo2Processor` so that it can be called directly.
"""
def __init__(self, processor: ProcessorMixin, hf_config: PretrainedConfig):
super().__init__()
self.processor = processor
self.hf_config = hf_config
@cached_property
def vocab(self) -> dict[str, int]:
return self.processor.tokenizer.vocab # type: ignore
@cached_property
def max_crops(self) -> int:
image_processor = self.processor.image_processor # type: ignore
max_crops = image_processor.max_crops
assert isinstance(max_crops, int)
return max_crops
@cached_property
def image_pooling_h(self) -> int:
image_processor = self.processor.image_processor # type: ignore
image_pooling_h = image_processor.pooling_size[0]
assert isinstance(image_pooling_h, int)
return image_pooling_h
@cached_property
def image_pooling_w(self) -> int:
image_processor = self.processor.image_processor # type: ignore
image_pooling_w = image_processor.pooling_size[1]
assert isinstance(image_pooling_w, int)
return image_pooling_w
@cached_property
def video_pooling_h(self) -> int:
video_processor = self.processor.video_processor # type: ignore
video_pooling_h = video_processor.pooling_size[0]
assert isinstance(video_pooling_h, int)
return video_pooling_h
@cached_property
def video_pooling_w(self) -> int:
video_processor = self.processor.video_processor # type: ignore
video_pooling_w = video_processor.pooling_size[1]
assert isinstance(video_pooling_w, int)
return video_pooling_w
@cached_property
def base_image_input_size(self) -> tuple[int, int]:
if getattr(self.processor, "image_processor", None) is not None:
processor = self.processor.image_processor # type: ignore
else:
processor = self.processor.video_processor # type: ignore
base_image_input_size = (processor.size["height"], processor.size["width"])
return base_image_input_size
@cached_property
def image_patch_size(self) -> int:
if getattr(self.processor, "image_processor", None) is not None:
processor = self.processor.image_processor # type: ignore
else:
processor = self.processor.video_processor # type: ignore
image_patch_size = processor.patch_size
assert isinstance(image_patch_size, int)
return image_patch_size
@cached_property
def overlap_margins(self) -> tuple[int, int]:
image_processor = self.processor.image_processor # type: ignore
left_margin, right_margin = image_processor.overlap_margins
assert isinstance(left_margin, int)
assert isinstance(right_margin, int)
return left_margin, right_margin
@cached_property
def bos_token(self) -> str:
return self.processor.tokenizer.bos_token or self.processor.tokenizer.eos_token
@cached_property
def image_patch_id(self) -> int:
return self.hf_config.image_patch_id
@cached_property
def im_col_id(self) -> int:
return self.hf_config.image_col_id
@cached_property
def im_start_id(self) -> int:
return self.hf_config.image_start_token_id
@cached_property
def im_end_id(self) -> int:
return self.hf_config.image_end_token_id
@cached_property
def low_res_im_start_id(self) -> int:
return self.hf_config.low_res_image_start_token_id
@cached_property
def frame_start_id(self) -> int:
return self.hf_config.frame_start_token_id
@cached_property
def frame_end_id(self) -> int:
return self.hf_config.frame_end_token_id
@cached_property
def im_low_res_id(self) -> int:
return self.hf_config.image_low_res_id
@cached_property
def image_placeholder_id(self) -> int:
return self.vocab[IMAGE_PROMPT]
@cached_property
def video_placeholder_id(self) -> int:
return self.vocab[VIDEO_PROMPT]
@cached_property
def image_token_ids(self) -> list[int]:
return [
self.image_patch_id,
self.im_col_id,
self.im_start_id,
self.low_res_im_start_id,
self.frame_start_id,
self.im_end_id,
self.frame_end_id,
self.im_low_res_id,
]
def select_tiling(
self,
*,
image_height: int,
image_width: int,
) -> tuple[int, int]:
max_crops = self.max_crops
left_margin, right_margin = self.overlap_margins
base_image_input_size = self.base_image_input_size
base_image_input_d = self.image_patch_size
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
crop_patches = base_image_input_size[0] // base_image_input_d
crop_window_patches = crop_patches - (right_margin + left_margin)
crop_window_size = crop_window_patches * base_image_input_d
tiling_h, tiling_w = select_tiling(
height=image_height - total_margin_pixels,
width=image_width - total_margin_pixels,
patch_size=crop_window_size,
max_num_patches=max_crops,
)
return tiling_h, tiling_w
def get_base_grid_size(self, is_video: bool) -> tuple[int, int]:
base_image_input_size = self.base_image_input_size
return get_patches_grid_size(
image_h=base_image_input_size[0],
image_w=base_image_input_size[1],
patch_size=self.image_patch_size,
pool_h=self.video_pooling_h if is_video else self.image_pooling_h,
pool_w=self.video_pooling_w if is_video else self.image_pooling_w,
)
def get_patches_grid_size(
self,
*,
image_height: int,
image_width: int,
) -> tuple[int, int]:
left_margin, right_margin = self.overlap_margins
base_image_input_size = self.base_image_input_size
base_image_input_d = self.image_patch_size
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
crop_patches = base_image_input_size[0] // base_image_input_d
crop_window_patches = crop_patches - (right_margin + left_margin)
crop_window_size = crop_window_patches * base_image_input_d
tiling_h, tiling_w = self.select_tiling(
image_height=image_height,
image_width=image_width,
)
h, w = [
tiling_h * crop_window_size + total_margin_pixels,
tiling_w * crop_window_size + total_margin_pixels,
]
nrows, ncols = get_patches_grid_size(
image_h=h,
image_w=w,
patch_size=base_image_input_d,
pool_h=self.image_pooling_h,
pool_w=self.image_pooling_w,
)
return nrows, ncols
def __call__(
self,
text: TextInput | list[TextInput] | None = None,
images: ImageInput | None = None,
videos: VideoInput | None = None,
return_tensors: str | TensorType = None,
**kwargs: object,
) -> BatchFeature:
inputs = [text]
images = exif_tranpose(images)
if getattr(self.processor, "image_processor", None) is not None:
inputs.append(images)
if getattr(self.processor, "video_processor", None) is not None:
inputs.append(videos)
outputs = self.processor( # type: ignore
*inputs,
return_tensors=return_tensors,
**kwargs,
)
# revert insert bos token
if outputs["input_ids"][0, 0] == self.vocab[self.bos_token]:
outputs["input_ids"] = outputs["input_ids"][:, 1:]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if videos is None:
videos = []
if not isinstance(videos, list):
videos = [videos]
assert len(videos) in {0, 1}, "At most one video is supported for Molmo2"
_attention_mask: torch.Tensor = outputs.pop("attention_mask")
_token_type_ids: torch.Tensor = outputs.pop("token_type_ids", None)
if len(images) > 0:
# For each image: tiling_h * tiling_w + global view
num_crops = []
for image in images:
image_size = get_image_size(image)
tiling = self.select_tiling(
image_height=image_size.height,
image_width=image_size.width,
)
num_crops.append(np.prod(tiling) + 1)
assert sum(num_crops) == len(outputs["pixel_values"])
assert sum(num_crops) == outputs["image_num_crops"].sum().item()
image_grids: torch.Tensor = outputs.pop("image_grids")
image_num_pooled_patches: torch.Tensor = image_grids[:, :2].prod(
dim=1
) + image_grids[:, 2:].prod(dim=1)
outputs["image_num_pooled_patches"] = image_num_pooled_patches
n_patches = outputs["pixel_values"].shape[1]
outputs["image_num_patches"] = outputs["image_num_crops"] * n_patches
image_tokens, num_image_tokens = build_flat_image_bool_length(
image_grids,
self.image_patch_id,
self.low_res_im_start_id,
self.im_start_id,
self.im_col_id,
self.im_end_id,
)
outputs["image_tokens"] = image_tokens
outputs["num_image_tokens"] = num_image_tokens
if len(videos) > 0:
video_grids: torch.Tensor = outputs.pop("video_grids")
assert video_grids[:, 0].sum() == len(outputs["pixel_values_videos"])
outputs["video_num_crops"] = video_grids[:, 0]
outputs["video_num_pooled_patches"] = video_grids.prod(dim=1)
n_patches = outputs["pixel_values_videos"].shape[1]
outputs["video_num_patches"] = outputs["video_num_crops"] * n_patches
video_tokens, num_video_tokens = build_flat_video_bool_length(
video_grids,
self.image_patch_id,
self.frame_start_id,
self.frame_end_id,
)
outputs["video_tokens"] = video_tokens
outputs["num_video_tokens"] = num_video_tokens
return BatchFeature(outputs)
def get_candidate_target_fps(
video_fps: int | float,
sampling_fps: int | float,
max_fps: int | float = _MAX_VIDEO_FPS,
) -> list[float]:
"""
Return the subset of `video_fps` factors that remain multiples
of `sampling_fps`.
Examples:
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
[2, 6]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
[1, 5]
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
[2]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
Traceback (most recent call last):
...
ValueError: sampling_fps=2 must divide video_fps=5 to produce
consistent frame steps.
"""
video_fps = int(video_fps)
sampling_fps = int(sampling_fps)
max_fps = int(max_fps)
if sampling_fps is None:
raise ValueError("sampling_fps must be provided")
if video_fps <= 0 or sampling_fps <= 0:
raise ValueError(
"video_fps and sampling_fps must be positive "
f"(got {video_fps}, {sampling_fps})"
)
if video_fps % sampling_fps != 0:
raise ValueError(
f"sampling_fps={sampling_fps} must divide video_fps={video_fps}."
)
candidates = []
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
if candidate > max_fps:
break
if video_fps % candidate == 0:
candidates.append(float(candidate))
return candidates
def get_target_fps(
video_fps: float,
max_frames: int,
total_frames: int,
frame_sample_mode: str,
candidate_target_fps: list[float],
) -> float | None:
"""
Get the target fps that best spans the video and has the most frames sampled
"""
num_frames_sampled = 0
selected_target_fps = None
for target_fps in candidate_target_fps:
step_size = max(int(video_fps / target_fps), 1)
num_frames_sampled_at_fps = int(total_frames / step_size)
if num_frames_sampled == 0:
if (
"uniform" in frame_sample_mode
and num_frames_sampled_at_fps > max_frames
):
break
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
else:
# the candidate sampling fps increases so frame count can't decrease
assert num_frames_sampled <= num_frames_sampled_at_fps
if num_frames_sampled_at_fps > max_frames:
# choose the sampling fps that spans the video
continue
elif num_frames_sampled_at_fps > num_frames_sampled:
# both are less than max_frames; choose the one with higher
# density of frames sampled
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
return selected_target_fps
def get_frame_times_and_chosen_fps(
selected_target_fps, total_frames, max_frames, video_fps
):
if selected_target_fps is None:
frame_indices = np.linspace(
0, total_frames, max_frames, endpoint=False, dtype=int
)
else:
step_size = max(int(video_fps / selected_target_fps), 1)
frame_indices = np.arange(0, total_frames, step_size)
if len(frame_indices) > max_frames:
frame_indices = frame_indices[:max_frames]
return selected_target_fps, frame_indices
class Molmo2ProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object) -> Molmo2ProcessorWrapper:
processor = self.ctx.get_hf_processor(**kwargs)
hf_config = self.ctx.get_hf_config()
return Molmo2ProcessorWrapper(processor, hf_config)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None, "video": 1}
def get_num_image_tokens(
self,
*,
image_height: int,
image_width: int,
processor: Molmo2ProcessorWrapper | None = None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
hf_processor = processor.processor # type: ignore
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False)
# start/end tokens + image patch token + col tokens
if hf_processor.use_single_crop_col_tokens is not None:
use_col_tokens = hf_processor.use_single_crop_col_tokens
else:
use_col_tokens = hf_processor.image_use_col_tokens
extra = 2 + resize_nrows * (resize_cols + int(use_col_tokens))
overlap_nrows, overlap_ncols = processor.get_patches_grid_size(
image_height=image_height,
image_width=image_width,
)
joint = 2 + overlap_nrows * (
overlap_ncols + int(hf_processor.image_use_col_tokens)
)
return extra + joint
def get_num_video_tokens(
self,
*,
num_frames: int,
processor: Molmo2ProcessorWrapper | None = None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=True)
# start/end tokens
extra = 2 + resize_nrows * (
resize_cols + int(processor.processor.video_use_col_tokens)
)
return num_frames * extra
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
left_margin, right_margin = processor.overlap_margins
base_image_input_size = processor.base_image_input_size
base_image_input_d = processor.image_patch_size
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
crop_patches = base_image_input_size[0] // base_image_input_d
crop_window_patches = crop_patches - (right_margin + left_margin)
crop_window_size = crop_window_patches * base_image_input_d
tilings = get_candidate_tilings(processor.max_crops)
largest_feature_size, largest_feature_pinpoint = 0, None
for hr, wr in tilings:
height = hr * crop_window_size + total_margin_pixels
width = wr * crop_window_size + total_margin_pixels
feat_size = self.get_num_image_tokens(
image_height=height, image_width=width, processor=processor
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width, height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_pinpoint
def _get_max_video_frames(self, max_tokens: int) -> int:
num_tokens_per_frame = self.get_num_video_tokens(num_frames=1)
max_frames = max_tokens // num_tokens_per_frame
return max(max_frames, 1)
def get_num_frames_with_most_features(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
video_processor = self.get_hf_processor().processor.video_processor
num_frames = video_processor.num_frames
max_videos = mm_counts.get("video", 0)
max_total_frames = self._get_max_video_frames(seq_len)
max_frames_per_video = min(
max_total_frames // max(max_videos, 1),
num_frames,
)
return max(max_frames_per_video, 1)
def _sample_frames(
self,
total_num_frames: int,
video_fps: float,
duration: float,
frame_sample_mode: str,
num_frames: int,
max_fps: int,
sampling_fps: int,
) -> np.ndarray:
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
if total_num_frames <= 2:
indices = np.arange(total_num_frames).astype(int)
elif duration > (num_frames - 1) / max_fps: # -1 to include the last frame
# uniform fallback
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
else:
float_indices = np.arange(
0.0,
stop=total_num_frames - 1,
step=float(video_fps / max_fps),
)
if np.round(float_indices[-1]) != total_num_frames - 1:
float_indices = np.concatenate(
[float_indices, [total_num_frames - 1]], axis=0
)
indices = np.round(float_indices).astype(int)
assert indices[-1] < total_num_frames
assert len(float_indices) <= num_frames
elif frame_sample_mode == "uniform_last_frame":
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
elif frame_sample_mode == "fps":
candidate_target_fps = get_candidate_target_fps(video_fps, sampling_fps)
selected_target_fps = get_target_fps(
video_fps,
num_frames,
total_num_frames,
frame_sample_mode,
candidate_target_fps,
)
_, indices = get_frame_times_and_chosen_fps(
selected_target_fps,
total_num_frames,
num_frames,
video_fps,
)
else:
raise NotImplementedError(frame_sample_mode)
return indices
def _get_video_second_idx(
self,
metadata: dict[str, Any],
do_sample_frames: bool | None = None,
) -> list[float]:
video_processor = self.get_hf_processor().processor.video_processor
# metadata["fps"] refers to the true fps of the input video.
video_fps = metadata["fps"]
frames_indices = metadata.get("frames_indices")
if do_sample_frames is None:
do_sample_frames = metadata.get("do_sample_frames", False)
if do_sample_frames:
# Frame-based sampling is applied in HF video processor
total_num_frames = metadata["total_num_frames"]
duration = total_num_frames / video_fps
frame_sample_mode = video_processor.frame_sample_mode
num_frames = video_processor.num_frames
max_fps = video_processor.max_fps
sampling_fps = video_processor.sampling_fps
frames_indices = self._sample_frames(
total_num_frames,
video_fps,
duration,
frame_sample_mode,
num_frames,
max_fps,
sampling_fps,
)
else:
# Time-based sampling is done in vllm molmo2 video loader or molmo_utils
assert frames_indices is not None
timestamps = [frame_idx / video_fps for frame_idx in frames_indices]
return timestamps
class Molmo2DummyInputsBuilder(BaseDummyInputsBuilder[Molmo2ProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
image_placeholder_token = IMAGE_PROMPT
video_placeholder_token = VIDEO_PROMPT
if num_images == 1:
image_string = image_placeholder_token
else:
image_string = "".join(
[f"Image {i + 1}" + image_placeholder_token for i in range(num_images)]
)
return image_string + video_placeholder_token * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
dummy_images = []
dummy_videos = []
if num_images > 0:
target_width, target_height = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
dummy_images = self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
)
if num_videos > 0:
processor = self.info.get_hf_processor()
base_image_input_size = processor.base_image_input_size
target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts
)
video_overrides = mm_options.get("video") if mm_options else None
if video_overrides:
assert isinstance(video_overrides, VideoDummyOptions)
num_frames_override = video_overrides.num_frames
if num_frames_override:
if num_frames_override > target_num_frames:
logger.warning(
"video.num_frames override (%d) exceeds model's "
"maximum number of frames (%d), will be ignored",
num_frames_override,
target_num_frames,
)
if num_frames_override < 2:
logger.warning(
"video.num_frames override (%d) cannot be less "
"than 2, will be ignored",
num_frames_override,
)
target_num_frames = min(target_num_frames, num_frames_override)
dummy_videos = self._get_dummy_videos(
width=base_image_input_size[1],
height=base_image_input_size[0],
num_frames=target_num_frames,
num_videos=num_videos,
)
return {
"image": dummy_images,
"video": dummy_videos,
}
def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
) -> list[VideoItem]:
video = np.full((num_frames, height, width, 3), 255, dtype=np.uint8)
video_items = []
for i in range(num_videos):
video_metadata = {
"fps": 2.0,
"duration": num_frames / 2.0,
"total_num_frames": num_frames,
"frames_indices": list(range(num_frames)),
"video_backend": "decord",
"do_sample_frames": False,
"height": height,
"width": width,
}
video_item = (video.copy(), video_metadata)
video_items.append(video_item)
return video_items
class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
processor = self.info.get_hf_processor()
tokenizer = processor.processor.tokenizer
bos_token_id = tokenizer.bos_token_id or tokenizer.eos_token_id
if len(prompt_tokens) > 0 and prompt_tokens[0] != bos_token_id:
# Prepend the bos token to the prompt tokens
prompt_tokens = [bos_token_id] + prompt_tokens
return prompt_tokens
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(video_needs_metadata=True)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_data = dict(mm_data)
processor = self.info.get_hf_processor(**mm_kwargs)
if videos := mm_data.pop("videos", []):
pixel_values_videos_lst = []
video_token_pooling_lst = []
video_num_crops_lst = []
video_num_pooled_patches_lst = []
video_num_patches_lst = []
video_tokens_lst = []
num_video_tokens_lst = []
for item in videos:
video_array, metadata = item
# NOTE: metadata.frames_indices indicates
# the sampled frames indices of pre-sampled videos, which is
# used to calculate the timestamps. Make sure that
# do_sample_frames in mm_kwargs is false for presampled videos.
# NOTE: a copy of mm_kwargs is created to update do_sample_frames,
# otherwise mm_hash for the object will be incorrect.
video_mm_kwargs = dict(**mm_kwargs)
if "do_sample_frames" not in video_mm_kwargs:
# molmo_utils already has "do_sample_frames" in
# mm_kwargs, don't overwrite it.
video_mm_kwargs["do_sample_frames"] = metadata.get(
"do_sample_frames", False
)
metadata = VideoMetadata(
**{k: metadata[k] for k in metadata if k != "do_sample_frames"}
)
video_mm_data = dict()
video_mm_data["videos"] = [[video_array]]
video_mm_data["video_metadata"] = [[metadata]]
video_outputs = super()._call_hf_processor(
prompt=VIDEO_PROMPT,
mm_data=video_mm_data,
mm_kwargs=video_mm_kwargs,
tok_kwargs=tok_kwargs,
)
input_ids = video_outputs.pop("input_ids")
video_string = processor.processor.tokenizer.batch_decode(input_ids)[0]
prompt = prompt.replace(
VIDEO_PROMPT,
video_string,
1,
)
pixel_values_videos_lst.append(video_outputs["pixel_values_videos"])
video_token_pooling_lst.append(video_outputs["video_token_pooling"])
video_num_crops_lst.append(video_outputs["video_num_crops"])
video_num_pooled_patches_lst.append(
video_outputs["video_num_pooled_patches"]
)
video_num_patches_lst.append(video_outputs["video_num_patches"])
video_tokens_lst.append(video_outputs["video_tokens"])
num_video_tokens_lst.append(video_outputs["num_video_tokens"])
video_outputs = dict(
pixel_values_videos=torch.cat(pixel_values_videos_lst),
video_token_pooling=torch.cat(video_token_pooling_lst),
video_num_crops=torch.cat(video_num_crops_lst),
video_num_pooled_patches=torch.cat(video_num_pooled_patches_lst),
video_num_patches=torch.cat(video_num_patches_lst),
video_tokens=torch.cat(video_tokens_lst),
num_video_tokens=torch.cat(num_video_tokens_lst),
)
else:
video_outputs = dict()
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
bos_token_id = processor.vocab[processor.bos_token]
input_ids = processed_outputs["input_ids"]
# add bos token back to prompt start
if input_ids.numel() > 0 and input_ids[0, 0] != bos_token_id:
bos_token_id_tensor = torch.tensor(
[[bos_token_id]], device=input_ids.device, dtype=input_ids.dtype
)
processed_outputs["input_ids"] = torch.concat(
[bos_token_id_tensor, input_ids], dim=1
)
combined_outputs = dict(
processed_outputs,
**video_outputs,
)
return BatchFeature(combined_outputs)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_crops = hf_inputs.get("image_num_crops", torch.empty(0))
image_num_pooled_patches = hf_inputs.get(
"image_num_pooled_patches", torch.empty(0)
)
video_num_crops = hf_inputs.get("video_num_crops", torch.empty(0))
video_num_pooled_patches = hf_inputs.get(
"video_num_pooled_patches", torch.empty(0)
)
num_image_tokens = hf_inputs.get("num_image_tokens", torch.empty(0))
num_video_tokens = hf_inputs.get("num_video_tokens", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_crops
),
image_token_pooling=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_pooled_patches
),
image_num_crops=MultiModalFieldConfig.batched("image"),
image_num_pooled_patches=MultiModalFieldConfig.batched("image"),
image_num_patches=MultiModalFieldConfig.batched("image"),
image_tokens=MultiModalFieldConfig.flat_from_sizes(
"image", num_image_tokens
),
num_image_tokens=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_crops
),
video_token_pooling=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_pooled_patches
),
video_num_crops=MultiModalFieldConfig.batched("video"),
video_num_pooled_patches=MultiModalFieldConfig.batched("video"),
video_num_patches=MultiModalFieldConfig.batched("video"),
video_tokens=MultiModalFieldConfig.flat_from_sizes(
"video", num_video_tokens
),
num_video_tokens=MultiModalFieldConfig.batched("video"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
img_patch_id = processor.image_patch_id
img_col_id = processor.im_col_id
img_start_id = processor.im_start_id
img_end_id = processor.im_end_id
image_use_col_tokens = processor.processor.image_use_col_tokens
use_single_crop_col_tokens = processor.processor.use_single_crop_col_tokens
use_single_crop_start_token = processor.processor.use_single_crop_start_token
video_use_col_tokens = processor.processor.video_use_col_tokens
use_frame_special_tokens = processor.processor.use_frame_special_tokens
def get_image_replacement_molmo2(item_idx: int) -> list[int]:
images = mm_items.get_items("image", ImageProcessorItems)
image = images.get(item_idx)
image = exif_tranpose(image)
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False)
if use_single_crop_col_tokens is not None:
use_col_tokens = use_single_crop_col_tokens
else:
use_col_tokens = image_use_col_tokens
if use_single_crop_start_token:
start_id = processor.low_res_im_start_id
else:
start_id = img_start_id
extra_row = [img_patch_id] * resize_cols + [img_col_id] * int(
use_col_tokens
)
extra_joint = [start_id] + extra_row * resize_nrows + [img_end_id]
image_size = get_image_size(image)
nrows, ncols = processor.get_patches_grid_size(
image_height=image_size.height,
image_width=image_size.width,
)
joint_row = [img_patch_id] * ncols + [img_col_id] * int(
image_use_col_tokens
)
joint = [img_start_id] + joint_row * nrows + [img_end_id]
img_token_ids = extra_joint + joint
return PromptUpdateDetails.select_token_ids(
img_token_ids,
processor.image_token_ids,
)
def get_video_replacement_molmo2(item_idx: int) -> list[int]:
video, metadata = mm_items["video"][item_idx]
do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames")
timestamps = self.info._get_video_second_idx(metadata, do_sample_frames)
nrows, ncols = processor.get_base_grid_size(is_video=True)
if use_frame_special_tokens:
start_id = processor.frame_start_id
end_id = processor.frame_end_id
else:
start_id = img_start_id
end_id = img_end_id
img_token_ids = []
for frame_idx, frame_time in enumerate(timestamps):
prev_space = " " if frame_idx > 0 else ""
frame_prefix = (
prev_space + f"{frame_time:.1f} "
) # explicit whitespace before/after image tokens
img_token_ids += processor.processor.tokenizer.encode(
frame_prefix,
add_special_tokens=False,
)
joint_row = [img_patch_id] * ncols + [img_col_id] * int(
video_use_col_tokens
)
joint = [start_id] + nrows * joint_row + [end_id]
img_token_ids += joint
return PromptUpdateDetails.select_token_ids(
img_token_ids,
processor.image_token_ids,
)
return [
PromptReplacement(
modality=modality,
target=[target],
replacement=replacement_fn,
)
for modality, target, replacement_fn in zip(
["image", "video"],
[processor.image_placeholder_id, processor.video_placeholder_id],
[get_image_replacement_molmo2, get_video_replacement_molmo2],
)
]
@MULTIMODAL_REGISTRY.register_processor(
Molmo2MultiModalProcessor,
info=Molmo2ProcessingInfo,
dummy_inputs=Molmo2DummyInputsBuilder,
)
class Molmo2ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant
):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
# vision backbone mapping
"image_pooling_2d.wq.": "image_pooling_2d.q_proj.",
"image_pooling_2d.wk.": "image_pooling_2d.k_proj.",
"image_pooling_2d.wv.": "image_pooling_2d.v_proj.",
"image_pooling_2d.wo.": "image_pooling_2d.o_proj.",
"image_projector.w1.": "image_projector.gate_proj.",
"image_projector.w3.": "image_projector.up_proj.",
"image_projector.w2.": "image_projector.down_proj.",
# language backbone mapping
"att_proj": "qkv_proj",
"attn_out": "o_proj",
"q_norm": "q_norm",
"k_norm": "k_norm",
"ff_proj": "up_gate_proj",
"ff_out": "down_proj",
"attn_norm": "input_layernorm",
"ff_norm": "post_attention_layernorm",
},
orig_to_new_prefix={
# vision backbone mapping
"model.vision_backbone.": "vision_backbone.",
# language backbone mapping
"model.transformer.blocks.": "model.layers.",
"model.transformer.ln_f.": "model.norm.",
},
)
packed_modules_mapping = {
"qkv_proj": ["qkv_proj"],
"up_gate_proj": ["up_gate_proj"], # language model
"merged_qkv": ["wq", "wk", "wv"], # vision backbone
"merged_kv": ["k_proj", "v_proj"], # image_pooling_2d
"merged_linear": ["gate_proj", "up_proj"], # image_projector
}
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return IMAGE_PROMPT
if modality.startswith("video"):
return VIDEO_PROMPT
raise ValueError("Only image or video modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
kwargs = {}
for field in fields(VitConfig):
kwargs[field.name] = getattr(config.vit_config, field.name)
vit_config = VitConfig(**kwargs)
kwargs = {}
for field in fields(AdapterConfig):
kwargs[field.name] = getattr(config.adapter_config, field.name)
adapter_config = AdapterConfig(**kwargs)
self.vision_backbone = Molmo2VisionBackbone(
vit_config,
adapter_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_backbone"),
)
self.model = Molmo2TextModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
self.img_patch_id = config.image_patch_id
if hasattr(config, "text_config"):
hf_text_config = config.text_config
else:
hf_text_config = config.llm_config
self.lm_head = ParallelLMHead(
hf_text_config.vocab_size,
hf_text_config.hidden_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(hf_text_config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
@property
def dtype(self):
return next(self.parameters()).dtype
def _parse_and_validate_image_input(
self,
**kwargs: object,
) -> Molmo2ImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is None:
return None
token_pooling = kwargs.pop("image_token_pooling", None)
num_pooled_patches = kwargs.pop("image_num_pooled_patches", None)
num_patches = kwargs.pop("image_num_patches", None)
image_tokens = kwargs.pop("image_tokens", None)
num_image_tokens = kwargs.pop("num_image_tokens", None)
accum_patches = [0] + num_patches.cumsum(dim=0)[:-1].tolist()
patch_offset = 0
new_token_pooling = token_pooling.clone()
for i, n in enumerate(num_pooled_patches):
cur_slice = token_pooling[patch_offset : patch_offset + n]
index_offset = int(accum_patches[i])
new_token_pooling[patch_offset : patch_offset + n] = torch.where(
cur_slice >= 0,
cur_slice + index_offset,
cur_slice,
)
patch_offset += n
return Molmo2ImageInputs(
pixel_values=pixel_values,
token_pooling=new_token_pooling,
num_pooled_patches=num_pooled_patches,
image_tokens=image_tokens,
num_image_tokens=num_image_tokens,
)
def _parse_and_validate_video_input(
self,
**kwargs: object,
) -> Molmo2VideoInputs | None:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
if pixel_values_videos is None:
return None
token_pooling = kwargs.pop("video_token_pooling", None)
num_pooled_patches = kwargs.pop("video_num_pooled_patches", None)
num_patches = kwargs.pop("video_num_patches", None)
video_tokens = kwargs.pop("video_tokens", None)
num_video_tokens = kwargs.pop("num_video_tokens", None)
accum_patches = [0] + num_patches.cumsum(dim=0)[:-1].tolist()
patch_offset = 0
new_token_pooling = token_pooling.clone()
for i, n in enumerate(num_pooled_patches):
cur_slice = token_pooling[patch_offset : patch_offset + n]
index_offset = int(accum_patches[i])
new_token_pooling[patch_offset : patch_offset + n] = torch.where(
cur_slice >= 0,
cur_slice + index_offset,
cur_slice,
)
patch_offset += n
return Molmo2VideoInputs(
pixel_values_videos=pixel_values_videos,
token_pooling=new_token_pooling,
num_pooled_patches=num_pooled_patches,
video_tokens=video_tokens,
num_video_tokens=num_video_tokens,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
for input_key in kwargs:
if input_key in ("pixel_values",) and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(**kwargs)
if input_key in ("pixel_values_videos",) and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
return modalities
def _process_image_input(
self,
image_input: Molmo2ImageInputs,
) -> tuple[torch.Tensor, ...]:
pixel_values = image_input["pixel_values"]
token_pooling = image_input["token_pooling"]
num_pooled_patches = image_input["num_pooled_patches"]
image_tokens = image_input["image_tokens"]
num_image_tokens = image_input["num_image_tokens"]
image_features_flat = self.vision_backbone(
images=pixel_values.unsqueeze(0),
token_pooling=token_pooling.unsqueeze(0),
)
assert len(image_features_flat) == num_pooled_patches.sum()
image_features_list = image_features_flat.split(
num_pooled_patches.tolist(), dim=0
)
image_tokens_list = image_tokens.split(num_image_tokens.tolist(), dim=0)
out = []
for image_features_i, image_tokens_i in zip(
image_features_list, image_tokens_list
):
out_features = self.get_language_model().embed_input_ids(image_tokens_i)
is_image_patch = image_tokens_i == self.img_patch_id
out_features[is_image_patch] = image_features_i
out.append(out_features)
return tuple(out)
def _process_video_input(
self,
video_input: Molmo2VideoInputs,
) -> tuple[torch.Tensor, ...]:
pixel_values_videos = video_input["pixel_values_videos"]
token_pooling = video_input["token_pooling"]
num_pooled_patches = video_input["num_pooled_patches"]
video_tokens = video_input["video_tokens"]
num_video_tokens = video_input["num_video_tokens"]
image_features_flat = self.vision_backbone(
images=pixel_values_videos.unsqueeze(0),
token_pooling=token_pooling.unsqueeze(0),
)
assert len(image_features_flat) == num_pooled_patches.sum()
image_features_list = image_features_flat.split(
num_pooled_patches.tolist(), dim=0
)
video_tokens_list = video_tokens.split(num_video_tokens.tolist(), dim=0)
out = []
for image_features_i, video_tokens_i in zip(
image_features_list, video_tokens_list
):
out_features = self.get_language_model().embed_input_ids(video_tokens_i)
is_image_patch = video_tokens_i == self.img_patch_id
out_features[is_image_patch] = image_features_i
out.append(out_features)
return tuple(out)
def get_language_model(self) -> torch.nn.Module:
return self.model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
image_embeddings = self._process_image_input(image_input)
multimodal_embeddings += image_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`embed_input_ids` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
return inputs_embeds
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
weights = _get_weights_with_merged_embedding(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="model",
connector="vision_backbone.image_projector",
tower_model="vision_backbone",
)
def _get_weights_with_merged_embedding(
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[tuple[str, torch.Tensor]]:
embedding_weights = {}
for name, weight in weights:
if "wte.embedding" in name:
embedding_weights["embedding"] = weight
elif "wte.new_embedding" in name:
embedding_weights["new_embedding"] = weight
else:
yield (name, weight)
# this is compatible with most of quantization,
# because they won't quantize embed_tokens
if "embedding" not in embedding_weights or "new_embedding" not in embedding_weights:
raise ValueError(
"Checkpoint is missing 'wte.embedding' or "
"'wte.new_embedding' weights required for Molmo2."
)
embedding_weights = torch.cat(
[embedding_weights["embedding"], embedding_weights["new_embedding"]],
dim=0,
)
yield ("model.embed_tokens.weight", embedding_weights)
...@@ -384,6 +384,7 @@ _MULTIMODAL_MODELS = { ...@@ -384,6 +384,7 @@ _MULTIMODAL_MODELS = {
"Mistral3ForConditionalGeneration", "Mistral3ForConditionalGeneration",
), ),
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
"Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
"NVLM_D": ("nvlm_d", "NVLM_D_Model"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
"Ovis": ("ovis", "Ovis"), "Ovis": ("ovis", "Ovis"),
"Ovis2_5": ("ovis2_5", "Ovis2_5"), "Ovis2_5": ("ovis2_5", "Ovis2_5"),
......
...@@ -386,6 +386,21 @@ class PromptUpdateDetails(Generic[_S]): ...@@ -386,6 +386,21 @@ class PromptUpdateDetails(Generic[_S]):
return PromptUpdateDetails(full=seq, is_embed=is_embed) return PromptUpdateDetails(full=seq, is_embed=is_embed)
@staticmethod
def select_token_ids(
seq: _S,
embed_token_ids: list[int],
) -> "PromptUpdateDetails[_S]":
def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor:
token_ids = _seq2tokens(tokenizer, full)
return torch.isin(
torch.tensor(token_ids),
torch.tensor(embed_token_ids),
)
return PromptUpdateDetails(full=seq, is_embed=is_embed)
PromptUpdateInfo: TypeAlias = PromptSeq | PromptUpdateDetails PromptUpdateInfo: TypeAlias = PromptSeq | PromptUpdateDetails
""" """
......
...@@ -6,7 +6,7 @@ from abc import abstractmethod ...@@ -6,7 +6,7 @@ from abc import abstractmethod
from functools import partial from functools import partial
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, cast
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
...@@ -439,6 +439,324 @@ class OpenCVDynamicVideoBackend(OpenCVVideoBackend): ...@@ -439,6 +439,324 @@ class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
return frames, metadata return frames, metadata
@VIDEO_LOADER_REGISTRY.register("molmo2")
class Molmo2VideoBackend(VideoLoader):
def get_cv2_video_api(self):
import cv2.videoio_registry as vr
api_pref = None
for backend in vr.getStreamBufferedBackends():
if not vr.hasBackend(backend):
continue
if not vr.isBackendBuiltIn(backend):
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
if abi < 1 or (abi == 1 and api < 2):
continue
api_pref = backend
break
return api_pref
@classmethod
def get_candidate_target_fps(
cls,
video_fps: float,
sampling_fps: float,
max_fps: float = 8.0,
) -> list[float]:
"""
Return the subset of `video_fps` factors that remain multiples
of `sampling_fps`.
Examples:
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
[2, 6]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
[1, 5]
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
[2]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
Traceback (most recent call last):
...
ValueError: sampling_fps=2 must divide video_fps=5 to produce
consistent frame steps.
"""
video_fps = int(video_fps)
sampling_fps = int(sampling_fps)
max_fps = int(max_fps)
if sampling_fps is None:
raise ValueError("sampling_fps must be provided")
if video_fps <= 0 or sampling_fps <= 0:
raise ValueError(
"video_fps and sampling_fps must be positive "
f"(got {video_fps}, {sampling_fps})"
)
if video_fps % sampling_fps != 0:
raise ValueError(
f"sampling_fps={sampling_fps} must divide video_fps={video_fps}."
)
candidates = []
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
if candidate > max_fps:
break
if video_fps % candidate == 0:
candidates.append(float(candidate))
return candidates
@classmethod
def get_target_fps(
cls,
video_fps: float,
max_frames: int,
total_frames: int,
frame_sample_mode: str,
candidate_target_fps: list[float],
) -> float | None:
"""
Get the target fps that best spans the videoand has the most frames sampled
"""
num_frames_sampled = 0
selected_target_fps = None
for target_fps in candidate_target_fps:
step_size = max(int(video_fps / target_fps), 1)
num_frames_sampled_at_fps = int(total_frames / step_size)
if num_frames_sampled == 0:
if (
"uniform" in frame_sample_mode
and num_frames_sampled_at_fps > max_frames
):
break
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
else:
# the candidate sampling fps increases so frame count can't decrease
assert num_frames_sampled <= num_frames_sampled_at_fps
if num_frames_sampled_at_fps > max_frames:
# choose the sampling fps that spans the video
continue
elif num_frames_sampled_at_fps > num_frames_sampled:
# both are less than max_frames; choose the one with higher
# density of frames sampled
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
return selected_target_fps
@classmethod
def get_frame_times_and_chosen_fps(
cls,
selected_target_fps: float | None,
total_frames: int,
max_frames: int,
video_fps: float,
) -> tuple[float | None, npt.NDArray]:
if selected_target_fps is None:
frame_indices = np.linspace(
0, total_frames, max_frames, endpoint=False, dtype=int
)
else:
step_size = max(int(video_fps / selected_target_fps), 1)
frame_indices = np.arange(0, total_frames, step_size)
if len(frame_indices) > max_frames:
frame_indices = frame_indices[:max_frames]
return selected_target_fps, frame_indices
@classmethod
def sample_times(
cls,
duration: float,
max_frames: int,
frame_sample_mode: str,
max_fps: int | None,
candidate_target_fps: list[float] | None = None,
**kwargs,
) -> npt.NDArray:
if frame_sample_mode == "fps":
assert candidate_target_fps is not None
# Try larger and larger FPSs until we hit one that can't span the video
sampling_fps = candidate_target_fps[0]
for candidate_fps in candidate_target_fps[1:]:
if max_frames / candidate_fps < duration:
break
sampling_fps = candidate_fps
times = np.arange(0, max_frames) / sampling_fps
times = times[times < duration]
return times
elif frame_sample_mode == "uniform_last_frame":
if max_fps is not None:
max_duration = (
max_frames - 1
) / max_fps # -1 to include the last frame
if max_duration < duration:
times = np.linspace(
0, duration, num=max_frames, endpoint=True, dtype=np.float64
)
else:
times = np.arange(0.0, stop=duration, step=1 / max_fps)
times = np.concatenate([times, [duration]], axis=0)
assert len(times) <= max_frames
else:
times = np.linspace(
0, duration, num=max_frames, endpoint=True, dtype=np.float64
)
return times
else:
raise NotImplementedError(frame_sample_mode)
@classmethod
def _sample_frames(
cls,
total_num_frames: int,
video_fps: float,
duration: float,
frame_sample_mode: str,
num_frames: int,
max_fps: int,
sampling_fps: int,
) -> npt.NDArray:
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
if total_num_frames <= 2:
indices = np.arange(total_num_frames).astype(int)
elif duration > (num_frames - 1) / max_fps: # -1 to include the last frame
# uniform fallback
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
else:
float_indices = np.arange(
0.0,
stop=total_num_frames - 1,
step=float(video_fps / max_fps),
)
if np.round(float_indices[-1]) != total_num_frames - 1:
float_indices = np.concatenate(
[float_indices, [total_num_frames - 1]], axis=0
)
indices = np.round(float_indices).astype(int)
assert indices[-1] < total_num_frames
assert len(float_indices) <= num_frames
elif frame_sample_mode == "uniform_last_frame":
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
elif frame_sample_mode == "fps":
candidate_target_fps = cls.get_candidate_target_fps(video_fps, sampling_fps)
selected_target_fps = cls.get_target_fps(
video_fps,
num_frames,
total_num_frames,
frame_sample_mode,
candidate_target_fps,
)
_, indices = cls.get_frame_times_and_chosen_fps(
selected_target_fps,
total_num_frames,
num_frames,
video_fps,
)
else:
raise NotImplementedError(frame_sample_mode)
return indices
@classmethod
def load_bytes_opencv(
cls,
data: bytes,
frame_sample_mode: str | None = None,
num_frames: int = -1,
max_fps: int = 2,
sampling_fps: int = 2,
**kwargs,
) -> tuple[npt.NDArray, dict[str, Any]]:
import cv2
backend = cls().get_cv2_video_api()
cap = cv2.VideoCapture(BytesIO(data), backend, [])
if not cap.isOpened():
raise ValueError("Could not open video stream")
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames_num / original_fps if original_fps > 0 else 0
if frame_sample_mode is None:
# Use transformers transformers.video_utils.VideoMetadata format
frame_idx = list(range(0, total_frames_num))
frame_idx_set = set(frame_idx)
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
cap, frame_idx_set, total_frames_num, max(frame_idx)
)
do_sample_frames = valid_num_frames == total_frames_num
metadata = {
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv",
"do_sample_frames": do_sample_frames,
}
if not do_sample_frames:
metadata["frames_indices"] = valid_frame_indices
return frames, metadata
frame_idx = cls._sample_frames(
total_frames_num,
original_fps,
duration,
frame_sample_mode,
num_frames,
max_fps,
sampling_fps,
).tolist()
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
cap,
set(frame_idx),
len(frame_idx),
total_frames_num - 1,
)
metadata = {
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv",
"frames_indices": valid_frame_indices,
"do_sample_frames": False,
}
return frames, metadata
@classmethod
def load_bytes(
cls,
data: bytes,
num_frames: int = -1,
**kwargs,
) -> tuple[npt.NDArray, dict[str, Any]]:
frame_sample_mode = cast(str | None, kwargs.pop("frame_sample_mode", None))
max_fps = cast(int, kwargs.pop("max_fps", 2))
sampling_fps = cast(int, kwargs.pop("sampling_fps", 2))
out = cls.load_bytes_opencv(
data,
frame_sample_mode,
num_frames,
max_fps,
sampling_fps,
**kwargs,
)
return out
class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]): class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
def __init__( def __init__(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment