Unverified Commit 51d7c6a2 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Model] Support Mistral3 in the HF Transformers format (#15505)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent f3aca1ee
......@@ -865,6 +865,13 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
* ✅︎
- * `Mistral3ForConditionalGeneration`
* Mistral3
* T + I<sup>+</sup>
* `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc.
*
* ✅︎
*
- * `MllamaForConditionalGeneration`
* Llama 3.2
* T + I<sup>+</sup>
......
......@@ -498,6 +498,29 @@ def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData:
return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")
# Mistral-3 HF-format
def run_mistral3(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
# NOTE: Need L40 (or equivalent) to avoid OOM
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
tensor_parallel_size=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# LLama 3.2
def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
......@@ -859,6 +882,7 @@ model_example_map = {
"mantis": run_mantis,
"minicpmo": run_minicpmo,
"minicpmv": run_minicpmv,
"mistral3": run_mistral3,
"mllama": run_mllama,
"molmo": run_molmo,
"NVLM_D": run_nvlm_d,
......
......@@ -218,6 +218,28 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
# Adjust this as necessary to fit in GPU
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
tensor_parallel_size=2,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = "[IMG]" * len(image_urls)
prompt = f"<s>[INST]{question}\n{placeholders}[/INST]"
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
......@@ -509,6 +531,7 @@ model_example_map = {
"h2ovl_chat": load_h2ovl,
"idefics3": load_idefics3,
"internvl_chat": load_internvl,
"mistral3": load_mistral3,
"mllama": load_mllama,
"NVLM_D": load_nvlm_d,
"phi3_v": load_phi3v,
......
......@@ -297,6 +297,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
trust_remote_code=True),
"Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501
min_transformers_version="4.50", # noqa: E501
extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
max_transformers_version="4.48",
transformers_version_reason="Use of private method which no longer exists.", # noqa: E501
......
......@@ -487,7 +487,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
if model_type in ("minicpmo", "minicpmv"):
return "(<image>./</image>)"
if model_type in ("blip-2", "fuyu", "paligemma", "pixtral"):
if model_type in ("blip-2", "fuyu", "paligemma", "pixtral",
"mistral3"):
# These models do not use image tokens in the prompt
return None
if model_type == "qwen":
......
This diff is collapsed.
......@@ -979,7 +979,8 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
return self.vision_config.image_size
def get_patch_size(self) -> int:
return self.vision_config.patch_size
return (self.vision_config.patch_size *
self.vision_config.spatial_merge_size)
def get_patch_grid_length(self) -> int:
image_size, patch_size = self.get_image_size(), self.get_patch_size()
......@@ -1001,8 +1002,8 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
ratio = max(image_width / max_width, image_height / max_height)
if ratio > 1:
image_width = int(math.ceil(image_width / ratio))
image_height = int(math.ceil(image_height / ratio))
image_width = int(math.floor(image_width / ratio))
image_height = int(math.floor(image_height / ratio))
nrows, ncols = _get_pixtral_hf_num_image_tokens(
(image_height, image_width),
......
......@@ -177,6 +177,7 @@ _MULTIMODAL_MODELS = {
"MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501
"MiniCPMO": ("minicpmo", "MiniCPMO"),
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
......
......@@ -69,6 +69,9 @@ def get_vision_encoder_info(
if isinstance(vision_config, CLIPVisionConfig):
return CLIPEncoderInfo(vision_config)
if isinstance(vision_config, PixtralVisionConfig):
# Need to sneak in spatial_merge_size for Mistral3
vision_config.spatial_merge_size = getattr(hf_config,
"spatial_merge_size", 1)
return PixtralHFEncoderInfo(vision_config)
if isinstance(vision_config, SiglipVisionConfig):
return SiglipEncoderInfo(vision_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment