Unverified Commit 644d57d5 authored by CSWYF3634076's avatar CSWYF3634076 Committed by GitHub
Browse files

[Model] Add Ernie4.5 VL Model Support (#22514)


Signed-off-by: default avatarwangyafeng <wangyafeng@baidu.com>
parent c905684c
......@@ -616,6 +616,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ |
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ |
| `DonutForConditionalGeneration`<sup>^</sup> | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | |
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ |
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
......
......@@ -173,6 +173,37 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
)
# Ernie4.5-VL
def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={modality: 1},
trust_remote_code=True,
)
if modality == "image":
placeholder = "Picture 1:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>"
elif modality == "video":
placeholder = "Video 1:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>"
prompts = [
(
f"<|begin_of_sentence|>User: {question}{placeholder}\n"
"Assistant: <think></think>"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Florence2
def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
......@@ -1602,6 +1633,7 @@ model_example_map = {
"chameleon": run_chameleon,
"command_a_vision": run_command_a_vision,
"deepseek_vl_v2": run_deepseek_vl2,
"ernie45_vl": run_ernie45_vl,
"florence2": run_florence2,
"fuyu": run_fuyu,
"gemma3": run_gemma3,
......
......@@ -54,3 +54,4 @@ runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10
terratorch==1.1rc2 # required for PrithviMAE test
decord==0.6.0
......@@ -156,6 +156,8 @@ datasets==3.0.2
# mteb
decorator==5.1.1
# via librosa
decord==0.6.0
# via -r requirements/test.in
dill==0.3.8
# via
# datasets
......@@ -493,6 +495,7 @@ numpy==1.26.4
# contourpy
# cupy-cuda12x
# datasets
# decord
# einx
# encodec
# evaluate
......
......@@ -272,6 +272,7 @@ def _test_processing_correctness_one(
"CohereLabs/command-a-vision-07-2025",
"deepseek-ai/deepseek-vl2-tiny",
"naver-clova-ix/donut-base-finetuned-docvqa",
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
"microsoft/Florence-2-base",
"adept/fuyu-8b",
"google/gemma-3-4b-it",
......
......@@ -396,6 +396,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
transformers_version_reason="HF model is not compatible.", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501
trust_remote_code=True),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from .common import apply_rotary_emb_dispatch
from .mrope import MRotaryEmbedding
class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
"""3D rotary positional embedding. 3D is t:time h:height w:width"""
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section
section_h = self.mrope_section[0] # 22
section_w = self.mrope_section[1] # 22
section_t = self.mrope_section[2] # 20
assert section_h == section_w
# Split according to [h w h w h w h w... t t t...]
section_cos_t = cos[..., -section_t:]
section_cos_h = cos[..., :section_h + section_w:2]
section_cos_w = cos[..., 1:section_h + section_w:2]
cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[
1], section_cos_w[2]
cos_hw = torch.stack([cos_h, cos_w],
dim=-1).reshape(cos_h.shape[:-1] +
(cos_h.shape[-1] * 2, ))
cos = torch.cat([cos_hw, cos_t], dim=-1)
section_sin_t = sin[..., -section_t:]
section_sin_h = sin[..., :section_h + section_w:2]
section_sin_w = sin[..., 1:section_h + section_w:2]
sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[
1], section_sin_w[2]
sin_hw = torch.stack([sin_h, sin_w],
dim=-1).reshape(sin_h.shape[:-1] +
(sin_h.shape[-1] * 2, ))
sin = torch.cat([sin_hw, sin_t], dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
......@@ -393,6 +393,15 @@ class MRotaryEmbedding(RotaryEmbedding):
context_len=context_len,
seq_len=seq_len,
)
elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]:
return cls._ernie_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
else:
return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens,
......@@ -513,6 +522,120 @@ class MRotaryEmbedding(RotaryEmbedding):
len(input_tokens)).item()
return llm_positions, mrope_position_delta
@classmethod
def _ernie_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for Ernie VL."""
image_token_id = hf_config.im_patch_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_conv_size = hf_config.spatial_conv_size
temporal_conv_size = hf_config.temporal_conv_size
llm_pos_ids_list: list = []
if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_conv_size, w // spatial_conv_size
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = (
video_grid_thw[mm_data_idx][0],
video_grid_thw[mm_data_idx][1],
video_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (t //
temporal_conv_size,
h //
spatial_conv_size,
w //
spatial_conv_size)
for t_idx in range(llm_grid_t):
t_index = torch.tensor(t_idx).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(
1, -1, 1).expand(1, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(
1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) +
st_idx)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
return llm_positions, mrope_position_delta
@classmethod
def _vl_get_input_positions_tensor(
cls,
......
This diff is collapsed.
This diff is collapsed.
......@@ -206,6 +206,7 @@ _MULTIMODAL_MODELS = {
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
"Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), # noqa: E501
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
"Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"), # noqa: E501
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
"Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment