Unverified Commit cde384cd authored by qscqesze's avatar qscqesze Committed by GitHub
Browse files

[Model] support MiniMax-VL-01 model (#16328)


Signed-off-by: default avatarqingjun <qingjun@minimaxi.com>
parent 96e06e3c
...@@ -446,6 +446,19 @@ VLM_TEST_SETTINGS = { ...@@ -446,6 +446,19 @@ VLM_TEST_SETTINGS = {
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner, patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
), ),
"minimax_vl_01": VLMTestInfo(
models=["MiniMaxAI/MiniMax-VL-01"],
prompt_formatter=lambda img_prompt: f"<beginning_of_sentence>user: {img_prompt} assistant:<end_of_sentence>", # noqa: E501
img_idx_to_prompt=lambda _: "<image>",
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
max_model_len=8192,
max_num_seqs=4,
dtype="bfloat16",
hf_output_post_proc=model_utils.minimax_vl_01_hf_output,
patch_hf_runner=model_utils.minimax_vl_01_patch_hf_runner,
auto_cls=AutoModelForImageTextToText,
marks=[large_gpu_mark(min_gb=80)],
),
"molmo": VLMTestInfo( "molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"], models=["allenai/Molmo-7B-D-0924"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
......
...@@ -229,6 +229,14 @@ def minicpmv_trunc_hf_output(hf_output: RunnerOutput, ...@@ -229,6 +229,14 @@ def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
return output_ids, output_str, out_logprobs return output_ids, output_str, out_logprobs
def minimax_vl_01_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
if output_str.endswith("<end_of_sentence>"):
output_str = output_str.split("<end_of_sentence>")[0]
return output_ids, output_str, out_logprobs
####### Functions for converting image assets to embeddings ####### Functions for converting image assets to embeddings
def get_llava_embeddings(image_assets: _ImageAssets): def get_llava_embeddings(image_assets: _ImageAssets):
return [asset.image_embeds for asset in image_assets] return [asset.image_embeds for asset in image_assets]
...@@ -627,6 +635,17 @@ def minicpmv_26_patch_hf_runner(hf_model: HfRunner) -> HfRunner: ...@@ -627,6 +635,17 @@ def minicpmv_26_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model return hf_model
def minimax_vl_01_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
orig_generate = hf_model.model.generate
def _generate(self, *args, image_sizes=None, **kwargs):
return orig_generate(*args, decode_text=False, **kwargs)
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
return hf_model
def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Molmo.""" """Patches and returns an instance of the HfRunner to use for Molmo."""
hf_processor = hf_model.processor hf_processor = hf_model.processor
......
# SPDX-License-Identifier: Apache-2.0
import pytest
from PIL import Image
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.parse import ImageSize
from vllm.multimodal.processing import BaseMultiModalProcessor
from ....conftest import _ImageAssets
from ...utils import build_model_context
@pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"])
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
image_assets: _ImageAssets,
model_id: str,
num_imgs: int,
):
ctx = build_model_context(
model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
prompt = "<image>" * num_imgs
image = Image.new("RGB", size=(364, 364))
mm_data = {"image": [image] * num_imgs}
processed_inputs = processor.apply(prompt, mm_data, {})
image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs
def _validate_image_prompt_replacements_one(
processor: BaseMultiModalProcessor,
num_imgs: int,
failed_size_excs: list[tuple[ImageSize, Exception]],
image_size: ImageSize,
) -> None:
prompt = "<image>" * num_imgs
image = Image.new("RGB", size=image_size)
mm_data = {"image": [image] * num_imgs}
try:
processed_inputs = processor.apply(prompt, mm_data, {})
image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs
except Exception as exc:
failed_size_excs.append((image_size, exc))
def _test_image_prompt_replacements(
processor,
*,
num_imgs: int,
image_sizes: list[ImageSize],
) -> None:
failed_size_excs = list[tuple[ImageSize, Exception]]()
for size in image_sizes:
_validate_image_prompt_replacements_one(processor, num_imgs,
failed_size_excs, size)
if failed_size_excs:
msg = "Found failing image sizes:" \
+ "\n========\n".join(f"[{size}]\n{exc}"
for size, exc in failed_size_excs)
raise AssertionError(msg)
@pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements_regression(model_id, num_imgs):
ctx = build_model_context(
model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
(488, 183), (2560, 1669)]
image_sizes = [
size for w, h in image_ratios
for size in [ImageSize(w, h), ImageSize(h, w)]
]
_test_image_prompt_replacements(
processor,
num_imgs=num_imgs,
image_sizes=image_sizes,
)
...@@ -337,6 +337,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -337,6 +337,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501 extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501
trust_remote_code=True),
"Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501 "Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501
extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501 extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import copy import copy
import math import math
import re import re
from typing import Dict, Iterable, List, Optional, Tuple, Union from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
import torch.distributed import torch.distributed
...@@ -110,7 +110,17 @@ class MiniMaxText01RMSNormTP(CustomOp): ...@@ -110,7 +110,17 @@ class MiniMaxText01RMSNormTP(CustomOp):
variance = tensor_model_parallel_all_reduce( variance = tensor_model_parallel_all_reduce(
variance) / self.tp_world variance) / self.tp_world
x = x * torch.rsqrt(variance + self.variance_epsilon) x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
weight = self.weight
if x.size(-1) != self.weight.size(0):
if self.weight.size(0) < x.size(-1):
repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
full_weight = self.weight.repeat(repeat_count)
weight = full_weight[:x.size(-1)]
else:
weight = self.weight[:x.size(-1)]
x = x.to(orig_dtype) * weight
return x return x
def forward( def forward(
...@@ -421,6 +431,10 @@ class MiniMaxText01LinearAttention(nn.Module): ...@@ -421,6 +431,10 @@ class MiniMaxText01LinearAttention(nn.Module):
attn_metadata): attn_metadata):
hidden = [] hidden = []
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_idx >= len(attn_metadata.query_start_loc):
break
if _prefill_idx >= len(state_indices_tensor):
break
_start = attn_metadata.query_start_loc[_prefill_idx] _start = attn_metadata.query_start_loc[_prefill_idx]
_end = attn_metadata.query_start_loc[_prefill_idx + 1] _end = attn_metadata.query_start_loc[_prefill_idx + 1]
slot_id = state_indices_tensor[_prefill_idx] slot_id = state_indices_tensor[_prefill_idx]
...@@ -443,6 +457,10 @@ class MiniMaxText01LinearAttention(nn.Module): ...@@ -443,6 +457,10 @@ class MiniMaxText01LinearAttention(nn.Module):
hidden.append( hidden.append(
self._decode_infer(q, k, v, kv_cache, state_indices_tensor, self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
attn_metadata)) attn_metadata))
if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
hidden = torch.concat(hidden, dim=0).contiguous() hidden = torch.concat(hidden, dim=0).contiguous()
return hidden return hidden
...@@ -663,6 +681,9 @@ class MiniMaxText01DecoderLayer(nn.Module): ...@@ -663,6 +681,9 @@ class MiniMaxText01DecoderLayer(nn.Module):
self.shared_moe = False self.shared_moe = False
shared_intermediate = getattr(config, 'shared_intermediate_size', 0) shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
if isinstance(shared_intermediate, list):
shared_intermediate = shared_intermediate[
layer_id] if layer_id < len(shared_intermediate) else 0
if shared_intermediate > 0: if shared_intermediate > 0:
self.shared_moe = True self.shared_moe = True
self.shared_mlp = MiniMaxText01MLP( self.shared_mlp = MiniMaxText01MLP(
...@@ -875,6 +896,8 @@ class MiniMaxText01Model(nn.Module): ...@@ -875,6 +896,8 @@ class MiniMaxText01Model(nn.Module):
slots_to_clear = [] slots_to_clear = []
for _prefill_id in range(getattr(attn_metadata, "num_prefills", 0)): for _prefill_id in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_id >= len(seq_id_map):
break
seq_id = seq_id_map[_prefill_id] seq_id = seq_id_map[_prefill_id]
if attn_metadata.context_lens_tensor[ if attn_metadata.context_lens_tensor[
_prefill_id] == 0 and seq_id in seq_to_slot_maps: _prefill_id] == 0 and seq_id in seq_to_slot_maps:
...@@ -886,13 +909,18 @@ class MiniMaxText01Model(nn.Module): ...@@ -886,13 +909,18 @@ class MiniMaxText01Model(nn.Module):
dtype=torch.long) dtype=torch.long)
minimax_cache_tensors[:, slots_tensor, ...] = 0 minimax_cache_tensors[:, slots_tensor, ...] = 0
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(self, def forward(self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None,
intermediate_tensors=None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor: **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
forward_context = get_forward_context() forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
if attn_metadata is None: if attn_metadata is None:
...@@ -901,6 +929,7 @@ class MiniMaxText01Model(nn.Module): ...@@ -901,6 +929,7 @@ class MiniMaxText01Model(nn.Module):
kwargs["request_ids_to_seq_ids"] = {} kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs: if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = [] kwargs["finished_requests_ids"] = []
( (
minimax_cache_tensors, minimax_cache_tensors,
state_indices_tensor, state_indices_tensor,
...@@ -922,15 +951,11 @@ class MiniMaxText01Model(nn.Module): ...@@ -922,15 +951,11 @@ class MiniMaxText01Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
kv_cache_index = 0
minimax_cache_index = 0 minimax_cache_index = 0
attn_metadata.rotary_emb = self.rotary_emb attn_metadata.rotary_emb = self.rotary_emb
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
_caches = None _caches = None
if isinstance(layer.self_attn, MiniMaxText01Attention):
_caches = kv_caches[kv_cache_index]
kv_cache_index += 1
if isinstance(layer.self_attn, MiniMaxText01LinearAttention): if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
current_state_layer = minimax_cache_index current_state_layer = minimax_cache_index
_caches = minimax_cache_params.at_layer_idx( _caches = minimax_cache_params.at_layer_idx(
...@@ -1009,15 +1034,20 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1009,15 +1034,20 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs( return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(
batch_size) batch_size)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor: **kwargs) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, self.kv_cache, hidden_states = self.model(input_ids, positions, intermediate_tensors,
intermediate_tensors, inputs_embeds, inputs_embeds, **kwargs)
**kwargs)
return hidden_states return hidden_states
...@@ -1043,8 +1073,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1043,8 +1073,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
}) })
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> None: torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
def which_layer(name: str) -> int: def which_layer(name: str) -> int:
if "layers" in name: if "layers" in name:
...@@ -1108,6 +1139,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1108,6 +1139,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
weight_name, weight_name,
expert_id=expert_id, expert_id=expert_id,
shard_id=shard_id) shard_id=shard_id)
loaded_params.add(name)
break break
else: else:
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
...@@ -1117,6 +1149,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1117,6 +1149,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
default_weight_loader) default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return return
def is_shared_mlp_weight(name: str) -> bool: def is_shared_mlp_weight(name: str) -> bool:
...@@ -1154,6 +1187,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1154,6 +1187,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
else: else:
raise AssertionError( raise AssertionError(
"MLP weight not in [gate_up_proj, down_proj]") "MLP weight not in [gate_up_proj, down_proj]")
loaded_params.add(name)
return return
def is_mha_weight(name: str) -> bool: def is_mha_weight(name: str) -> bool:
...@@ -1170,6 +1204,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1170,6 +1204,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
MiniMaxText01LinearAttention.weight_direct_load) MiniMaxText01LinearAttention.weight_direct_load)
weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return return
def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor, def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
...@@ -1194,6 +1229,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1194,6 +1229,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
default_weight_loader) default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break break
else: else:
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
...@@ -1204,6 +1240,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1204,6 +1240,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
default_weight_loader) default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return return
def is_layer_norm_weight(name: str) -> bool: def is_layer_norm_weight(name: str) -> bool:
...@@ -1219,6 +1256,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1219,6 +1256,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
default_weight_loader) default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return return
def load_basic_weight(name: str, loaded_weight: torch.Tensor, def load_basic_weight(name: str, loaded_weight: torch.Tensor,
...@@ -1230,6 +1268,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1230,6 +1268,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
default_weight_loader) default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return return
for name, loaded_weight in weights: for name, loaded_weight in weights:
...@@ -1258,4 +1297,4 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1258,4 +1297,4 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
continue continue
load_basic_weight(name, loaded_weight, self) load_basic_weight(name, loaded_weight, self)
return return loaded_params
This diff is collapsed.
...@@ -189,6 +189,7 @@ _MULTIMODAL_MODELS = { ...@@ -189,6 +189,7 @@ _MULTIMODAL_MODELS = {
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501
"MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501 "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501
"MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"), # noqa: E501
"MiniCPMO": ("minicpmo", "MiniCPMO"), "MiniCPMO": ("minicpmo", "MiniCPMO"),
"MiniCPMV": ("minicpmv", "MiniCPMV"), "MiniCPMV": ("minicpmv", "MiniCPMV"),
"Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501 "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501
......
...@@ -34,11 +34,13 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, ...@@ -34,11 +34,13 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
H2OVLChatConfig, H2OVLChatConfig,
InternVLChatConfig, JAISConfig, InternVLChatConfig, JAISConfig,
KimiVLConfig, MedusaConfig, KimiVLConfig, MedusaConfig,
MllamaConfig, MLPSpeculatorConfig, MiniMaxText01Config,
MPTConfig, NemotronConfig, MiniMaxVL01Config, MllamaConfig,
NVLM_D_Config, RWConfig, MLPSpeculatorConfig, MPTConfig,
SkyworkR1VChatConfig, SolarConfig, NemotronConfig, NVLM_D_Config,
Telechat2Config, UltravoxConfig) RWConfig, SkyworkR1VChatConfig,
SolarConfig, Telechat2Config,
UltravoxConfig)
# yapf: enable # yapf: enable
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname from vllm.utils import resolve_obj_by_qualname
...@@ -73,6 +75,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ...@@ -73,6 +75,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"exaone": ExaoneConfig, "exaone": ExaoneConfig,
"h2ovl_chat": H2OVLChatConfig, "h2ovl_chat": H2OVLChatConfig,
"internvl_chat": InternVLChatConfig, "internvl_chat": InternVLChatConfig,
"minimax_text_01": MiniMaxText01Config,
"minimax_vl_01": MiniMaxVL01Config,
"nemotron": NemotronConfig, "nemotron": NemotronConfig,
"NVLM_D": NVLM_D_Config, "NVLM_D": NVLM_D_Config,
"solar": SolarConfig, "solar": SolarConfig,
......
...@@ -15,6 +15,8 @@ from vllm.transformers_utils.configs.internvl import InternVLChatConfig ...@@ -15,6 +15,8 @@ from vllm.transformers_utils.configs.internvl import InternVLChatConfig
from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig from vllm.transformers_utils.configs.medusa import MedusaConfig
from vllm.transformers_utils.configs.minimax_text_01 import MiniMaxText01Config
from vllm.transformers_utils.configs.minimax_vl_01 import MiniMaxVL01Config
from vllm.transformers_utils.configs.mllama import MllamaConfig from vllm.transformers_utils.configs.mllama import MllamaConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.moonvit import MoonViTConfig from vllm.transformers_utils.configs.moonvit import MoonViTConfig
...@@ -39,6 +41,8 @@ __all__ = [ ...@@ -39,6 +41,8 @@ __all__ = [
"MedusaConfig", "MedusaConfig",
"EAGLEConfig", "EAGLEConfig",
"ExaoneConfig", "ExaoneConfig",
"MiniMaxText01Config",
"MiniMaxVL01Config",
"MllamaConfig", "MllamaConfig",
"MLPSpeculatorConfig", "MLPSpeculatorConfig",
"MoonViTConfig", "MoonViTConfig",
......
# SPDX-License-Identifier: Apache-2.0
""" MiniMaxText01 model configuration"""
from transformers.configuration_utils import PretrainedConfig
class MiniMaxText01Config(PretrainedConfig):
model_type = "MiniMaxText01"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=None,
eos_token_id=None,
tie_word_embeddings=False,
rope_theta=1e6,
sliding_window=None,
attention_dropout=0.0,
num_experts_per_tok=2,
num_local_experts=8,
output_router_logits=False,
router_aux_loss_coef=0.001,
router_jitter_noise=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.router_jitter_noise = router_jitter_noise
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# SPDX-License-Identifier: Apache-2.0
"""MiniMaxVL01 model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import CONFIG_MAPPING
from .minimax_text_01 import MiniMaxText01Config
class MiniMaxVL01Config(PretrainedConfig):
model_type = "minimax_vl_01"
def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=32000,
projector_hidden_act="gelu",
vision_feature_select_strategy="default",
vision_feature_layer=-2,
image_grid_pinpoints=None,
tie_word_embeddings=False,
image_seq_length=576,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.image_seq_length = image_seq_length
if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError("vision_feature_select_strategy should " +
"be one of 'default', 'full'." +
f"Got: {vision_feature_select_strategy}")
self.vision_feature_select_strategy = vision_feature_select_strategy
self.vision_feature_layer = vision_feature_layer
image_grid_pinpoints = (
image_grid_pinpoints if image_grid_pinpoints is not None else
[[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]])
self.image_grid_pinpoints = image_grid_pinpoints
if isinstance(vision_config, dict):
if "model_type" not in vision_config:
vision_config["model_type"] = "clip_vision_model"
vision_config = CONFIG_MAPPING[vision_config["model_type"]](
**vision_config)
elif vision_config is None:
vision_config = CONFIG_MAPPING["clip_vision_model"](
intermediate_size=4096,
hidden_size=1024,
patch_size=14,
image_size=336,
num_hidden_layers=24,
num_attention_heads=16,
vocab_size=32000,
projection_dim=768,
)
self.vision_config = vision_config
if text_config is not None:
text_config = MiniMaxText01Config(**text_config)
else:
text_config = MiniMaxText01Config()
self.text_config = text_config
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment