"docs/vscode:/vscode.git/clone" did not exist on "9e94b9d83808faa9be3fe2f1e1680d7cefc6bfda"
Unverified Commit cd3ac5b7 authored by Netanel Haber's avatar Netanel Haber Committed by GitHub
Browse files

support dynamic resolution image encoding for Nemotron Nano VL (#32121)


Signed-off-by: default avatarNetanel Haber <58652339+netanel-haber@users.noreply.github.com>
parent 2636d762
...@@ -282,12 +282,14 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -282,12 +282,14 @@ class InternVisionEncoderLayer(nn.Module):
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_cls: type[InternParallelAttention] = InternParallelAttention,
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type self.norm_type = config.norm_type
self.attn_cls = attn_cls
self.attn = self._init_attn( self.attn = self._init_attn(
config, config,
...@@ -327,7 +329,7 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -327,7 +329,7 @@ class InternVisionEncoderLayer(nn.Module):
use_data_parallel = ( use_data_parallel = (
use_data_parallel or (num_heads + num_dummy_heads) % tp_size != 0 use_data_parallel or (num_heads + num_dummy_heads) % tp_size != 0
) )
return InternParallelAttention( return self.attn_cls(
config, config,
quant_config=quant_config, quant_config=quant_config,
num_dummy_heads=num_dummy_heads, num_dummy_heads=num_dummy_heads,
...@@ -356,10 +358,12 @@ class InternVisionEncoder(nn.Module): ...@@ -356,10 +358,12 @@ class InternVisionEncoder(nn.Module):
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
layer_cls: type[InternVisionEncoderLayer] = InternVisionEncoderLayer,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer_cls = layer_cls
if num_hidden_layers_override is None: if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers num_hidden_layers = config.num_hidden_layers
...@@ -368,7 +372,7 @@ class InternVisionEncoder(nn.Module): ...@@ -368,7 +372,7 @@ class InternVisionEncoder(nn.Module):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
InternVisionEncoderLayer( self.layer_cls(
config, config,
quant_config, quant_config,
num_dummy_heads=num_dummy_heads, num_dummy_heads=num_dummy_heads,
......
...@@ -8,11 +8,15 @@ ...@@ -8,11 +8,15 @@
# -------------------------------------------------------- # --------------------------------------------------------
import copy import copy
import math
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import Annotated, Any, Literal, TypeAlias, TypeVar from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import einops
import numpy.typing as npt import numpy.typing as npt
import regex as re import regex as re
import torch import torch
...@@ -23,6 +27,7 @@ from transformers import BatchFeature, PretrainedConfig, TensorType ...@@ -23,6 +27,7 @@ from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -39,7 +44,7 @@ from vllm.model_executor.models.internvl import ( ...@@ -39,7 +44,7 @@ from vllm.model_executor.models.internvl import (
) )
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
from vllm.model_executor.models.radio import RadioModel from vllm.model_executor.models.radio import RadioModel, calc_seq_lens
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
...@@ -78,6 +83,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -78,6 +83,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .utils import _merge_multimodal_embeddings from .utils import _merge_multimodal_embeddings
logger = init_logger(__name__)
# Configure PIL to handle large images without warnings # Configure PIL to handle large images without warnings
# This prevents DecompressionBombWarning for legitimate large images # This prevents DecompressionBombWarning for legitimate large images
Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
...@@ -103,11 +109,25 @@ class NanoNemotronVLImagePixelInputs(TensorSchema): ...@@ -103,11 +109,25 @@ class NanoNemotronVLImagePixelInputs(TensorSchema):
- w: Width of each image patch - w: Width of each image patch
""" """
type: Literal["pixel_values"] type: Literal["pixel_values"] = "pixel_values"
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")] num_patches: Annotated[torch.Tensor, TensorShape("bn")]
class NanoNemotronVLImagePixelInputsDynamic(TensorSchema):
"""
Dynamic-resolution image inputs.
imgs_sizes: per-image (height, width) in pixels.
num_tokens_per_image: per-image number of embedding tokens (post downsample).
"""
type: Literal["pixel_values_dynamic"] = "pixel_values_dynamic"
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bn", "h", "w")]
imgs_sizes: list[tuple[int, int]]
num_tokens_per_image: list[int]
class NanoNemotronVLImageEmbeddingInputs(TensorSchema): class NanoNemotronVLImageEmbeddingInputs(TensorSchema):
""" """
Dimensions: Dimensions:
...@@ -121,7 +141,9 @@ class NanoNemotronVLImageEmbeddingInputs(TensorSchema): ...@@ -121,7 +141,9 @@ class NanoNemotronVLImageEmbeddingInputs(TensorSchema):
NanoNemotronVLImageInputs: TypeAlias = ( NanoNemotronVLImageInputs: TypeAlias = (
NanoNemotronVLImagePixelInputs | NanoNemotronVLImageEmbeddingInputs NanoNemotronVLImagePixelInputs
| NanoNemotronVLImagePixelInputsDynamic
| NanoNemotronVLImageEmbeddingInputs
) )
...@@ -267,6 +289,329 @@ def calculate_timestamps( ...@@ -267,6 +289,329 @@ def calculate_timestamps(
return timestamps return timestamps
class DynamicResolutionImageTiler:
CONV_MERGING = False
PIXEL_SHUFFLE = True
USE_THUMBNAIL = False
def __init__(
self,
*,
max_model_len: int,
patch_size: int,
min_num_patches: int,
max_num_patches: int,
downsample_ratio: int,
norm_mean: Sequence[float],
norm_std: Sequence[float],
factor_max: float = 1.0,
use_thumbnail: bool = False,
) -> None:
assert use_thumbnail is False, "use_thumbnail is not supported"
self._patch_size: int = patch_size
self._max_model_len = max_model_len
self._min_num_patches = min_num_patches
self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf")
self._factor_max = factor_max
self.norm_mean = torch.tensor(norm_mean).reshape(3, 1, 1)
self.norm_std = torch.tensor(norm_std).reshape(3, 1, 1)
self._transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.ToTensor(),
]
)
assert downsample_ratio < 1
reduction_factor = 1 / downsample_ratio
assert reduction_factor == 2.0
self._downsample_ratio = int(reduction_factor) ** (
self.PIXEL_SHUFFLE + self.CONV_MERGING
)
assert self._downsample_ratio == 2
def _get_num_embeddings(self, width: int, height: int) -> int:
num_patches = (width // self._patch_size) * (height // self._patch_size)
num_tokens = num_patches // (self._downsample_ratio**2)
return num_tokens
def width_and_height_for_max_num_tokens_available(
self,
target_num_tokens_post_shuffle: int,
) -> tuple[int, int]:
"""
TODO: optimize this so it squeezes closer to target number of tokens.
Calculate image dimensions that produce approximately `target` tokens after
pixel_shuffle.
With pixel_shuffle enabled, each 2x2 patch grid becomes 1 token, so we
need 4*B patches to get B tokens.
Examples:
>>> PATCH_SIZE = 16
>>> DOWNSAMPLE_RATIO = 0.5
>>> tiler = DynamicResolutionImageTiler(
... max_model_len=16384,
... patch_size=PATCH_SIZE,
... downsample_ratio=DOWNSAMPLE_RATIO,
... min_num_patches=4,
... max_num_patches=0,
... )
>>> width, height = tiler.width_and_height_for_max_num_tokens_available(
... target_num_tokens_post_shuffle=8192,
... )
>>> assert width, height == (2880, 2880)
>>> assert (width // PATCH_SIZE) * (
... height // PATCH_SIZE
... ) // 2**2 == 8100 # tokens post-shuffle
>>> assert tiler._get_num_embeddings(width=width, height=height) == 8100
"""
side_pixels = (
math.isqrt(target_num_tokens_post_shuffle)
* self._downsample_ratio
* self._patch_size
)
assert isinstance(side_pixels, int) and side_pixels % self._patch_size == 0
return side_pixels, side_pixels
def max_num_tokens_available(self, text_prompt_length: int) -> int:
return self._max_model_len - text_prompt_length - 4
def _images_to_pixel_values_lst(
self,
text_prompt_length: int,
images: list[Image.Image],
) -> tuple[list[torch.Tensor], list[int]]:
num_tokens_available = self.max_num_tokens_available(text_prompt_length)
params_per_image = self.compute_params(images, num_tokens_available)
feature_sizes = []
images = []
for param in params_per_image:
for t in self.apply_params(param):
assert t.ndim == 3, f"{t.ndim=}: expected 3 dim tensor"
images.append(t)
feature_sizes.append(param.num_embeddings)
return images, feature_sizes
feature_size_cache: dict[Image.Image, int] = {}
@classmethod
def get_cached_feature_size(cls, image: Image.Image) -> int:
feature_size = cls.feature_size_cache[id(image)]
# hard assert that we only use the feature size once
del cls.feature_size_cache[id(image)]
return feature_size
@dataclass
class DynamicResolutionParams:
media: Image.Image
num_tiles: int
num_embeddings: int
patch_size: tuple[int, int]
def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]:
resized_img = params.media.resize(
(
params.patch_size[0] * self._patch_size,
params.patch_size[1] * self._patch_size,
)
)
processed_images = [resized_img]
return [self._transform(img) for img in processed_images]
def process_media(
self,
media: Image.Image,
num_tokens_available: int,
) -> tuple[DynamicResolutionParams, int]:
"""Process a single media item and return its parameters.
Args:
media: The media item to process
num_tokens_available: Number of tokens available for this media
Returns:
DynamicResolutionParams for the media
"""
current_num_tokens_available = num_tokens_available
assert isinstance(media, Image.Image), (
"Dynamic resolution is only supported for image media"
)
orig_width, orig_height = media.width, media.height
closest_patch_height = round(orig_height / self._patch_size + 0.5)
closest_patch_width = round(orig_width / self._patch_size + 0.5)
patches = closest_patch_height * closest_patch_width
factor = min(
math.sqrt(current_num_tokens_available / patches), self._factor_max
)
target_patch_height = math.floor(factor * closest_patch_height)
target_patch_width = math.floor(factor * closest_patch_width)
# Consider self._min_num_patches if > current_num_tokens_available.
if (
current_num_tokens_available > self._min_num_patches
and target_patch_height * target_patch_width < self._min_num_patches
):
up_factor = math.sqrt(
self._min_num_patches / (target_patch_height * target_patch_width)
)
target_patch_height = math.ceil(up_factor * target_patch_height)
target_patch_width = math.ceil(up_factor * target_patch_width)
# Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
# or by 4 when BOTH are enabled (two successive 2x reductions)
if self.PIXEL_SHUFFLE or self.CONV_MERGING:
required_divisor = 4 if (self.PIXEL_SHUFFLE and self.CONV_MERGING) else 2
rem_h = target_patch_height % required_divisor
if rem_h != 0:
inc_h = required_divisor - rem_h
if (
target_patch_height + inc_h
) * target_patch_width <= current_num_tokens_available:
target_patch_height += inc_h
else:
target_patch_height = max(
required_divisor, target_patch_height - rem_h
)
rem_w = target_patch_width % required_divisor
if rem_w != 0:
inc_w = required_divisor - rem_w
if (
target_patch_height * (target_patch_width + inc_w)
<= current_num_tokens_available
):
target_patch_width += inc_w
else:
target_patch_width = max(
required_divisor, target_patch_width - rem_w
)
# Calculate embeddings for the main dynamic resolution image
num_embeddings = self._get_num_embeddings(
target_patch_width * self._patch_size,
target_patch_height * self._patch_size,
)
token_count = target_patch_width * target_patch_height
# Add thumbnail embeddings if enabled and image area is below threshold
num_tiles = 1 # Base dynamic resolution image
return self.DynamicResolutionParams(
media=media,
num_tiles=num_tiles,
num_embeddings=num_embeddings,
patch_size=(target_patch_width, target_patch_height),
), token_count
def compute_params(
self,
media_list: list[Image.Image],
num_tokens_available: int | None = None,
) -> list[DynamicResolutionParams]:
"""Compute parameters for all media with iterative token budgeting.
Args:
media_list: List of media items to process
num_tokens_available: Total number of tokens available across all media
Returns:
List of ImageTilingParams for each media item
"""
num_tokens_available = (
num_tokens_available
* (4 if self.PIXEL_SHUFFLE else 1)
* (4 if self.CONV_MERGING else 1)
)
# When the number of available token is too small,
# allow self._min_num_patches per media and let the sample be truncated.
num_tokens_available = max(
num_tokens_available, self._min_num_patches * len(media_list)
)
# Clip the number of tokens available per media to >min and <max patches.
num_tokens_available_per_media = [
max(min(num_tokens_available, self._max_num_patches), self._min_num_patches)
for _ in range(len(media_list))
]
# prevent infinite loop in any case
for _ in range(10):
# Step 1: Process each media with current token budget
params = []
token_counts = []
for media, tokens_for_media in zip(
media_list, num_tokens_available_per_media
):
param, token_count = self.process_media(media, tokens_for_media)
params.append(param)
token_counts.append(token_count)
self.feature_size_cache[id(param.media)] = param.num_embeddings
# Step 2: Check if total tokens is within budget
total_tokens = sum(token_counts)
if total_tokens <= num_tokens_available:
# We're within budget, return the params
return params
# Step 3: We're over budget, need to scale down
# Calculate scaling factor to get under budget
scaling_factor = num_tokens_available / total_tokens
# Recalculate token budgets for each media based on scaling
# Each media gets a proportional share of the total budget
scaled_down_num_tokens_available_per_media = [
max(self._min_num_patches, int(token_count * scaling_factor))
for token_count in token_counts
]
scaled_down = any(
[
scaled_down_num_tokens_available_per_media[i]
< num_tokens_available_per_media[i]
for i in range(len(num_tokens_available_per_media))
]
)
# If there wasn't scaling down, we're stuck with min_num_patches per media,
# else try with the scaled down num_tokens_available_per_media.
if not scaled_down:
num_tokens_available_per_media = [self._min_num_patches] * len(
media_list
)
else:
num_tokens_available_per_media = (
scaled_down_num_tokens_available_per_media
)
ctx = f"{params=} {total_tokens=} {num_tokens_available=}"
raise ValueError(
f"Should be unreachable - `return params` above must be reached: {ctx}"
)
@staticmethod
def stack(images: list[torch.Tensor], patch_size: int) -> torch.Tensor:
assert len(images) > 0, "No images to stack"
def rearrange_img(x):
py = x.shape[-2] // patch_size
px = x.shape[-1] // patch_size
x = einops.rearrange(
x,
"c (py yy) (px xx) -> (py px) (c yy xx)",
py=py,
yy=patch_size,
px=px,
xx=patch_size,
)
return x
imgs = [rearrange_img(img) for img in images]
pixel_values_flat = torch.cat(imgs, dim=0).unsqueeze(0)
return pixel_values_flat
class BaseNanoNemotronVLProcessor(ABC): class BaseNanoNemotronVLProcessor(ABC):
""" """
This model doesn't define its own HF processor, This model doesn't define its own HF processor,
...@@ -281,6 +626,7 @@ class BaseNanoNemotronVLProcessor(ABC): ...@@ -281,6 +626,7 @@ class BaseNanoNemotronVLProcessor(ABC):
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
*args, *args,
max_model_len: int,
max_num_tiles: int | None = None, max_num_tiles: int | None = None,
**kwargs, **kwargs,
) -> None: ) -> None:
...@@ -292,15 +638,32 @@ class BaseNanoNemotronVLProcessor(ABC): ...@@ -292,15 +638,32 @@ class BaseNanoNemotronVLProcessor(ABC):
self.max_num_tiles = max_num_tiles or DEFAULT_NUM_TILES self.max_num_tiles = max_num_tiles or DEFAULT_NUM_TILES
image_size: int = config.force_image_size image_size: int = config.force_image_size
patch_size: int = config.patch_size patch_size: int = config.patch_size
downsample_ratio: int = config.downsample_ratio
self.num_image_token = int( self.num_image_token = int(
(image_size // patch_size) ** 2 * (config.downsample_ratio**2) (image_size // patch_size) ** 2 * (downsample_ratio**2)
) )
self.image_size = image_size self.image_size = image_size
self.use_thumbnail: bool = config.use_thumbnail self.use_thumbnail: bool = config.use_thumbnail
self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1) self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1)
self.norm_std = torch.Tensor(config.norm_std).reshape(1, 3, 1, 1) self.norm_std = torch.Tensor(config.norm_std).reshape(1, 3, 1, 1)
self.dynamic_tiler: DynamicResolutionImageTiler | None = None
if self.use_dynamic_resolution(config):
self.dynamic_tiler = DynamicResolutionImageTiler(
max_model_len=max_model_len,
patch_size=patch_size,
downsample_ratio=downsample_ratio,
min_num_patches=config.vision_config.args["min_num_patches"],
max_num_patches=config.vision_config.args["max_num_patches"],
norm_mean=config.norm_mean,
norm_std=config.norm_std,
)
@staticmethod
def use_dynamic_resolution(config: PretrainedConfig) -> bool:
return "min_num_patches" in config.vision_config.args
@property @property
@abstractmethod @abstractmethod
def image_token_id(self) -> int: def image_token_id(self) -> int:
...@@ -354,36 +717,61 @@ class BaseNanoNemotronVLProcessor(ABC): ...@@ -354,36 +717,61 @@ class BaseNanoNemotronVLProcessor(ABC):
text: list[str], text: list[str],
images: list[Image.Image], images: list[Image.Image],
max_num_tiles: int, max_num_tiles: int,
) -> tuple[list[str], dict[str, torch.Tensor]]: ) -> tuple[list[str], dict[str, Any]]:
if len(images) == 0: if len(images) == 0:
image_inputs = {} image_inputs = {}
return text, image_inputs
if tiler := self.dynamic_tiler:
sans_images = text[0].replace("<image>", "")
text_prompt_length = len(
self.tokenizer(sans_images, add_special_tokens=False).input_ids
)
pixel_values_lst, num_tokens_per_image = tiler._images_to_pixel_values_lst(
text_prompt_length=text_prompt_length,
images=images,
)
imgs_sizes = [(pv.shape[-2], pv.shape[-1]) for pv in pixel_values_lst]
normalized = [
input_conditioner(img, tiler.norm_mean, tiler.norm_std)
for img in pixel_values_lst
]
image_num_patches = torch.tensor([1] * len(num_tokens_per_image))
image_inputs = {
"pixel_values_flat": normalized,
"imgs_sizes": imgs_sizes,
"num_tokens_per_image": num_tokens_per_image,
}
else: else:
pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles) pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles)
image_num_patches = torch.tensor([len(item) for item in pixel_values_lst])
pixel_values_flat = input_conditioner(
torch.cat(pixel_values_lst), self.norm_mean, self.norm_std
)
image_inputs = { image_inputs = {
"pixel_values_flat": input_conditioner( "pixel_values_flat": pixel_values_flat,
torch.cat(pixel_values_lst), self.norm_mean, self.norm_std "image_num_patches": image_num_patches,
),
"image_num_patches": torch.tensor(
[len(item) for item in pixel_values_lst]
),
} }
num_tokens_per_image = [
self.num_image_token * len(item) for item in pixel_values_lst
]
assert len(text) == 1, ( assert len(text) == 1, (
"hf_processor is called on the output of get_dummy_text, " "hf_processor is called on the output of get_dummy_text, "
"which should be a single string" "which should be a single string"
) )
parts = [x for x in re.split(r"(<image>)", text[0]) if x] parts = [x for x in re.split(r"(<image>)", text[0]) if x]
assert parts.count("<image>") == len(pixel_values_lst), ( assert parts.count("<image>") == len(pixel_values_lst), (
"the number of <image> tokens in the text should be the " "the number of <image> tokens in the text should be the "
"same as the number of images" "same as the number of images"
) )
for i, pixel_values in enumerate(pixel_values_lst): for i, (feature_size, num_patches) in enumerate(
num_patches = pixel_values.shape[0] zip(num_tokens_per_image, image_num_patches, strict=True)
feature_size = num_patches * self.num_image_token ):
image_repl = self.get_image_repl(feature_size, num_patches) image_repl = self.get_image_repl(feature_size, num_patches)
parts[i] = parts[i].replace("<image>", image_repl.full) parts[i] = parts[i].replace("<image>", image_repl.full)
text = ["".join(parts)] text = ["".join(parts)]
return text, image_inputs return text, image_inputs
def _make_batch_input(self, input_item: Any | list[Any] | None = None): def _make_batch_input(self, input_item: Any | list[Any] | None = None):
...@@ -393,6 +781,7 @@ class BaseNanoNemotronVLProcessor(ABC): ...@@ -393,6 +781,7 @@ class BaseNanoNemotronVLProcessor(ABC):
input_item = [input_item] input_item = [input_item]
return input_item return input_item
@abstractmethod
def __call__( def __call__(
self, self,
text: str | list[str] | None = None, text: str | list[str] | None = None,
...@@ -400,23 +789,7 @@ class BaseNanoNemotronVLProcessor(ABC): ...@@ -400,23 +789,7 @@ class BaseNanoNemotronVLProcessor(ABC):
return_tensors: str | TensorType | None = None, return_tensors: str | TensorType | None = None,
max_num_tiles: int | None = None, max_num_tiles: int | None = None,
) -> BatchFeature: ) -> BatchFeature:
# Use default if not provided raise NotImplementedError
if max_num_tiles is None:
max_num_tiles = self.max_num_tiles
text, images = [self._make_batch_input(x) for x in (text, images)]
text, image_inputs = self._preprocess_image(
text=text,
images=images,
max_num_tiles=max_num_tiles,
)
text_inputs = self.tokenizer(text, add_special_tokens=False)
combined_outputs = {**text_inputs, **image_inputs}
return BatchFeature(combined_outputs, tensor_type=return_tensors)
class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
...@@ -431,20 +804,16 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -431,20 +804,16 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
*, *,
max_model_len: int,
max_num_tiles: int | None = None, max_num_tiles: int | None = None,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,
dynamic_image_size: bool | None = None,
video_token: str | None = None, video_token: str | None = None,
video_pruning_rate: float | None = None, video_pruning_rate: float | None = None,
) -> None: ) -> None:
super().__init__( super().__init__(
config=config, config=config,
tokenizer=tokenizer, tokenizer=tokenizer,
max_model_len=max_model_len,
max_num_tiles=max_num_tiles, max_num_tiles=max_num_tiles,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
) )
# add extra video token for video processing # add extra video token for video processing
self.video_token = video_token self.video_token = video_token
...@@ -478,7 +847,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -478,7 +847,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
self, self,
videos: list[npt.NDArray], videos: list[npt.NDArray],
max_num_tiles: int, max_num_tiles: int,
dynamic_image_size: bool | None = None,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
return [ return [
video_to_pixel_values( video_to_pixel_values(
...@@ -495,7 +863,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -495,7 +863,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
text: list[str], text: list[str],
videos: list[tuple[npt.NDArray, dict[str, Any]]], videos: list[tuple[npt.NDArray, dict[str, Any]]],
max_num_tiles: int, max_num_tiles: int,
dynamic_image_size: bool | None = None,
): ):
if len(videos) == 0 or not self.supports_video: if len(videos) == 0 or not self.supports_video:
video_inputs = {} video_inputs = {}
...@@ -505,7 +872,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -505,7 +872,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
pixel_values_lst_video = self._videos_to_pixel_values_lst( pixel_values_lst_video = self._videos_to_pixel_values_lst(
videos_lst, videos_lst,
max_num_tiles=max_num_tiles, max_num_tiles=max_num_tiles,
dynamic_image_size=dynamic_image_size,
) )
# We use frame duration in milliseconds (as integer) to ensure # We use frame duration in milliseconds (as integer) to ensure
...@@ -592,7 +958,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -592,7 +958,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
videos: list[tuple[npt.NDArray, dict[str, Any]]] | None = None, videos: list[tuple[npt.NDArray, dict[str, Any]]] | None = None,
return_tensors: str | TensorType | None = None, return_tensors: str | TensorType | None = None,
max_num_tiles: int | None = None, max_num_tiles: int | None = None,
dynamic_image_size: bool | None = None,
) -> BatchFeature: ) -> BatchFeature:
# Use default if not provided # Use default if not provided
if max_num_tiles is None: if max_num_tiles is None:
...@@ -612,14 +977,23 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -612,14 +977,23 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
text=text, text=text,
videos=videos, videos=videos,
max_num_tiles=1, max_num_tiles=1,
dynamic_image_size=dynamic_image_size,
) )
text_inputs = self.tokenizer(text, add_special_tokens=False) text_inputs = self.tokenizer(text, add_special_tokens=False)
combined_outputs = {**text_inputs, **image_inputs, **video_inputs} if self.dynamic_tiler is None:
batch = BatchFeature(
return BatchFeature(combined_outputs, tensor_type=return_tensors) {**text_inputs, **video_inputs, **image_inputs},
tensor_type=return_tensors,
)
else:
batch = BatchFeature(
{**text_inputs, **video_inputs}, tensor_type=return_tensors
)
# allow images to be exempt from the BatchFeature validation:
# We will .stack() them in _parse_and_validate_image_input
batch.update(image_inputs)
return batch
def get_image_repl( def get_image_repl(
self, self,
...@@ -722,23 +1096,6 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo): ...@@ -722,23 +1096,6 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None} return {"image": None}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
max_num_tiles: int,
processor: BaseNanoNemotronVLProcessor | None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
max_num_tiles=max_num_tiles,
)
def get_image_size_with_most_features(self, max_num_tiles: int) -> ImageSize: def get_image_size_with_most_features(self, max_num_tiles: int) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()
...@@ -749,11 +1106,8 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo): ...@@ -749,11 +1106,8 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
for wr, hr in target_ratios: for wr, hr in target_ratios:
width, height = base_size * wr, base_size * hr width, height = base_size * wr, base_size * hr
feat_size = self.get_num_image_tokens( feat_size = processor.get_num_image_tokens(
image_width=width, image_width=width, image_height=height, max_num_tiles=max_num_tiles
image_height=height,
max_num_tiles=max_num_tiles,
processor=processor,
) )
if feat_size > largest_feature_size: if feat_size > largest_feature_size:
largest_feature_size = feat_size largest_feature_size = feat_size
...@@ -772,11 +1126,10 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo): ...@@ -772,11 +1126,10 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
max_num_tiles max_num_tiles
) )
return self.get_num_image_tokens( return processor.get_num_image_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
max_num_tiles=max_num_tiles, max_num_tiles=max_num_tiles,
processor=processor,
) )
...@@ -822,6 +1175,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): ...@@ -822,6 +1175,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
tokenizer=self.get_tokenizer(), tokenizer=self.get_tokenizer(),
video_token=self.get_video_token(), video_token=self.get_video_token(),
video_pruning_rate=self.get_video_pruning_rate(), video_pruning_rate=self.get_video_pruning_rate(),
max_model_len=self.ctx.model_config.max_model_len,
**kwargs, **kwargs,
) )
...@@ -829,19 +1183,29 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): ...@@ -829,19 +1183,29 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]): class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""Basic image-only MultiModalProcessor for InternVL-style models.""" """Basic image-only MultiModalProcessor for InternVL-style models."""
@cached_property
def is_dynamic_tiler(self) -> bool:
return self.info.get_hf_processor().dynamic_tiler is not None
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) if self.is_dynamic_tiler:
pixel_values_flat = MultiModalFieldConfig.batched("image")
else:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
pixel_values_flat = MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches
)
return dict( return dict(
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( pixel_values_flat=pixel_values_flat,
"image", image_num_patches
),
image_num_patches=MultiModalFieldConfig.batched("image"), image_num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
num_tokens_per_image=MultiModalFieldConfig.batched("image"),
imgs_sizes=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_updates( def _get_prompt_updates(
...@@ -870,17 +1234,19 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -870,17 +1234,19 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
if isinstance(images, ImageEmbeddingItems): if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx) feature_size = images.get_feature_size(item_idx)
elif tiler := hf_processor.dynamic_tiler:
image = images.get(item_idx)
feature_size = tiler.get_cached_feature_size(image)
else: else:
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
# Extract max_num_tiles from kwargs, default to 12 # Extract max_num_tiles from kwargs, default to 12
max_num_tiles = hf_processor_mm_kwargs.get( max_num_tiles = hf_processor_mm_kwargs.get(
"max_num_tiles", hf_processor.max_num_tiles "max_num_tiles", hf_processor.max_num_tiles
) )
feature_size = self.info.get_num_image_tokens( feature_size = hf_processor.get_num_image_tokens(
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
max_num_tiles=max_num_tiles, max_num_tiles=max_num_tiles,
processor=hf_processor,
) )
num_patches = None num_patches = None
...@@ -1017,12 +1383,18 @@ class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]): ...@@ -1017,12 +1383,18 @@ class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
# Use default max_num_tiles for dummy data generation
max_num_tiles = 12
target_width, target_height = self.info.get_image_size_with_most_features(
max_num_tiles
)
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
if tiler := processor.dynamic_tiler:
budget = tiler.max_num_tokens_available(text_prompt_length=num_images)
target_width, target_height = (
tiler.width_and_height_for_max_num_tokens_available(budget)
)
else:
max_num_tiles = 12
target_width, target_height = self.info.get_image_size_with_most_features(
max_num_tiles
)
image_overrides = mm_options.get("image") if mm_options else None image_overrides = mm_options.get("image") if mm_options else None
...@@ -1181,6 +1553,11 @@ class NemotronH_Nano_VL_V2( ...@@ -1181,6 +1553,11 @@ class NemotronH_Nano_VL_V2(
self._img_context_token_ids = tokenizer.encode( self._img_context_token_ids = tokenizer.encode(
IMG_CONTEXT, add_special_tokens=False IMG_CONTEXT, add_special_tokens=False
) )
self.dynamic_resolution = BaseNanoNemotronVLProcessor.use_dynamic_resolution(
config
)
if self.dynamic_resolution:
logger.info("Dynamic resolution is enabled for NanoNemotronVLProcessor")
def pixel_shuffle(self, x, scale_factor=0.5): def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size() n, w, h, c = x.size()
...@@ -1211,7 +1588,51 @@ class NemotronH_Nano_VL_V2( ...@@ -1211,7 +1588,51 @@ class NemotronH_Nano_VL_V2(
x = x.permute(0, 2, 1, 3).contiguous() x = x.permute(0, 2, 1, 3).contiguous()
return x return x
def extract_feature(self, pixel_values): def pixel_shuffle_dynamic_res(
self, x: torch.Tensor, *, imgs_sizes: list[tuple[int, int]]
) -> torch.Tensor:
scale_factor = self.downsample_ratio
patch_dim = self.patch_size
seq_lens = calc_seq_lens(imgs_sizes, patch_dim)
splits = torch.split(x, seq_lens, dim=-2)
out = []
for i, sv in enumerate(splits):
h = imgs_sizes[i][0] // patch_dim
w = imgs_sizes[i][1] // patch_dim
sv = sv.reshape(sv.shape[0], h, w, -1)
n, h, w, c = sv.size()
sv = sv.view(n, h, int(w * scale_factor), int(c / scale_factor))
sv = sv.permute(0, 2, 1, 3).contiguous()
sv = sv.view(
n,
int(w * scale_factor),
int(h * scale_factor),
int(c / (scale_factor * scale_factor)),
)
if self.ps_version == "v2":
sv = sv.permute(0, 2, 1, 3).contiguous()
sv = sv.reshape(sv.shape[0], -1, sv.shape[-1])
out.append(sv)
x = torch.cat(out, dim=-2)
return x
def extract_feature_dynamic(
self, pixel_values: torch.Tensor, imgs_sizes: list[tuple[int, int]]
):
"""Dynamic resolution extract_feature for images."""
_, vit_embeds = self.vision_model(pixel_values, imgs_sizes=imgs_sizes)
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
vit_embeds = self.pixel_shuffle_dynamic_res(vit_embeds, imgs_sizes=imgs_sizes)
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def extract_feature(self, pixel_values: torch.Tensor):
# Process images in a micro-batch of at most 128 frames per call # Process images in a micro-batch of at most 128 frames per call
# This is done on purpose to ensure peak GPU ram usage of huge batch # This is done on purpose to ensure peak GPU ram usage of huge batch
# (namely for really long videos with EVS ON) won't cause any problems # (namely for really long videos with EVS ON) won't cause any problems
...@@ -1239,36 +1660,39 @@ class NemotronH_Nano_VL_V2( ...@@ -1239,36 +1660,39 @@ class NemotronH_Nano_VL_V2(
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
) -> NanoNemotronVLImageInputs | None: ) -> NanoNemotronVLImageInputs | None:
pixel_values_flat = kwargs.pop("pixel_values_flat", None) if image_embeds := kwargs.pop("image_embeds", None):
image_num_patches = kwargs.pop("image_num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None:
return None
if image_embeds is not None:
return NanoNemotronVLImageEmbeddingInputs( return NanoNemotronVLImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=image_embeds, data=image_embeds,
) )
if pixel_values_flat is not None: if self.dynamic_resolution:
return NanoNemotronVLImagePixelInputs( pixel_values_flat = DynamicResolutionImageTiler.stack(
type="pixel_values", kwargs.pop("pixel_values_flat"), self.patch_size
pixel_values_flat=pixel_values_flat,
num_patches=image_num_patches,
) )
return NanoNemotronVLImagePixelInputsDynamic(
pixel_values_flat=pixel_values_flat, **kwargs
)
else:
return NanoNemotronVLImagePixelInputs(**kwargs)
raise AssertionError("This line should be unreachable.") def _process_image_input_dynamic(
self, image_input: NanoNemotronVLImagePixelInputsDynamic
def _process_image_input(
self, image_input: NanoNemotronVLImageInputs
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds": image_embeds = self.extract_feature_dynamic(
return image_input["data"] image_input.pixel_values_flat, image_input.imgs_sizes
)
num_tokens_per_image = image_input.num_tokens_per_image
if len(num_tokens_per_image) == 1:
return (image_embeds.view(-1, self.config.text_config.hidden_size),)
assert self.vision_model is not None image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
return image_embeds.split(num_tokens_per_image)
def _process_image_input(
self, image_input: NanoNemotronVLImagePixelInputs
) -> tuple[torch.Tensor, ...]:
image_embeds = self.extract_feature(image_input["pixel_values_flat"]) image_embeds = self.extract_feature(image_input["pixel_values_flat"])
num_patches = image_input["num_patches"] num_patches = image_input["num_patches"]
...@@ -1470,7 +1894,13 @@ class NemotronH_Nano_VL_V2( ...@@ -1470,7 +1894,13 @@ class NemotronH_Nano_VL_V2(
for modality in modalities: for modality in modalities:
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
image_embeddings = self._process_image_input(image_input) if image_input["type"] == "image_embeds":
image_embeddings = image_input["data"]
elif self.dynamic_resolution:
assert image_input["type"] == "pixel_values_dynamic"
image_embeddings = self._process_image_input_dynamic(image_input)
else:
image_embeddings = self._process_image_input(image_input)
multimodal_embeddings += tuple(image_embeddings) multimodal_embeddings += tuple(image_embeddings)
if modality == "videos": if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
...@@ -1652,33 +2082,6 @@ class NemotronH_Nano_VL_V2( ...@@ -1652,33 +2082,6 @@ class NemotronH_Nano_VL_V2(
if save_to_file and sys.stdout != original_stdout: if save_to_file and sys.stdout != original_stdout:
sys.stdout = original_stdout sys.stdout = original_stdout
def get_model_info(self):
"""
Get basic model information as a dictionary.
"""
total_params = sum(p.numel() for p in self.parameters())
component_info = {}
for name, param in self.named_parameters():
component = name.split(".")[0]
if component not in component_info:
component_info[component] = {"params": 0, "size": 0}
component_info[component]["params"] += 1
component_info[component]["size"] += param.numel()
return {
"model_name": "NemotronH_Nano_VL_V2",
"total_parameters": total_params,
"memory_estimate_mb": total_params * 2 / (1024**2), # bfloat16
"components": component_info,
"config": {
"image_size": getattr(self.config, "force_image_size", None),
"patch_size": getattr(self.config, "patch_size", None),
"num_image_token": self.num_image_token,
"downsample_ratio": self.downsample_ratio,
},
}
def get_vit_model_from_radio_config(self, hf_config): def get_vit_model_from_radio_config(self, hf_config):
hf_config_vision = hf_config.vision_config hf_config_vision = hf_config.vision_config
model_name = hf_config_vision.args.get("model") model_name = hf_config_vision.args.get("model")
......
...@@ -21,7 +21,11 @@ from transformers import PretrainedConfig ...@@ -21,7 +21,11 @@ from transformers import PretrainedConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.intern_vit import InternVisionEncoder from vllm.model_executor.models.intern_vit import (
InternParallelAttention,
InternVisionEncoder,
InternVisionEncoderLayer,
)
input_dim_t: TypeAlias = int | tuple[int, int] input_dim_t: TypeAlias = int | tuple[int, int]
norm_t: TypeAlias = tuple[float, float, float] | torch.Tensor norm_t: TypeAlias = tuple[float, float, float] | torch.Tensor
...@@ -43,6 +47,15 @@ to_4tuple = _ntuple(4) ...@@ -43,6 +47,15 @@ to_4tuple = _ntuple(4)
to_ntuple = _ntuple to_ntuple = _ntuple
def calc_seq_len(size: tuple[int, int], patch_size: int) -> int:
h, w = size
return (h // patch_size) * (w // patch_size)
def calc_seq_lens(sizes: list[tuple[int, int]], patch_size: int) -> list[int]:
return [calc_seq_len(size, patch_size) for size in sizes]
class ClsToken(nn.Module): class ClsToken(nn.Module):
def __init__( def __init__(
self, self,
...@@ -164,15 +177,73 @@ class ViTPatchGenerator(nn.Module): ...@@ -164,15 +177,73 @@ class ViTPatchGenerator(nn.Module):
nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity() nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(
patches = self.embed_patches(x) self, x: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:]) ) -> torch.Tensor:
patches = self.cls_token(patches) if imgs_sizes is not None:
patches = self.embedder(x)
patches, pos_enc = self.apply_pos_enc_dynamic(
patches, imgs_sizes=imgs_sizes
)
patches = self.cls_token_dynamic(patches, imgs_sizes=imgs_sizes)
else:
patches = self.embed_patches(x)
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
patches = self.cls_token(patches)
patches = self.patch_normalizer(patches) patches = self.patch_normalizer(patches)
if self.return_pos_enc: if self.return_pos_enc:
return patches, pos_enc return patches, pos_enc
return patches return patches
def apply_pos_enc_dynamic(
self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
) -> tuple[torch.Tensor, torch.Tensor | None]:
if not self.abs_pos:
return patches, None
current_length = 0
pos_enc_list = []
for size in imgs_sizes:
seq_length = calc_seq_len(size, self.patch_size)
img_patches = patches[:, current_length : current_length + seq_length, :]
pos_enc = self.get_pos_enc(patches.shape[0], input_size=size)
img_patches_with_pos = img_patches + pos_enc
patches = torch.cat(
[
patches[:, :current_length, :],
img_patches_with_pos,
patches[:, current_length + seq_length :, :],
],
dim=1,
)
pos_enc_list.append(pos_enc)
current_length += seq_length
full_pos_enc = torch.cat(pos_enc_list, dim=1) if pos_enc_list else None
return patches, full_pos_enc
def cls_token_dynamic(
self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
) -> torch.Tensor:
if not self.cls_token.enabled:
return patches
out = []
current_length = 0
for seq_len in calc_seq_lens(imgs_sizes, self.patch_size):
class_token = self.cls_token.token.unsqueeze(0).expand(
patches.shape[0], -1, -1
)
out.append(class_token)
out.append(patches[:, current_length : current_length + seq_len, :])
current_length += seq_len
return torch.cat(out, dim=1)
@property @property
def apply_cls_token(self): def apply_cls_token(self):
return self.cls_token.enabled return self.cls_token.enabled
...@@ -406,6 +477,66 @@ class ViTPatchLinear(nn.Linear): ...@@ -406,6 +477,66 @@ class ViTPatchLinear(nn.Linear):
self.patch_size = patch_size self.patch_size = patch_size
class RadioParallelAttention(InternParallelAttention):
def forward(
self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
) -> torch.Tensor:
if attn_mask is None:
return super().forward(x)
B, N, _ = x.shape
qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
q, k, v = (t.transpose(1, 2) for t in (q, k, v))
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, scale=self.scale
)
out = out.transpose(1, 2).reshape(B, N, -1)
out, _ = self.proj(out)
return out
class RadioVisionEncoderLayer(InternVisionEncoderLayer):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, attn_cls=RadioParallelAttention, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
attn_mask: torch.Tensor | None = None,
):
hidden_states = (
hidden_states
+ self.attn(self.norm1(hidden_states), attn_mask=attn_mask) * self.ls1
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2
return hidden_states
class RadioVisionEncoder(InternVisionEncoder):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, layer_cls=RadioVisionEncoderLayer, **kwargs)
def forward(
self,
inputs_embeds: torch.Tensor,
attn_mask: torch.Tensor | None = None,
):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states, attn_mask=attn_mask)
return hidden_states
class RadioInternVisionModel(nn.Module): class RadioInternVisionModel(nn.Module):
packed_modules_mapping = { packed_modules_mapping = {
"qkv": ["qkv"], "qkv": ["qkv"],
...@@ -440,7 +571,7 @@ class RadioInternVisionModel(nn.Module): ...@@ -440,7 +571,7 @@ class RadioInternVisionModel(nn.Module):
register_multiple=config.register_multiple, register_multiple=config.register_multiple,
) )
self.encoder = InternVisionEncoder( self.encoder = RadioVisionEncoder(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
...@@ -459,10 +590,45 @@ class RadioInternVisionModel(nn.Module): ...@@ -459,10 +590,45 @@ class RadioInternVisionModel(nn.Module):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
def forward(self, x: torch.Tensor) -> torch.FloatTensor: def create_inter_image_attention_mask(
self, imgs_sizes: list[tuple[int, int]], device: torch.device
) -> torch.Tensor:
patch_size = self.patch_generator.patch_size
num_skip = self.patch_generator.num_skip
seq_lens = calc_seq_lens(imgs_sizes, patch_size)
patch_counts = [seq_len + num_skip for seq_len in seq_lens]
total_patches = sum(patch_counts)
# Create attention mask - default to False (mask out)
mask = torch.zeros(
total_patches, total_patches, dtype=torch.bool, device=device
)
# Each image's patches can only attend to patches from the same image
start_idx = 0
for patch_count in patch_counts:
end_idx = start_idx + patch_count
# Allow attention within this image's patches
mask[start_idx:end_idx, start_idx:end_idx] = True
start_idx = end_idx
return mask
def forward(
self,
x: torch.Tensor,
imgs_sizes: torch.Tensor | None = None,
) -> torch.FloatTensor:
assert self.patch_generator is not None assert self.patch_generator is not None
hidden_states = self.patch_generator(x) hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
encoder_outputs = self.encoder(inputs_embeds=hidden_states) attn_mask = None
if imgs_sizes is not None and len(imgs_sizes) > 1:
# Dynamic Resolution
attn_mask = self.create_inter_image_attention_mask(
imgs_sizes, device=x.device
)
encoder_outputs = self.encoder(inputs_embeds=hidden_states, attn_mask=attn_mask)
return encoder_outputs return encoder_outputs
...@@ -504,9 +670,11 @@ class RadioModel(nn.Module): ...@@ -504,9 +670,11 @@ class RadioModel(nn.Module):
self, self,
pixel_values: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None,
pixel_embeds: torch.Tensor | None = None, pixel_embeds: torch.Tensor | None = None,
*,
imgs_sizes: torch.Tensor | None = None,
) -> tuple[torch.FloatTensor, torch.FloatTensor]: ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
y = self.model(pixel_values) y = self.model(pixel_values, imgs_sizes=imgs_sizes)
return self._extract_final(y) return self._extract_final(y, imgs_sizes=imgs_sizes)
def load_weights(self, weights) -> set[str]: def load_weights(self, weights) -> set[str]:
loaded_params: set[str] = set() loaded_params: set[str] = set()
...@@ -558,16 +726,32 @@ class RadioModel(nn.Module): ...@@ -558,16 +726,32 @@ class RadioModel(nn.Module):
return loaded_params return loaded_params
def _extract_final( def _extract_final(
self, y: torch.Tensor self, y: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
) -> tuple[torch.FloatTensor, torch.FloatTensor]: ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
# Remove CLS + REGISTERS tokens # Remove CLS + REGISTERS tokens
patch_gen = getattr(self.model, "patch_generator", None) num_skip = self.model.patch_generator.num_skip
if patch_gen is not None: patch_size = self.model.patch_generator.patch_size
all_summary = y[:, : patch_gen.num_cls_tokens] num_cls_tokens = self.model.patch_generator.num_cls_tokens
if self.summary_idxs is not None: if imgs_sizes is None:
bb_summary = all_summary[:, self.summary_idxs] all_summary = y[:, :num_cls_tokens]
else: all_feat = y[:, num_skip:]
bb_summary = all_summary else:
all_feat = y[:, patch_gen.num_skip :] all_patches = []
summaries = []
current_pos = 0
for num_patches in calc_seq_lens(imgs_sizes, patch_size):
patches = y[
:, current_pos + num_skip : current_pos + num_skip + num_patches, :
]
all_patches.append(patches)
summary = y[:, current_pos : current_pos + num_cls_tokens, :]
summaries.append(summary)
current_pos += num_skip + num_patches
all_summary = torch.cat(summaries, dim=1)
all_feat = torch.cat(all_patches, dim=1)
if self.summary_idxs is not None:
bb_summary = all_summary[:, self.summary_idxs]
else:
bb_summary = all_summary
return bb_summary.flatten(1), all_feat return bb_summary.flatten(1), all_feat
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment