Unverified Commit 0b6d5262 authored by Collin McCarthy's avatar Collin McCarthy Committed by GitHub
Browse files

Support temporal compression for Nemotron-3-VL videos (#36808)


Signed-off-by: default avatarCollin McCarthy <cmccarthy@nvidia.com>
parent d3cc3795
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# -------------------------------------------------------- # --------------------------------------------------------
import copy import copy
import math
import warnings import warnings
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
...@@ -77,6 +78,7 @@ from vllm.renderers import TokenizeParams ...@@ -77,6 +78,7 @@ from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers import cached_tokenizer_from_config
from vllm.transformers_utils.configs.radio import RadioConfig from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.transformers_utils.processors.internvl import get_internvl_target_ratios
from vllm.transformers_utils.processors.nano_nemotron_vl import ( from vllm.transformers_utils.processors.nano_nemotron_vl import (
AUDIO_CONTEXT, AUDIO_CONTEXT,
IMG_CONTEXT, IMG_CONTEXT,
...@@ -85,7 +87,7 @@ from vllm.transformers_utils.processors.nano_nemotron_vl import ( ...@@ -85,7 +87,7 @@ from vllm.transformers_utils.processors.nano_nemotron_vl import (
BaseNanoNemotronVLProcessor, BaseNanoNemotronVLProcessor,
DynamicResolutionImageTiler, DynamicResolutionImageTiler,
NanoNemotronVLProcessor, NanoNemotronVLProcessor,
get_internvl_target_ratios, get_video_target_size_and_feature_size,
) )
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -295,10 +297,13 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): ...@@ -295,10 +297,13 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
max_videos = mm_counts.get("video", 0) max_videos = mm_counts.get("video", 0)
processor = self.get_hf_processor() # we get the CustomProcessor here processor = self.get_hf_processor() # we get the CustomProcessor here
T = processor.video_temporal_patch_size
max_image_tokens = self.get_max_image_tokens() * max_images max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token tokens_per_tubelet = processor.num_video_token
max_frames_per_video = max_total_frames // max(max_videos, 1) max_total_tubelets = (seq_len - max_image_tokens) // tokens_per_tubelet
max_tubelets_per_video = max_total_tubelets // max(max_videos, 1)
max_frames_per_video = max_tubelets_per_video * T
return max(max_frames_per_video, 1) return max(max_frames_per_video, 1)
def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor: def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor:
...@@ -589,28 +594,49 @@ class NanoNemotronVLMultiModalProcessor( ...@@ -589,28 +594,49 @@ class NanoNemotronVLMultiModalProcessor(
video_num_patches = [] video_num_patches = []
def get_video_replacement_internvl(item_idx: int): def get_video_replacement_internvl(item_idx: int):
feature_size = hf_processor.num_image_token
video, metadata = mm_items["video"][item_idx] video, metadata = mm_items["video"][item_idx]
patch_size = hf_processor.config.patch_size
downsample_ratio = hf_processor.config.downsample_ratio
target_patches = hf_processor.video_target_num_patches
if target_patches is not None and video is not None and video.shape[0] > 0:
orig_h, orig_w = video.shape[1], video.shape[2]
_, _, feature_size = get_video_target_size_and_feature_size(
orig_w=orig_w,
orig_h=orig_h,
target_patches=target_patches,
maintain_aspect_ratio=hf_processor.video_maintain_aspect_ratio,
patch_size=patch_size,
downsample_ratio=downsample_ratio,
)
else:
feature_size = hf_processor.num_image_token
num_patches = video_num_patches[item_idx] num_patches = video_num_patches[item_idx]
if num_patches is not None: if num_patches is not None:
assert isinstance(num_patches, int) assert isinstance(num_patches, int)
T = hf_processor.video_temporal_patch_size
if T > 1 and num_patches is not None:
num_tubelets = math.ceil(num_patches / T)
else:
num_tubelets = num_patches
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
if video_pruning_rate is not None and video_pruning_rate > 0.0: if video_pruning_rate is not None and video_pruning_rate > 0.0:
# Start of EVS-specific code # Start of EVS-specific code
num_tokens = compute_retained_tokens_count( num_tokens = compute_retained_tokens_count(
tokens_per_frame=feature_size, tokens_per_frame=feature_size,
num_frames=num_patches, num_frames=num_tubelets,
q=video_pruning_rate, q=video_pruning_rate,
) )
# Here we just need placeholders that won't actually be replaced - # Here we just need placeholders that won't actually be replaced -
# we just need to make sure the total number of tokens is correct # we just need to make sure the total number of tokens is correct
# assign all tokens to the first frame # assign all tokens to the first frame
tokens_per_frame = [num_tokens] + [0] * (num_patches - 1) tokens_per_frame = [num_tokens] + [0] * (num_tubelets - 1)
# End of EVS-specific code # End of EVS-specific code
else: else:
tokens_per_frame = [feature_size] * num_patches tokens_per_frame = [feature_size] * num_tubelets
frame_duration_ms = int(1000 / metadata["fps"]) frame_duration_ms = int(1000 / metadata["fps"])
return hf_processor.get_video_repl( return hf_processor.get_video_repl(
...@@ -621,6 +647,7 @@ class NanoNemotronVLMultiModalProcessor( ...@@ -621,6 +647,7 @@ class NanoNemotronVLMultiModalProcessor(
img_start_token_ids=hf_processor._img_start_token_ids, img_start_token_ids=hf_processor._img_start_token_ids,
img_end_token_ids=hf_processor._img_end_token_ids, img_end_token_ids=hf_processor._img_end_token_ids,
img_context_token_ids=hf_processor._img_context_token_ids, img_context_token_ids=hf_processor._img_context_token_ids,
video_temporal_patch_size=T,
) )
if self.info.supports_video: if self.info.supports_video:
...@@ -745,15 +772,39 @@ class NanoNemotronVLDummyInputsBuilder( ...@@ -745,15 +772,39 @@ class NanoNemotronVLDummyInputsBuilder(
if self.info.supports_video: if self.info.supports_video:
config = self.info.get_hf_config() config = self.info.get_hf_config()
image_size: int = config.force_image_size image_size: int = config.force_image_size
processor = self.info.get_hf_processor()
# When video_target_num_patches is set the per-frame pixel
# resolution can exceed image_size. Use the actual target
# dimensions so that profiling sees the correct upper bound.
if processor.video_target_num_patches is not None:
target_w, target_h, _ = get_video_target_size_and_feature_size(
orig_w=image_size,
orig_h=image_size,
target_patches=processor.video_target_num_patches,
maintain_aspect_ratio=processor.video_maintain_aspect_ratio,
patch_size=config.patch_size,
downsample_ratio=config.downsample_ratio,
)
video_width, video_height = target_w, target_h
else:
video_width, video_height = image_size, image_size
target_num_frames = self.info.get_num_frames_with_most_features( target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts seq_len, mm_counts
) )
mm_config = self.info.ctx.get_mm_config()
if num_frames := mm_config.media_io_kwargs.get("video", {}).get(
"num_frames"
):
assert num_frames > 0
target_num_frames = num_frames
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
video_overrides = mm_options.get("video") video_overrides = mm_options.get("video")
dummy_video = { dummy_video = {
"video": self._get_dummy_videos( "video": self._get_dummy_videos(
width=image_size, width=video_width,
height=image_size, height=video_height,
num_frames=target_num_frames, num_frames=target_num_frames,
num_videos=num_videos, num_videos=num_videos,
overrides=video_overrides, overrides=video_overrides,
...@@ -790,6 +841,9 @@ class NanoNemotronVLDummyInputsBuilder( ...@@ -790,6 +841,9 @@ class NanoNemotronVLDummyInputsBuilder(
class NemotronH_Nano_VL_V2( class NemotronH_Nano_VL_V2(
nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning
): ):
requires_sequential_video_encoding = True
"""Temporarily needed for dynamic res video w/ conv3d, doesn't support bs>1 yet"""
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"): if modality.startswith("image"):
...@@ -817,6 +871,11 @@ class NemotronH_Nano_VL_V2( ...@@ -817,6 +871,11 @@ class NemotronH_Nano_VL_V2(
self.image_tag_type = config.image_tag_type self.image_tag_type = config.image_tag_type
self.video_pruning_rate = multimodal_config.video_pruning_rate self.video_pruning_rate = multimodal_config.video_pruning_rate
vision_config = getattr(config, "vision_config", config)
self.video_temporal_patch_size: int = getattr(
vision_config, "video_temporal_patch_size", 1
)
with self._mark_language_model(vllm_config): with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
...@@ -838,11 +897,12 @@ class NemotronH_Nano_VL_V2( ...@@ -838,11 +897,12 @@ class NemotronH_Nano_VL_V2(
mlp1 = nn.Sequential( mlp1 = nn.Sequential(
RMSNorm( RMSNorm(
hidden_size=vit_hidden_size * int(1 / self.downsample_ratio) ** 2, hidden_size=vit_hidden_size
* int(round(1 / self.downsample_ratio)) ** 2,
eps=1e-5, eps=1e-5,
), ),
nn.Linear( nn.Linear(
vit_hidden_size * int(1 / self.downsample_ratio) ** 2, vit_hidden_size * int(round(1 / self.downsample_ratio)) ** 2,
vision_projection_hidden_size, vision_projection_hidden_size,
bias=False, bias=False,
), ),
...@@ -958,19 +1018,37 @@ class NemotronH_Nano_VL_V2( ...@@ -958,19 +1018,37 @@ class NemotronH_Nano_VL_V2(
vit_embeds = self.mlp1(vit_embeds) vit_embeds = self.mlp1(vit_embeds)
return vit_embeds return vit_embeds
def extract_feature(self, pixel_values: torch.Tensor): def extract_feature(
self,
pixel_values: torch.Tensor,
num_frames: int | None = None,
) -> torch.Tensor:
# Process images in a micro-batch of at most 128 frames per call # Process images in a micro-batch of at most 128 frames per call
# This is done on purpose to ensure peak GPU ram usage of huge batch # This is done on purpose to ensure peak GPU ram usage of huge batch
# (namely for really long videos with EVS ON) won't cause any problems # (namely for really long videos with EVS ON) won't cause any problems
# as we don't support chunked prefill for video media # as we don't support chunked prefill for video media
micro_batch_size = 128 # When num_frames is provided and temporal_patch_size > 1, consecutive
n = pixel_values.shape[0] # frames are grouped into tubelets — the batch size must be a multiple
# of T so chunk boundaries don't split a tubelet.
N, _C, H, W = pixel_values.shape
T = self.video_temporal_patch_size if num_frames is not None else 1
micro_batch_size = 128 - (128 % T)
patch_size = self.patch_size
H_patches = H // patch_size
W_patches = W // patch_size
vit_embeds_list = [] vit_embeds_list = []
for i in range(0, n, micro_batch_size): for i in range(0, N, micro_batch_size):
_, vit_embeds = self.vision_model(pixel_values[i : i + micro_batch_size]) chunk = pixel_values[i : i + micro_batch_size]
if num_frames is not None and T > 1:
_, vit_embeds = self.vision_model(chunk, num_frames=chunk.shape[0])
else:
_, vit_embeds = self.vision_model(chunk)
vit_embeds = vit_embeds.to(dtype=torch.bfloat16) vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds.shape[0], H_patches, W_patches, -1
)
vit_embeds = self.pixel_shuffle( vit_embeds = self.pixel_shuffle(
vit_embeds, scale_factor=self.downsample_ratio vit_embeds, scale_factor=self.downsample_ratio
) )
...@@ -1042,16 +1120,21 @@ class NemotronH_Nano_VL_V2( ...@@ -1042,16 +1120,21 @@ class NemotronH_Nano_VL_V2(
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
"""Process video input and create final embeddings with video content """Process video input and create final embeddings with video content
and indicator tokens.""" and indicator tokens."""
# Get video embeddings using the same processing as images T = self.video_temporal_patch_size
video_embeddings = self._process_image_input(video_input)
if T > 1:
video_embeddings = self._extract_video_embeddings_temporal(video_input)
else:
video_embeddings = self._process_image_input(video_input)
final_video_embeddings: tuple[torch.Tensor, ...] = () final_video_embeddings: tuple[torch.Tensor, ...] = ()
image_rows = image_cols = self.config.force_image_size
downsample_ratio = self.config.downsample_ratio downsample_ratio = self.config.downsample_ratio
patch_size = self.config.patch_size patch_size = self.config.patch_size
rows = int(image_rows * downsample_ratio // patch_size) pixel_values = video_input["pixel_values_flat"]
cols = int(image_cols * downsample_ratio // patch_size) frame_h, frame_w = pixel_values.shape[-2], pixel_values.shape[-1]
rows = int(frame_h * downsample_ratio // patch_size)
cols = int(frame_w * downsample_ratio // patch_size)
video_pruning_rate = self.video_pruning_rate video_pruning_rate = self.video_pruning_rate
video_num_frames = video_input["num_patches"].tolist() video_num_frames = video_input["num_patches"].tolist()
video_frames_indices = video_input["frames_indices"].split(video_num_frames) video_frames_indices = video_input["frames_indices"].split(video_num_frames)
...@@ -1062,13 +1145,14 @@ class NemotronH_Nano_VL_V2( ...@@ -1062,13 +1145,14 @@ class NemotronH_Nano_VL_V2(
num_frames = video_num_frames[i] num_frames = video_num_frames[i]
frames_indices = video_frames_indices[i].tolist() frames_indices = video_frames_indices[i].tolist()
frame_duration_ms = video_input["frame_duration_ms"][i].item() frame_duration_ms = video_input["frame_duration_ms"][i].item()
assert single_video_embeddings.shape[0] % num_frames == 0 num_tubelets = math.ceil(num_frames / T) if T > 1 else num_frames
assert single_video_embeddings.shape[0] % num_tubelets == 0
if video_pruning_rate is not None and video_pruning_rate > 0.0: if video_pruning_rate is not None and video_pruning_rate > 0.0:
# Start of EVS-specific code # Start of EVS-specific code
retention_mask = compute_retention_mask( retention_mask = compute_retention_mask(
single_video_embeddings, single_video_embeddings,
video_size_thw=(num_frames, rows, cols), video_size_thw=(num_tubelets, rows, cols),
spatial_merge_size=1, spatial_merge_size=1,
q=video_pruning_rate, q=video_pruning_rate,
) )
...@@ -1077,14 +1161,14 @@ class NemotronH_Nano_VL_V2( ...@@ -1077,14 +1161,14 @@ class NemotronH_Nano_VL_V2(
single_video_embeddings = single_video_embeddings[retention_mask] single_video_embeddings = single_video_embeddings[retention_mask]
# calculate the actual number of retained tokens per frame # calculate the actual number of retained tokens per frame
retention_mask_thw = retention_mask.reshape(num_frames, rows, cols) retention_mask_thw = retention_mask.reshape(num_tubelets, rows, cols)
num_tokens_per_frame = ( num_tokens_per_frame = (
retention_mask_thw.sum(dim=(1, 2)).long().tolist() retention_mask_thw.sum(dim=(1, 2)).long().tolist()
) )
# End of EVS-specific code # End of EVS-specific code
else: else:
feature_size = single_video_embeddings.shape[0] // num_frames feature_size = single_video_embeddings.shape[0] // num_tubelets
num_tokens_per_frame = [feature_size] * num_frames num_tokens_per_frame = [feature_size] * num_tubelets
final_video_embeddings += ( final_video_embeddings += (
self._create_final_video_embeddings( self._create_final_video_embeddings(
...@@ -1092,11 +1176,36 @@ class NemotronH_Nano_VL_V2( ...@@ -1092,11 +1176,36 @@ class NemotronH_Nano_VL_V2(
num_tokens_per_frame, num_tokens_per_frame,
frames_indices, frames_indices,
frame_duration_ms, frame_duration_ms,
video_temporal_patch_size=T,
), ),
) )
return final_video_embeddings return final_video_embeddings
def _extract_video_embeddings_temporal(
self, video_input: NanoNemotronVLVideoPixelInputs
) -> tuple[torch.Tensor, ...]:
"""Extract per-video embeddings with temporal compression.
Each video is processed separately through extract_feature with
num_frames, which uses the fixed-resolution temporal path in RADIO
(no attention mask, flash attention).
"""
pixel_values = video_input["pixel_values_flat"]
num_frames_per_video = video_input["num_patches"].tolist()
hidden_size = self.config.text_config.hidden_size
results: list[torch.Tensor] = []
frame_offset = 0
for nf in num_frames_per_video:
video_frames = pixel_values[frame_offset : frame_offset + nf]
frame_offset += nf
vit_embeds = self.extract_feature(video_frames, num_frames=nf)
results.append(vit_embeds.view(-1, hidden_size))
return tuple(results)
def _process_audio_input( def _process_audio_input(
self, audio_input: NanoNemotronVLAudioFeatureInputs self, audio_input: NanoNemotronVLAudioFeatureInputs
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
...@@ -1134,6 +1243,7 @@ class NemotronH_Nano_VL_V2( ...@@ -1134,6 +1243,7 @@ class NemotronH_Nano_VL_V2(
num_tokens_per_frame: list[int], num_tokens_per_frame: list[int],
frames_indices: list[int], frames_indices: list[int],
frame_duration_ms: int, frame_duration_ms: int,
video_temporal_patch_size: int = 1,
) -> torch.Tensor: ) -> torch.Tensor:
"""Create final embeddings that combine video embeddings with """Create final embeddings that combine video embeddings with
text embeddings of indicator tokens. text embeddings of indicator tokens.
...@@ -1161,6 +1271,7 @@ class NemotronH_Nano_VL_V2( ...@@ -1161,6 +1271,7 @@ class NemotronH_Nano_VL_V2(
img_start_token_ids=self._img_start_token_ids, img_start_token_ids=self._img_start_token_ids,
img_end_token_ids=self._img_end_token_ids, img_end_token_ids=self._img_end_token_ids,
img_context_token_ids=self._img_context_token_ids, img_context_token_ids=self._img_context_token_ids,
video_temporal_patch_size=video_temporal_patch_size,
) )
# video_repl.full is a list of token IDs # video_repl.full is a list of token IDs
...@@ -1207,8 +1318,27 @@ class NemotronH_Nano_VL_V2( ...@@ -1207,8 +1318,27 @@ class NemotronH_Nano_VL_V2(
else: else:
frames_indices = torch.cat([f.flatten() for f in frames_indices], dim=0) frames_indices = torch.cat([f.flatten() for f in frames_indices], dim=0)
frame_duration_ms = frame_duration_ms.flatten() if torch.is_tensor(frame_duration_ms):
expected_h = expected_w = self.config.force_image_size frame_duration_ms = frame_duration_ms.flatten()
else:
frame_duration_ms = torch.cat(
[f.flatten() for f in frame_duration_ms], dim=0
)
if (
torch.is_tensor(pixel_values_flat_video)
and pixel_values_flat_video.ndim == 5
):
# batched._reduce_data stacked same-shape videos into
# [num_videos, nf, 3, H, W]; unstack back to a list so the
# same-H,W cat path below handles it uniformly.
pixel_values_flat_video = list(pixel_values_flat_video)
if not torch.is_tensor(pixel_values_flat_video):
pixel_values_flat_video = torch.cat(pixel_values_flat_video, dim=0)
expected_h = pixel_values_flat_video.shape[-2]
expected_w = pixel_values_flat_video.shape[-1]
num_frames = video_num_patches[0].item() num_frames = video_num_patches[0].item()
resolve_bindings = {"h": expected_h, "w": expected_w, "f": num_frames} resolve_bindings = {"h": expected_h, "w": expected_w, "f": num_frames}
...@@ -1361,8 +1491,7 @@ class NemotronH_Nano_VL_V2( ...@@ -1361,8 +1491,7 @@ class NemotronH_Nano_VL_V2(
self.language_model.load_weights(llm_weights) self.language_model.load_weights(llm_weights)
self.vision_model.load_weights(vision_weights) self.vision_model.load_weights(vision_weights)
if self.sound_encoder is not None: if self.sound_encoder is not None and len(sound_weights) > 0:
assert len(sound_weights) > 0
self.sound_encoder.load_weights(sound_weights) self.sound_encoder.load_weights(sound_weights)
def get_vit_model_from_radio_config(self, hf_config): def get_vit_model_from_radio_config(self, hf_config):
...@@ -1375,12 +1504,23 @@ class NemotronH_Nano_VL_V2( ...@@ -1375,12 +1504,23 @@ class NemotronH_Nano_VL_V2(
image_size = preferred_resolution[0] if preferred_resolution else 224 image_size = preferred_resolution[0] if preferred_resolution else 224
patch_size = getattr(hf_config_vision, "patch_size", 16) patch_size = getattr(hf_config_vision, "patch_size", 16)
# video_temporal_patch_size and separate_video_embedder are
# top-level vision_config attributes, not inside args.
video_temporal_patch_size = getattr(
hf_config_vision, "video_temporal_patch_size", 1
)
separate_video_embedder = getattr(
hf_config_vision, "separate_video_embedder", True
)
radio_config = RadioConfig( radio_config = RadioConfig(
model_name=model_name, model_name=model_name,
image_size=image_size, image_size=image_size,
patch_size=patch_size, patch_size=patch_size,
norm_mean=hf_config.norm_mean, norm_mean=hf_config.norm_mean,
norm_std=hf_config.norm_std, norm_std=hf_config.norm_std,
video_temporal_patch_size=video_temporal_patch_size,
separate_video_embedder=separate_video_embedder,
**hf_config_vision.args, **hf_config_vision.args,
) )
......
...@@ -123,6 +123,8 @@ class ViTPatchGenerator(nn.Module): ...@@ -123,6 +123,8 @@ class ViTPatchGenerator(nn.Module):
register_multiple: int | None = None, register_multiple: int | None = None,
num_registers: int | None = None, num_registers: int | None = None,
patch_bias: bool = False, patch_bias: bool = False,
temporal_patch_size: int = 1,
separate_video_embedder: bool = True,
device=None, device=None,
dtype=None, dtype=None,
): ):
...@@ -148,6 +150,7 @@ class ViTPatchGenerator(nn.Module): ...@@ -148,6 +150,7 @@ class ViTPatchGenerator(nn.Module):
self.patch_size = patch_size self.patch_size = patch_size
self.abs_pos = abs_pos self.abs_pos = abs_pos
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.temporal_patch_size = temporal_patch_size
self.num_rows = max_input_dims[0] // patch_size self.num_rows = max_input_dims[0] // patch_size
self.num_cols = max_input_dims[1] // patch_size self.num_cols = max_input_dims[1] // patch_size
...@@ -160,6 +163,21 @@ class ViTPatchGenerator(nn.Module): ...@@ -160,6 +163,21 @@ class ViTPatchGenerator(nn.Module):
patch_size, embed_dim, bias=patch_bias, **factory patch_size, embed_dim, bias=patch_bias, **factory
) )
if temporal_patch_size > 1:
if not separate_video_embedder:
raise NotImplementedError(
"Only separate_video_embedder=True is supported for"
" temporal compression (temporal_patch_size > 1)"
)
self.video_embedder = ViTPatchLinear(
patch_size,
embed_dim,
bias=patch_bias,
temporal_patch_size=temporal_patch_size,
**factory,
)
self._video_embedder_loaded = False
if abs_pos: if abs_pos:
scale = embed_dim**-0.5 scale = embed_dim**-0.5
self.pos_embed = nn.Parameter( self.pos_embed = nn.Parameter(
...@@ -196,6 +214,60 @@ class ViTPatchGenerator(nn.Module): ...@@ -196,6 +214,60 @@ class ViTPatchGenerator(nn.Module):
return patches, pos_enc return patches, pos_enc
return patches return patches
def forward_video(self, x: torch.Tensor) -> torch.Tensor:
"""Process video frames with temporal compression.
Groups T consecutive frames into tubelets before embedding.
Args:
x: [num_frames, 3, H, W] tensor of video frames
Returns:
Embedded patches with temporal compression applied.
"""
if not self._video_embedder_loaded:
raise ValueError(
"Temporal compression (video_temporal_patch_size > 1) requires "
"video_embedder weights, but they were never loaded. "
"Ensure the checkpoint was trained with temporal compression."
)
T = self.temporal_patch_size
input_size = x.shape[2:]
patches = self.im_to_patches(x) # [N, num_patches, 3*P*P]
num_frames, num_spatial, feat_dim = patches.shape
# Pad to a multiple of T by repeating the last frame so that
# all tubelets have exactly T frames.
num_pad_frames = (-num_frames) % T
if num_pad_frames > 0:
last_frame_dup = patches[-1:].expand(num_pad_frames, -1, -1)
patches = torch.cat([patches, last_frame_dup], dim=0)
# Group T frames per tubelet: for each spatial position, concatenate
# features across T consecutive frames; order follows Megatron training
num_frames_padded = patches.shape[0]
num_tublets = num_frames_padded // T
patches = rearrange(
patches,
"(tubelets frames) spatial feat -> tubelets spatial (frames feat)",
tubelets=num_tublets,
frames=T,
spatial=num_spatial,
feat=feat_dim,
)
patches = self.video_embedder(patches)
patches, pos_enc = self.apply_pos_enc(patches, input_size=input_size)
patches = self.cls_token(patches)
patches = self.patch_normalizer(patches)
if self.return_pos_enc:
return patches, pos_enc
return patches
def apply_pos_enc_dynamic( def apply_pos_enc_dynamic(
self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]] self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
) -> tuple[torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
...@@ -381,66 +453,21 @@ class ViTPatchGenerator(nn.Module): ...@@ -381,66 +453,21 @@ class ViTPatchGenerator(nn.Module):
return pos_embed return pos_embed
if self.cpe_mode: if self.cpe_mode:
if self.training: max_dim = max(input_dims)
min_scale = math.sqrt(0.1) pos_embed = F.interpolate(
scale = ( pos_embed.float(),
torch.rand(batch_size, 1, 1, device=pos_embed.device) size=(max_dim, max_dim),
* (1 - min_scale) align_corners=False,
+ min_scale mode="bilinear",
) ).to(pos_embed.dtype)
aspect_min = math.log(3 / 4)
aspect_max = -aspect_min
aspect = torch.exp(
torch.rand(batch_size, 1, 1, device=pos_embed.device)
* (aspect_max - aspect_min)
+ aspect_min
)
scale_x = scale * aspect
scale_y = scale * (1 / aspect)
scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (
1 - scale_xy
)
lin_x = torch.linspace( pos_embed = window_select(pos_embed)
0, 1, steps=input_dims[1], device=pos_embed.device
)[None, None].expand(batch_size, input_dims[0], -1)
lin_y = torch.linspace(
0, 1, steps=input_dims[0], device=pos_embed.device
)[None, :, None].expand(batch_size, -1, input_dims[1])
lin_xy = torch.stack([lin_x, lin_y], dim=-1)
grid_xy = lin_xy * scale_xy + pos_xy
# Convert to [-1, 1] range
grid_xy.mul_(2).sub_(1)
pos_embed = F.grid_sample(
pos_embed.float().expand(batch_size, -1, -1, -1),
grid=grid_xy,
mode="bilinear",
padding_mode="zeros",
align_corners=True,
).to(pos_embed.dtype)
else:
max_dim = max(input_dims)
pos_embed = F.interpolate(
pos_embed.float(),
size=(max_dim, max_dim),
align_corners=True,
mode="bilinear",
).to(pos_embed.dtype)
pos_embed = window_select(pos_embed)
else: else:
pos_embed = window_select(pos_embed) pos_embed = window_select(pos_embed)
if pos_embed.shape[-2:] != input_dims: if pos_embed.shape[-2:] != input_dims:
pos_embed = F.interpolate( pos_embed = F.interpolate(
pos_embed.float(), size=input_dims, align_corners=True, mode="bilinear" pos_embed.float(), size=input_dims, align_corners=False, mode="bilinear"
).to(pos_embed.dtype) ).to(pos_embed.dtype)
pos_embed = pos_embed.flatten(2).permute(0, 2, 1) pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
...@@ -473,9 +500,19 @@ class Im2Patches(nn.Module): ...@@ -473,9 +500,19 @@ class Im2Patches(nn.Module):
class ViTPatchLinear(nn.Linear): class ViTPatchLinear(nn.Linear):
def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory): def __init__(
super().__init__(3 * (patch_size**2), embed_dim, bias=bias, **factory) self,
patch_size: int,
embed_dim: int,
bias: bool = False,
temporal_patch_size: int = 1,
**factory,
):
super().__init__(
3 * temporal_patch_size * (patch_size**2), embed_dim, bias=bias, **factory
)
self.patch_size = patch_size self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
...@@ -560,6 +597,7 @@ class RadioInternVisionModel(nn.Module): ...@@ -560,6 +597,7 @@ class RadioInternVisionModel(nn.Module):
max_img_size = int( max_img_size = int(
round(config.cpe_max_size / config.patch_size) * config.patch_size round(config.cpe_max_size / config.patch_size) * config.patch_size
) )
self.temporal_patch_size = config.video_temporal_patch_size
unique_teachers = set(t["name"] for t in config.teachers) unique_teachers = set(t["name"] for t in config.teachers)
self.patch_generator = ViTPatchGenerator( self.patch_generator = ViTPatchGenerator(
config.patch_size, config.patch_size,
...@@ -569,6 +607,8 @@ class RadioInternVisionModel(nn.Module): ...@@ -569,6 +607,8 @@ class RadioInternVisionModel(nn.Module):
cls_token=True, cls_token=True,
num_cls_tokens=len(unique_teachers) if config.cls_token_per_teacher else 1, num_cls_tokens=len(unique_teachers) if config.cls_token_per_teacher else 1,
register_multiple=config.register_multiple, register_multiple=config.register_multiple,
temporal_patch_size=self.temporal_patch_size,
separate_video_embedder=config.separate_video_embedder,
) )
self.encoder = RadioVisionEncoder( self.encoder = RadioVisionEncoder(
...@@ -593,33 +633,68 @@ class RadioInternVisionModel(nn.Module): ...@@ -593,33 +633,68 @@ class RadioInternVisionModel(nn.Module):
def inter_image_mask_metadata( def inter_image_mask_metadata(
self, imgs_sizes: list[tuple[int, int]], device: torch.device self, imgs_sizes: list[tuple[int, int]], device: torch.device
) -> MaskMetadata: ) -> MaskMetadata:
"""Build mask metadata from image pixel sizes. Adds num_skip to each
sequence length (cls/register tokens) to match patch generator output."""
patch_size = self.patch_generator.patch_size patch_size = self.patch_generator.patch_size
num_skip = self.patch_generator.num_skip num_skip = self.patch_generator.num_skip
seq_lens = calc_seq_lens(imgs_sizes, patch_size) seq_lens = calc_seq_lens(imgs_sizes, patch_size)
adjusted = [s + num_skip for s in seq_lens] adjusted = [s + num_skip for s in seq_lens]
return self._inter_image_mask_metadata_from_seq_lens(adjusted, device=device)
def _inter_image_mask_metadata_from_seq_lens(
self, seq_lens: list[int], device: torch.device
) -> MaskMetadata:
"""Build mask metadata from actual sequence lengths (already including
cls/register tokens, i.e. patch_count + num_skip per item).
Use inter_image_mask_metadata() when you only have imgs_sizes."""
assert len(seq_lens) > 0
cu_seqlens = torch.tensor( cu_seqlens = torch.tensor(
list(accumulate(adjusted, initial=0)), dtype=torch.int32, device=device list(accumulate(seq_lens, initial=0)), dtype=torch.int32, device=device
) )
# Keep max_seqlen on CPU to avoid .item() sync # Keep max_seqlen on CPU to avoid .item() sync
# See: https://github.com/vllm-project/vllm/blob/20b6b01/vllm/v1/attention/ops/vit_attn_wrappers.py#L48 # See: https://github.com/vllm-project/vllm/blob/20b6b01/vllm/v1/attention/ops/vit_attn_wrappers.py#L48
max_seqlen = torch.tensor(max(adjusted), dtype=torch.int32) max_seqlen = torch.tensor(max(seq_lens), dtype=torch.int32)
return MaskMetadata(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) return MaskMetadata(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
imgs_sizes: list[tuple[int, int]] | None = None, imgs_sizes: list[tuple[int, int]] | None = None,
num_frames: int | None = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes) T = self.temporal_patch_size
# Build packed-sequence metadata for MMEncoderAttention when needed.
mask_meta = None mask_meta = None
if imgs_sizes is not None: packed_batch_size = None # Original batch size before packing
assert len(imgs_sizes) > 0
# Dynamic resolution: process each image as an independent sequence. if num_frames is not None and T > 1:
mask_meta = self.inter_image_mask_metadata( # Conv3d video: all tubelets have the same sequence length.
imgs_sizes, device=hidden_states.device # Pack [num_tubelets, seq_per_tubelet, hidden] → [1, total, hidden]
hidden_states = self.patch_generator.forward_video(x)
packed_batch_size, seq_per_tubelet, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape(1, -1, hidden_dim)
mask_meta = self._inter_image_mask_metadata_from_seq_lens(
[seq_per_tubelet] * packed_batch_size, device=hidden_states.device
) )
else:
# Images for any model, or video for non-conv3d model
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
if imgs_sizes is not None and len(imgs_sizes) > 1:
# Dynamic resolution w/ > 1 image, create attn mask
mask_meta = self.inter_image_mask_metadata(
imgs_sizes, device=hidden_states.device
)
encoder_outputs = self.encoder(inputs_embeds=hidden_states, mask_meta=mask_meta) encoder_outputs = self.encoder(inputs_embeds=hidden_states, mask_meta=mask_meta)
# Unpack back to original batch shape if we packed for video
if packed_batch_size is not None:
encoder_outputs = encoder_outputs.reshape(
packed_batch_size, seq_per_tubelet, -1
)
return encoder_outputs return encoder_outputs
...@@ -663,8 +738,13 @@ class RadioModel(nn.Module): ...@@ -663,8 +738,13 @@ class RadioModel(nn.Module):
pixel_embeds: torch.Tensor | None = None, pixel_embeds: torch.Tensor | None = None,
*, *,
imgs_sizes: list[tuple[int, int]] | None = None, imgs_sizes: list[tuple[int, int]] | None = None,
num_frames: int | None = None,
) -> tuple[torch.FloatTensor, torch.FloatTensor]: ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
y = self.model(pixel_values, imgs_sizes=imgs_sizes) y = self.model(
pixel_values,
imgs_sizes=imgs_sizes,
num_frames=num_frames,
)
return self._extract_final(y, imgs_sizes=imgs_sizes) return self._extract_final(y, imgs_sizes=imgs_sizes)
def load_weights(self, weights) -> set[str]: def load_weights(self, weights) -> set[str]:
...@@ -714,6 +794,9 @@ class RadioModel(nn.Module): ...@@ -714,6 +794,9 @@ class RadioModel(nn.Module):
weight_loader(param, weight) weight_loader(param, weight)
loaded_params.add(vllm_key) loaded_params.add(vllm_key)
if "model.patch_generator.video_embedder.weight" in loaded_params:
self.model.patch_generator._video_embedder_loaded = True
return loaded_params return loaded_params
def _extract_final( def _extract_final(
......
...@@ -47,6 +47,14 @@ class RadioConfig(PretrainedConfig): ...@@ -47,6 +47,14 @@ class RadioConfig(PretrainedConfig):
teachers: A list of teacher model configurations. Each teacher configuration is teachers: A list of teacher model configurations. Each teacher configuration is
a dict with keys like "name" and some may have "use_summary". a dict with keys like "name" and some may have "use_summary".
cls_token_per_teacher: Whether to use a separate CLS token for each teacher. cls_token_per_teacher: Whether to use a separate CLS token for each teacher.
video_temporal_patch_size: Number of consecutive video frames grouped into
a single tubelet for temporal compression. Default 1 (no compression).
When > 1, a dedicated video_embedder (3*T*P*P -> hidden) is created
alongside the image embedder (3*P*P -> hidden).
separate_video_embedder: When True and video_temporal_patch_size > 1, use a
dedicated video patch embedder (3*T*P*P -> hidden) separate from the
image embedder (3*P*P -> hidden). When False, a single embedder with
input size 3*T*P*P is used for both (images are duplicated T times).
""" """
model_type = "radio" model_type = "radio"
...@@ -68,6 +76,8 @@ class RadioConfig(PretrainedConfig): ...@@ -68,6 +76,8 @@ class RadioConfig(PretrainedConfig):
register_multiple: int | None = None, register_multiple: int | None = None,
teachers: list[dict[str, Any]] | None = None, teachers: list[dict[str, Any]] | None = None,
cls_token_per_teacher: bool = False, cls_token_per_teacher: bool = False,
video_temporal_patch_size: int = 1,
separate_video_embedder: bool = True,
**kwargs, **kwargs,
): ):
self.model_name = model_name self.model_name = model_name
...@@ -95,4 +105,6 @@ class RadioConfig(PretrainedConfig): ...@@ -95,4 +105,6 @@ class RadioConfig(PretrainedConfig):
self.register_multiple = register_multiple self.register_multiple = register_multiple
self.teachers = teachers if teachers is not None else [] self.teachers = teachers if teachers is not None else []
self.cls_token_per_teacher = cls_token_per_teacher self.cls_token_per_teacher = cls_token_per_teacher
self.video_temporal_patch_size = video_temporal_patch_size
self.separate_video_embedder = separate_video_embedder
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -11,6 +11,7 @@ import math ...@@ -11,6 +11,7 @@ import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property
from typing import Any, TypeVar from typing import Any, TypeVar
import einops import einops
...@@ -43,6 +44,12 @@ AUDIO_CONTEXT = "<so_embedding>" ...@@ -43,6 +44,12 @@ AUDIO_CONTEXT = "<so_embedding>"
# MAX_FRAMES = 16 # MAX_FRAMES = 16
DEFAULT_NUM_TILES = 12 DEFAULT_NUM_TILES = 12
# Configure PIL to handle large images without warnings
# This prevents DecompressionBombWarning for legitimate large images
Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
# Alternative: Set a specific higher limit
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
def calculate_timestamps( def calculate_timestamps(
indices: list[int] | torch.Tensor, indices: list[int] | torch.Tensor,
...@@ -138,19 +145,110 @@ def image_to_pixel_values( ...@@ -138,19 +145,110 @@ def image_to_pixel_values(
return pixel_values return pixel_values
def _compute_aspect_preserving_size(
orig_w: int,
orig_h: int,
target_num_patches: int,
patch_size: int,
downsample_ratio: float,
) -> tuple[int, int]:
"""Compute target pixel dimensions that preserve aspect ratio.
Mirrors Megatron-LM image_processing.py video frame resizing:
target area in patch-grid space is *target_num_patches*, distributed
according to the source aspect ratio, then snapped to a multiple of
the required divisor (2 for pixel-shuffle).
"""
aspect_wh = orig_w / max(orig_h, 1)
ph = round(math.sqrt(target_num_patches / aspect_wh))
pw = round(math.sqrt(target_num_patches * aspect_wh))
ph = max(ph, 1)
pw = max(pw, 1)
reduction_factor = int(round(1 / downsample_ratio))
required_divisor = reduction_factor # 2 for pixel-shuffle
if required_divisor > 1:
rem_h = ph % required_divisor
rem_w = pw % required_divisor
ph_up = ph + (required_divisor - rem_h if rem_h else 0)
ph_down = ph - rem_h
pw_up = pw + (required_divisor - rem_w if rem_w else 0)
pw_down = pw - rem_w
if ph_up * pw_up <= target_num_patches:
ph, pw = ph_up, pw_up
else:
ph = max(required_divisor, ph_down)
pw = max(required_divisor, pw_down)
return pw * patch_size, ph * patch_size # (width, height) in pixels
def get_video_target_size_and_feature_size(
orig_w: int,
orig_h: int,
target_patches: int,
maintain_aspect_ratio: bool,
patch_size: int,
downsample_ratio: float,
) -> tuple[int, int, int]:
"""Compute target (width, height) and feature_size for video resize and token count.
Used by video_to_pixel_values (resize) and get_video_replacement_internvl
(seq length calc) so both use the same dimensions.
"""
if maintain_aspect_ratio:
target_w, target_h = _compute_aspect_preserving_size(
orig_w=orig_w,
orig_h=orig_h,
target_num_patches=target_patches,
patch_size=patch_size,
downsample_ratio=downsample_ratio,
)
else:
reduction_factor = int(round(1 / downsample_ratio))
side = int(math.sqrt(target_patches))
side = max(reduction_factor, (side // reduction_factor) * reduction_factor)
target_w = side * patch_size
target_h = side * patch_size
feature_size = int((target_h // patch_size) * downsample_ratio) * int(
(target_w // patch_size) * downsample_ratio
)
return target_w, target_h, feature_size
def video_to_pixel_values( def video_to_pixel_values(
video: npt.NDArray, video: npt.NDArray,
*, *,
input_size: int, input_size: int,
max_num_tiles: int = 1, video_target_num_patches: int | None = None,
use_thumbnail: bool, video_maintain_aspect_ratio: bool = False,
patch_size: int = 16,
downsample_ratio: float = 0.5,
) -> torch.Tensor: ) -> torch.Tensor:
assert max_num_tiles == 1, "Video modality always uses one tile"
# (num_frames, H, W, C) -> (num_frames, C, H, W) # (num_frames, H, W, C) -> (num_frames, C, H, W)
video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2) video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2)
if video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size: if video_target_num_patches is not None:
# Resize to target patch count (aspect-preserving or square).
orig_h, orig_w = video_tensor.shape[2], video_tensor.shape[3]
target_w, target_h, _ = get_video_target_size_and_feature_size(
orig_w=orig_w,
orig_h=orig_h,
target_patches=video_target_num_patches,
maintain_aspect_ratio=video_maintain_aspect_ratio,
patch_size=patch_size,
downsample_ratio=downsample_ratio,
)
if video_tensor.shape[2] != target_h or video_tensor.shape[3] != target_w:
video_tensor = torch.nn.functional.interpolate(
video_tensor,
size=(target_h, target_w),
mode="bicubic",
align_corners=False,
antialias=True,
)
elif video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size:
video_tensor = torch.nn.functional.interpolate( video_tensor = torch.nn.functional.interpolate(
video_tensor, video_tensor,
size=(input_size, input_size), size=(input_size, input_size),
...@@ -645,9 +743,9 @@ class BaseNanoNemotronVLProcessor(ABC): ...@@ -645,9 +743,9 @@ class BaseNanoNemotronVLProcessor(ABC):
"which should be a single string" "which should be a single string"
) )
parts = [x for x in re.split(r"(<image>)", text[0]) if x] parts = [x for x in re.split(r"(<image>)", text[0]) if x]
assert parts.count("<image>") == len(pixel_values_lst), ( assert parts.count("<image>") == len(num_tokens_per_image), (
"the number of <image> tokens in the text should be the " f"Expected {len(num_tokens_per_image)} <image> tokens in text "
"same as the number of images" f"but found {parts.count('<image>')}"
) )
for i, (feature_size, num_patches) in enumerate( for i, (feature_size, num_patches) in enumerate(
...@@ -706,6 +804,33 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -706,6 +804,33 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
self.video_token = video_token self.video_token = video_token
self.video_pruning_rate = video_pruning_rate self.video_pruning_rate = video_pruning_rate
# Video params live exclusively in vision_config
vision_config = getattr(config, "vision_config", config)
self.video_temporal_patch_size: int = getattr(
vision_config, "video_temporal_patch_size", 1
)
self.video_maintain_aspect_ratio: bool = getattr(
vision_config, "video_maintain_aspect_ratio", False
)
# Resolve video frame target size: exactly one of video_target_num_patches
# or video_target_img_size may be set (mirrors Megatron's
# DynamicResolutionImageTilingStrategy validation).
target_num_patches = getattr(vision_config, "video_target_num_patches", None)
target_img_size = getattr(vision_config, "video_target_img_size", None)
if target_num_patches is not None and target_img_size is not None:
raise ValueError(
"Exactly one of video_target_num_patches or "
"video_target_img_size must be set, got both"
)
if target_num_patches is not None:
self.video_target_num_patches: int | None = target_num_patches
elif target_img_size is not None:
base_patches = math.ceil(target_img_size / config.patch_size)
self.video_target_num_patches = base_patches * base_patches
else:
self.video_target_num_patches = None
self.audio_extractor: ParakeetExtractor | None = None self.audio_extractor: ParakeetExtractor | None = None
raw_sound_config = getattr(config, "sound_config", None) raw_sound_config = getattr(config, "sound_config", None)
if raw_sound_config is not None: if raw_sound_config is not None:
...@@ -721,6 +846,27 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -721,6 +846,27 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
IMG_CONTEXT, add_special_tokens=False IMG_CONTEXT, add_special_tokens=False
) )
@cached_property
def num_video_token(self) -> int:
"""Token count per video frame, accounting for video_target_num_patches.
When video_target_num_patches is set the per-frame feature count
differs from the image-based num_image_token. We use a square
dummy (1:1) to compute the feature_size because the dummy video is
square and the user confirmed that is acceptable.
"""
if self.video_target_num_patches is not None:
_, _, feature_size = get_video_target_size_and_feature_size(
orig_w=self.image_size,
orig_h=self.image_size,
target_patches=self.video_target_num_patches,
maintain_aspect_ratio=self.video_maintain_aspect_ratio,
patch_size=self.config.patch_size,
downsample_ratio=self.config.downsample_ratio,
)
return feature_size
return self.num_image_token
@property @property
def supports_video(self) -> bool: def supports_video(self) -> bool:
return self.video_token_id is not None return self.video_token_id is not None
...@@ -738,14 +884,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -738,14 +884,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
def _videos_to_pixel_values_lst( def _videos_to_pixel_values_lst(
self, self,
videos: list[npt.NDArray], videos: list[npt.NDArray],
max_num_tiles: int,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
return [ return [
video_to_pixel_values( video_to_pixel_values(
video, video,
input_size=self.image_size, input_size=self.image_size,
max_num_tiles=max_num_tiles, video_target_num_patches=self.video_target_num_patches,
use_thumbnail=self.use_thumbnail, video_maintain_aspect_ratio=self.video_maintain_aspect_ratio,
patch_size=self.config.patch_size,
downsample_ratio=self.config.downsample_ratio,
) )
for video in videos for video in videos
] ]
...@@ -754,7 +901,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -754,7 +901,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
self, self,
text: list[str], text: list[str],
videos: list[tuple[npt.NDArray, dict[str, Any]]], videos: list[tuple[npt.NDArray, dict[str, Any]]],
max_num_tiles: int,
) -> tuple[list[str], dict[str, Any]]: ) -> tuple[list[str], dict[str, Any]]:
if len(videos) == 0 or not self.supports_video: if len(videos) == 0 or not self.supports_video:
return text, {} return text, {}
...@@ -763,7 +909,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -763,7 +909,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
video_metadata_lst = [v[1] for v in videos] video_metadata_lst = [v[1] for v in videos]
pixel_values_lst_video = self._videos_to_pixel_values_lst( pixel_values_lst_video = self._videos_to_pixel_values_lst(
videos_lst, videos_lst,
max_num_tiles=max_num_tiles,
) )
# We use frame duration in milliseconds (as integer) to ensure # We use frame duration in milliseconds (as integer) to ensure
...@@ -788,12 +933,10 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -788,12 +933,10 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
"frame_duration_ms": torch.tensor(frame_duration_ms_lst), "frame_duration_ms": torch.tensor(frame_duration_ms_lst),
} }
image_size: int = self.config.force_image_size
patch_size: int = self.config.patch_size patch_size: int = self.config.patch_size
downsample_ratio = self.config.downsample_ratio downsample_ratio = self.config.downsample_ratio
tokens_in_single_frame = int(
(image_size * image_size // patch_size**2) * (downsample_ratio**2) T = self.video_temporal_patch_size
)
for pixel_values, video_metadata, frames_indices, frame_duration_ms in zip( for pixel_values, video_metadata, frames_indices, frame_duration_ms in zip(
pixel_values_lst_video, pixel_values_lst_video,
...@@ -802,23 +945,28 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -802,23 +945,28 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
frame_duration_ms_lst, frame_duration_ms_lst,
): ):
num_frames = pixel_values.shape[0] num_frames = pixel_values.shape[0]
frame_h, frame_w = pixel_values.shape[-2], pixel_values.shape[-1]
tokens_in_single_frame = int(
(frame_h * frame_w // patch_size**2) * (downsample_ratio**2)
)
num_tubelets = math.ceil(num_frames / T) if T > 1 else num_frames
if self.video_pruning_rate is not None and self.video_pruning_rate > 0.0: if self.video_pruning_rate is not None and self.video_pruning_rate > 0.0:
# Start of EVS-specific code # Start of EVS-specific code
num_tokens = compute_retained_tokens_count( num_tokens = compute_retained_tokens_count(
tokens_per_frame=tokens_in_single_frame, tokens_per_frame=tokens_in_single_frame,
num_frames=num_frames, num_frames=num_tubelets,
q=self.video_pruning_rate, q=self.video_pruning_rate,
) )
# Here we just need placeholders that won't actually be replaced - # Here we just need placeholders that won't actually be replaced -
# we just need to make sure the total number of tokens is correct # we just need to make sure the total number of tokens is correct
# assign all tokens to the first frame # assign all tokens to the first frame
tokens_per_frame = [num_tokens] + [0] * (num_frames - 1) tokens_per_frame = [num_tokens] + [0] * (num_tubelets - 1)
# End of EVS-specific code # End of EVS-specific code
else: else:
tokens_per_frame = [tokens_in_single_frame] * num_frames tokens_per_frame = [tokens_in_single_frame] * num_tubelets
video_repl = self.get_video_repl( video_repl = self.get_video_repl(
tokens_per_frame=tokens_per_frame, tokens_per_frame=tokens_per_frame,
...@@ -828,6 +976,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -828,6 +976,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
img_start_token_ids=self._img_start_token_ids, img_start_token_ids=self._img_start_token_ids,
img_end_token_ids=self._img_end_token_ids, img_end_token_ids=self._img_end_token_ids,
img_context_token_ids=self._img_context_token_ids, img_context_token_ids=self._img_context_token_ids,
video_temporal_patch_size=T,
) )
# video_repl.full is a list of token IDs # video_repl.full is a list of token IDs
...@@ -908,7 +1057,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -908,7 +1057,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
text, video_inputs = self._preprocess_video( text, video_inputs = self._preprocess_video(
text=text, text=text,
videos=videos, videos=videos,
max_num_tiles=1,
) )
text, audio_inputs = self._preprocess_audio( text, audio_inputs = self._preprocess_audio(
...@@ -962,6 +1110,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -962,6 +1110,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
img_start_token_ids: list[int], img_start_token_ids: list[int],
img_end_token_ids: list[int], img_end_token_ids: list[int],
img_context_token_ids: list[int], img_context_token_ids: list[int],
video_temporal_patch_size: int = 1,
) -> PromptUpdateDetails[list[int]]: ) -> PromptUpdateDetails[list[int]]:
""" """
Build prompt replacement for a video. Build prompt replacement for a video.
...@@ -981,31 +1130,60 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -981,31 +1130,60 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
- EVS real (called from get_real_video_repl_for_evs) - different value per frame - EVS real (called from get_real_video_repl_for_evs) - different value per frame
Args: Args:
tokens_per_frame (list[int]): number of tokens per frame tokens_per_frame (list[int]): number of tokens per frame
frames_indices (list[int]): frame indices (one per tubelet when T > 1)
frames_indices (list[int]): orig. frame indices
(one per frame, before tubelet subsampling)
frame_duration_ms (int): duration of each frame in milliseconds frame_duration_ms (int): duration of each frame in milliseconds
tokenizer (HfTokenizer): tokenizer to use for tokenizing frame separators tokenizer (TokenizerLike): tokenizer to use for tokenizing frame separators
img_start_token_ids (list[int]): pre-tokenized IMG_START tokens img_start_token_ids (list[int]): pre-tokenized IMG_START tokens
img_end_token_ids (list[int]): pre-tokenized IMG_END tokens img_end_token_ids (list[int]): pre-tokenized IMG_END tokens
img_context_token_ids (list[int]): pre-tokenized IMG_CONTEXT tokens img_context_token_ids (list[int]): pre-tokenized IMG_CONTEXT tokens
video_temporal_patch_size (int): temporal patch size for videos
""" """
# TODO: Add support of frame_duration_ms to be None # TODO: Add support of frame_duration_ms to be None
# At preprocessing step we should allow absent / metadata without # At preprocessing step we should allow absent / metadata without
# frames_indices field. # frames_indices field.
timestamps_enabled = frame_duration_ms is not None timestamps_enabled = frame_duration_ms is not None
T = video_temporal_patch_size
if timestamps_enabled: num_frames = len(frames_indices)
if T > 1 and timestamps_enabled:
all_timestamps = calculate_timestamps(frames_indices, frame_duration_ms)
frame_separators = []
for group_idx, i in enumerate(range(0, num_frames, T)):
group_frames = []
for j in range(T): # Every frame in the group
frame_idx = i + j
if frame_idx < num_frames:
# Valid idx (haven't padded to mult. of T yet)
ts = all_timestamps[frame_idx]
frame_str = "Frame" if j == 0 else "frame"
group_frames.append(
f"{frame_str} {frame_idx + 1} sampled at {ts:.2f} seconds"
)
if group_frames:
# Join by `and` if there are >1 frame, otherwise no `and`
# Prepend \n to match training format (except first group)
sep = " and ".join(group_frames) + ": "
if group_idx > 0:
sep = "\n" + sep
frame_separators.append(sep)
elif timestamps_enabled:
timestamps = calculate_timestamps(frames_indices, frame_duration_ms) timestamps = calculate_timestamps(frames_indices, frame_duration_ms)
assert len(timestamps) == len(tokens_per_frame), ( assert len(timestamps) == len(tokens_per_frame), (
"timestamps and tokens_per_frame must have the same length" "timestamps and tokens_per_frame must have the same length"
) )
frame_separators = [ frame_separators = [
f"Frame {i + 1} sampled at {timestamp:.2f} seconds: " ("\n" if i > 0 else "")
+ f"Frame {i + 1} sampled at {timestamp:.2f} seconds: "
for i, timestamp in enumerate(timestamps) for i, timestamp in enumerate(timestamps)
] ]
else: else:
frame_separators = [ frame_separators = [
f"Frame {i + 1}: " for i, _ in enumerate(tokens_per_frame) ("\n" if i > 0 else "") + f"Frame {i + 1}: "
for i, _ in enumerate(tokens_per_frame)
] ]
# Tokenize frame separator independently # Tokenize frame separator independently
......
...@@ -420,8 +420,9 @@ class GPUModelRunner( ...@@ -420,8 +420,9 @@ class GPUModelRunner(
self.is_multimodal_raw_input_only_model = ( self.is_multimodal_raw_input_only_model = (
model_config.is_multimodal_raw_input_only_model model_config.is_multimodal_raw_input_only_model
) )
# This will be overridden in load_model() # These will be overridden in load_model()
self.is_multimodal_pruning_enabled = False self.is_multimodal_pruning_enabled = False
self.requires_sequential_video_encoding = False
# Set to True after init_routed_experts_capturer() completes. # Set to True after init_routed_experts_capturer() completes.
# Prevents routed experts code from running during profiling/dummy run. # Prevents routed experts code from running during profiling/dummy run.
self.routed_experts_initialized = False self.routed_experts_initialized = False
...@@ -2625,17 +2626,23 @@ class GPUModelRunner( ...@@ -2625,17 +2626,23 @@ class GPUModelRunner(
): ):
batch_outputs: MultiModalEmbeddings batch_outputs: MultiModalEmbeddings
# EVS-related change. # EVS and dynamic res video related change.
# (ekhvedchenia): Temporary hack to limit peak memory usage when # (ekhvedchenia): Temporary hack to limit peak memory usage when
# processing multimodal data. This solves the issue with scheduler # processing multimodal data. This solves the issue with scheduler
# putting too many video samples into a single batch. Scheduler # putting too many video samples into a single batch. Scheduler
# uses pruned vision tokens count to compare it versus compute # uses pruned vision tokens count to compare it versus compute
# budget which is incorrect (Either input media size or non-pruned # budget which is incorrect (Either input media size or non-pruned
# output vision tokens count should be considered) # output vision tokens count should be considered)
# dynamic res video for nemotron temporarily uses this hack via
# requires_sequential_video_encoding
# because it doesn't yet support video batching.
# TODO(ywang96): Fix memory profiling to take EVS into account and # TODO(ywang96): Fix memory profiling to take EVS into account and
# remove this hack. # remove this hack.
if ( if (
self.is_multimodal_pruning_enabled (
self.is_multimodal_pruning_enabled
or self.requires_sequential_video_encoding
)
and modality == "video" and modality == "video"
and num_items > 1 and num_items > 1
): ):
...@@ -4609,6 +4616,9 @@ class GPUModelRunner( ...@@ -4609,6 +4616,9 @@ class GPUModelRunner(
and mm_config is not None and mm_config is not None
and mm_config.is_multimodal_pruning_enabled() and mm_config.is_multimodal_pruning_enabled()
) )
self.requires_sequential_video_encoding = hasattr(
self.get_model(), "requires_sequential_video_encoding"
) # Temporary hack for dynamic res video w/o support for bs>1 yet
if ( if (
is_mixture_of_experts(self.model) is_mixture_of_experts(self.model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment