Unverified Commit fbebcb7a authored by Mick's avatar Mick Committed by GitHub
Browse files

model: support mllama4 (#5144)

parent 87eddedf
...@@ -486,8 +486,8 @@ multimodal_model_archs = [ ...@@ -486,8 +486,8 @@ multimodal_model_archs = [
"Gemma3ForConditionalGeneration", "Gemma3ForConditionalGeneration",
"Grok1VForCausalLM", "Grok1VForCausalLM",
"Grok1AForCausalLM", "Grok1AForCausalLM",
# TODO: add multimodal support for "Llama4ForConditionalGeneration",
"LlavaLlamaForCausalLM", "LlavaLlamaForCausalLM",
"Llama4ForConditionalGeneration",
"LlavaMistralForCausalLM", "LlavaMistralForCausalLM",
"LlavaQwenForCausalLM", "LlavaQwenForCausalLM",
"LlavaVidForCausalLM", "LlavaVidForCausalLM",
......
...@@ -148,7 +148,8 @@ def get_embedding_and_mask( ...@@ -148,7 +148,8 @@ def get_embedding_and_mask(
placeholder_tensor, placeholder_tensor,
).unsqueeze(-1) ).unsqueeze(-1)
num_mm_tokens_in_input_ids = special_multimodal_mask.sum() num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding: if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
logger.warning( logger.warning(
f"Number of tokens in multimodal embedding does not match those in the input text." f"Number of tokens in multimodal embedding does not match those in the input text."
...@@ -172,7 +173,7 @@ def get_embedding_and_mask( ...@@ -172,7 +173,7 @@ def get_embedding_and_mask(
embedding = embedding[-num_multimodal:, :] embedding = embedding[-num_multimodal:, :]
else: else:
raise RuntimeError( raise RuntimeError(
"Insufficient multimodal embedding length. This is an internal error" f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
) )
return embedding, special_multimodal_mask return embedding, special_multimodal_mask
......
from typing import List, Mapping, Optional, Tuple, Union from typing import List, Union
import torch import torch
from PIL import Image
from transformers import Llama4Processor
from transformers.image_utils import SizeDict from transformers.image_utils import SizeDict
from transformers.models.llama4.image_processing_llama4 import ( from transformers.models.llama4.image_processing_llama4_fast import (
find_supported_resolutions, find_supported_resolutions,
get_best_fit, get_best_fit,
) )
...@@ -15,7 +13,6 @@ from sglang.srt.managers.multimodal_processors.base_processor import ( ...@@ -15,7 +13,6 @@ from sglang.srt.managers.multimodal_processors.base_processor import (
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
from sglang.srt.utils import load_image
class Mllama4ImageProcessor(BaseMultimodalProcessor): class Mllama4ImageProcessor(BaseMultimodalProcessor):
...@@ -25,6 +22,9 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): ...@@ -25,6 +22,9 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.vision_config = hf_config.vision_config self.vision_config = hf_config.vision_config
self.text_config = hf_config.text_config self.text_config = hf_config.text_config
self.boi_token_index = hf_config.boi_token_index
self.eoi_token_index = hf_config.eoi_token_index
self.image_token_index = hf_config.image_token_index
self.multimodal_tokens = MultimodalSpecialTokens( self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token image_token=_processor.image_token
) )
...@@ -54,19 +54,16 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): ...@@ -54,19 +54,16 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
) )
# Process the images using the processor # Process the images using the processor
processor = Llama4Processor.from_pretrained( processor = self._processor
self.server_args.model_path, **kwargs
)
# Process the prompt and images # Process the prompt and images
image_inputs = processor( processor_output = self.process_mm_data(
text=processed_data.input_text, input_text=processed_data.input_text,
images=processed_data.images, images=processed_data.images,
return_tensors="pt",
) )
# Handle image resolutions and aspect ratios # Handle image resolutions and aspect ratios
if "pixel_values" in image_inputs: if "pixel_values" in processor_output:
image_processor = processor.image_processor image_processor = processor.image_processor
tokenizer = self._processor.tokenizer tokenizer = self._processor.tokenizer
...@@ -100,8 +97,8 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): ...@@ -100,8 +97,8 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
] ]
# Add to image_inputs # Add to image_inputs
image_inputs["aspect_ratios"] = aspect_ratios processor_output["aspect_ratios"] = aspect_ratios
image_inputs["patches_per_image"] = torch.tensor(patches_per_image) processor_output["patches_per_image"] = torch.tensor(patches_per_image)
# Process embed_is_patch # Process embed_is_patch
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
...@@ -109,7 +106,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): ...@@ -109,7 +106,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
image_end_id = vocab.get(processor.end_of_img_token, -1) image_end_id = vocab.get(processor.end_of_img_token, -1)
if patch_id != -1 and image_end_id != -1: if patch_id != -1 and image_end_id != -1:
input_ids = image_inputs["input_ids"].view(-1) input_ids = processor_output["input_ids"].view(-1)
# Remove BOS token if present # Remove BOS token if present
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id: if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
...@@ -129,33 +126,21 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): ...@@ -129,33 +126,21 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
for per_image_input_ids in split_input_ids: for per_image_input_ids in split_input_ids:
embed_is_patch.append(per_image_input_ids == patch_id) embed_is_patch.append(per_image_input_ids == patch_id)
image_inputs["embed_is_patch"] = embed_is_patch processor_output["embed_is_patch"] = embed_is_patch
# Convert to the format expected by SGLang # Convert to the format expected by SGLang
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
processor_output["im_start_id"] = self.boi_token_index
processor_output["im_end_id"] = self.eoi_token_index
processor_output["im_token_id"] = self.image_token_index
# Add metadata for image processing # Add metadata for image processing
image_inputs["mm_items"] = [ processor_output["mm_items"] = [
MultimodalDataItem( MultimodalDataItem(
pixel_values=image_inputs["pixel_values"], pixel_values=processor_output["pixel_values"],
modality=Modality.IMAGE, modality=Modality.IMAGE,
# Add additional metadata needed for Llama4 vision processing
embed_is_patch=image_inputs.get("embed_is_patch", None),
aspect_ratios=image_inputs.get("aspect_ratios", None),
patches_per_image=image_inputs.get("patches_per_image", None),
) )
] ]
return image_inputs return processor_output
def get_patch_per_chunk(self):
"""Calculate patches per chunk based on vision config"""
image_size = self.vision_config.image_size
patch_size = self.vision_config.patch_size
assert (
image_size % patch_size == 0
), f"chunk size {image_size} should be multiple of patch_size {patch_size}"
ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2)))
return (image_size // patch_size) ** 2 // ds_ratio
from __future__ import annotations from __future__ import annotations
import hashlib
from enum import Enum, auto from enum import Enum, auto
# Copyright 2023-2024 SGLang Team # Copyright 2023-2024 SGLang Team
...@@ -157,7 +158,7 @@ class Modality(Enum): ...@@ -157,7 +158,7 @@ class Modality(Enum):
@dataclasses.dataclass @dataclasses.dataclass
class MultimodalDataItem: class MultimodalDataItem:
""" """
A single multimodal data, from a single image/video/audio or other A single multimodal data, from a single image/video/audio or others
""" """
modality: Modality modality: Modality
...@@ -195,25 +196,54 @@ class MultimodalDataItem: ...@@ -195,25 +196,54 @@ class MultimodalDataItem:
def set_pad_value(self): def set_pad_value(self):
""" """
Set the pad value after first hashign the data Set the pad value after first hashing the data
""" """
def tensor_hash(f): def data_hash(data) -> int:
f_list = flatten_nested_list(f) hash_bytes = hashlib.sha256(data).digest()[:8]
f_list = [x.flatten() if isinstance(x, torch.Tensor) else x for x in f_list] return int.from_bytes(hash_bytes, byteorder="big", signed=False)
f_cat = torch.concat(f_list).contiguous().numpy().tobytes()
return hash(f_cat) def tensor_hash(tensor_list) -> int:
"""
hash a tensor or a tensor list
"""
tensor = tensor_list
if isinstance(tensor_list, list):
tensor_list = flatten_nested_list(tensor_list)
tensor_list = [
x.flatten() if isinstance(x, torch.Tensor) else x
for x in tensor_list
]
tensor = torch.concat(tensor_list)
tensor = tensor.detach().contiguous()
if tensor.dtype == torch.bfloat16:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor = tensor.float()
if tensor.is_cuda:
tensor_cpu = torch.frombuffer(
tensor.storage().untyped(), dtype=tensor.dtype, count=tensor.numel()
).clone()
else:
tensor_cpu = tensor
mv = memoryview(tensor_cpu.numpy())
return data_hash(mv.tobytes())
def hash_feature(f): def hash_feature(f):
if isinstance(f, list): if isinstance(f, list):
if isinstance(f[0], torch.Tensor): if isinstance(f[0], torch.Tensor):
return tensor_hash(f) return tensor_hash(f)
return hash(tuple(flatten_nested_list(f))) return data_hash(tuple(flatten_nested_list(f)))
elif isinstance(f, np.ndarray): elif isinstance(f, np.ndarray):
arr = np.ascontiguousarray(f) arr = np.ascontiguousarray(f)
arr_bytes = arr.tobytes() arr_bytes = arr.tobytes()
return hash(arr_bytes) return data_hash(arr_bytes)
return hash(f) elif isinstance(f, torch.Tensor):
return tensor_hash([f])
return data_hash(f)
if self.is_audio(): if self.is_audio():
self.hash = hash_feature(self.audio_features) self.hash = hash_feature(self.audio_features)
...@@ -256,7 +286,7 @@ class MultimodalInputs: ...@@ -256,7 +286,7 @@ class MultimodalInputs:
mrope_position_delta: Optional[torch.Tensor] = None mrope_position_delta: Optional[torch.Tensor] = None
# image # image
im_token_id: Optional[torch.Tensor] = None im_token_id: Optional[int] = None
im_start_id: Optional[int] = None im_start_id: Optional[int] = None
im_end_id: Optional[int] = None im_end_id: Optional[int] = None
slice_start_id: Optional[int] = None slice_start_id: Optional[int] = None
...@@ -330,10 +360,8 @@ class MultimodalInputs: ...@@ -330,10 +360,8 @@ class MultimodalInputs:
# args needed to be merged # args needed to be merged
optional_args = [ optional_args = [
"items", "mm_items",
"image_offsets",
"image_pad_len", "image_pad_len",
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
] ]
for arg in optional_args: for arg in optional_args:
self_arg = getattr(self, arg, None) self_arg = getattr(self, arg, None)
......
...@@ -466,6 +466,9 @@ class Llama4ForCausalLM(LlamaForCausalLM): ...@@ -466,6 +466,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
): ):
super().__init__(config, quant_config, prefix) super().__init__(config, quant_config, prefix)
def get_input_embeddings(self):
return self.model.embed_tokens
def _init_model( def _init_model(
self, self,
config: Llama4TextConfig, config: Llama4TextConfig,
......
# TODO: add Aapted from vllm/mllama4.py
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Set, Tuple from typing import List, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import Llama4Config from transformers import Llama4Config, Llama4VisionModel
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternImageTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
...@@ -30,6 +35,9 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -30,6 +35,9 @@ class Llama4ForConditionalGeneration(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.vision_model = Llama4VisionModel(config.vision_config)
self.multi_modal_projector = Llama4MultiModalProjector(config)
# Initialize the language model # Initialize the language model
from sglang.srt.models.llama4 import Llama4ForCausalLM from sglang.srt.models.llama4 import Llama4ForCausalLM
...@@ -41,6 +49,29 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -41,6 +49,29 @@ class Llama4ForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config.text_config) self.logits_processor = LogitsProcessor(config.text_config)
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs
im_token_id: int = mm_inputs.im_token_id
pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id))
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(
self,
items: List[MultimodalDataItem],
) -> torch.Tensor:
pixel_values = (
torch.concat([item.pixel_values for item in items])
.to(next(self.vision_model.parameters()).device)
.type(next(self.vision_model.parameters()).dtype)
)
image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
image_features = image_outputs.last_hidden_state
vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat)
return projected_vision_flat
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -49,7 +80,15 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -49,7 +80,15 @@ class Llama4ForConditionalGeneration(nn.Module):
**kwargs: object, **kwargs: object,
) -> torch.Tensor: ) -> torch.Tensor:
return self.language_model(input_ids, positions, forward_batch) hs = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
positions=positions,
)
return hs
def permute_qk_weight_for_rotary( def permute_qk_weight_for_rotary(
self, self,
...@@ -108,17 +147,17 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -108,17 +147,17 @@ class Llama4ForConditionalGeneration(nn.Module):
) )
for name, loaded_weight in weights: for name, loaded_weight in weights:
if not "vision" in name:
if name.startswith("vision_model") or name.startswith( name, loaded_weight = self.permute_qk_weight_for_rotary(
"multi_modal_projector" name, loaded_weight
): )
continue
name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
if "vision" in name:
continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
......
...@@ -307,7 +307,6 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -307,7 +307,6 @@ class TestOpenAIVisionServer(CustomTestCase):
self.assertGreater(len(video_response), 0) self.assertGreater(len(video_response), 0)
def test_regex(self): def test_regex(self):
return
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
regex = ( regex = (
...@@ -683,6 +682,31 @@ class TestJanusProServer(TestOpenAIVisionServer): ...@@ -683,6 +682,31 @@ class TestJanusProServer(TestOpenAIVisionServer):
pass pass
class TestLlama4Server(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--chat-template",
"llama-4",
"--mem-fraction-static",
"0.8",
"--tp-size=8",
"--context-length=8192",
],
)
cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
class TestGemma3itServer(TestOpenAIVisionServer): class TestGemma3itServer(TestOpenAIVisionServer):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment