"vscode:/vscode.git/clone" did not exist on "1f8f1261e31931c241a2b4517523aec7f5622771"
Unverified Commit fbebcb7a authored by Mick's avatar Mick Committed by GitHub
Browse files

model: support mllama4 (#5144)

parent 87eddedf
......@@ -486,8 +486,8 @@ multimodal_model_archs = [
"Gemma3ForConditionalGeneration",
"Grok1VForCausalLM",
"Grok1AForCausalLM",
# TODO: add multimodal support for "Llama4ForConditionalGeneration",
"LlavaLlamaForCausalLM",
"Llama4ForConditionalGeneration",
"LlavaMistralForCausalLM",
"LlavaQwenForCausalLM",
"LlavaVidForCausalLM",
......
......@@ -148,7 +148,8 @@ def get_embedding_and_mask(
placeholder_tensor,
).unsqueeze(-1)
num_mm_tokens_in_input_ids = special_multimodal_mask.sum()
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
logger.warning(
f"Number of tokens in multimodal embedding does not match those in the input text."
......@@ -172,7 +173,7 @@ def get_embedding_and_mask(
embedding = embedding[-num_multimodal:, :]
else:
raise RuntimeError(
"Insufficient multimodal embedding length. This is an internal error"
f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
)
return embedding, special_multimodal_mask
......
from typing import List, Mapping, Optional, Tuple, Union
from typing import List, Union
import torch
from PIL import Image
from transformers import Llama4Processor
from transformers.image_utils import SizeDict
from transformers.models.llama4.image_processing_llama4 import (
from transformers.models.llama4.image_processing_llama4_fast import (
find_supported_resolutions,
get_best_fit,
)
......@@ -15,7 +13,6 @@ from sglang.srt.managers.multimodal_processors.base_processor import (
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
from sglang.srt.utils import load_image
class Mllama4ImageProcessor(BaseMultimodalProcessor):
......@@ -25,6 +22,9 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
super().__init__(hf_config, server_args, _processor)
self.vision_config = hf_config.vision_config
self.text_config = hf_config.text_config
self.boi_token_index = hf_config.boi_token_index
self.eoi_token_index = hf_config.eoi_token_index
self.image_token_index = hf_config.image_token_index
self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token
)
......@@ -54,19 +54,16 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
)
# Process the images using the processor
processor = Llama4Processor.from_pretrained(
self.server_args.model_path, **kwargs
)
processor = self._processor
# Process the prompt and images
image_inputs = processor(
text=processed_data.input_text,
processor_output = self.process_mm_data(
input_text=processed_data.input_text,
images=processed_data.images,
return_tensors="pt",
)
# Handle image resolutions and aspect ratios
if "pixel_values" in image_inputs:
if "pixel_values" in processor_output:
image_processor = processor.image_processor
tokenizer = self._processor.tokenizer
......@@ -100,8 +97,8 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
]
# Add to image_inputs
image_inputs["aspect_ratios"] = aspect_ratios
image_inputs["patches_per_image"] = torch.tensor(patches_per_image)
processor_output["aspect_ratios"] = aspect_ratios
processor_output["patches_per_image"] = torch.tensor(patches_per_image)
# Process embed_is_patch
vocab = tokenizer.get_vocab()
......@@ -109,7 +106,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
image_end_id = vocab.get(processor.end_of_img_token, -1)
if patch_id != -1 and image_end_id != -1:
input_ids = image_inputs["input_ids"].view(-1)
input_ids = processor_output["input_ids"].view(-1)
# Remove BOS token if present
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
......@@ -129,33 +126,21 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
for per_image_input_ids in split_input_ids:
embed_is_patch.append(per_image_input_ids == patch_id)
image_inputs["embed_is_patch"] = embed_is_patch
processor_output["embed_is_patch"] = embed_is_patch
# Convert to the format expected by SGLang
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
processor_output["im_start_id"] = self.boi_token_index
processor_output["im_end_id"] = self.eoi_token_index
processor_output["im_token_id"] = self.image_token_index
# Add metadata for image processing
image_inputs["mm_items"] = [
processor_output["mm_items"] = [
MultimodalDataItem(
pixel_values=image_inputs["pixel_values"],
pixel_values=processor_output["pixel_values"],
modality=Modality.IMAGE,
# Add additional metadata needed for Llama4 vision processing
embed_is_patch=image_inputs.get("embed_is_patch", None),
aspect_ratios=image_inputs.get("aspect_ratios", None),
patches_per_image=image_inputs.get("patches_per_image", None),
)
]
return image_inputs
def get_patch_per_chunk(self):
"""Calculate patches per chunk based on vision config"""
image_size = self.vision_config.image_size
patch_size = self.vision_config.patch_size
assert (
image_size % patch_size == 0
), f"chunk size {image_size} should be multiple of patch_size {patch_size}"
ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2)))
return (image_size // patch_size) ** 2 // ds_ratio
return processor_output
from __future__ import annotations
import hashlib
from enum import Enum, auto
# Copyright 2023-2024 SGLang Team
......@@ -157,7 +158,7 @@ class Modality(Enum):
@dataclasses.dataclass
class MultimodalDataItem:
"""
A single multimodal data, from a single image/video/audio or other
A single multimodal data, from a single image/video/audio or others
"""
modality: Modality
......@@ -195,25 +196,54 @@ class MultimodalDataItem:
def set_pad_value(self):
"""
Set the pad value after first hashign the data
Set the pad value after first hashing the data
"""
def tensor_hash(f):
f_list = flatten_nested_list(f)
f_list = [x.flatten() if isinstance(x, torch.Tensor) else x for x in f_list]
f_cat = torch.concat(f_list).contiguous().numpy().tobytes()
return hash(f_cat)
def data_hash(data) -> int:
hash_bytes = hashlib.sha256(data).digest()[:8]
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
def tensor_hash(tensor_list) -> int:
"""
hash a tensor or a tensor list
"""
tensor = tensor_list
if isinstance(tensor_list, list):
tensor_list = flatten_nested_list(tensor_list)
tensor_list = [
x.flatten() if isinstance(x, torch.Tensor) else x
for x in tensor_list
]
tensor = torch.concat(tensor_list)
tensor = tensor.detach().contiguous()
if tensor.dtype == torch.bfloat16:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor = tensor.float()
if tensor.is_cuda:
tensor_cpu = torch.frombuffer(
tensor.storage().untyped(), dtype=tensor.dtype, count=tensor.numel()
).clone()
else:
tensor_cpu = tensor
mv = memoryview(tensor_cpu.numpy())
return data_hash(mv.tobytes())
def hash_feature(f):
if isinstance(f, list):
if isinstance(f[0], torch.Tensor):
return tensor_hash(f)
return hash(tuple(flatten_nested_list(f)))
return data_hash(tuple(flatten_nested_list(f)))
elif isinstance(f, np.ndarray):
arr = np.ascontiguousarray(f)
arr_bytes = arr.tobytes()
return hash(arr_bytes)
return hash(f)
return data_hash(arr_bytes)
elif isinstance(f, torch.Tensor):
return tensor_hash([f])
return data_hash(f)
if self.is_audio():
self.hash = hash_feature(self.audio_features)
......@@ -256,7 +286,7 @@ class MultimodalInputs:
mrope_position_delta: Optional[torch.Tensor] = None
# image
im_token_id: Optional[torch.Tensor] = None
im_token_id: Optional[int] = None
im_start_id: Optional[int] = None
im_end_id: Optional[int] = None
slice_start_id: Optional[int] = None
......@@ -330,10 +360,8 @@ class MultimodalInputs:
# args needed to be merged
optional_args = [
"items",
"image_offsets",
"mm_items",
"image_pad_len",
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
]
for arg in optional_args:
self_arg = getattr(self, arg, None)
......
......@@ -466,6 +466,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
):
super().__init__(config, quant_config, prefix)
def get_input_embeddings(self):
return self.model.embed_tokens
def _init_model(
self,
config: Llama4TextConfig,
......
# TODO: add Aapted from vllm/mllama4.py
from collections.abc import Iterable
from typing import Optional, Set, Tuple
from typing import List, Optional, Set, Tuple
import torch
from torch import nn
from transformers import Llama4Config
from transformers import Llama4Config, Llama4VisionModel
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternImageTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
......@@ -30,6 +35,9 @@ class Llama4ForConditionalGeneration(nn.Module):
self.config = config
self.quant_config = quant_config
self.vision_model = Llama4VisionModel(config.vision_config)
self.multi_modal_projector = Llama4MultiModalProjector(config)
# Initialize the language model
from sglang.srt.models.llama4 import Llama4ForCausalLM
......@@ -41,6 +49,29 @@ class Llama4ForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config.text_config)
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs
im_token_id: int = mm_inputs.im_token_id
pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id))
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(
self,
items: List[MultimodalDataItem],
) -> torch.Tensor:
pixel_values = (
torch.concat([item.pixel_values for item in items])
.to(next(self.vision_model.parameters()).device)
.type(next(self.vision_model.parameters()).dtype)
)
image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
image_features = image_outputs.last_hidden_state
vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat)
return projected_vision_flat
def forward(
self,
input_ids: torch.Tensor,
......@@ -49,7 +80,15 @@ class Llama4ForConditionalGeneration(nn.Module):
**kwargs: object,
) -> torch.Tensor:
return self.language_model(input_ids, positions, forward_batch)
hs = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
positions=positions,
)
return hs
def permute_qk_weight_for_rotary(
self,
......@@ -108,17 +147,17 @@ class Llama4ForConditionalGeneration(nn.Module):
)
for name, loaded_weight in weights:
if name.startswith("vision_model") or name.startswith(
"multi_modal_projector"
):
continue
name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
if not "vision" in name:
name, loaded_weight = self.permute_qk_weight_for_rotary(
name, loaded_weight
)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "vision" in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
......
......@@ -307,7 +307,6 @@ class TestOpenAIVisionServer(CustomTestCase):
self.assertGreater(len(video_response), 0)
def test_regex(self):
return
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
regex = (
......@@ -683,6 +682,31 @@ class TestJanusProServer(TestOpenAIVisionServer):
pass
class TestLlama4Server(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--chat-template",
"llama-4",
"--mem-fraction-static",
"0.8",
"--tp-size=8",
"--context-length=8192",
],
)
cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
class TestGemma3itServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment