Unverified Commit a8023891 authored by Zijian Zhang's avatar Zijian Zhang Committed by GitHub
Browse files

model: support NVILA and NVILA Lite (#10399)

parent 0103f374
...@@ -44,6 +44,7 @@ in the GitHub search bar. ...@@ -44,6 +44,7 @@ in the GitHub search bar.
| **GLM-4.5V** (106B) / **GLM-4.1V**(9B) | `zai-org/GLM-4.5V` | GLM-4.5V and GLM-4.1V-Thinking: Towards Versatile Multimodal Reasoning with Scalable Reinforcement Learning | Use `--chat-template glm-4v` | | **GLM-4.5V** (106B) / **GLM-4.1V**(9B) | `zai-org/GLM-4.5V` | GLM-4.5V and GLM-4.1V-Thinking: Towards Versatile Multimodal Reasoning with Scalable Reinforcement Learning | Use `--chat-template glm-4v` |
| **DotsVLM** (General/OCR) | `rednote-hilab/dots.vlm1.inst` | RedNote's vision-language model built on a 1.2B vision encoder and DeepSeek V3 LLM, featuring NaViT vision encoder trained from scratch with dynamic resolution support and enhanced OCR capabilities through structured image data training. | | | **DotsVLM** (General/OCR) | `rednote-hilab/dots.vlm1.inst` | RedNote's vision-language model built on a 1.2B vision encoder and DeepSeek V3 LLM, featuring NaViT vision encoder trained from scratch with dynamic resolution support and enhanced OCR capabilities through structured image data training. | |
| **DotsVLM-OCR** | `rednote-hilab/dots.ocr` | Specialized OCR variant of DotsVLM optimized for optical character recognition tasks with enhanced text extraction and document understanding capabilities. | Don't use `--trust-remote-code` | | **DotsVLM-OCR** | `rednote-hilab/dots.ocr` | Specialized OCR variant of DotsVLM optimized for optical character recognition tasks with enhanced text extraction and document understanding capabilities. | Don't use `--trust-remote-code` |
| **NVILA** (8B, 15B, Lite-2B, Lite-8B, Lite-15B) | `Efficient-Large-Model/NVILA-8B` | `chatml` | NVILA explores the full stack efficiency of multi-modal design, achieving cheaper training, faster deployment and better performance. |
## Usage Notes ## Usage Notes
......
...@@ -914,12 +914,13 @@ multimodal_model_archs = [ ...@@ -914,12 +914,13 @@ multimodal_model_archs = [
"InternVLChatModel", "InternVLChatModel",
"InternS1ForConditionalGeneration", "InternS1ForConditionalGeneration",
"Phi4MMForCausalLM", "Phi4MMForCausalLM",
"VILAForConditionalGeneration",
"Step3VLForConditionalGeneration", "Step3VLForConditionalGeneration",
"POINTSV15ChatModel", "POINTSV15ChatModel",
"DotsVLMForCausalLM", "DotsVLMForCausalLM",
"DotsOCRForCausalLM", "DotsOCRForCausalLM",
"Sarashina2VisionForCausalLM", "Sarashina2VisionForCausalLM",
"NVILAForConditionalGeneration",
"NVILALiteForConditionalGeneration",
"DeepseekOCRForCausalLM", "DeepseekOCRForCausalLM",
] ]
......
import itertools
import math
from collections.abc import Iterable
from typing import Any
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
import sglang.srt.managers.mm_utils as mm_utils
import sglang.srt.model_loader.weight_utils as weight_utils
import sglang.srt.utils as utils
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
MM_HIDDEN_SIZE = 3456
class NVILAConfig(PretrainedConfig):
model_type = "nvila"
sub_configs = {
"text_config": Qwen2Config,
"vision_config": SiglipVisionConfig,
}
_auto_class = "AutoConfig"
def __init__(
self,
*,
text_config: dict[str, Any] | None = None,
vision_config: dict[str, Any] | None = None,
image_token_id: int | None = None,
video_token_id: int | None = None,
**kwargs,
):
self.text_config = (
Qwen2Config(**text_config) if text_config is not None else Qwen2Config()
)
self.vision_config = (
SiglipVisionConfig(**vision_config)
if vision_config is not None
else SiglipVisionConfig()
)
self.image_token_id = image_token_id if image_token_id is not None else -1
self.video_token_id = video_token_id if video_token_id is not None else -1
super().__init__(**kwargs)
class NVILAMultiModalProjectorDownsampleBlock(nn.Module):
def forward(self, x: Tensor) -> Tensor:
batch_size, sequence_length, hidden_size = x.shape
feat_size = math.isqrt(sequence_length)
features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
pad_after = feat_size % 2
if pad_after > 0:
features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
feat_size = feat_size + pad_after
features = features.reshape(
batch_size, feat_size // 2, 2, feat_size // 2, 2, hidden_size
)
features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
features = features.reshape(batch_size, -1, 4 * hidden_size)
return features
class NVILAMultiModalProjector(nn.Module):
def __init__(self, config: NVILAConfig):
super().__init__()
self.layers = nn.Sequential(
NVILAMultiModalProjectorDownsampleBlock(),
nn.LayerNorm(MM_HIDDEN_SIZE * 4),
nn.Linear(MM_HIDDEN_SIZE * 4, config.text_config.hidden_size),
nn.GELU(),
nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size),
)
def forward(self, x: Tensor) -> Tensor:
return self.layers(x)
class NVILAForConditionalGeneration(nn.Module):
def __init__(
self,
config: NVILAConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.vision_tower = SiglipVisionModel(config.vision_config)
self.mm_projector = NVILAMultiModalProjector(config)
self.llm = Qwen2ForCausalLM(
config=config.text_config,
quant_config=quant_config,
prefix=utils.add_prefix("llm", prefix),
)
def forward(
self,
input_ids: Tensor,
positions: Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
) -> LogitsProcessorOutput:
output = mm_utils.general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.llm,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
Modality.VIDEO: self.get_image_feature,
},
get_embedding=get_embedding,
positions=positions,
)
assert isinstance(output, LogitsProcessorOutput)
return output
def get_image_feature(self, mm_input: list[MultimodalDataItem]) -> Tensor:
block_sizes = (
list(
itertools.chain.from_iterable(
x.block_sizes for x in mm_input if hasattr(x, "block_sizes")
)
)
or None
)
pixel_values = torch.cat([torch.tensor(x.feature) for x in mm_input], dim=0)
vision_tower_output: BaseModelOutputWithPooling = self.vision_tower(
pixel_values.to(
device=self.vision_tower.device, dtype=self.vision_tower.dtype
),
output_hidden_states=True,
)
assert vision_tower_output.hidden_states is not None
vision_features: Tensor = vision_tower_output.hidden_states[-2]
vision_features_list, block_sizes = merge_features_for_dynamic_s2(
vision_features,
block_sizes=(
block_sizes
if block_sizes is not None
else [None] * vision_features.shape[0]
),
resize_output_to_scale_idx=-1,
scales=[448, 896, 1344],
)
vision_features_list = [
split_chessboard(x, block_size[0], block_size[1])
for x, block_size in zip(vision_features_list, block_sizes)
]
vision_features = torch.cat(
[einops.rearrange(x, "b c h w -> b (h w) c") for x in vision_features_list]
)
vision_features = self.mm_projector(vision_features)
vision_features_list = list(
vision_features.split(
[block_size[0] * block_size[1] for block_size in block_sizes], dim=0
)
)
vision_features_list = [
merge_chessboard(x, block_size[0], block_size[1])
for x, block_size in zip(vision_features_list, block_sizes)
]
vision_features = torch.stack(
[einops.rearrange(x, "1 c h w -> (h w) c") for x in vision_features_list]
)
vision_features = einops.rearrange(vision_features, "n p d -> (n p) d")
return vision_features
def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> None:
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if name.startswith("llm."):
self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
else:
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", weight_utils.default_weight_loader
)
weight_loader(param, loaded_weight)
def pad_input_ids(
self, input_ids: list[int], mm_inputs: MultimodalInputs
) -> list[int]:
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return pattern.pad_input_tokens(input_ids, mm_inputs)
def merge_chessboard(x, num_split_h, num_split_w):
"""
x: b * n * c or b * h * w * c
out: b * c * h * w
Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square.
"""
B = x.shape[0]
if x.dim() == 3:
N = x.shape[1]
x = einops.rearrange(
x, "b (h w) c -> b c h w", h=math.isqrt(N), w=math.isqrt(N)
)
assert B % (num_split_h * num_split_w) == 0
b = B // (num_split_h * num_split_w)
x_merge = torch.cat(
[
torch.cat(
[
x[(i * num_split_w + j) * b : (i * num_split_w + j + 1) * b]
for j in range(num_split_w)
],
dim=-1,
)
for i in range(num_split_h)
],
dim=-2,
)
return x_merge
def merge_features_for_dynamic_s2(
image_features, block_sizes, *, scales, resize_output_to_scale_idx
):
image_features_each_image = []
new_block_sizes = []
block_cnt = 0
for block_size_each_image in block_sizes:
if block_size_each_image is None:
cur_features = image_features[block_cnt : block_cnt + 1]
cur_features = einops.rearrange(
cur_features,
"1 (h w) c -> 1 c h w",
h=math.isqrt(cur_features.shape[1]),
)
cur_features = cur_features.repeat(1, len(scales), 1, 1)
image_features_each_image.append(cur_features)
new_block_sizes.append((1, 1))
block_cnt += 1
else:
cur_features_each_scale = []
for scale in scales[:-1]:
num_blocks_this_scale = (scale // scales[0]) ** 2
cur_features_each_scale.append(
merge_chessboard(
image_features[block_cnt : block_cnt + num_blocks_this_scale],
num_split_h=scale // scales[0],
num_split_w=scale // scales[0],
)
) # 1 * C * H * W
block_cnt += num_blocks_this_scale
num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
cur_features_each_scale.append(
merge_chessboard(
image_features[block_cnt : block_cnt + num_blocks_last_scale],
num_split_h=block_size_each_image[0],
num_split_w=block_size_each_image[1],
)
) # 1 * C * H * W
block_cnt += num_blocks_last_scale
# resize and concat features from different scales
output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
cur_features = torch.cat(
[
F.interpolate(
cur_features_each_scale[i].to(torch.float32),
size=output_size,
mode="area",
).to(cur_features_each_scale[i].dtype)
for i in range(len(cur_features_each_scale))
],
dim=1,
)
image_features_each_image.append(cur_features)
if (
resize_output_to_scale_idx == len(scales) - 1
or resize_output_to_scale_idx == -1
):
new_block_sizes.append(block_size_each_image)
else:
new_block_sizes.append(
(
scales[resize_output_to_scale_idx] // scales[0],
scales[resize_output_to_scale_idx] // scales[0],
)
)
assert block_cnt == len(
image_features
), f"The number of blocks ({block_cnt}) does not match length of image_features ({len(image_features)})!"
return image_features_each_image, new_block_sizes
def split_chessboard(x, num_split_h, num_split_w):
"""
x: b * c * h * w
out: b * c * h * w
Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension
"""
B, C, H, W = x.shape
assert H % num_split_h == 0 and W % num_split_w == 0
h, w = H // num_split_h, W // num_split_w
x_split = torch.cat(
[
x[:, :, i * h : (i + 1) * h, j * w : (j + 1) * w]
for i in range(num_split_h)
for j in range(num_split_w)
],
dim=0,
)
return x_split
EntryClass = [NVILAForConditionalGeneration]
import logging import math
from typing import Any, Dict, Iterable, List, Optional, Tuple, cast from collections.abc import Iterable
from typing import Any
import einops
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -13,8 +15,7 @@ from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel ...@@ -13,8 +15,7 @@ from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
import sglang.srt.managers.mm_utils as mm_utils import sglang.srt.managers.mm_utils as mm_utils
import sglang.srt.model_loader.weight_utils as weight_utils import sglang.srt.model_loader.weight_utils as weight_utils
import sglang.srt.utils as utils import sglang.srt.utils as utils
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
...@@ -25,88 +26,46 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -25,88 +26,46 @@ from sglang.srt.managers.schedule_batch import (
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
logger = logging.getLogger(__name__) MM_HIDDEN_SIZE = 1152
##### BEGIN COPY configuration.py ##### class NVILALiteConfig(PretrainedConfig):
model_type = "nvila_lite"
sub_configs = {
class VILAConfig(PretrainedConfig): "text_config": Qwen2Config,
# Class attributes. "vision_config": SiglipVisionConfig,
model_type: str = "vila"
sub_configs: Dict[str, PretrainedConfig] = {
"text_config": Qwen2Config(),
"vision_config": SiglipVisionConfig(),
} }
_auto_class: Optional[str] = "AutoConfig" _auto_class = "AutoConfig"
# Configuration for sub-modules.
text_config: Qwen2Config = Qwen2Config()
vision_config: SiglipVisionConfig = SiglipVisionConfig()
# Model configuration.
hidden_size: int
image_token_id: int
mm_hidden_size: int
mm_projector_type: str
mm_vision_select_feature: str
mm_vision_select_layer: int
video_token_id: int
def __init__( def __init__(
self, self,
text_config: Optional[Dict[str, Any]] = None,
vision_config: Optional[Dict[str, Any]] = None,
*, *,
hidden_size: int = 1536, text_config: dict[str, Any] | None = None,
image_token_id: int = 151649, vision_config: dict[str, Any] | None = None,
mm_hidden_size: int = 1152, image_token_id: int | None = None,
mm_projector_type: str = "mlp_downsample_3x3_fix", video_token_id: int | None = None,
mm_vision_select_feature: str = "cls_patch",
mm_vision_select_layer: int = -2,
video_token_id: int = 151650,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) self.text_config = (
Qwen2Config(**text_config) if text_config is not None else Qwen2Config()
self.text_config = Qwen2Config(**text_config) if text_config else Qwen2Config() )
self.vision_config = ( self.vision_config = (
SiglipVisionConfig(**vision_config) SiglipVisionConfig(**vision_config)
if vision_config if vision_config is not None
else SiglipVisionConfig() else SiglipVisionConfig()
) )
self.hidden_size = hidden_size self.image_token_id = image_token_id if image_token_id is not None else -1
self.image_token_id = image_token_id self.video_token_id = video_token_id if video_token_id is not None else -1
self.mm_hidden_size = mm_hidden_size
self.mm_projector_type = mm_projector_type
self.mm_vision_select_feature = mm_vision_select_feature
self.mm_vision_select_layer = mm_vision_select_layer
self.video_token_id = video_token_id
##### END COPY configuration.py ##### super().__init__(**kwargs)
##### BEGIN COPY modeling_vila.py #####
class DownSample3x3BlockFix(nn.Module): class NVILALiteMultiModalProjectorDownsampleBlock(nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
Returns:
The output tensor of shape (batch_size, image_pad_len, mm_hidden_size * 9).
"""
batch_size, sequence_length, hidden_size = x.shape batch_size, sequence_length, hidden_size = x.shape
feat_size = int(sequence_length**0.5) feat_size = math.isqrt(sequence_length)
if feat_size**2 != sequence_length:
raise ValueError(
f"Cannot take square root: sequence_length {sequence_length} is not a perfect square"
)
features = x.reshape(batch_size, feat_size, feat_size, hidden_size) features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
...@@ -124,97 +83,43 @@ class DownSample3x3BlockFix(nn.Module): ...@@ -124,97 +83,43 @@ class DownSample3x3BlockFix(nn.Module):
return features return features
class MultimodalProjector(nn.Module): class NVILALiteMultiModalProjector(nn.Module):
layers: nn.Sequential def __init__(self, config: NVILALiteConfig):
super().__init__()
def __init__( self.layers = nn.Sequential(
self, NVILALiteMultiModalProjectorDownsampleBlock(),
config: VILAConfig, nn.LayerNorm(MM_HIDDEN_SIZE * 9),
*args, nn.Linear(MM_HIDDEN_SIZE * 9, MM_HIDDEN_SIZE * 3),
**kwargs, nn.GELU(),
): nn.LayerNorm(MM_HIDDEN_SIZE * 3),
super().__init__(*args, **kwargs) nn.Linear(MM_HIDDEN_SIZE * 3, config.text_config.hidden_size),
nn.GELU(),
if config.mm_projector_type == "mlp_downsample_3x3_fix": nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size),
self.layers = nn.Sequential( )
DownSample3x3BlockFix(),
nn.LayerNorm(config.mm_hidden_size * 9),
nn.Linear(
config.mm_hidden_size * 9,
config.mm_hidden_size * 3,
),
nn.GELU(),
nn.LayerNorm(config.vision_config.hidden_size * 3),
nn.Linear(config.vision_config.hidden_size * 3, config.hidden_size),
nn.GELU(),
nn.Linear(config.hidden_size, config.hidden_size),
)
else:
raise NotImplementedError(
f"Unsupported mm_projector_type: {config.mm_projector_type}"
)
self.layers.type(config.torch_dtype)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
@property
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
""" return self.layers(x)
Args:
x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
Returns:
The output tensor of shape (batch_size, image_pad_len, hidden_size).
"""
return self.layers(x.to(device=self.device, dtype=self.dtype))
##### END COPY modeling_vila.py ##### class NVILALiteForConditionalGeneration(nn.Module):
class VILAForConditionalGeneration(nn.Module):
config: VILAConfig
quant_config: Optional[QuantizationConfig]
logits_processor: LogitsProcessor
pooler: Pooler
llm: Qwen2ForCausalLM
mm_projector: MultimodalProjector
vision_tower: SiglipVisionModel
def __init__( def __init__(
self, self,
config: VILAConfig, config: NVILALiteConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.vision_tower = SiglipVisionModel(config.vision_config)
self.mm_projector = NVILALiteMultiModalProjector(config)
self.llm = Qwen2ForCausalLM( self.llm = Qwen2ForCausalLM(
config=config.text_config, config=config.text_config,
quant_config=quant_config, quant_config=quant_config,
prefix=utils.add_prefix("llm", prefix), prefix=utils.add_prefix("llm", prefix),
) )
self.mm_projector = MultimodalProjector(config)
self.vision_tower = SiglipVisionModel(config.vision_config)
@property
def dtype(self) -> torch.dtype:
return self.config.torch_dtype
def forward( def forward(
self, self,
...@@ -229,40 +134,34 @@ class VILAForConditionalGeneration(nn.Module): ...@@ -229,40 +134,34 @@ class VILAForConditionalGeneration(nn.Module):
language_model=self.llm, language_model=self.llm,
data_embedding_funcs={ data_embedding_funcs={
Modality.IMAGE: self.get_image_feature, Modality.IMAGE: self.get_image_feature,
Modality.VIDEO: self.get_image_feature,
}, },
get_embedding=get_embedding, get_embedding=get_embedding,
positions=positions, positions=positions,
) )
return cast(LogitsProcessorOutput, output) assert isinstance(output, LogitsProcessorOutput)
def get_image_feature(self, mm_input: List[MultimodalDataItem]) -> Tensor: return output
pixel_values = cast(Tensor, mm_input[0].feature)
##### BEGIN COPY modeling_vila.py ##### def get_image_feature(self, mm_input: list[MultimodalDataItem]) -> Tensor:
pixel_values = torch.cat([torch.tensor(x.feature) for x in mm_input], dim=0)
vision_tower_output: BaseModelOutputWithPooling = self.vision_tower.__call__( vision_tower_output: BaseModelOutputWithPooling = self.vision_tower(
pixel_values.to( pixel_values,
device=self.vision_tower.device, dtype=self.vision_tower.dtype
),
output_hidden_states=True, output_hidden_states=True,
) )
assert vision_tower_output.hidden_states is not None
mm_projector_input = self._vision_tower_output_to_mm_projector_input( vision_features = vision_tower_output.hidden_states[-2]
vision_tower_output
)
image_embedding: Tensor = self.mm_projector.__call__( vision_features = self.mm_projector(vision_features)
mm_projector_input.to(
device=self.mm_projector.device, dtype=self.mm_projector.dtype
)
)
##### END COPY modeling_vila.py ##### vision_features = einops.rearrange(vision_features, "n p d -> (n p) d")
return image_embedding return vision_features
def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> None: def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> None:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
...@@ -276,31 +175,10 @@ class VILAForConditionalGeneration(nn.Module): ...@@ -276,31 +175,10 @@ class VILAForConditionalGeneration(nn.Module):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
def pad_input_ids( def pad_input_ids(
self, input_ids: List[int], mm_inputs: MultimodalInputs self, input_ids: list[int], mm_inputs: MultimodalInputs
) -> List[int]: ) -> list[int]:
pattern = MultiModalityDataPaddingPatternMultimodalTokens() pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
##### BEGIN COPY modeling_vila.py #####
def _vision_tower_output_to_mm_projector_input(
self,
vision_tower_output: BaseModelOutputWithPooling,
) -> Tensor:
assert vision_tower_output.hidden_states is not None
selected_layer_hidden_states = vision_tower_output.hidden_states[
self.config.mm_vision_select_layer
]
if self.config.mm_vision_select_feature == "cls_patch":
return selected_layer_hidden_states
else:
raise NotImplementedError(
f"Unsupported mm_vision_select_feature: {self.config.mm_vision_select_feature}"
)
##### END COPY modeling_vila.py #####
EntryClass = [VILAForConditionalGeneration] EntryClass = [NVILALiteForConditionalGeneration]
...@@ -185,6 +185,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -185,6 +185,7 @@ class BaseMultimodalProcessor(ABC):
"aspect_ratio_mask": Modality.IMAGE, "aspect_ratio_mask": Modality.IMAGE,
"num_patches": Modality.IMAGE, "num_patches": Modality.IMAGE,
"patch_pixel_values": Modality.IMAGE, "patch_pixel_values": Modality.IMAGE,
"block_sizes": Modality.IMAGE,
# Audio-related attributes # Audio-related attributes
"audio_features": Modality.AUDIO, "audio_features": Modality.AUDIO,
"audio_feature_lens": Modality.AUDIO, "audio_feature_lens": Modality.AUDIO,
......
from typing import Any, Dict, List, Optional, Type from typing import Any
import torch.nn as nn import torch.nn as nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.processing_utils import ProcessorMixin from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import GenerateReqInput
EmbeddingReqInput, from sglang.srt.models.nvila import NVILAForConditionalGeneration
GenerateReqInput, from sglang.srt.models.nvila_lite import NVILALiteForConditionalGeneration
ImageDataInputItem,
)
from sglang.srt.models.vila import VILAForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import ( from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
MultimodalSpecialTokens, MultimodalSpecialTokens,
) )
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
NUM_VIDEO_FRAMES = 8
class VILAProcessor(ProcessorMixin):
"""A stub class for the VILA processor."""
tokenizer: PreTrainedTokenizerBase
class VILAMultimodalProcessor(BaseMultimodalProcessor):
models: List[Type[nn.Module]] = [VILAForConditionalGeneration]
_processor: VILAProcessor class NVILAMultimodalProcessor(BaseMultimodalProcessor):
models: list[type[nn.Module]] = [
NVILAForConditionalGeneration,
NVILALiteForConditionalGeneration,
]
def __init__( def __init__(
self, self,
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
server_args: ServerArgs, server_args: ServerArgs,
_processor: VILAProcessor, _processor: ProcessorMixin,
*args, *args,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(hf_config, server_args, _processor, *args, **kwargs) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self._processor: ProcessorMixin
tokenizer: PreTrainedTokenizerBase = getattr(self._processor, "tokenizer")
self.mm_tokens = MultimodalSpecialTokens( self.mm_tokens = MultimodalSpecialTokens(
image_token=self._processor.tokenizer.image_token, image_token=tokenizer.image_token,
image_token_id=hf_config.image_token_id, image_token_id=hf_config.image_token_id,
video_token=tokenizer.video_token,
video_token_id=hf_config.video_token_id, video_token_id=hf_config.video_token_id,
).build(_processor) ).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: Optional[ImageDataInputItem | List[ImageDataInputItem]], image_data,
input_text: str | List[int], audio_data,
request_obj: GenerateReqInput | EmbeddingReqInput, input_text,
request_obj: GenerateReqInput,
**kwargs, **kwargs,
) -> Optional[Dict[str, Any]]: ) -> dict[str, Any] | None:
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
multimodal_tokens=self.mm_tokens, multimodal_tokens=self.mm_tokens,
image_data=image_data, image_data=request_obj.image_data, # type: ignore
video_data=request_obj.video_data, # type: ignore
) )
for i, video in enumerate(base_output.videos): # type: ignore
base_output.videos[i] = [x.asnumpy() for x in video] # type: ignore
mm_items, input_ids, _ = self.process_and_combine_mm_data( mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_output, self.mm_tokens base_output,
self.mm_tokens,
do_sample_frames=True,
num_frames=NUM_VIDEO_FRAMES,
) )
return { return {
......
...@@ -21,9 +21,12 @@ MODEL_THRESHOLDS = { ...@@ -21,9 +21,12 @@ MODEL_THRESHOLDS = {
0.330, 56.1 0.330, 56.1
), ),
ModelLaunchSettings("deepseek-ai/Janus-Pro-7B"): ModelEvalMetrics(0.285, 40.3), ModelLaunchSettings("deepseek-ai/Janus-Pro-7B"): ModelEvalMetrics(0.285, 40.3),
ModelLaunchSettings( ModelLaunchSettings("Efficient-Large-Model/NVILA-8B-hf"): ModelEvalMetrics(
"Efficient-Large-Model/NVILA-Lite-2B-hf-0626" 0.270, 56.7
): ModelEvalMetrics(0.305, 23.8), ),
ModelLaunchSettings("Efficient-Large-Model/NVILA-Lite-2B-hf"): ModelEvalMetrics(
0.270, 23.8
),
ModelLaunchSettings("google/gemma-3-4b-it"): ModelEvalMetrics(0.360, 10.9), ModelLaunchSettings("google/gemma-3-4b-it"): ModelEvalMetrics(0.360, 10.9),
ModelLaunchSettings("google/gemma-3n-E4B-it"): ModelEvalMetrics(0.360, 17.7), ModelLaunchSettings("google/gemma-3n-E4B-it"): ModelEvalMetrics(0.360, 17.7),
ModelLaunchSettings("mistral-community/pixtral-12b"): ModelEvalMetrics(0.360, 16.6), ModelLaunchSettings("mistral-community/pixtral-12b"): ModelEvalMetrics(0.360, 16.6),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment