Unverified Commit 5380cd7e authored by Kiv Chen's avatar Kiv Chen Committed by GitHub
Browse files

model(vlm): pixtral (#5084)

parent b2e95f62
...@@ -20,6 +20,7 @@ python3 -m sglang.launch_server \ ...@@ -20,6 +20,7 @@ python3 -m sglang.launch_server \
| **Janus-Pro** (1B, 7B) | `deepseek-ai/Janus-Pro-7B` | `janus-pro` | DeepSeek’s open-source multimodal model capable of both image understanding and generation. Janus-Pro employs a decoupled architecture for separate visual encoding paths, enhancing performance in both tasks. | | **Janus-Pro** (1B, 7B) | `deepseek-ai/Janus-Pro-7B` | `janus-pro` | DeepSeek’s open-source multimodal model capable of both image understanding and generation. Janus-Pro employs a decoupled architecture for separate visual encoding paths, enhancing performance in both tasks. |
| **MiniCPM-V / MiniCPM-o** | `openbmb/MiniCPM-V-2_6` | `minicpmv` | MiniCPM-V (2.6, ~8B) supports image inputs, and MiniCPM-o adds audio/video; these multimodal LLMs are optimized for end-side deployment on mobile/edge devices. | | **MiniCPM-V / MiniCPM-o** | `openbmb/MiniCPM-V-2_6` | `minicpmv` | MiniCPM-V (2.6, ~8B) supports image inputs, and MiniCPM-o adds audio/video; these multimodal LLMs are optimized for end-side deployment on mobile/edge devices. |
| **Llama 3.2 Vision** (11B) | `meta-llama/Llama-3.2-11B-Vision-Instruct` | `llama_3_vision` | Vision-enabled variant of Llama 3 (11B) that accepts image inputs for visual question answering and other multimodal tasks. | | **Llama 3.2 Vision** (11B) | `meta-llama/Llama-3.2-11B-Vision-Instruct` | `llama_3_vision` | Vision-enabled variant of Llama 3 (11B) that accepts image inputs for visual question answering and other multimodal tasks. |
| **Pixtral** (12B, 124B) | `mistral-community/pixtral-12b` | `mistral` | Pixtral is a vision-language model from Mistral AI that can process both text and images. |
| **LLaVA** (v1.5 & v1.6) | *e.g.* `liuhaotian/llava-v1.5-13b` | `vicuna_v1.1` | Open vision-chat models that add an image encoder to LLaMA/Vicuna (e.g. LLaMA2 13B) for following multimodal instruction prompts. | | **LLaVA** (v1.5 & v1.6) | *e.g.* `liuhaotian/llava-v1.5-13b` | `vicuna_v1.1` | Open vision-chat models that add an image encoder to LLaMA/Vicuna (e.g. LLaMA2 13B) for following multimodal instruction prompts. |
| **LLaVA-NeXT** (8B, 72B) | `lmms-lab/llava-next-72b` | `chatml-llava` | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks. | | **LLaVA-NeXT** (8B, 72B) | `lmms-lab/llava-next-72b` | `chatml-llava` | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks. |
| **LLaVA-OneVision** | `lmms-lab/llava-onevision-qwen2-7b-ov` | `chatml-llava` | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format. | | **LLaVA-OneVision** | `lmms-lab/llava-onevision-qwen2-7b-ov` | `chatml-llava` | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format. |
......
...@@ -33,9 +33,10 @@ The `hidden_states` folder contains examples on how to extract hidden states usi ...@@ -33,9 +33,10 @@ The `hidden_states` folder contains examples on how to extract hidden states usi
* `hidden_states_engine.py`: An example how to extract hidden states using the Engine API. * `hidden_states_engine.py`: An example how to extract hidden states using the Engine API.
* `hidden_states_server.py`: An example how to extract hidden states using the Server API. * `hidden_states_server.py`: An example how to extract hidden states using the Server API.
## LLaVA-NeXT ## Multimodal
SGLang supports multimodal inputs for various model architectures. The `multimodal` folder contains examples showing how to use urls, files or encoded data to make requests to multimodal models. Examples include querying the [Llava-OneVision](multimodal/llava_onevision_server.py) model (image, multi-image, video), Llava-backed [Qwen-Llava](multimodal/qwen_llava_server.py) and [Llama3-Llava](multimodal/llama3_llava_server.py) models (image, multi-image), and Mistral AI's [Pixtral](multimodal/pixtral_server.py) (image, multi-image).
SGLang support LLaVA-OneVision with single-image, multi-image and video are supported. The folder `llava_onevision` shows how to do this.
## Token In, Token Out ## Token In, Token Out
......
...@@ -6,7 +6,7 @@ Usage: ...@@ -6,7 +6,7 @@ Usage:
# Endpoint Service CLI: # Endpoint Service CLI:
python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000
python3 http_llama3_llava_test.py python3 llama3_llava_server.py
Output: Output:
"Friends posing for a fun photo with a life-sized teddy bear, creating a playful and memorable moment." "Friends posing for a fun photo with a life-sized teddy bear, creating a playful and memorable moment."
......
...@@ -3,7 +3,7 @@ Usage: ...@@ -3,7 +3,7 @@ Usage:
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8
python3 http_llava_onevision_test.py python3 llava_onevision_server.py
""" """
import base64 import base64
......
"""
Usage:
# Run a Pixtral model with SGLang:
# HuggingFace:
python -m sglang.launch_server --model-path mistral-community/pixtral-12b --port=30000
# ModelScope:
python -m sglang.launch_server --model-path AI-ModelScope/pixtral-12b --port=30000
# Then test it with:
python pixtral_server.py
This script tests Pixtral model with both single and multiple images.
"""
import argparse
import asyncio
import json
import aiohttp
import requests
IMAGE_TOKEN_SEP = "\n[IMG]"
ROUTE = "/generate"
async def send_request(url, data, delay=0):
await asyncio.sleep(delay)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as resp:
output = await resp.json()
return output
async def test_concurrent(args):
url = f"{args.host}:{args.port}{ROUTE}"
# Single image test
if args.single_image:
prompt = f"<s>[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]"
image_url = "https://picsum.photos/id/237/400/300"
modality = ["image"]
# Multiple images test
else:
image_urls = [
"https://picsum.photos/id/237/400/300",
"https://picsum.photos/id/27/500/500",
]
prompt = f"<s>[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]"
image_url = image_urls
modality = ["multi-images"]
response = await send_request(
url,
{
"text": prompt,
"image_data": image_url,
"sampling_params": {
"max_new_tokens": 100,
"temperature": 0.7,
"top_p": 0.9,
},
"modalities": modality,
},
)
print(f"Response: {response}")
if "text" in response:
print("\nOutput text:", response["text"])
def test_streaming(args):
url = f"{args.host}:{args.port}/generate"
# Single image test
if args.single_image:
prompt = f"<s>[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]"
image_data = "https://picsum.photos/id/237/400/300"
modality = ["image"]
# Multiple images test
else:
image_urls = [
"https://picsum.photos/id/237/400/300",
"https://picsum.photos/id/27/500/500",
]
prompt = f"<s>[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]"
image_data = image_urls
modality = ["multi-images"]
pload = {
"text": prompt,
"image_data": image_data,
"sampling_params": {"max_new_tokens": 100, "temperature": 0.7, "top_p": 0.9},
"modalities": modality,
"stream": True,
}
response = requests.post(url, json=pload, stream=True)
print("Streaming response:")
prev = 0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
parser.add_argument(
"--single-image",
action="store_true",
help="Test with single image instead of multiple images",
)
parser.add_argument("--no-stream", action="store_true", help="Don't test streaming")
args = parser.parse_args()
asyncio.run(test_concurrent(args))
if not args.no_stream:
test_streaming(args)
...@@ -6,7 +6,7 @@ Usage: ...@@ -6,7 +6,7 @@ Usage:
# Endpoint Service CLI: # Endpoint Service CLI:
python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8 python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8
python3 http_qwen_llava_test.py python3 qwen_llava_server.py
Output: Output:
"Two children pose with a large teddy bear, one holding a smaller stuffed bear, in a room with an American flag and potted plants." "Two children pose with a large teddy bear, one holding a smaller stuffed bear, in a room with an American flag and potted plants."
......
...@@ -194,6 +194,21 @@ register_chat_template( ...@@ -194,6 +194,21 @@ register_chat_template(
) )
) )
# Reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
register_chat_template(
ChatTemplate(
name="mistral",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("[SYSTEM_PROMPT] ", " [/SYSTEM_PROMPT]"),
"user": ("[INST] ", " [/INST]"),
"assistant": ("", " </s><s>"),
},
stop_str=("</s>",),
image_token="[IMG]",
)
)
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
name="llama-3-instruct", name="llama-3-instruct",
...@@ -509,13 +524,19 @@ def match_vicuna(model_path: str): ...@@ -509,13 +524,19 @@ def match_vicuna(model_path: str):
@register_chat_template_matching_function @register_chat_template_matching_function
def match_llama2_chat(model_path: str): def match_llama2_chat(model_path: str):
if re.search( if re.search(
r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct", r"llama-2.*chat|codellama.*instruct",
model_path, model_path,
re.IGNORECASE, re.IGNORECASE,
): ):
return "llama-2-chat" return "llama-2-chat"
@register_chat_template_matching_function
def match_mistral(model_path: str):
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
return "mistral"
@register_chat_template_matching_function @register_chat_template_matching_function
def match_llama3_instruct(model_path: str): def match_llama3_instruct(model_path: str):
if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE): if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE):
......
...@@ -545,6 +545,7 @@ multimodal_model_archs = [ ...@@ -545,6 +545,7 @@ multimodal_model_archs = [
"Llama4ForConditionalGeneration", "Llama4ForConditionalGeneration",
"LlavaMistralForCausalLM", "LlavaMistralForCausalLM",
"LlavaQwenForCausalLM", "LlavaQwenForCausalLM",
"LlavaForConditionalGeneration",
"LlavaVidForCausalLM", "LlavaVidForCausalLM",
"MiniCPMO", "MiniCPMO",
"MiniCPMV", "MiniCPMV",
......
...@@ -634,6 +634,20 @@ register_conv_template( ...@@ -634,6 +634,20 @@ register_conv_template(
) )
) )
# reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
register_conv_template(
Conversation(
name="mistral",
system_template="[SYSTEM_PROMPT]\n{system_message}\n[/SYSTEM_PROMPT]\n\n",
roles=("[INST]", "[/INST]"),
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2=" </s><s>",
stop_str=["[INST]", "[/INST]", "[SYSTEM_PROMPT]", "[/SYSTEM_PROMPT]"],
image_token="[IMG]",
)
)
# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json # reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
register_conv_template( register_conv_template(
Conversation( Conversation(
...@@ -880,13 +894,19 @@ def match_vicuna(model_path: str): ...@@ -880,13 +894,19 @@ def match_vicuna(model_path: str):
@register_conv_template_matching_function @register_conv_template_matching_function
def match_llama2_chat(model_path: str): def match_llama2_chat(model_path: str):
if re.search( if re.search(
r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct", r"llama-2.*chat|codellama.*instruct",
model_path, model_path,
re.IGNORECASE, re.IGNORECASE,
): ):
return "llama-2" return "llama-2"
@register_conv_template_matching_function
def match_mistral(model_path: str):
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
return "mistral"
@register_conv_template_matching_function @register_conv_template_matching_function
def match_deepseek_vl(model_path: str): def match_deepseek_vl(model_path: str):
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE): if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
......
import asyncio import asyncio
import importlib
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
from transformers.models.auto.processing_auto import (
PROCESSOR_MAPPING_NAMES as HF_MAPPING_NAMES,
)
import sglang.srt.managers.multimodal_processor as sgl_mm_processor_utils
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.models.llava import ( from sglang.srt.models.llava import (
LlavaForConditionalGeneration,
LlavaLlamaForCausalLM, LlavaLlamaForCausalLM,
LlavaMistralForCausalLM, LlavaMistralForCausalLM,
LlavaQwenForCausalLM, LlavaQwenForCausalLM,
...@@ -133,6 +139,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor): ...@@ -133,6 +139,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
img_data, aspect_ratio, grid_pinpoints img_data, aspect_ratio, grid_pinpoints
) )
) )
res = await asyncio.gather(*res) res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res: for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v) pixel_values.append(pixel_v)
...@@ -165,3 +172,42 @@ class LlavaImageProcessor(BaseMultimodalProcessor): ...@@ -165,3 +172,42 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
) )
], ],
} }
class LlavaMultimodalProcessor(BaseMultimodalProcessor):
"""
This is a wrapper class used to identify the multimodal processor for Llava architecture models.
"""
models = [LlavaForConditionalGeneration]
def _get_sgl_processor_cls(self, model_type: str):
if hf_name := HF_MAPPING_NAMES.get(model_type):
sgl_mm_processor_set = sgl_mm_processor_utils.PROCESSOR_MAPPING.values()
sgl_processor_cls = list(
filter(lambda p: p.__name__ == hf_name, sgl_mm_processor_set)
)
if sgl_processor_cls:
return sgl_processor_cls[0]
raise ValueError(
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
)
def __init__(self, hf_config, server_args, _processor):
assert hasattr(hf_config, "vision_config")
assert hasattr(hf_config, "text_config")
self.vision_config = hf_config.vision_config
self.text_config = hf_config.text_config
self.hf_config = hf_config
if vision_type := getattr(self.vision_config, "model_type"):
self.inner = self._get_sgl_processor_cls(vision_type)(
hf_config, server_args, _processor
)
else:
raise ValueError(
f"Required `vision_config.model_type` is not found in hf_config: `{hf_config}`"
)
async def process_mm_data_async(self, *args, **kwargs):
return await self.inner.process_mm_data_async(*args, **kwargs)
import asyncio
import math
from typing import List, Optional, Union
import numpy as np
from transformers import PretrainedConfig
from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
)
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.models.pixtral import PixtralVisionModel
class PixtralProcessor(BaseMultimodalProcessor):
models = [PixtralVisionModel]
PAD_TOKEN = "<pad>"
IMG_BREAK_TOKEN_ID = 12
IMG_END_TOKEN_ID = 13
def get_patch_grid_size(
self,
*,
image_width: int,
image_height: int,
) -> tuple[int, int]:
max_width = max_height = self.image_size
patch_width = patch_height = self.patch_size
ratio = max(image_width / max_width, image_height / max_height)
if ratio > 1:
image_width = int(math.floor(image_width / ratio))
image_height = int(math.floor(image_height / ratio))
nrows, ncols = _get_pixtral_hf_num_image_tokens(
(image_height, image_width),
(patch_height, patch_width),
)
return ncols, nrows
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.image_token_id = getattr(
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
)
# Instantiate the patcher logic helper using the class defined above
self.vision_config = hf_config.vision_config
self.image_size = self.vision_config.image_size
self.patch_size = self.vision_config.patch_size
self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token
)
_processor.tokenizer.add_special_tokens(
{
"pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN),
}
)
async def _resize(self, image):
num_w_tokens, num_h_tokens = self.get_patch_grid_size(
image_width=image.size[0],
image_height=image.size[1],
)
new_size = (num_w_tokens * self.patch_size, num_h_tokens * self.patch_size)
return image.resize(new_size)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
mm_data = self.load_mm_data(
prompt=input_text,
multimodal_tokens=self.multimodal_tokens,
max_req_input_len=kwargs.get("max_req_input_len", 4096),
image_data=image_data,
return_text=True,
)
if mm_data.images:
resize_tasks = [self._resize(image) for image in mm_data.images]
mm_data.images = await asyncio.gather(*resize_tasks)
processor_output = self.process_mm_data(
input_text=mm_data.input_text,
images=mm_data.images,
)
if "pixel_values" in processor_output:
mm_items = [
MultimodalDataItem(
pixel_values=processor_output["pixel_values"],
image_sizes=processor_output["image_sizes"],
modality=Modality.IMAGE,
)
]
input_ids = processor_output["input_ids"].view(-1).tolist()
processor_output.update(
input_ids=input_ids,
mm_items=mm_items,
# there's no im_start_id for pixtral, only im_token and im_end_token
im_end_id=self.IMG_END_TOKEN_ID,
im_token_id=self.image_token_id,
)
return processor_output
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
import math import math
import re import re
from typing import Iterable, List, Optional, Tuple from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
import numpy as np import numpy as np
import torch import torch
...@@ -28,10 +29,18 @@ from transformers import ( ...@@ -28,10 +29,18 @@ from transformers import (
Qwen2Config, Qwen2Config,
SiglipVisionModel, SiglipVisionModel,
) )
from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
# leave till last and symbol only in case circular import
import sglang.srt.models as sgl_models
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs from sglang.srt.managers.mm_utils import general_mm_embed_routine
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.mm_utils import ( from sglang.srt.mm_utils import (
get_anyres_image_grid_shape, get_anyres_image_grid_shape,
unpad_image, unpad_image,
...@@ -42,7 +51,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader ...@@ -42,7 +51,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.utils import add_prefix, flatten_nested_list from sglang.srt.utils import add_prefix, flatten_nested_list, logger
class LlavaBaseForCausalLM(nn.Module): class LlavaBaseForCausalLM(nn.Module):
...@@ -114,7 +123,16 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -114,7 +123,16 @@ class LlavaBaseForCausalLM(nn.Module):
image_inputs.image_offsets = offset_list image_inputs.image_offsets = offset_list
return input_ids return input_ids
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: def encode_images(
self, pixel_values: Union[torch.Tensor, List[torch.Tensor]]
) -> torch.Tensor:
"""
encode images by vision tower and multimodal projector
Args:
pixel_values: torch.Tensor or List[torch.Tensor]: each tensor for an input image
Returns:
torch.Tensor: encoded image features from the input image; if multiple, flattened by seq_len axis
"""
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated. # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
...@@ -583,4 +601,229 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM): ...@@ -583,4 +601,229 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
) )
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM] class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
"""
An adaptor class to enable support for multiple mmlm such as mistral-community/pixtral-12b
It follows the structure of (vision_tower, multi_modal_projector, language_model)
Once a model config is loaded, text_config and vision_config will be extracted, and
LlavaForConditionalGeneration will load the language_model and vision_tower models
according to config.
"""
MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
if hasattr(self.vision_tower, "pad_input_ids"):
return self.vision_tower.pad_input_ids(input_ids, image_inputs)
else:
return super().pad_input_ids(input_ids, image_inputs)
def _get_sgl_model_cls(self, config, auto_model_type: Type[AutoModel] = AutoModel):
"""
Get the SGLang model implementation class according to config.
Args:
config: The config object of the model.
auto_model_type: The type of the auto model.
Returns:
The SGLang model implementation class.
"""
config_cls_name = config.__class__.__name__
arch_name_mapping = self._config_cls_name_to_arch_name_mapping(auto_model_type)
if arch := arch_name_mapping.get(config_cls_name):
if isinstance(arch, tuple):
arch = arch[0]
logger.warning(
f"Multiple {auto_model_type.__name__} models found for submodule config `{config_cls_name}`, defaulting to [0]: {arch.__name__}"
)
try:
return sgl_models.registry.ModelRegistry.resolve_model_cls(arch)[0]
except Exception as e:
raise ValueError(
f"{auto_model_type.__name__} found a corresponding model `{arch}` for config class `{config_cls_name}`, but failed to load it from SGLang ModelRegistry. \n{e}"
)
else:
raise ValueError(
f"{auto_model_type.__name__} cannot find a corresponding model for config class `{config_cls_name}`"
)
@lru_cache
def _config_cls_name_to_arch_name_mapping(
self, auto_model_type: Type[AutoModel]
) -> Dict[str, str]:
mapping = {}
for config_cls, archs in auto_model_type._model_mapping.items():
if isinstance(archs, tuple):
mapping[config_cls.__name__] = tuple(arch.__name__ for arch in archs)
else:
mapping[config_cls.__name__] = archs.__name__
return mapping
def __init__(
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
assert hasattr(config, "text_config")
assert hasattr(config, "vision_config")
self.config = config
self.text_config = config.text_config
self.vision_config = config.vision_config
if not hasattr(self.config, "vocab_size"):
self.config.vocab_size = self.config.text_config.vocab_size
if not hasattr(self.config, "image_aspect_ratio"):
self.config.image_aspect_ratio = "anyres"
if not hasattr(self.config, "image_grid_pinpoints"):
# from transformers.models.llava_onevision.configuration_llava_onevision import LlavaOnevisionConfig
# self.config.image_grid_pinpoints = LlavaOnevisionConfig().image_grid_pinpoints
self.config.image_grid_pinpoints = [
[96, 96],
[224, 224],
[384, 384],
[512, 512],
[768, 768],
[1024, 1024],
]
if not hasattr(self.config, "mm_patch_merge_type"):
self.config.mm_patch_merge_type = "flat"
if not hasattr(self.config, "image_token_index"):
self.config.image_token_index = 10
if not hasattr(self.config, "projector_hidden_act"):
self.config.projector_hidden_act = "gelu"
self.vision_feature_layer = getattr(config, "vision_feature_layer", -1)
self.vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", "full"
)
self.image_size = self.config.vision_config.image_size
self.patch_size = self.config.vision_config.patch_size
self.mm_patch_merge_type = config.mm_patch_merge_type
self.image_aspect_ratio = config.image_aspect_ratio
self.image_grid_pinpoints = config.image_grid_pinpoints
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)
language_model_cls = self._get_sgl_model_cls(
config.text_config, AutoModelForCausalLM
)
vision_model_cls = self._get_sgl_model_cls(config.vision_config, AutoModel)
self.language_model = language_model_cls(
config.text_config,
quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
)
self.vision_tower = vision_model_cls(
config.vision_config,
quant_config=quant_config,
prefix=add_prefix("vision_tower", prefix),
)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
"""Extract features from image inputs.
Args:
items: List of MultimodalDataItem objects containing image data
Note that an item can be either "image" or "multi-images"
Returns:
torch.Tensor: features from image inputs, concatenated
"""
features = []
for item in items:
# in each item, we assume pixel_values is always batched
pixel_values, image_sizes = item.pixel_values, item.image_sizes
image_outputs = self.vision_tower(
pixel_values, image_sizes, output_hidden_states=True
)
selected_image_feature = image_outputs.hidden_states[
self.vision_feature_layer
]
if self.vision_feature_select_strategy in ["default", "patch"]:
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(
f"Unexpected select feature: {self.vision_feature_select_strategy}"
)
features.append(
self.multi_modal_projector(selected_image_feature.squeeze(0))
)
ret = torch.cat(features, dim=0)
return ret
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
get_embedding=get_embedding,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
placeholder_tokens=None, # using mm_item.pad_value
positions=positions,
)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""Load weights for LlavaForConditionalGeneration.
Unlike the base class implementation, this one doesn't need to handle
weight name remapping as the weights are already properly structured with
'language_model' and 'vision_tower' prefixes in the safetensors files.
"""
if (
self.vision_feature_select_strategy == "patch"
or self.vision_feature_select_strategy == "full"
):
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1
else:
raise ValueError(
f"Unexpected select feature: {self.vision_feature_select_strategy}"
)
# Create dictionaries for direct parameter loading
params_dict = dict(self.named_parameters())
# Load weights directly without remapping
for name, loaded_weight in weights:
for part in ("language_model", "vision_tower"):
if name.startswith(part):
name = name[len(part + ".") :]
getattr(self, part).load_weights([(name, loaded_weight)])
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = [
LlavaLlamaForCausalLM,
LlavaQwenForCausalLM,
LlavaMistralForCausalLM,
LlavaForConditionalGeneration,
]
# Copyright 2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Using mistral-community/pixtral-12b as reference.
"""
import logging
import math
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PixtralVisionConfig, PretrainedConfig
from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding
from transformers.models.pixtral.modeling_pixtral import (
generate_block_attention_mask as _get_pixtral_attention_mask,
)
from transformers.models.pixtral.modeling_pixtral import position_ids_in_meshgrid
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import MergedColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.model_loader.weight_utils import default_weight_loader
class PixtralHFMLP(nn.Module):
"""MLP for PixtralHFVisionModel using SGLang components."""
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__()
assert config.intermediate_size is not None
# Use MergedColumnParallelLinear for gate_up_proj to handle combined weights
self.gate_up_proj = MergedColumnParallelLinear(
input_size=config.hidden_size,
output_sizes=[config.intermediate_size, config.intermediate_size],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up_output, _ = self.gate_up_proj(x)
# Apply SiLU activation and multiply
gate_up = self.act_fn(gate_up_output)
# Project back to hidden size
out, _ = self.down_proj(gate_up)
return out
class PixtralHFTransformerBlock(nn.Module):
"""Transformer block for PixtralHFVisionModel using SGLang components."""
def __init__(
self,
config: PretrainedConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__()
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
# Use SGLang's VisionAttention instead of vLLM's PixtralHFAttention
self.attention = VisionAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
projection_size=config.hidden_size,
use_qkv_parallel=True,
quant_config=quant_config,
dropout=0.0,
use_context_forward=False,
softmax_in_single_precision=False,
flatten_batch=False,
prefix=f"{prefix}.attention",
)
self.feed_forward = PixtralHFMLP(
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
)
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
# Ensure hidden_states has the batch dimension [batch, seq_len, hidden_dim]
batch_size, seq_len, hidden_dim = hidden_states.shape
# Apply attention norm - normalize along the last dimension
attn_normalized = self.attention_norm(hidden_states.view(-1, hidden_dim)).view(
batch_size, seq_len, hidden_dim
)
# Pass through attention layer
attention_output = self.attention(
attn_normalized,
attention_mask=attention_mask,
cu_seqlens=None,
position_embeddings=position_embeddings,
)
# Apply first residual connection
hidden_states = hidden_states + attention_output
# Apply feed-forward norm - normalize along the last dimension
ffn_normalized = self.ffn_norm(hidden_states.view(-1, hidden_dim)).view(
batch_size, seq_len, hidden_dim
)
# Pass through feed-forward layer
# First reshape to 2D for the feed-forward network, then reshape back
ffn_output = self.feed_forward(ffn_normalized)
# Apply second residual connection
output = hidden_states + ffn_output
return output
class PixtralHFTransformer(nn.Module):
"""Transformer for PixtralHFVisionModel using SGLang components."""
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
num_hidden_layers = config.num_hidden_layers
if num_hidden_layers_override is not None:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList(
[
PixtralHFTransformerBlock(
config=config,
layer_id=layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
)
for layer_idx in range(num_hidden_layers)
]
)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor],
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
return_all_hidden_states: bool = False,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Forward pass through transformer layers.
Args:
x: Input tensor
attention_mask: Optional attention mask
position_embeddings: Optional position embeddings for rotary attention
return_all_hidden_states: Whether to return all hidden states
Returns:
Either the final hidden state, or a list of all hidden states if
return_all_hidden_states is True
"""
# For HF model compatibility, always start with the input
hidden_states = x
all_hidden_states = [hidden_states] if return_all_hidden_states else None
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, attention_mask, position_embeddings)
if return_all_hidden_states:
all_hidden_states.append(hidden_states)
if return_all_hidden_states:
return all_hidden_states
return hidden_states
def resolve_visual_encoder_outputs(
outputs: Union[torch.Tensor, List[torch.Tensor]],
feature_sample_layers: Optional[List[int]],
post_norm: Optional[nn.Module],
num_hidden_layers: int,
) -> torch.Tensor:
"""Resolve outputs from visual encoder based on feature_sample_layers."""
if feature_sample_layers is None:
# Just use the last layer's output
if isinstance(outputs, list):
outputs = outputs[-1]
if post_norm is not None:
outputs = post_norm(outputs)
return outputs
# Handle the case where we want to use specific layers
if not isinstance(outputs, list):
raise ValueError(
"Expected outputs to be a list when feature_sample_layers is provided"
)
# Validate layer indices
for layer_idx in feature_sample_layers:
if layer_idx < 0 or layer_idx > num_hidden_layers:
raise ValueError(
f"Feature sample layer index {layer_idx} is out of range "
f"[0, {num_hidden_layers}]"
)
# Collect outputs from specified layers
selected_outputs = [outputs[layer_idx] for layer_idx in feature_sample_layers]
# Combine the outputs
combined_outputs = torch.cat(selected_outputs, dim=-1)
if post_norm is not None:
combined_outputs = post_norm(combined_outputs)
return combined_outputs
class PixtralHFVisionModel(nn.Module):
"""Hugging Face Pixtral Vision Model implemented using SGLang components."""
DEFAULT_IMAGE_TOKEN_ID = 10
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
return self.input_padder.pad_input_tokens(input_ids, image_inputs)
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_conv = nn.Conv2d(
in_channels=config.num_channels,
out_channels=config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size,
bias=False,
)
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
self.transformer = PixtralHFTransformer(
config,
quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.transformer",
)
# Check that num_hidden_layers is valid
num_hidden_layers = config.num_hidden_layers
if len(self.transformer.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.transformer.layers)} "
"layers."
)
# Initialize patch position embedding
self.image_token_id = image_token_id
self.patch_positional_embedding = PixtralRotaryEmbedding(config)
self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens(
[self.image_token_id]
)
@property
def dtype(self):
return next(self.parameters()).dtype
@property
def device(self):
return next(self.parameters()).device
def forward(
self,
pixel_values: torch.Tensor,
image_sizes: list[tuple[int, int]],
output_hidden_states: bool = False,
feature_sample_layers: Optional[list[int]] = None,
) -> Union[torch.Tensor, tuple]:
"""
Args:
pixel_values: [batch_size, C, H, W], padded if multiple images
image_sizes: list of (H, W) for each image in the batch
output_hidden_states: Whether to return all hidden states.
feature_sample_layers: Layer indices whose features should be
concatenated and used as the visual encoder output. If none
are provided, the last layer is used.
Returns:
A tuple containing:
- hidden_states: Final model outputs (or selected layers if feature_sample_layers given)
- hidden_states tuple (optional): All hidden states if output_hidden_states=True
"""
# batch patch images
embeds_orig = self.patch_conv(
pixel_values.to(device=self.device, dtype=self.dtype)
)
# crop the embeddings
embeds_2d = [
embed[..., : h // self.patch_size, : w // self.patch_size]
for embed, (h, w) in zip(embeds_orig, image_sizes)
]
# flatten to sequence
embeds_1d = torch.cat([p.flatten(1).T for p in embeds_2d], dim=0)
embeds_featurized = self.ln_pre(embeds_1d).unsqueeze(0)
# positional embeddings
position_ids = position_ids_in_meshgrid(
embeds_2d,
max_width=self.image_size // self.patch_size,
).to(self.device)
# The original PixtralRotaryEmbedding expects 2D input but returns a tuple of tensors (cos, sin)
# These tensors are used by apply_rotary_pos_emb in the transformer blocks
position_embedding = self.patch_positional_embedding(
embeds_featurized, position_ids
)
attention_mask = _get_pixtral_attention_mask(
[p.shape[-2] * p.shape[-1] for p in embeds_2d], embeds_featurized
)
return_all_hidden_states = (
output_hidden_states or feature_sample_layers is not None
)
transformer_outputs = self.transformer(
embeds_featurized, # add batch dimension
attention_mask,
position_embedding,
return_all_hidden_states=return_all_hidden_states,
)
# Store all hidden states if requested
all_hidden_states = None
if isinstance(transformer_outputs, list):
all_hidden_states = transformer_outputs
# Use the last layer by default if feature_sample_layers is not specified
if feature_sample_layers is None:
out = transformer_outputs[-1]
else:
# Resolve outputs based on feature sample layers
out = resolve_visual_encoder_outputs(
transformer_outputs,
feature_sample_layers,
None,
self.config.num_hidden_layers,
)
else:
out = transformer_outputs
# Format return to be compatible with HuggingFace vision models
if output_hidden_states:
return type(
"VisualOutput",
(),
{
"last_hidden_state": out,
"hidden_states": all_hidden_states,
},
)
else:
return out
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
"""Load weights from a HuggingFace checkpoint with proper parameter mapping."""
params_dict = dict(self.named_parameters())
# for (param, weight, shard_id): load weight into param as param's shard_id part
stacked_params_mapping = [
(".attention.qkv_proj", ".attention.q_proj", "q"),
(".attention.qkv_proj", ".attention.k_proj", "k"),
(".attention.qkv_proj", ".attention.v_proj", "v"),
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
]
# Process each weight
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name in name:
# Replace the weight name part with the combined parameter name
transformed_name = name.replace(weight_name, param_name)
if transformed_name in params_dict:
param = params_dict[transformed_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight, shard_id)
break
else:
if ".attention.o_proj" in name:
alt_name = name.replace(".attention.o_proj", ".attention.proj")
if alt_name in params_dict:
name = alt_name
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
class PixtralVisionModel(PixtralHFVisionModel):
pass
# Register the model classes for external access
EntryClass = [PixtralVisionModel]
...@@ -19,7 +19,9 @@ from typing import List, Optional, Tuple, Union ...@@ -19,7 +19,9 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformers
from transformers import ( from transformers import (
AutoConfig,
AutoModel, AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForVision2Seq, AutoModelForVision2Seq,
...@@ -211,7 +213,12 @@ class HFRunner: ...@@ -211,7 +213,12 @@ class HFRunner:
# Load the model and tokenizer # Load the model and tokenizer
if self.model_type == "generation": if self.model_type == "generation":
self.base_model = AutoModelForCausalLM.from_pretrained( config = AutoConfig.from_pretrained(model_path)
if model_archs := getattr(config, "architectures"):
model_cls = getattr(transformers, model_archs[0])
else:
model_cls = AutoModelForCausalLM
self.base_model = model_cls.from_pretrained(
model_path, model_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
......
...@@ -14,14 +14,15 @@ ...@@ -14,14 +14,15 @@
""" """
Usage: Usage:
To test a specific model: To test a specific model locally:
1. Add it to ALL_OTHER_MODELS 1. Add it to ALL_MODELS, for example, `ModelCase("Qwen/Qwen2-1.5B")`
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others` 2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels`
""" """
import dataclasses import dataclasses
import multiprocessing as mp import multiprocessing as mp
import os import os
import random
import unittest import unittest
from typing import List from typing import List
...@@ -53,8 +54,9 @@ CI_MODELS = [ ...@@ -53,8 +54,9 @@ CI_MODELS = [
ModelCase("google/gemma-2-2b"), ModelCase("google/gemma-2-2b"),
] ]
# All other models that do not run on the CI # the complete set of models to test sglang's generation model
ALL_OTHER_MODELS = [ ALL_MODELS = [
*CI_MODELS,
ModelCase("Qwen/Qwen2-1.5B"), ModelCase("Qwen/Qwen2-1.5B"),
ModelCase("Qwen/Qwen2.5-14B-Instruct"), ModelCase("Qwen/Qwen2.5-14B-Instruct"),
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True), ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
...@@ -63,7 +65,7 @@ ALL_OTHER_MODELS = [ ...@@ -63,7 +65,7 @@ ALL_OTHER_MODELS = [
"THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True "THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True
), ),
ModelCase("openai-community/gpt2"), ModelCase("openai-community/gpt2"),
ModelCase("microsoft/Phi-3-small-8k-instruct"), ModelCase("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True),
ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True), ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True), ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True),
] ]
...@@ -117,12 +119,13 @@ class TestGenerationModels(CustomTestCase): ...@@ -117,12 +119,13 @@ class TestGenerationModels(CustomTestCase):
debug_text=f"model_path={model_path} prompts={prompts}", debug_text=f"model_path={model_path} prompts={prompts}",
) )
@unittest.skipIf(not is_in_ci(), "Local test should run all models")
def test_ci_models(self): def test_ci_models(self):
for model_case in CI_MODELS: for model_case in CI_MODELS:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
prompts = DEFAULT_PROMPTS
# Skip long prompts for models that do not have a long context # Skip long prompts for models that do not have a long context
prompts = DEFAULT_PROMPTS
if model_case.skip_long_prompt: if model_case.skip_long_prompt:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000] prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
...@@ -131,25 +134,25 @@ class TestGenerationModels(CustomTestCase): ...@@ -131,25 +134,25 @@ class TestGenerationModels(CustomTestCase):
prompts, model_case, torch_dtype prompts, model_case, torch_dtype
) )
def test_others(self): @unittest.skipIf(is_in_ci(), "CI only runs selected models for simplicity")
if is_in_ci(): def test_all_models(self):
return for model_case in ALL_MODELS:
for torch_dtype in TORCH_DTYPES:
for model_case in ALL_OTHER_MODELS: if (
# Only run a specified model "ONLY_RUN" in os.environ
if ( and os.environ["ONLY_RUN"] != model_case.model_path
"ONLY_RUN" in os.environ ):
and os.environ["ONLY_RUN"] != model_case.model_path continue
):
continue # Skip long prompts for models that do not have a long context
prompts = DEFAULT_PROMPTS
# Skip long prompts for models that do not have a long context if model_case.skip_long_prompt:
prompts = DEFAULT_PROMPTS prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
if model_case.skip_long_prompt:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000] # Assert the logits and output strs are close
self.assert_close_logits_and_output_strs(
# Assert the logits and output strs are close prompts, model_case, torch_dtype
self.assert_close_logits_and_output_strs(prompts, model_case, torch.float16) )
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -642,6 +642,28 @@ class TestMinicpmoServer(TestOpenAIVisionServer): ...@@ -642,6 +642,28 @@ class TestMinicpmoServer(TestOpenAIVisionServer):
self._test_audio_ambient_completion() self._test_audio_ambient_completion()
class TestPixtralServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "mistral-community/pixtral-12b"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--mem-fraction-static",
"0.73",
],
)
cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
class TestDeepseekVL2Server(TestOpenAIVisionServer): class TestDeepseekVL2Server(TestOpenAIVisionServer):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment