"docs/source/usage/compatibility_matrix.rst" did not exist on "b4be5a8adba95020187ae3cb43a7db7eef20c0ff"
Unverified Commit f967e51f authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Model] Initialize support for Deepseek-VL2 models (#11578)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 43f3d9e6
......@@ -52,6 +52,7 @@ steps:
- tests/worker
- tests/standalone_tests/lazy_torch_compile.py
commands:
- pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git # Used by multimoda processing test
- python3 standalone_tests/lazy_torch_compile.py
- pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine
......
......@@ -610,6 +610,13 @@ See [this page](#generative-models) for more information on how to use generativ
-
- ✅︎
- ✅︎
* - `DeepseekVLV2ForCausalLM`
- DeepSeek-VL2
- T + I<sup>+</sup>
- `deepseek-ai/deepseek-vl2-tiny`(WIP), `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. (see note)
-
- ✅︎
- ✅︎
* - `FuyuForCausalLM`
- Fuyu
- T + I
......@@ -755,8 +762,19 @@ See [this page](#generative-models) for more information on how to use generativ
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
````{note}
The `deepseek-ai/deepseek-vl2-tiny` is not supported yet.
To use `DeepSeek-VL2` series models, you need to install a fork version `deepseek_vl2` package:
```shell
pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git
```
Besides, to run `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
````
```{note}
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
```
```{note}
......
......@@ -66,6 +66,23 @@ def run_chameleon(question: str, modality: str):
return llm, prompt, stop_token_ids
# Deepseek-VL2
def run_deepseek_vl2(question: str, modality: str):
assert modality == "image"
model_name = "deepseek-ai/deepseek-vl2-small"
llm = LLM(model=model_name,
max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]})
prompt = f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
stop_token_ids = None
return llm, prompt, stop_token_ids
# Fuyu
def run_fuyu(question: str, modality: str):
assert modality == "image"
......@@ -498,6 +515,7 @@ model_example_map = {
"aria": run_aria,
"blip-2": run_blip2,
"chameleon": run_chameleon,
"deepseek_vl_v2": run_deepseek_vl2,
"fuyu": run_fuyu,
"glm4v": run_glm4v,
"h2ovl_chat": run_h2ovl,
......
......@@ -54,6 +54,28 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData:
)
def load_deepseek_vl2(question: str, image_urls: List[str]):
model_name = "deepseek-ai/deepseek-vl2-small"
llm = LLM(model=model_name,
max_model_len=4096,
max_num_seqs=2,
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
limit_mm_per_prompt={"image": len(image_urls)})
placeholder = "".join(f"image_{i}:<image>\n"
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|User|>: {placeholder}{question}\n\n<|Assistant|>:"
return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=None,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "h2oai/h2ovl-mississippi-2b"
......@@ -372,6 +394,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
model_example_map = {
"aria": load_aria,
"deepseek_vl2": load_deepseek_vl2,
"h2ovl_chat": load_h2onvl,
"idefics3": load_idefics3,
"internvl_chat": load_internvl,
......
......@@ -188,6 +188,33 @@ VLM_TEST_SETTINGS = {
max_tokens=8,
dtype="bfloat16",
),
"deepseek_vl_v2": VLMTestInfo(
models=["deepseek-ai/deepseek-vl2-small"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
dtype="bfloat16",
prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<image>\nWhat's the color of the stop sign and car?",
"cherry_blossom": "<image>\nWhat's the color of the tower?",
}),
multi_image_prompt="image_1:<image>\nimage_2:<image>\nDescribe the two images shortly.", # noqa: E501
vllm_runner_kwargs={"hf_overrides": {"architectures": ["DeepseekVLV2ForCausalLM"]}}, # noqa: E501
image_size_factors=[(0.10, 0.15)],
patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner,
postprocess_inputs=model_utils.cast_dtype_post_processor("images"),
hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output,
stop_str=["<|end▁of▁sentence|>", "<|begin▁of▁sentence|>"], # noqa: E501
num_logprobs=5,
marks=[
pytest.mark.skipif(
not is_flash_attn_2_available(),
reason="Model needs flash-attn for numeric convergence.",
),
large_gpu_mark(min_gb=48),
],
),
"fuyu": VLMTestInfo(
models=["adept/fuyu-8b"],
test_type=VLMTestType.IMAGE,
......
......@@ -183,6 +183,14 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput,
####### Post-processors for HF outputs
def deepseekvl2_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
if output_str.endswith("<|end▁of▁sentence|>"):
output_str = output_str.split("<|end▁of▁sentence|>")[0]
return output_ids, output_str, out_logprobs
def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
......@@ -261,6 +269,34 @@ def qwen_prompt_path_encoder(
####### Model-specific HuggingFace runner patchers
def deepseekvl2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4."""
hf_processor = hf_model.processor
def processor(*args, text="", images=None, **kwargs):
if isinstance(images, Image):
images = [images]
# inputs is a custom class instead of dict or BatchFeature
inputs = hf_processor(
*args,
prompt=text,
images=images,
**kwargs,
)
inputs = {
k: inputs[k]
for k in inputs.keys() # noqa
if k not in ("seq_lens", "sft_format")
}
inputs = BatchEncoding(data=inputs, tensor_type="pt")
return inputs
hf_model.processor = processor
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language.model.embed_tokens
return hf_model
def glm_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4."""
hf_processor = hf_model.processor
......
......@@ -179,6 +179,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True),
"ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b",
is_available_online=False),
# TODO(Isotr0py): Use deepseek-vl2-tiny for test after it's supported
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-small"), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"),
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
......
......@@ -26,6 +26,9 @@ def test_can_initialize(model_arch):
# Avoid OOM
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type == "deepseek_vl_v2":
hf_config.update({"architectures": ["DeepseekVLV2ForCausalLM"]})
if hasattr(hf_config, "text_config"):
text_config: PretrainedConfig = hf_config.text_config
else:
......
......@@ -403,8 +403,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat", "NVLM_D",
"h2ovl_chat"):
if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
"NVLM_D", "h2ovl_chat"):
return "<image>"
if model_type == "mllama":
return "<|image|>"
......
......@@ -243,7 +243,11 @@ class DeepseekV2Attention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
rope_scaling["rope_type"] = 'deepseek_yarn'
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.use_normal_rope = False
else:
self.use_normal_rope = True
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
......@@ -298,7 +302,18 @@ class DeepseekV2Attention(nn.Module):
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]
if self.use_normal_rope:
seq_len = positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe = q_pe.reshape(seq_len, -1)
k_pe = k_pe.reshape(seq_len, -1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
if self.use_normal_rope:
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
......@@ -355,6 +370,7 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
......
......@@ -251,7 +251,11 @@ class DeepseekV3Attention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
rope_scaling["rope_type"] = 'deepseek_yarn'
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.use_normal_rope = False
else:
self.use_normal_rope = True
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
......@@ -306,7 +310,18 @@ class DeepseekV3Attention(nn.Module):
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]
if self.use_normal_rope:
seq_len = positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe = q_pe.reshape(seq_len, -1)
k_pe = k_pe.reshape(seq_len, -1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
if self.use_normal_rope:
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
......@@ -583,7 +598,8 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
continue
# TODO(simon): support nextn predict layers
if self.config.num_nextn_predict_layers > 0:
if hasattr(self.config, "num_nextn_predict_layers"
) and self.config.num_nextn_predict_layers > 0:
assert self.config.num_nextn_predict_layers == 1
layer_idx = self.config.num_hidden_layers
if name.startswith(f"model.layers.{layer_idx}"):
......
This diff is collapsed.
......@@ -657,7 +657,7 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module:
# TODO: refactor this vision model
# TODO: refactor vision model through timm wrapper from transformers
try:
import timm
except ImportError:
......
......@@ -149,6 +149,7 @@ _MULTIMODAL_MODELS = {
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
......
......@@ -23,8 +23,9 @@ from vllm.logger import init_logger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
DbrxConfig, EAGLEConfig,
ExaoneConfig, H2OVLChatConfig,
DbrxConfig, DeepseekVLV2Config,
EAGLEConfig, ExaoneConfig,
H2OVLChatConfig,
InternVLChatConfig, JAISConfig,
MedusaConfig, MllamaConfig,
MLPSpeculatorConfig, MPTConfig,
......@@ -54,6 +55,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chatglm": ChatGLMConfig,
"cohere2": Cohere2Config,
"dbrx": DbrxConfig,
"deepseek_vl_v2": DeepseekVLV2Config,
"mpt": MPTConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
......
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.cohere2 import Cohere2Config
from vllm.transformers_utils.configs.dbrx import DbrxConfig
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
from vllm.transformers_utils.configs.eagle import EAGLEConfig
from vllm.transformers_utils.configs.exaone import ExaoneConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
......@@ -25,6 +26,7 @@ __all__ = [
"ChatGLMConfig",
"Cohere2Config",
"DbrxConfig",
"DeepseekVLV2Config",
"MPTConfig",
"RWConfig",
"H2OVLChatConfig",
......
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py#L115-L268
from typing import Tuple
from transformers.configuration_utils import PretrainedConfig
class VisionEncoderConfig(PretrainedConfig):
model_type: str = "vision"
model_name: str = "vit_so400m_patch14_siglip_384.webli"
image_size: int = 384
patch_size: int = 16
width: int = 1024
layers: int = 24
heads: int = 16
mlp_ratio: int = 4
global_pool: str = "map"
ignore_head: bool = True
class_token: bool = False
num_classes: int = 0
use_checkpoint: bool = False
weight_init: str = "skip"
deterministic: bool = False
num_recomputing_layers: int = 0
def __init__(self,
model_name: str = "vit_so400m_patch14_siglip_384.webli",
image_size: int = 384,
patch_size: int = 16,
width: int = 1024,
layers: int = 24,
heads: int = 16,
mlp_ratio: int = 4,
global_pool: str = "map",
ignore_head: bool = True,
class_token: bool = False,
num_classes: int = 0,
use_checkpoint: bool = False,
**kwargs):
self.model_name = model_name
self.image_size = image_size
self.patch_size = patch_size
self.width = width
self.layers = layers
self.heads = heads
self.mlp_ratio = mlp_ratio
self.global_pool = global_pool
self.ignore_head = ignore_head
self.class_token = class_token
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
super().__init__(**kwargs)
class MlpProjectorConfig(PretrainedConfig):
model_type = "mlp_projector"
projector_type: str = "downsample_mlp_gelu"
input_dim: int = 1152
n_embed: int = 2048
depth: int = 2
mlp_ratio: int = 1
downsample_ratio: int = 2
token_pooling: bool = False
def __init__(self,
projector_type: str = "downsample_mlp_gelu",
input_dim: int = 1152,
n_embed: int = 2048,
depth: int = 2,
mlp_ratio: int = 1,
downsample_ratio: int = 2,
**kwargs):
self.projector_type = projector_type
self.input_dim = input_dim
self.n_embed = n_embed
self.depth = depth
self.mlp_ratio = mlp_ratio
self.downsample_ratio = downsample_ratio
super().__init__(**kwargs)
class DeepseekV2Config(PretrainedConfig):
model_type = "deepseek_v2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=None,
n_routed_experts=None,
ep_size=1,
routed_scaling_factor=1.0,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method='gready',
n_group=None,
topk_group=None,
num_experts_per_tok=None,
moe_layer_freq=1,
first_k_dense_replace=0,
norm_topk_prob=False,
scoring_func='softmax',
aux_loss_alpha=0.001,
seq_aux=True,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
use_mla=True,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = float(rms_norm_eps)
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_mla = use_mla
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class DeepseekVLV2Config(PretrainedConfig):
model_type = "deepseek_vl_v2"
vision_config: VisionEncoderConfig
projector_config: MlpProjectorConfig
tile_tag: str = "2D"
global_view_pos: str = "head"
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384), )
def __init__(self,
tile_tag: str = "tile_tag",
global_view_pos: str = "head",
candidate_resolutions: Tuple[Tuple[int,
int]] = ((384, 384), ),
**kwargs):
super().__init__(**kwargs)
vision_config = kwargs.get("vision_config", {})
self.vision_config = VisionEncoderConfig(**vision_config)
projector_config = kwargs.get("projector_config", {})
self.projector_config = MlpProjectorConfig(**projector_config)
language_config = kwargs.get("language_config", {})
self.text_config = DeepseekV2Config(**language_config)
self.tile_tag = tile_tag
self.global_view_pos = global_view_pos
self.candidate_resolutions = candidate_resolutions
self.vocab_size = self.text_config.vocab_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment