"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4b557132ce955d58fd84572c03e79f43bdc91450"
Unverified Commit 88568c01 authored by 996_icu's avatar 996_icu Committed by GitHub
Browse files

[model] Support POINTSV15Chat (#9651)


Co-authored-by: default avatarjosephyou <josephyou@tencent.com>
Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: default avatarroot <root@TENCENT64.site>
parent 904655c5
...@@ -917,6 +917,7 @@ multimodal_model_archs = [ ...@@ -917,6 +917,7 @@ multimodal_model_archs = [
"Phi4MMForCausalLM", "Phi4MMForCausalLM",
"VILAForConditionalGeneration", "VILAForConditionalGeneration",
"Step3VLForConditionalGeneration", "Step3VLForConditionalGeneration",
"POINTSV15ChatModel",
"DotsVLMForCausalLM", "DotsVLMForCausalLM",
"DotsOCRForCausalLM", "DotsOCRForCausalLM",
"Sarashina2VisionForCausalLM", "Sarashina2VisionForCausalLM",
......
from typing import Optional, Union
from transformers import PretrainedConfig, Qwen2Config
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
class POINTSV15ChatConfig(PretrainedConfig):
model_type = "pointsv1.5_chat"
def __init__(
self,
vision_config: Optional[Union[dict, Qwen2VLVisionConfig]] = None,
llm_config: Optional[Union[dict, Qwen2Config]] = None,
**kwargs,
):
super().__init__(**kwargs)
if vision_config is None:
vision_config = Qwen2VLVisionConfig()
elif isinstance(vision_config, dict):
vision_config = Qwen2VLVisionConfig(**vision_config)
self.vision_config = vision_config
if llm_config is None:
llm_config = Qwen2Config()
elif isinstance(llm_config, dict):
llm_config = Qwen2Config(**llm_config)
self.llm_config = llm_config
self.hidden_size = self.llm_config.hidden_size
import copy
from typing import Iterable, List, Optional, Set, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from sglang.srt.configs.points_v15_chat import POINTSV15ChatConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.models.qwen2_vl import Qwen2VisionPatchMerger, Qwen2VisionTransformer
from sglang.srt.utils import add_prefix
class Qwen2VisionTransformerForNavitPOINTS(Qwen2VisionTransformer):
def __init__(
self,
vision_config: POINTSV15ChatConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(
vision_config,
norm_eps=norm_eps,
quant_config=quant_config,
prefix=prefix,
)
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
# compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# transformers
x = x.unsqueeze(1)
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
return x
class POINTSV15ChatModel(nn.Module):
def __init__(
self,
config: POINTSV15ChatConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
config.llm_config._attn_implementation = "flash_attention_2"
config._attn_implementation_autoset = False
self.config = config
self.quant_config = quant_config
llm_config = copy.deepcopy(config.llm_config)
llm_config.architectures = ["Qwen2ForCausalLM"]
self.llm = Qwen2ForCausalLM(
config=llm_config,
quant_config=quant_config,
prefix=add_prefix("llm", prefix),
)
self.vision_encoder = Qwen2VisionTransformerForNavitPOINTS(
config.vision_config,
quant_config=quant_config,
prefix=add_prefix("vision_encoder", prefix),
)
self.vision_projector = Qwen2VisionPatchMerger(
d_model=config.llm_config.hidden_size,
context_dim=1280,
quant_config=quant_config,
prefix=add_prefix("vision_projector", prefix),
)
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.vision_encoder.dtype
)
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_features = self.vision_encoder(pixel_values, grid_thw=image_grid_thw)
image_features = self.vision_projector(image_features)
return image_features
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.llm,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions,
)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "vision_encoder" in name:
# adapt to VisionAttention
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
except KeyError:
print(params_dict.keys())
raise
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = [POINTSV15ChatModel]
# Copy from qwen_vl.py, adapted for points-v15-chat
import asyncio
from typing import List, Union
from PIL import Image
from sglang.srt.models.points_v15_chat import POINTSV15ChatModel
from sglang.srt.multimodal.processors.qwen_vl import (
Qwen2_5VLImageProcessor,
resize_image_async,
)
class POINTSV15ChatProcessor(Qwen2_5VLImageProcessor):
models = [POINTSV15ChatModel]
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
# Compatible with POINTSV15Chat
hf_config.vision_start_token_id = None
hf_config.vision_end_token_id = None
hf_config.video_token_id = None
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=self.mm_tokens,
)
if base_output.images and isinstance(base_output.images[0], Image.Image):
resize_tasks = [resize_image_async(image) for image in base_output.images]
base_output.images = await asyncio.gather(*resize_tasks)
mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_output, self.mm_tokens
)
return {
"input_ids": input_ids.tolist(),
"mm_items": mm_items,
"im_token_id": self.mm_tokens.image_token_id,
}
...@@ -960,6 +960,19 @@ register_conv_template( ...@@ -960,6 +960,19 @@ register_conv_template(
) )
) )
register_conv_template(
Conversation(
name="points-v15-chat",
system_message="",
system_template="",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep="<|im_end|>\n",
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=["<|im_end|>"],
image_token="<|vision_start|><|image_pad|><|vision_end|>",
video_token="<|vision_start|><|video_pad|><|vision_end|>",
)
)
MODEL_TYPE_TO_TEMPLATE = { MODEL_TYPE_TO_TEMPLATE = {
"internvl_chat": "internvl-2-5", "internvl_chat": "internvl-2-5",
...@@ -971,6 +984,12 @@ MODEL_TYPE_TO_TEMPLATE = { ...@@ -971,6 +984,12 @@ MODEL_TYPE_TO_TEMPLATE = {
} }
@register_conv_template_matching_function
def match_points_v15_chat(model_path: str):
if re.search(r"points", model_path, re.IGNORECASE):
return "points-v15-chat"
def get_model_type(model_path: str) -> Optional[str]: def get_model_type(model_path: str) -> Optional[str]:
config_path = os.path.join(model_path, "config.json") config_path = os.path.join(model_path, "config.json")
if not os.path.exists(config_path): if not os.path.exists(config_path):
......
...@@ -111,6 +111,12 @@ def get_hf_text_config(config: PretrainedConfig): ...@@ -111,6 +111,12 @@ def get_hf_text_config(config: PretrainedConfig):
# if transformers config doesn't align with this assumption. # if transformers config doesn't align with this assumption.
assert hasattr(config.text_config, "num_attention_heads") assert hasattr(config.text_config, "num_attention_heads")
return config.text_config return config.text_config
if hasattr(config, "llm_config"):
# PointsV1.5 Chat Model
assert hasattr(config.llm_config, "num_attention_heads")
return config.llm_config
if hasattr(config, "language_config"): if hasattr(config, "language_config"):
return config.language_config return config.language_config
if hasattr(config, "thinker_config"): if hasattr(config, "thinker_config"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment