Unverified Commit 5380cd7e authored by Kiv Chen's avatar Kiv Chen Committed by GitHub
Browse files

model(vlm): pixtral (#5084)

parent b2e95f62
......@@ -20,6 +20,7 @@ python3 -m sglang.launch_server \
| **Janus-Pro** (1B, 7B) | `deepseek-ai/Janus-Pro-7B` | `janus-pro` | DeepSeek’s open-source multimodal model capable of both image understanding and generation. Janus-Pro employs a decoupled architecture for separate visual encoding paths, enhancing performance in both tasks. |
| **MiniCPM-V / MiniCPM-o** | `openbmb/MiniCPM-V-2_6` | `minicpmv` | MiniCPM-V (2.6, ~8B) supports image inputs, and MiniCPM-o adds audio/video; these multimodal LLMs are optimized for end-side deployment on mobile/edge devices. |
| **Llama 3.2 Vision** (11B) | `meta-llama/Llama-3.2-11B-Vision-Instruct` | `llama_3_vision` | Vision-enabled variant of Llama 3 (11B) that accepts image inputs for visual question answering and other multimodal tasks. |
| **Pixtral** (12B, 124B) | `mistral-community/pixtral-12b` | `mistral` | Pixtral is a vision-language model from Mistral AI that can process both text and images. |
| **LLaVA** (v1.5 & v1.6) | *e.g.* `liuhaotian/llava-v1.5-13b` | `vicuna_v1.1` | Open vision-chat models that add an image encoder to LLaMA/Vicuna (e.g. LLaMA2 13B) for following multimodal instruction prompts. |
| **LLaVA-NeXT** (8B, 72B) | `lmms-lab/llava-next-72b` | `chatml-llava` | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks. |
| **LLaVA-OneVision** | `lmms-lab/llava-onevision-qwen2-7b-ov` | `chatml-llava` | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format. |
......
......@@ -33,9 +33,10 @@ The `hidden_states` folder contains examples on how to extract hidden states usi
* `hidden_states_engine.py`: An example how to extract hidden states using the Engine API.
* `hidden_states_server.py`: An example how to extract hidden states using the Server API.
## LLaVA-NeXT
## Multimodal
SGLang supports multimodal inputs for various model architectures. The `multimodal` folder contains examples showing how to use urls, files or encoded data to make requests to multimodal models. Examples include querying the [Llava-OneVision](multimodal/llava_onevision_server.py) model (image, multi-image, video), Llava-backed [Qwen-Llava](multimodal/qwen_llava_server.py) and [Llama3-Llava](multimodal/llama3_llava_server.py) models (image, multi-image), and Mistral AI's [Pixtral](multimodal/pixtral_server.py) (image, multi-image).
SGLang support LLaVA-OneVision with single-image, multi-image and video are supported. The folder `llava_onevision` shows how to do this.
## Token In, Token Out
......
......@@ -6,7 +6,7 @@ Usage:
# Endpoint Service CLI:
python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000
python3 http_llama3_llava_test.py
python3 llama3_llava_server.py
Output:
"Friends posing for a fun photo with a life-sized teddy bear, creating a playful and memorable moment."
......
......@@ -3,7 +3,7 @@ Usage:
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8
python3 http_llava_onevision_test.py
python3 llava_onevision_server.py
"""
import base64
......
"""
Usage:
# Run a Pixtral model with SGLang:
# HuggingFace:
python -m sglang.launch_server --model-path mistral-community/pixtral-12b --port=30000
# ModelScope:
python -m sglang.launch_server --model-path AI-ModelScope/pixtral-12b --port=30000
# Then test it with:
python pixtral_server.py
This script tests Pixtral model with both single and multiple images.
"""
import argparse
import asyncio
import json
import aiohttp
import requests
IMAGE_TOKEN_SEP = "\n[IMG]"
ROUTE = "/generate"
async def send_request(url, data, delay=0):
await asyncio.sleep(delay)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as resp:
output = await resp.json()
return output
async def test_concurrent(args):
url = f"{args.host}:{args.port}{ROUTE}"
# Single image test
if args.single_image:
prompt = f"<s>[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]"
image_url = "https://picsum.photos/id/237/400/300"
modality = ["image"]
# Multiple images test
else:
image_urls = [
"https://picsum.photos/id/237/400/300",
"https://picsum.photos/id/27/500/500",
]
prompt = f"<s>[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]"
image_url = image_urls
modality = ["multi-images"]
response = await send_request(
url,
{
"text": prompt,
"image_data": image_url,
"sampling_params": {
"max_new_tokens": 100,
"temperature": 0.7,
"top_p": 0.9,
},
"modalities": modality,
},
)
print(f"Response: {response}")
if "text" in response:
print("\nOutput text:", response["text"])
def test_streaming(args):
url = f"{args.host}:{args.port}/generate"
# Single image test
if args.single_image:
prompt = f"<s>[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]"
image_data = "https://picsum.photos/id/237/400/300"
modality = ["image"]
# Multiple images test
else:
image_urls = [
"https://picsum.photos/id/237/400/300",
"https://picsum.photos/id/27/500/500",
]
prompt = f"<s>[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]"
image_data = image_urls
modality = ["multi-images"]
pload = {
"text": prompt,
"image_data": image_data,
"sampling_params": {"max_new_tokens": 100, "temperature": 0.7, "top_p": 0.9},
"modalities": modality,
"stream": True,
}
response = requests.post(url, json=pload, stream=True)
print("Streaming response:")
prev = 0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
parser.add_argument(
"--single-image",
action="store_true",
help="Test with single image instead of multiple images",
)
parser.add_argument("--no-stream", action="store_true", help="Don't test streaming")
args = parser.parse_args()
asyncio.run(test_concurrent(args))
if not args.no_stream:
test_streaming(args)
......@@ -6,7 +6,7 @@ Usage:
# Endpoint Service CLI:
python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8
python3 http_qwen_llava_test.py
python3 qwen_llava_server.py
Output:
"Two children pose with a large teddy bear, one holding a smaller stuffed bear, in a room with an American flag and potted plants."
......
......@@ -194,6 +194,21 @@ register_chat_template(
)
)
# Reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
register_chat_template(
ChatTemplate(
name="mistral",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("[SYSTEM_PROMPT] ", " [/SYSTEM_PROMPT]"),
"user": ("[INST] ", " [/INST]"),
"assistant": ("", " </s><s>"),
},
stop_str=("</s>",),
image_token="[IMG]",
)
)
register_chat_template(
ChatTemplate(
name="llama-3-instruct",
......@@ -509,13 +524,19 @@ def match_vicuna(model_path: str):
@register_chat_template_matching_function
def match_llama2_chat(model_path: str):
if re.search(
r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct",
r"llama-2.*chat|codellama.*instruct",
model_path,
re.IGNORECASE,
):
return "llama-2-chat"
@register_chat_template_matching_function
def match_mistral(model_path: str):
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
return "mistral"
@register_chat_template_matching_function
def match_llama3_instruct(model_path: str):
if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE):
......
......@@ -545,6 +545,7 @@ multimodal_model_archs = [
"Llama4ForConditionalGeneration",
"LlavaMistralForCausalLM",
"LlavaQwenForCausalLM",
"LlavaForConditionalGeneration",
"LlavaVidForCausalLM",
"MiniCPMO",
"MiniCPMV",
......
......@@ -634,6 +634,20 @@ register_conv_template(
)
)
# reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
register_conv_template(
Conversation(
name="mistral",
system_template="[SYSTEM_PROMPT]\n{system_message}\n[/SYSTEM_PROMPT]\n\n",
roles=("[INST]", "[/INST]"),
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2=" </s><s>",
stop_str=["[INST]", "[/INST]", "[SYSTEM_PROMPT]", "[/SYSTEM_PROMPT]"],
image_token="[IMG]",
)
)
# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
register_conv_template(
Conversation(
......@@ -880,13 +894,19 @@ def match_vicuna(model_path: str):
@register_conv_template_matching_function
def match_llama2_chat(model_path: str):
if re.search(
r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct",
r"llama-2.*chat|codellama.*instruct",
model_path,
re.IGNORECASE,
):
return "llama-2"
@register_conv_template_matching_function
def match_mistral(model_path: str):
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
return "mistral"
@register_conv_template_matching_function
def match_deepseek_vl(model_path: str):
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
......
import asyncio
import importlib
from typing import List, Optional, Union
import numpy as np
from transformers.models.auto.processing_auto import (
PROCESSOR_MAPPING_NAMES as HF_MAPPING_NAMES,
)
import sglang.srt.managers.multimodal_processor as sgl_mm_processor_utils
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.models.llava import (
LlavaForConditionalGeneration,
LlavaLlamaForCausalLM,
LlavaMistralForCausalLM,
LlavaQwenForCausalLM,
......@@ -133,6 +139,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
img_data, aspect_ratio, grid_pinpoints
)
)
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v)
......@@ -165,3 +172,42 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
)
],
}
class LlavaMultimodalProcessor(BaseMultimodalProcessor):
"""
This is a wrapper class used to identify the multimodal processor for Llava architecture models.
"""
models = [LlavaForConditionalGeneration]
def _get_sgl_processor_cls(self, model_type: str):
if hf_name := HF_MAPPING_NAMES.get(model_type):
sgl_mm_processor_set = sgl_mm_processor_utils.PROCESSOR_MAPPING.values()
sgl_processor_cls = list(
filter(lambda p: p.__name__ == hf_name, sgl_mm_processor_set)
)
if sgl_processor_cls:
return sgl_processor_cls[0]
raise ValueError(
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
)
def __init__(self, hf_config, server_args, _processor):
assert hasattr(hf_config, "vision_config")
assert hasattr(hf_config, "text_config")
self.vision_config = hf_config.vision_config
self.text_config = hf_config.text_config
self.hf_config = hf_config
if vision_type := getattr(self.vision_config, "model_type"):
self.inner = self._get_sgl_processor_cls(vision_type)(
hf_config, server_args, _processor
)
else:
raise ValueError(
f"Required `vision_config.model_type` is not found in hf_config: `{hf_config}`"
)
async def process_mm_data_async(self, *args, **kwargs):
return await self.inner.process_mm_data_async(*args, **kwargs)
import asyncio
import math
from typing import List, Optional, Union
import numpy as np
from transformers import PretrainedConfig
from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
)
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.models.pixtral import PixtralVisionModel
class PixtralProcessor(BaseMultimodalProcessor):
models = [PixtralVisionModel]
PAD_TOKEN = "<pad>"
IMG_BREAK_TOKEN_ID = 12
IMG_END_TOKEN_ID = 13
def get_patch_grid_size(
self,
*,
image_width: int,
image_height: int,
) -> tuple[int, int]:
max_width = max_height = self.image_size
patch_width = patch_height = self.patch_size
ratio = max(image_width / max_width, image_height / max_height)
if ratio > 1:
image_width = int(math.floor(image_width / ratio))
image_height = int(math.floor(image_height / ratio))
nrows, ncols = _get_pixtral_hf_num_image_tokens(
(image_height, image_width),
(patch_height, patch_width),
)
return ncols, nrows
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.image_token_id = getattr(
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
)
# Instantiate the patcher logic helper using the class defined above
self.vision_config = hf_config.vision_config
self.image_size = self.vision_config.image_size
self.patch_size = self.vision_config.patch_size
self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token
)
_processor.tokenizer.add_special_tokens(
{
"pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN),
}
)
async def _resize(self, image):
num_w_tokens, num_h_tokens = self.get_patch_grid_size(
image_width=image.size[0],
image_height=image.size[1],
)
new_size = (num_w_tokens * self.patch_size, num_h_tokens * self.patch_size)
return image.resize(new_size)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
mm_data = self.load_mm_data(
prompt=input_text,
multimodal_tokens=self.multimodal_tokens,
max_req_input_len=kwargs.get("max_req_input_len", 4096),
image_data=image_data,
return_text=True,
)
if mm_data.images:
resize_tasks = [self._resize(image) for image in mm_data.images]
mm_data.images = await asyncio.gather(*resize_tasks)
processor_output = self.process_mm_data(
input_text=mm_data.input_text,
images=mm_data.images,
)
if "pixel_values" in processor_output:
mm_items = [
MultimodalDataItem(
pixel_values=processor_output["pixel_values"],
image_sizes=processor_output["image_sizes"],
modality=Modality.IMAGE,
)
]
input_ids = processor_output["input_ids"].view(-1).tolist()
processor_output.update(
input_ids=input_ids,
mm_items=mm_items,
# there's no im_start_id for pixtral, only im_token and im_end_token
im_end_id=self.IMG_END_TOKEN_ID,
im_token_id=self.image_token_id,
)
return processor_output
......@@ -15,7 +15,8 @@
import math
import re
from typing import Iterable, List, Optional, Tuple
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
import numpy as np
import torch
......@@ -28,10 +29,18 @@ from transformers import (
Qwen2Config,
SiglipVisionModel,
)
from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
# leave till last and symbol only in case circular import
import sglang.srt.models as sgl_models
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
from sglang.srt.managers.mm_utils import general_mm_embed_routine
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,
......@@ -42,7 +51,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.utils import add_prefix, flatten_nested_list
from sglang.srt.utils import add_prefix, flatten_nested_list, logger
class LlavaBaseForCausalLM(nn.Module):
......@@ -114,7 +123,16 @@ class LlavaBaseForCausalLM(nn.Module):
image_inputs.image_offsets = offset_list
return input_ids
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
def encode_images(
self, pixel_values: Union[torch.Tensor, List[torch.Tensor]]
) -> torch.Tensor:
"""
encode images by vision tower and multimodal projector
Args:
pixel_values: torch.Tensor or List[torch.Tensor]: each tensor for an input image
Returns:
torch.Tensor: encoded image features from the input image; if multiple, flattened by seq_len axis
"""
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
......@@ -583,4 +601,229 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
)
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
"""
An adaptor class to enable support for multiple mmlm such as mistral-community/pixtral-12b
It follows the structure of (vision_tower, multi_modal_projector, language_model)
Once a model config is loaded, text_config and vision_config will be extracted, and
LlavaForConditionalGeneration will load the language_model and vision_tower models
according to config.
"""
MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
if hasattr(self.vision_tower, "pad_input_ids"):
return self.vision_tower.pad_input_ids(input_ids, image_inputs)
else:
return super().pad_input_ids(input_ids, image_inputs)
def _get_sgl_model_cls(self, config, auto_model_type: Type[AutoModel] = AutoModel):
"""
Get the SGLang model implementation class according to config.
Args:
config: The config object of the model.
auto_model_type: The type of the auto model.
Returns:
The SGLang model implementation class.
"""
config_cls_name = config.__class__.__name__
arch_name_mapping = self._config_cls_name_to_arch_name_mapping(auto_model_type)
if arch := arch_name_mapping.get(config_cls_name):
if isinstance(arch, tuple):
arch = arch[0]
logger.warning(
f"Multiple {auto_model_type.__name__} models found for submodule config `{config_cls_name}`, defaulting to [0]: {arch.__name__}"
)
try:
return sgl_models.registry.ModelRegistry.resolve_model_cls(arch)[0]
except Exception as e:
raise ValueError(
f"{auto_model_type.__name__} found a corresponding model `{arch}` for config class `{config_cls_name}`, but failed to load it from SGLang ModelRegistry. \n{e}"
)
else:
raise ValueError(
f"{auto_model_type.__name__} cannot find a corresponding model for config class `{config_cls_name}`"
)
@lru_cache
def _config_cls_name_to_arch_name_mapping(
self, auto_model_type: Type[AutoModel]
) -> Dict[str, str]:
mapping = {}
for config_cls, archs in auto_model_type._model_mapping.items():
if isinstance(archs, tuple):
mapping[config_cls.__name__] = tuple(arch.__name__ for arch in archs)
else:
mapping[config_cls.__name__] = archs.__name__
return mapping
def __init__(
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
assert hasattr(config, "text_config")
assert hasattr(config, "vision_config")
self.config = config
self.text_config = config.text_config
self.vision_config = config.vision_config
if not hasattr(self.config, "vocab_size"):
self.config.vocab_size = self.config.text_config.vocab_size
if not hasattr(self.config, "image_aspect_ratio"):
self.config.image_aspect_ratio = "anyres"
if not hasattr(self.config, "image_grid_pinpoints"):
# from transformers.models.llava_onevision.configuration_llava_onevision import LlavaOnevisionConfig
# self.config.image_grid_pinpoints = LlavaOnevisionConfig().image_grid_pinpoints
self.config.image_grid_pinpoints = [
[96, 96],
[224, 224],
[384, 384],
[512, 512],
[768, 768],
[1024, 1024],
]
if not hasattr(self.config, "mm_patch_merge_type"):
self.config.mm_patch_merge_type = "flat"
if not hasattr(self.config, "image_token_index"):
self.config.image_token_index = 10
if not hasattr(self.config, "projector_hidden_act"):
self.config.projector_hidden_act = "gelu"
self.vision_feature_layer = getattr(config, "vision_feature_layer", -1)
self.vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", "full"
)
self.image_size = self.config.vision_config.image_size
self.patch_size = self.config.vision_config.patch_size
self.mm_patch_merge_type = config.mm_patch_merge_type
self.image_aspect_ratio = config.image_aspect_ratio
self.image_grid_pinpoints = config.image_grid_pinpoints
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)
language_model_cls = self._get_sgl_model_cls(
config.text_config, AutoModelForCausalLM
)
vision_model_cls = self._get_sgl_model_cls(config.vision_config, AutoModel)
self.language_model = language_model_cls(
config.text_config,
quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
)
self.vision_tower = vision_model_cls(
config.vision_config,
quant_config=quant_config,
prefix=add_prefix("vision_tower", prefix),
)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
"""Extract features from image inputs.
Args:
items: List of MultimodalDataItem objects containing image data
Note that an item can be either "image" or "multi-images"
Returns:
torch.Tensor: features from image inputs, concatenated
"""
features = []
for item in items:
# in each item, we assume pixel_values is always batched
pixel_values, image_sizes = item.pixel_values, item.image_sizes
image_outputs = self.vision_tower(
pixel_values, image_sizes, output_hidden_states=True
)
selected_image_feature = image_outputs.hidden_states[
self.vision_feature_layer
]
if self.vision_feature_select_strategy in ["default", "patch"]:
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(
f"Unexpected select feature: {self.vision_feature_select_strategy}"
)
features.append(
self.multi_modal_projector(selected_image_feature.squeeze(0))
)
ret = torch.cat(features, dim=0)
return ret
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
get_embedding=get_embedding,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
placeholder_tokens=None, # using mm_item.pad_value
positions=positions,
)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""Load weights for LlavaForConditionalGeneration.
Unlike the base class implementation, this one doesn't need to handle
weight name remapping as the weights are already properly structured with
'language_model' and 'vision_tower' prefixes in the safetensors files.
"""
if (
self.vision_feature_select_strategy == "patch"
or self.vision_feature_select_strategy == "full"
):
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1
else:
raise ValueError(
f"Unexpected select feature: {self.vision_feature_select_strategy}"
)
# Create dictionaries for direct parameter loading
params_dict = dict(self.named_parameters())
# Load weights directly without remapping
for name, loaded_weight in weights:
for part in ("language_model", "vision_tower"):
if name.startswith(part):
name = name[len(part + ".") :]
getattr(self, part).load_weights([(name, loaded_weight)])
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = [
LlavaLlamaForCausalLM,
LlavaQwenForCausalLM,
LlavaMistralForCausalLM,
LlavaForConditionalGeneration,
]
# Copyright 2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Using mistral-community/pixtral-12b as reference.
"""
import logging
import math
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PixtralVisionConfig, PretrainedConfig
from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding
from transformers.models.pixtral.modeling_pixtral import (
generate_block_attention_mask as _get_pixtral_attention_mask,
)
from transformers.models.pixtral.modeling_pixtral import position_ids_in_meshgrid
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import MergedColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.model_loader.weight_utils import default_weight_loader
class PixtralHFMLP(nn.Module):
"""MLP for PixtralHFVisionModel using SGLang components."""
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__()
assert config.intermediate_size is not None
# Use MergedColumnParallelLinear for gate_up_proj to handle combined weights
self.gate_up_proj = MergedColumnParallelLinear(
input_size=config.hidden_size,
output_sizes=[config.intermediate_size, config.intermediate_size],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up_output, _ = self.gate_up_proj(x)
# Apply SiLU activation and multiply
gate_up = self.act_fn(gate_up_output)
# Project back to hidden size
out, _ = self.down_proj(gate_up)
return out
class PixtralHFTransformerBlock(nn.Module):
"""Transformer block for PixtralHFVisionModel using SGLang components."""
def __init__(
self,
config: PretrainedConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__()
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
# Use SGLang's VisionAttention instead of vLLM's PixtralHFAttention
self.attention = VisionAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
projection_size=config.hidden_size,
use_qkv_parallel=True,
quant_config=quant_config,
dropout=0.0,
use_context_forward=False,
softmax_in_single_precision=False,
flatten_batch=False,
prefix=f"{prefix}.attention",
)
self.feed_forward = PixtralHFMLP(
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
)
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
# Ensure hidden_states has the batch dimension [batch, seq_len, hidden_dim]
batch_size, seq_len, hidden_dim = hidden_states.shape
# Apply attention norm - normalize along the last dimension
attn_normalized = self.attention_norm(hidden_states.view(-1, hidden_dim)).view(
batch_size, seq_len, hidden_dim
)
# Pass through attention layer
attention_output = self.attention(
attn_normalized,
attention_mask=attention_mask,
cu_seqlens=None,
position_embeddings=position_embeddings,
)
# Apply first residual connection
hidden_states = hidden_states + attention_output
# Apply feed-forward norm - normalize along the last dimension
ffn_normalized = self.ffn_norm(hidden_states.view(-1, hidden_dim)).view(
batch_size, seq_len, hidden_dim
)
# Pass through feed-forward layer
# First reshape to 2D for the feed-forward network, then reshape back
ffn_output = self.feed_forward(ffn_normalized)
# Apply second residual connection
output = hidden_states + ffn_output
return output
class PixtralHFTransformer(nn.Module):
"""Transformer for PixtralHFVisionModel using SGLang components."""
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
num_hidden_layers = config.num_hidden_layers
if num_hidden_layers_override is not None:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList(
[
PixtralHFTransformerBlock(
config=config,
layer_id=layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
)
for layer_idx in range(num_hidden_layers)
]
)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor],
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
return_all_hidden_states: bool = False,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Forward pass through transformer layers.
Args:
x: Input tensor
attention_mask: Optional attention mask
position_embeddings: Optional position embeddings for rotary attention
return_all_hidden_states: Whether to return all hidden states
Returns:
Either the final hidden state, or a list of all hidden states if
return_all_hidden_states is True
"""
# For HF model compatibility, always start with the input
hidden_states = x
all_hidden_states = [hidden_states] if return_all_hidden_states else None
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, attention_mask, position_embeddings)
if return_all_hidden_states:
all_hidden_states.append(hidden_states)
if return_all_hidden_states:
return all_hidden_states
return hidden_states
def resolve_visual_encoder_outputs(
outputs: Union[torch.Tensor, List[torch.Tensor]],
feature_sample_layers: Optional[List[int]],
post_norm: Optional[nn.Module],
num_hidden_layers: int,
) -> torch.Tensor:
"""Resolve outputs from visual encoder based on feature_sample_layers."""
if feature_sample_layers is None:
# Just use the last layer's output
if isinstance(outputs, list):
outputs = outputs[-1]
if post_norm is not None:
outputs = post_norm(outputs)
return outputs
# Handle the case where we want to use specific layers
if not isinstance(outputs, list):
raise ValueError(
"Expected outputs to be a list when feature_sample_layers is provided"
)
# Validate layer indices
for layer_idx in feature_sample_layers:
if layer_idx < 0 or layer_idx > num_hidden_layers:
raise ValueError(
f"Feature sample layer index {layer_idx} is out of range "
f"[0, {num_hidden_layers}]"
)
# Collect outputs from specified layers
selected_outputs = [outputs[layer_idx] for layer_idx in feature_sample_layers]
# Combine the outputs
combined_outputs = torch.cat(selected_outputs, dim=-1)
if post_norm is not None:
combined_outputs = post_norm(combined_outputs)
return combined_outputs
class PixtralHFVisionModel(nn.Module):
"""Hugging Face Pixtral Vision Model implemented using SGLang components."""
DEFAULT_IMAGE_TOKEN_ID = 10
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
return self.input_padder.pad_input_tokens(input_ids, image_inputs)
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_conv = nn.Conv2d(
in_channels=config.num_channels,
out_channels=config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size,
bias=False,
)
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
self.transformer = PixtralHFTransformer(
config,
quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.transformer",
)
# Check that num_hidden_layers is valid
num_hidden_layers = config.num_hidden_layers
if len(self.transformer.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.transformer.layers)} "
"layers."
)
# Initialize patch position embedding
self.image_token_id = image_token_id
self.patch_positional_embedding = PixtralRotaryEmbedding(config)
self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens(
[self.image_token_id]
)
@property
def dtype(self):
return next(self.parameters()).dtype
@property
def device(self):
return next(self.parameters()).device
def forward(
self,
pixel_values: torch.Tensor,
image_sizes: list[tuple[int, int]],
output_hidden_states: bool = False,
feature_sample_layers: Optional[list[int]] = None,
) -> Union[torch.Tensor, tuple]:
"""
Args:
pixel_values: [batch_size, C, H, W], padded if multiple images
image_sizes: list of (H, W) for each image in the batch
output_hidden_states: Whether to return all hidden states.
feature_sample_layers: Layer indices whose features should be
concatenated and used as the visual encoder output. If none
are provided, the last layer is used.
Returns:
A tuple containing:
- hidden_states: Final model outputs (or selected layers if feature_sample_layers given)
- hidden_states tuple (optional): All hidden states if output_hidden_states=True
"""
# batch patch images
embeds_orig = self.patch_conv(
pixel_values.to(device=self.device, dtype=self.dtype)
)
# crop the embeddings
embeds_2d = [
embed[..., : h // self.patch_size, : w // self.patch_size]
for embed, (h, w) in zip(embeds_orig, image_sizes)
]
# flatten to sequence
embeds_1d = torch.cat([p.flatten(1).T for p in embeds_2d], dim=0)
embeds_featurized = self.ln_pre(embeds_1d).unsqueeze(0)
# positional embeddings
position_ids = position_ids_in_meshgrid(
embeds_2d,
max_width=self.image_size // self.patch_size,
).to(self.device)
# The original PixtralRotaryEmbedding expects 2D input but returns a tuple of tensors (cos, sin)
# These tensors are used by apply_rotary_pos_emb in the transformer blocks
position_embedding = self.patch_positional_embedding(
embeds_featurized, position_ids
)
attention_mask = _get_pixtral_attention_mask(
[p.shape[-2] * p.shape[-1] for p in embeds_2d], embeds_featurized
)
return_all_hidden_states = (
output_hidden_states or feature_sample_layers is not None
)
transformer_outputs = self.transformer(
embeds_featurized, # add batch dimension
attention_mask,
position_embedding,
return_all_hidden_states=return_all_hidden_states,
)
# Store all hidden states if requested
all_hidden_states = None
if isinstance(transformer_outputs, list):
all_hidden_states = transformer_outputs
# Use the last layer by default if feature_sample_layers is not specified
if feature_sample_layers is None:
out = transformer_outputs[-1]
else:
# Resolve outputs based on feature sample layers
out = resolve_visual_encoder_outputs(
transformer_outputs,
feature_sample_layers,
None,
self.config.num_hidden_layers,
)
else:
out = transformer_outputs
# Format return to be compatible with HuggingFace vision models
if output_hidden_states:
return type(
"VisualOutput",
(),
{
"last_hidden_state": out,
"hidden_states": all_hidden_states,
},
)
else:
return out
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
"""Load weights from a HuggingFace checkpoint with proper parameter mapping."""
params_dict = dict(self.named_parameters())
# for (param, weight, shard_id): load weight into param as param's shard_id part
stacked_params_mapping = [
(".attention.qkv_proj", ".attention.q_proj", "q"),
(".attention.qkv_proj", ".attention.k_proj", "k"),
(".attention.qkv_proj", ".attention.v_proj", "v"),
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
]
# Process each weight
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name in name:
# Replace the weight name part with the combined parameter name
transformed_name = name.replace(weight_name, param_name)
if transformed_name in params_dict:
param = params_dict[transformed_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight, shard_id)
break
else:
if ".attention.o_proj" in name:
alt_name = name.replace(".attention.o_proj", ".attention.proj")
if alt_name in params_dict:
name = alt_name
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
class PixtralVisionModel(PixtralHFVisionModel):
pass
# Register the model classes for external access
EntryClass = [PixtralVisionModel]
......@@ -19,7 +19,9 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import transformers
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForVision2Seq,
......@@ -211,7 +213,12 @@ class HFRunner:
# Load the model and tokenizer
if self.model_type == "generation":
self.base_model = AutoModelForCausalLM.from_pretrained(
config = AutoConfig.from_pretrained(model_path)
if model_archs := getattr(config, "architectures"):
model_cls = getattr(transformers, model_archs[0])
else:
model_cls = AutoModelForCausalLM
self.base_model = model_cls.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=self.trust_remote_code,
......
......@@ -14,14 +14,15 @@
"""
Usage:
To test a specific model:
1. Add it to ALL_OTHER_MODELS
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others`
To test a specific model locally:
1. Add it to ALL_MODELS, for example, `ModelCase("Qwen/Qwen2-1.5B")`
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels`
"""
import dataclasses
import multiprocessing as mp
import os
import random
import unittest
from typing import List
......@@ -53,8 +54,9 @@ CI_MODELS = [
ModelCase("google/gemma-2-2b"),
]
# All other models that do not run on the CI
ALL_OTHER_MODELS = [
# the complete set of models to test sglang's generation model
ALL_MODELS = [
*CI_MODELS,
ModelCase("Qwen/Qwen2-1.5B"),
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
......@@ -63,7 +65,7 @@ ALL_OTHER_MODELS = [
"THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True
),
ModelCase("openai-community/gpt2"),
ModelCase("microsoft/Phi-3-small-8k-instruct"),
ModelCase("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True),
ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True),
]
......@@ -117,12 +119,13 @@ class TestGenerationModels(CustomTestCase):
debug_text=f"model_path={model_path} prompts={prompts}",
)
@unittest.skipIf(not is_in_ci(), "Local test should run all models")
def test_ci_models(self):
for model_case in CI_MODELS:
for torch_dtype in TORCH_DTYPES:
prompts = DEFAULT_PROMPTS
# Skip long prompts for models that do not have a long context
prompts = DEFAULT_PROMPTS
if model_case.skip_long_prompt:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
......@@ -131,25 +134,25 @@ class TestGenerationModels(CustomTestCase):
prompts, model_case, torch_dtype
)
def test_others(self):
if is_in_ci():
return
for model_case in ALL_OTHER_MODELS:
# Only run a specified model
if (
"ONLY_RUN" in os.environ
and os.environ["ONLY_RUN"] != model_case.model_path
):
continue
# Skip long prompts for models that do not have a long context
prompts = DEFAULT_PROMPTS
if model_case.skip_long_prompt:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
# Assert the logits and output strs are close
self.assert_close_logits_and_output_strs(prompts, model_case, torch.float16)
@unittest.skipIf(is_in_ci(), "CI only runs selected models for simplicity")
def test_all_models(self):
for model_case in ALL_MODELS:
for torch_dtype in TORCH_DTYPES:
if (
"ONLY_RUN" in os.environ
and os.environ["ONLY_RUN"] != model_case.model_path
):
continue
# Skip long prompts for models that do not have a long context
prompts = DEFAULT_PROMPTS
if model_case.skip_long_prompt:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
# Assert the logits and output strs are close
self.assert_close_logits_and_output_strs(
prompts, model_case, torch_dtype
)
if __name__ == "__main__":
......
......@@ -642,6 +642,28 @@ class TestMinicpmoServer(TestOpenAIVisionServer):
self._test_audio_ambient_completion()
class TestPixtralServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "mistral-community/pixtral-12b"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--mem-fraction-static",
"0.73",
],
)
cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
class TestDeepseekVL2Server(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment