Unverified Commit d6d21640 authored by 萝卜菜's avatar 萝卜菜 Committed by GitHub
Browse files

[Feature] Support Deepseek-VL2 (#2798)


Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
Co-authored-by: default avatarYi Zhang <1109276519@qq.com>
parent 0212d2e2
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
- Phi-3-Small - Phi-3-Small
- IBM Granite 3 - IBM Granite 3
- Janus-Pro-1B / Janus-Pro-7B - Janus-Pro-1B / Janus-Pro-7B
- Deepseek-VL2 / Deepseek-VL2-small
- Gemma 3 (it) - Gemma 3 (it)
## Embedding Models ## Embedding Models
......
from sglang.srt.configs.chatglm import ChatGLMConfig from sglang.srt.configs.chatglm import ChatGLMConfig
from sglang.srt.configs.dbrx import DbrxConfig from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.janus_pro import MultiModalityConfig
...@@ -12,6 +13,7 @@ __all__ = [ ...@@ -12,6 +13,7 @@ __all__ = [
"ExaoneConfig", "ExaoneConfig",
"ChatGLMConfig", "ChatGLMConfig",
"DbrxConfig", "DbrxConfig",
"DeepseekVL2Config",
"Qwen2_5_VLConfig", "Qwen2_5_VLConfig",
"Qwen2_5_VLVisionConfig", "Qwen2_5_VLVisionConfig",
"MultiModalityConfig", "MultiModalityConfig",
......
This diff is collapsed.
...@@ -135,6 +135,11 @@ class ModelConfig: ...@@ -135,6 +135,11 @@ class ModelConfig:
self.attention_arch = AttentionArch.MLA self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures:
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
else: else:
self.attention_arch = AttentionArch.MHA self.attention_arch = AttentionArch.MHA
...@@ -362,6 +367,8 @@ def get_hf_text_config(config: PretrainedConfig): ...@@ -362,6 +367,8 @@ def get_hf_text_config(config: PretrainedConfig):
# if transformers config doesn't align with this assumption. # if transformers config doesn't align with this assumption.
assert hasattr(config.text_config, "num_attention_heads") assert hasattr(config.text_config, "num_attention_heads")
return config.text_config return config.text_config
if hasattr(config, "language_config"):
return config.language_config
else: else:
return config return config
...@@ -465,6 +472,7 @@ multimodal_model_archs = [ ...@@ -465,6 +472,7 @@ multimodal_model_archs = [
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"MiniCPMV", "MiniCPMV",
"MultiModalityCausalLM", "MultiModalityCausalLM",
"DeepseekVL2ForCausalLM",
] ]
......
...@@ -44,6 +44,7 @@ class SeparatorStyle(IntEnum): ...@@ -44,6 +44,7 @@ class SeparatorStyle(IntEnum):
CHATGLM3 = auto() CHATGLM3 = auto()
DEEPSEEK_CHAT = auto() DEEPSEEK_CHAT = auto()
METAMATH = auto() METAMATH = auto()
DeepSeekVL2 = auto()
QWEN2_VL_EMBED = auto() QWEN2_VL_EMBED = auto()
GEMMA3 = auto() GEMMA3 = auto()
...@@ -75,6 +76,7 @@ class Conversation: ...@@ -75,6 +76,7 @@ class Conversation:
image_data: Optional[List[str]] = None image_data: Optional[List[str]] = None
modalities: Optional[List[str]] = None modalities: Optional[List[str]] = None
stop_token_ids: Optional[int] = None
def get_prompt(self) -> str: def get_prompt(self) -> str:
"""Get the prompt for generation.""" """Get the prompt for generation."""
...@@ -286,6 +288,18 @@ class Conversation: ...@@ -286,6 +288,18 @@ class Conversation:
else: else:
ret += role + ":" ret += role + ":"
return ret return ret
elif self.sep_style == SeparatorStyle.DeepSeekVL2:
seps = [self.sep, self.sep2]
if system_prompt == "" or system_prompt is None:
ret = ""
else:
ret = system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.GEMMA3: elif self.sep_style == SeparatorStyle.GEMMA3:
ret = system_prompt ret = system_prompt
for i, (role, message) in enumerate(self.messages): for i, (role, message) in enumerate(self.messages):
...@@ -617,6 +631,23 @@ register_conv_template( ...@@ -617,6 +631,23 @@ register_conv_template(
) )
) )
register_conv_template(
Conversation(
name="deepseek-vl2",
system_template="{system_message}",
# system_message="You are a helpful assistant. Please answer truthfully and write out your "
# "thinking step by step to be sure you get the right answer.",
system_message="",
roles=("<|User|>", "<|Assistant|>"),
messages=(),
offset=0,
sep_style=SeparatorStyle.DeepSeekVL2,
sep="\n\n",
sep2="<|end▁of▁sentence|>",
stop_str=["User:", "<|end▁of▁sentence|>"],
)
)
# Reference: https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json # Reference: https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json
register_conv_template( register_conv_template(
Conversation( Conversation(
......
...@@ -33,6 +33,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_N ...@@ -33,6 +33,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_N
from sglang.srt.configs import ( from sglang.srt.configs import (
ChatGLMConfig, ChatGLMConfig,
DbrxConfig, DbrxConfig,
DeepseekVL2Config,
ExaoneConfig, ExaoneConfig,
Gemma3Config, Gemma3Config,
Gemma3TextConfig, Gemma3TextConfig,
...@@ -47,6 +48,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ...@@ -47,6 +48,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
DbrxConfig.model_type: DbrxConfig, DbrxConfig.model_type: DbrxConfig,
ExaoneConfig.model_type: ExaoneConfig, ExaoneConfig.model_type: ExaoneConfig,
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig, Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
DeepseekVL2Config.model_type: DeepseekVL2Config,
MultiModalityConfig.model_type: MultiModalityConfig, MultiModalityConfig.model_type: MultiModalityConfig,
Gemma3Config.model_type: Gemma3Config, Gemma3Config.model_type: Gemma3Config,
Gemma3TextConfig.model_type: Gemma3TextConfig, Gemma3TextConfig.model_type: Gemma3TextConfig,
......
# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import asyncio
import math
from typing import List, Union
import torch
from PIL import Image, ImageOps
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor,
)
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
class DeepseekVL2ImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
# with contextlib.suppress(ValueError):
# AutoProcessor.register("DeepseekVLV2Processor", DeepseekVLV2Processor)
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<image>"
@staticmethod
def _process_images_task(image, input_text, max_req_input_len):
return get_global_processor().__call__(
conversations=input_text, images=image, max_req_input_len=max_req_input_len
)
async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
DeepseekVL2ImageProcessor._process_images_task,
image_data,
input_text,
max_req_input_len,
)
else:
image_inputs = self._process_images_task(
image_data, input_text, max_req_input_len
)
return image_inputs
async def process_images_async(
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
images, image_hashes, image_sizes = [], [], []
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
input_ids, image_data, image_token, max_req_input_len
)
base_output.all_frames = [img.convert("RGB") for img in base_output.all_frames]
res = await self._process_images(
base_output.all_frames, base_output.input_text, max_req_input_len
)
pixel_values = res["images"]
input_ids = res["input_ids"]
images_seq_mask = res["images_seq_mask"]
images_spatial_crop = res["images_spatial_crop"]
batched_images_spatial_crop = []
batched_images_spatial_crop.append(images_spatial_crop)
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
return {
"input_ids": input_ids.tolist(),
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"image_seq_mask": images_seq_mask,
"image_spatial_crop": batched_images_spatial_crop,
"modalities": request_obj.modalities or ["image"],
}
ImageProcessorMapping = {
DeepseekVL2ForCausalLM: DeepseekVL2ImageProcessor,
}
...@@ -160,8 +160,13 @@ class ImageInputs: ...@@ -160,8 +160,13 @@ class ImageInputs:
image_grid_thws: List[Tuple[int, int, int]] = None image_grid_thws: List[Tuple[int, int, int]] = None
mrope_position_delta: Optional[torch.Tensor] = None mrope_position_delta: Optional[torch.Tensor] = None
# deepseek vl2 related
image_seq_mask: Optional[List[torch.Tensor]] = None
image_spatial_crop: Optional[List[torch.Tensor]] = None
# The id of the single-image placeholder token # The id of the single-image placeholder token
im_token_id: Optional[torch.Tensor] = None im_token_id: Optional[torch.Tensor] = None
# All the images in the batch should share the same special image # All the images in the batch should share the same special image
# bound token ids. # bound token ids.
im_start_id: Optional[int] = None im_start_id: Optional[int] = None
...@@ -192,6 +197,8 @@ class ImageInputs: ...@@ -192,6 +197,8 @@ class ImageInputs:
"aspect_ratio_ids", "aspect_ratio_ids",
"aspect_ratio_mask", "aspect_ratio_mask",
"image_grid_thws", "image_grid_thws",
"image_seq_mask",
"image_spatial_crop",
"im_token_id", "im_token_id",
"im_start_id", "im_start_id",
"im_end_id", "im_end_id",
...@@ -228,6 +235,8 @@ class ImageInputs: ...@@ -228,6 +235,8 @@ class ImageInputs:
"aspect_ratio_ids", "aspect_ratio_ids",
"aspect_ratio_mask", "aspect_ratio_mask",
"image_grid_thws", "image_grid_thws",
"image_seq_mask",
"image_spatial_crop",
] ]
for arg in optional_args: for arg in optional_args:
if getattr(self, arg, None) is not None: if getattr(self, arg, None) is not None:
......
...@@ -266,6 +266,14 @@ class ModelRunner: ...@@ -266,6 +266,14 @@ class ModelRunner:
server_args.chunked_prefill_size = -1 server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True server_args.disable_radix_cache = True
if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
logger.info(
"Automatically turn off --chunked-prefill-size and disable radix cache for deekseek-vl2."
)
server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True
def init_torch_distributed(self): def init_torch_distributed(self):
logger.info("Init torch distributed begin.") logger.info("Init torch distributed begin.")
......
...@@ -1021,6 +1021,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1021,6 +1021,7 @@ class DeepseekV2Model(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
# Gather # Gather
...@@ -1035,7 +1036,11 @@ class DeepseekV2Model(nn.Module): ...@@ -1035,7 +1036,11 @@ class DeepseekV2Model(nn.Module):
) )
dp_gather(input_ids, local_input_ids, forward_batch, "embedding") dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
hidden_states = self.embed_tokens(input_ids) if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
...@@ -1076,8 +1081,10 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1076,8 +1081,10 @@ class DeepseekV2ForCausalLM(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if self.dp_size != 1: if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather. # important: forward batch.gathered_buffer is used both after scatter and after gather.
......
import collections
import itertools
import math
import warnings
from enum import Enum
from functools import partial
from typing import Callable, Iterable, List, Optional, Tuple, Type, Union
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn
from sglang.srt.configs import DeepseekVL2Config
from sglang.srt.configs.deepseekvl2 import (
DeepseekVL2Config,
DeepseekVL2MlpProjectorConfig,
)
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
LinearBase,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
class DeepseekVL2MlpProjector(nn.Module):
def __init__(
self,
config: DeepseekVL2MlpProjectorConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
if config.projector_type == "identity":
modules = nn.Identity()
elif config.projector_type == "linear":
self.layers = nn.ModuleList(
[
ReplicatedLinear(
config.input_dim,
config.n_embed,
quant_config=quant_config,
)
]
)
elif config.projector_type == "mlp_gelu":
mlp_depth = config.depth
self.layers = nn.ModuleList(
[
ReplicatedLinear(
config.input_dim,
config.n_embed,
quant_config=quant_config,
)
]
)
for _ in range(1, mlp_depth):
self.layers.append(nn.GELU())
self.layers.append(
ReplicatedLinear(
config.n_embed,
config.n_embed,
quant_config=quant_config,
)
)
elif config.projector_type == "downsample_mlp_gelu":
mlp_depth = config.depth
mlp_ratio = config.mlp_ratio
self.layers = nn.ModuleList(
[
ReplicatedLinear(
config.input_dim
* config.downsample_ratio
* config.downsample_ratio,
config.n_embed * mlp_ratio,
quant_config=quant_config,
)
]
)
for _ in range(1, mlp_depth - 1):
self.layers.append(nn.GELU())
self.layers.append(
ReplicatedLinear(
config.n_embed * mlp_ratio,
config.n_embed * mlp_ratio,
quant_config=quant_config,
)
)
self.layers.append(nn.GELU())
self.layers.append(
ReplicatedLinear(
config.n_embed * mlp_ratio,
config.n_embed,
quant_config=quant_config,
)
)
else:
raise ValueError(f"Unknown projector type: {config.projector_type}")
if config.token_pooling:
self.token_pooling_layer = ReplicatedLinear(
config.input_dim * 4, config.input_dim, quant_config=quant_config
)
def forward(self, x):
if self.config.token_pooling:
batch_size, wxh, channels = x.shape
w = h = int(wxh**0.5)
x = x.view(batch_size, w, h, channels)
x = x.permute(0, 3, 1, 2)
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
patches = patches.contiguous().view(
batch_size, channels, h_patches * w_patches, -1
)
patches = patches.permute(0, 2, 1, 3).contiguous()
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
x = self.token_pooling_layer(patches)[0]
elif self.config.projector_type == "downsample_mlp_gelu":
bs, hw, input_dim = x.shape
h = w = int((hw) ** 0.5)
"""compute padding"""
if h % self.config.downsample_ratio:
pad = self.config.downsample_ratio - h % self.config.downsample_ratio
else:
pad = 0
x = x.reshape(bs, h, w, input_dim)
if pad > 0:
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
"""4 to 1 concat"""
x = x.permute(0, 3, 1, 2) # B, C, H, W
x = F.unfold(
x,
kernel_size=self.config.downsample_ratio,
stride=self.config.downsample_ratio,
padding=0,
) # B, C*4, HW // 4
x = x.permute(0, 2, 1)
for layer in self.layers:
x = layer(x)
if isinstance(x, tuple):
x = x[0]
return x
# todo
class DeepseekVL2ForCausalLM(nn.Module):
def __init__(
self,
config: DeepseekVL2Config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# ----------- vision encoder ------------
vision_config = config.vision_config
self.vision = self._init_vision_module(vision_config, quant_config)
# ----------- vl projector ------------
projector_config = config.projector_config
self.projector = DeepseekVL2MlpProjector(projector_config, quant_config)
self.tile_tag = config.tile_tag
self.global_view_pos = config.global_view_pos
embed_std = 1 / torch.sqrt(
torch.tensor(projector_config.n_embed, dtype=torch.float32)
)
if self.tile_tag == "2D":
self.image_newline = nn.Parameter(
torch.randn(projector_config.n_embed) * embed_std
)
self.view_seperator = nn.Parameter(
torch.randn(projector_config.n_embed) * embed_std
)
else:
raise ValueError(f"tile tag should be 2D, but got {self.tile_tag}")
# ----------- language model ------------
language_config = config.language_config
self.language_model = DeepseekV2ForCausalLM(language_config)
def _init_vision_module(
self, vision_config, quant_config: Optional[QuantizationConfig]
) -> nn.Module:
# TODO: refactor vision model through timm wrapper from transformers
try:
import timm
except ImportError:
raise ImportError("Please install timm") from ImportError
model = timm.create_model(
"vit_so400m_patch14_siglip_384.webli",
pretrained=False,
num_classes=0,
dynamic_img_size=True,
dynamic_img_pad=True,
)
model = model.to(dtype=torch.get_default_dtype())
return model
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
**kwargs: object,
):
input_embeds = self.language_model.model.embed_tokens(input_ids)
if forward_batch.forward_mode.is_extend() and forward_batch.image_inputs != [
None
]:
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
for idx, image in enumerate(forward_batch.image_inputs):
if image is None:
continue
start_idx = extend_start_loc_cpu[idx]
end_idx = start_idx + extend_seq_lens_cpu[idx]
pixel_values = image.pixel_values.to(
device="cuda", dtype=torch.bfloat16
)
image_seq_mask = image.image_seq_mask.to(device="cuda")
image_spatial_crop = image.image_spatial_crop
input_embeds[start_idx:end_idx] = self.prepare_inputs_embeds(
pixel_values,
image_seq_mask,
image_spatial_crop,
input_embeds[start_idx:end_idx],
)
outputs = self.language_model.forward(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
input_embeds=input_embeds,
)
return outputs
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
params_dict = dict(self.named_parameters())
weights = list(weights)
for name, loaded_weight in weights:
if "language" in name:
name = name.replace("language.", "")
self.language_model.load_weights([(name, loaded_weight)])
else:
param = params_dict[name]
weights_loader = getattr(param, "weight_loader", default_weight_loader)
weights_loader(param, loaded_weight)
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
return input_ids
def prepare_inputs_embeds(
self,
pixel_values,
images_seq_mask,
images_spatial_crop,
input_embeds,
):
image_feature = self.vision.forward_features(pixel_values)
images_embeds = self.projector(image_feature)
_, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5)
tile_index = 0
images_in_this_batch = []
for jdx in range(images_spatial_crop.shape[1]):
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
if num_width_tiles == 0 or num_height_tiles == 0:
break
num_tiles_in_image = num_width_tiles * num_height_tiles
# [hw, D]
global_features = images_embeds[tile_index]
# [num_height_tiles * num_width_tiles, hw, D]
local_features = images_embeds[
tile_index + 1 : tile_index + 1 + num_tiles_in_image
]
tile_index += num_tiles_in_image + 1
# format global and local features
# ----------------- global view add newline -----------------
# [hw, D] -> [h, w, D]
global_features = global_features.view(h, w, n_dim)
# [D] -> [h, 1, D]
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
# [h, w + 1, D] -> [h * (w + 1), D]
global_features = global_features.view(-1, n_dim)
# ----------------- local view add newline -----------------
# [num_height_tiles * num_width_tiles, h * w, D] ->
# [num_height_tiles * h, num_width_tiles * w, D]
local_features = rearrange(
local_features,
"(th tw) (h w) d -> (th h) (tw w) d",
th=num_height_tiles,
tw=num_width_tiles,
h=h,
w=w,
)
# [D] -> [num_height_tiles * h, 1, D]
new_lines_in_local = repeat(
self.image_newline,
"d -> (th h) 1 d",
th=num_height_tiles,
h=h,
)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
local_features = local_features.view(-1, n_dim)
# merge global and local tiles
if self.global_view_pos == "head":
global_local_features = torch.cat(
[
global_features,
self.view_seperator[None, :],
local_features,
]
)
else:
global_local_features = torch.cat(
[
local_features,
self.view_seperator[None, :],
global_features,
]
)
images_in_this_batch.append(global_local_features)
if len(images_in_this_batch) > 0:
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
input_embeds.masked_scatter_(
images_seq_mask.unsqueeze(-1), images_in_this_batch
)
return input_embeds
EntryClass = DeepseekVL2ForCausalLM
...@@ -24,3 +24,6 @@ pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pa ...@@ -24,3 +24,6 @@ pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pa
# For compling xgrammar kernels # For compling xgrammar kernels
pip install cuda-python nvidia-cuda-nvrtc-cu12 pip install cuda-python nvidia-cuda-nvrtc-cu12
# For DeepSeek-VL2
pip install timm
...@@ -513,6 +513,30 @@ class TestMinicpmvServer(TestOpenAIVisionServer): ...@@ -513,6 +513,30 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
cls.base_url += "/v1" cls.base_url += "/v1"
class TestDeepseekVL2Server(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "deepseek-ai/deepseek-vl2-small"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--chat-template",
"deepseek-vl2",
"--context-length",
"4096",
],
)
cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
class TestJanusProServer(TestOpenAIVisionServer): class TestJanusProServer(TestOpenAIVisionServer):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment