"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "e6c6f2c79d2c1e0765d17b4dec83a4a8283342e4"
Unverified Commit cd3ac5b7 authored by Netanel Haber's avatar Netanel Haber Committed by GitHub
Browse files

support dynamic resolution image encoding for Nemotron Nano VL (#32121)


Signed-off-by: default avatarNetanel Haber <58652339+netanel-haber@users.noreply.github.com>
parent 2636d762
...@@ -282,12 +282,14 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -282,12 +282,14 @@ class InternVisionEncoderLayer(nn.Module):
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_cls: type[InternParallelAttention] = InternParallelAttention,
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type self.norm_type = config.norm_type
self.attn_cls = attn_cls
self.attn = self._init_attn( self.attn = self._init_attn(
config, config,
...@@ -327,7 +329,7 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -327,7 +329,7 @@ class InternVisionEncoderLayer(nn.Module):
use_data_parallel = ( use_data_parallel = (
use_data_parallel or (num_heads + num_dummy_heads) % tp_size != 0 use_data_parallel or (num_heads + num_dummy_heads) % tp_size != 0
) )
return InternParallelAttention( return self.attn_cls(
config, config,
quant_config=quant_config, quant_config=quant_config,
num_dummy_heads=num_dummy_heads, num_dummy_heads=num_dummy_heads,
...@@ -356,10 +358,12 @@ class InternVisionEncoder(nn.Module): ...@@ -356,10 +358,12 @@ class InternVisionEncoder(nn.Module):
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
layer_cls: type[InternVisionEncoderLayer] = InternVisionEncoderLayer,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer_cls = layer_cls
if num_hidden_layers_override is None: if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers num_hidden_layers = config.num_hidden_layers
...@@ -368,7 +372,7 @@ class InternVisionEncoder(nn.Module): ...@@ -368,7 +372,7 @@ class InternVisionEncoder(nn.Module):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
InternVisionEncoderLayer( self.layer_cls(
config, config,
quant_config, quant_config,
num_dummy_heads=num_dummy_heads, num_dummy_heads=num_dummy_heads,
......
...@@ -21,7 +21,11 @@ from transformers import PretrainedConfig ...@@ -21,7 +21,11 @@ from transformers import PretrainedConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.intern_vit import InternVisionEncoder from vllm.model_executor.models.intern_vit import (
InternParallelAttention,
InternVisionEncoder,
InternVisionEncoderLayer,
)
input_dim_t: TypeAlias = int | tuple[int, int] input_dim_t: TypeAlias = int | tuple[int, int]
norm_t: TypeAlias = tuple[float, float, float] | torch.Tensor norm_t: TypeAlias = tuple[float, float, float] | torch.Tensor
...@@ -43,6 +47,15 @@ to_4tuple = _ntuple(4) ...@@ -43,6 +47,15 @@ to_4tuple = _ntuple(4)
to_ntuple = _ntuple to_ntuple = _ntuple
def calc_seq_len(size: tuple[int, int], patch_size: int) -> int:
h, w = size
return (h // patch_size) * (w // patch_size)
def calc_seq_lens(sizes: list[tuple[int, int]], patch_size: int) -> list[int]:
return [calc_seq_len(size, patch_size) for size in sizes]
class ClsToken(nn.Module): class ClsToken(nn.Module):
def __init__( def __init__(
self, self,
...@@ -164,15 +177,73 @@ class ViTPatchGenerator(nn.Module): ...@@ -164,15 +177,73 @@ class ViTPatchGenerator(nn.Module):
nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity() nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(
patches = self.embed_patches(x) self, x: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:]) ) -> torch.Tensor:
patches = self.cls_token(patches) if imgs_sizes is not None:
patches = self.embedder(x)
patches, pos_enc = self.apply_pos_enc_dynamic(
patches, imgs_sizes=imgs_sizes
)
patches = self.cls_token_dynamic(patches, imgs_sizes=imgs_sizes)
else:
patches = self.embed_patches(x)
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
patches = self.cls_token(patches)
patches = self.patch_normalizer(patches) patches = self.patch_normalizer(patches)
if self.return_pos_enc: if self.return_pos_enc:
return patches, pos_enc return patches, pos_enc
return patches return patches
def apply_pos_enc_dynamic(
self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
) -> tuple[torch.Tensor, torch.Tensor | None]:
if not self.abs_pos:
return patches, None
current_length = 0
pos_enc_list = []
for size in imgs_sizes:
seq_length = calc_seq_len(size, self.patch_size)
img_patches = patches[:, current_length : current_length + seq_length, :]
pos_enc = self.get_pos_enc(patches.shape[0], input_size=size)
img_patches_with_pos = img_patches + pos_enc
patches = torch.cat(
[
patches[:, :current_length, :],
img_patches_with_pos,
patches[:, current_length + seq_length :, :],
],
dim=1,
)
pos_enc_list.append(pos_enc)
current_length += seq_length
full_pos_enc = torch.cat(pos_enc_list, dim=1) if pos_enc_list else None
return patches, full_pos_enc
def cls_token_dynamic(
self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
) -> torch.Tensor:
if not self.cls_token.enabled:
return patches
out = []
current_length = 0
for seq_len in calc_seq_lens(imgs_sizes, self.patch_size):
class_token = self.cls_token.token.unsqueeze(0).expand(
patches.shape[0], -1, -1
)
out.append(class_token)
out.append(patches[:, current_length : current_length + seq_len, :])
current_length += seq_len
return torch.cat(out, dim=1)
@property @property
def apply_cls_token(self): def apply_cls_token(self):
return self.cls_token.enabled return self.cls_token.enabled
...@@ -406,6 +477,66 @@ class ViTPatchLinear(nn.Linear): ...@@ -406,6 +477,66 @@ class ViTPatchLinear(nn.Linear):
self.patch_size = patch_size self.patch_size = patch_size
class RadioParallelAttention(InternParallelAttention):
def forward(
self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
) -> torch.Tensor:
if attn_mask is None:
return super().forward(x)
B, N, _ = x.shape
qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
q, k, v = (t.transpose(1, 2) for t in (q, k, v))
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, scale=self.scale
)
out = out.transpose(1, 2).reshape(B, N, -1)
out, _ = self.proj(out)
return out
class RadioVisionEncoderLayer(InternVisionEncoderLayer):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, attn_cls=RadioParallelAttention, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
attn_mask: torch.Tensor | None = None,
):
hidden_states = (
hidden_states
+ self.attn(self.norm1(hidden_states), attn_mask=attn_mask) * self.ls1
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2
return hidden_states
class RadioVisionEncoder(InternVisionEncoder):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, layer_cls=RadioVisionEncoderLayer, **kwargs)
def forward(
self,
inputs_embeds: torch.Tensor,
attn_mask: torch.Tensor | None = None,
):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states, attn_mask=attn_mask)
return hidden_states
class RadioInternVisionModel(nn.Module): class RadioInternVisionModel(nn.Module):
packed_modules_mapping = { packed_modules_mapping = {
"qkv": ["qkv"], "qkv": ["qkv"],
...@@ -440,7 +571,7 @@ class RadioInternVisionModel(nn.Module): ...@@ -440,7 +571,7 @@ class RadioInternVisionModel(nn.Module):
register_multiple=config.register_multiple, register_multiple=config.register_multiple,
) )
self.encoder = InternVisionEncoder( self.encoder = RadioVisionEncoder(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
...@@ -459,10 +590,45 @@ class RadioInternVisionModel(nn.Module): ...@@ -459,10 +590,45 @@ class RadioInternVisionModel(nn.Module):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
def forward(self, x: torch.Tensor) -> torch.FloatTensor: def create_inter_image_attention_mask(
self, imgs_sizes: list[tuple[int, int]], device: torch.device
) -> torch.Tensor:
patch_size = self.patch_generator.patch_size
num_skip = self.patch_generator.num_skip
seq_lens = calc_seq_lens(imgs_sizes, patch_size)
patch_counts = [seq_len + num_skip for seq_len in seq_lens]
total_patches = sum(patch_counts)
# Create attention mask - default to False (mask out)
mask = torch.zeros(
total_patches, total_patches, dtype=torch.bool, device=device
)
# Each image's patches can only attend to patches from the same image
start_idx = 0
for patch_count in patch_counts:
end_idx = start_idx + patch_count
# Allow attention within this image's patches
mask[start_idx:end_idx, start_idx:end_idx] = True
start_idx = end_idx
return mask
def forward(
self,
x: torch.Tensor,
imgs_sizes: torch.Tensor | None = None,
) -> torch.FloatTensor:
assert self.patch_generator is not None assert self.patch_generator is not None
hidden_states = self.patch_generator(x) hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
encoder_outputs = self.encoder(inputs_embeds=hidden_states) attn_mask = None
if imgs_sizes is not None and len(imgs_sizes) > 1:
# Dynamic Resolution
attn_mask = self.create_inter_image_attention_mask(
imgs_sizes, device=x.device
)
encoder_outputs = self.encoder(inputs_embeds=hidden_states, attn_mask=attn_mask)
return encoder_outputs return encoder_outputs
...@@ -504,9 +670,11 @@ class RadioModel(nn.Module): ...@@ -504,9 +670,11 @@ class RadioModel(nn.Module):
self, self,
pixel_values: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None,
pixel_embeds: torch.Tensor | None = None, pixel_embeds: torch.Tensor | None = None,
*,
imgs_sizes: torch.Tensor | None = None,
) -> tuple[torch.FloatTensor, torch.FloatTensor]: ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
y = self.model(pixel_values) y = self.model(pixel_values, imgs_sizes=imgs_sizes)
return self._extract_final(y) return self._extract_final(y, imgs_sizes=imgs_sizes)
def load_weights(self, weights) -> set[str]: def load_weights(self, weights) -> set[str]:
loaded_params: set[str] = set() loaded_params: set[str] = set()
...@@ -558,16 +726,32 @@ class RadioModel(nn.Module): ...@@ -558,16 +726,32 @@ class RadioModel(nn.Module):
return loaded_params return loaded_params
def _extract_final( def _extract_final(
self, y: torch.Tensor self, y: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
) -> tuple[torch.FloatTensor, torch.FloatTensor]: ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
# Remove CLS + REGISTERS tokens # Remove CLS + REGISTERS tokens
patch_gen = getattr(self.model, "patch_generator", None) num_skip = self.model.patch_generator.num_skip
if patch_gen is not None: patch_size = self.model.patch_generator.patch_size
all_summary = y[:, : patch_gen.num_cls_tokens] num_cls_tokens = self.model.patch_generator.num_cls_tokens
if self.summary_idxs is not None: if imgs_sizes is None:
bb_summary = all_summary[:, self.summary_idxs] all_summary = y[:, :num_cls_tokens]
else: all_feat = y[:, num_skip:]
bb_summary = all_summary else:
all_feat = y[:, patch_gen.num_skip :] all_patches = []
summaries = []
current_pos = 0
for num_patches in calc_seq_lens(imgs_sizes, patch_size):
patches = y[
:, current_pos + num_skip : current_pos + num_skip + num_patches, :
]
all_patches.append(patches)
summary = y[:, current_pos : current_pos + num_cls_tokens, :]
summaries.append(summary)
current_pos += num_skip + num_patches
all_summary = torch.cat(summaries, dim=1)
all_feat = torch.cat(all_patches, dim=1)
if self.summary_idxs is not None:
bb_summary = all_summary[:, self.summary_idxs]
else:
bb_summary = all_summary
return bb_summary.flatten(1), all_feat return bb_summary.flatten(1), all_feat
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment