Unverified Commit ea93079b authored by Wenchen Lo's avatar Wenchen Lo Committed by GitHub
Browse files

model: adapt mllama4 to VisionAttention (#8512)


Co-authored-by: default avatarroot <mickjagger19@icloud.com>
parent 4bec99ec
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
"""Utilities for Huggingface Transformers.""" """Utilities for Huggingface Transformers."""
import contextlib import contextlib
import logging
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -45,7 +44,7 @@ from sglang.srt.configs import ( ...@@ -45,7 +44,7 @@ from sglang.srt.configs import (
) )
from sglang.srt.configs.internvl import InternVLChatConfig from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url, lru_cache_frozenset from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig, ChatGLMConfig.model_type: ChatGLMConfig,
...@@ -317,15 +316,31 @@ def get_processor( ...@@ -317,15 +316,31 @@ def get_processor(
if config.model_type not in {"llava", "clip"}: if config.model_type not in {"llava", "clip"}:
kwargs["use_fast"] = use_fast kwargs["use_fast"] = use_fast
try:
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
processor = AutoProcessor.from_pretrained( except ValueError as e:
tokenizer_name, error_message = str(e)
*args, if "does not have a slow version" in error_message:
trust_remote_code=trust_remote_code, logger.info(
revision=revision, f"Processor {tokenizer_name} does not have a slow version. Automatically use fast version"
**kwargs, )
) kwargs["use_fast"] = True
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
else:
raise e
tokenizer = get_tokenizer_from_processor(processor) tokenizer = get_tokenizer_from_processor(processor)
attach_additional_stop_token_ids(tokenizer) attach_additional_stop_token_ids(tokenizer)
......
...@@ -4,7 +4,7 @@ import dataclasses ...@@ -4,7 +4,7 @@ import dataclasses
import functools import functools
import math import math
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Any, Optional, Tuple, Union from typing import Any, Callable, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -308,6 +308,7 @@ class VisionFlash3Attention(nn.Module): ...@@ -308,6 +308,7 @@ class VisionFlash3Attention(nn.Module):
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device) cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen = seq_lens.max().item() max_seqlen = seq_lens.max().item()
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q, q,
k, k,
...@@ -358,6 +359,9 @@ class VisionAttention(nn.Module): ...@@ -358,6 +359,9 @@ class VisionAttention(nn.Module):
qkv_bias: bool = True, qkv_bias: bool = True,
qk_normalization: bool = False, qk_normalization: bool = False,
layer_norm_eps: float = 1e-06, layer_norm_eps: float = 1e-06,
customized_position_embedding_applier: Callable[
[torch.Tensor, torch.Tensor, Any, Any], Tuple[torch.Tensor, torch.Tensor]
] = None,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
...@@ -392,6 +396,7 @@ class VisionAttention(nn.Module): ...@@ -392,6 +396,7 @@ class VisionAttention(nn.Module):
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
) )
# priority: server_args > passed qkv_backend > sdpa
if global_server_args_dict["mm_attention_backend"] is None: if global_server_args_dict["mm_attention_backend"] is None:
if qkv_backend is None: if qkv_backend is None:
qkv_backend = "sdpa" qkv_backend = "sdpa"
...@@ -401,6 +406,9 @@ class VisionAttention(nn.Module): ...@@ -401,6 +406,9 @@ class VisionAttention(nn.Module):
print_info_once(f"Using {qkv_backend} as multimodal attention backend.") print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
self.customized_position_embedding_applier = (
customized_position_embedding_applier
)
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend]( self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
head_dim=self.head_size, head_dim=self.head_size,
num_heads=self.num_attention_heads_per_partition, num_heads=self.num_attention_heads_per_partition,
...@@ -473,13 +481,13 @@ class VisionAttention(nn.Module): ...@@ -473,13 +481,13 @@ class VisionAttention(nn.Module):
if x.dim() == 2: if x.dim() == 2:
x = x.unsqueeze(0) x = x.unsqueeze(0)
assert x.dim() == 3, x.shape assert x.dim() == 3, x.shape
bsz, s, _ = x.shape x_shape = x.shape
bsz, s, _ = x_shape
head = self.num_attention_heads_per_partition head = self.num_attention_heads_per_partition
kv_head = self.num_attention_kv_heads_per_partition kv_head = self.num_attention_kv_heads_per_partition
if self.use_qkv_parallel: if self.use_qkv_parallel:
# [b, s, embed_dim] --> [b, s, embed_dim] # [b, s, embed_dim] --> [b, s, embed_dim]
qkv, _ = self.qkv_proj(x) qkv, _ = self.qkv_proj(x)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# [b, s, embed_dim] --> [b * s, head, head_size] # [b, s, embed_dim] --> [b * s, head, head_size]
...@@ -508,16 +516,25 @@ class VisionAttention(nn.Module): ...@@ -508,16 +516,25 @@ class VisionAttention(nn.Module):
] ]
if position_embeddings is not None: if position_embeddings is not None:
cos, sin = position_embeddings
original_shape = q.shape original_shape = q.shape
# [total_tokens, head, head_size]
q = q.view(-1, head, self.head_size)
k = k.view(-1, head, self.head_size)
q, k = apply_rotary_pos_emb(q, k, cos, sin) if self.customized_position_embedding_applier is not None:
q, k = self.customized_position_embedding_applier(
q, k, position_embeddings, x_shape
)
q = q.view(original_shape)
k = k.view(original_shape)
else:
cos, sin = position_embeddings
# [total_tokens, head, head_size]
q = q.view(-1, head, self.head_size)
k = k.view(-1, head, self.head_size)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q = q.view(original_shape) q = q.view(original_shape)
k = k.view(original_shape) k = k.view(original_shape)
if q.dim() == 4: if q.dim() == 4:
# [b, s, head, head_size] --> [b * s, head, head_size] # [b, s, head, head_size] --> [b * s, head, head_size]
......
...@@ -70,7 +70,6 @@ from sglang.srt.managers.io_struct import ( ...@@ -70,7 +70,6 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut, BatchMultimodalOut,
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
BlockReqType,
CloseSessionReqInput, CloseSessionReqInput,
ConfigureLoggingReq, ConfigureLoggingReq,
EmbeddingReqInput, EmbeddingReqInput,
...@@ -202,13 +201,29 @@ class TokenizerManager: ...@@ -202,13 +201,29 @@ class TokenizerManager:
if self.model_config.is_multimodal: if self.model_config.is_multimodal:
import_processors() import_processors()
_processor = get_processor( try:
server_args.tokenizer_path, _processor = get_processor(
tokenizer_mode=server_args.tokenizer_mode, server_args.tokenizer_path,
trust_remote_code=server_args.trust_remote_code, tokenizer_mode=server_args.tokenizer_mode,
revision=server_args.revision, trust_remote_code=server_args.trust_remote_code,
use_fast=not server_args.disable_fast_image_processor, revision=server_args.revision,
) use_fast=not server_args.disable_fast_image_processor,
)
except ValueError as e:
error_message = str(e)
if "does not have a slow version" in error_message:
logger.info(
f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version"
)
_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
use_fast=True,
)
else:
raise e
transport_mode = _determine_tensor_transport_mode(self.server_args) transport_mode = _determine_tensor_transport_mode(self.server_args)
# We want to parallelize the image pre-processing so we create an executor for it # We want to parallelize the image pre-processing so we create an executor for it
......
...@@ -241,13 +241,22 @@ class Llama4Attention(nn.Module): ...@@ -241,13 +241,22 @@ class Llama4Attention(nn.Module):
if self.use_qk_norm if self.use_qk_norm
else None else None
) )
qkv_quant_config = quant_config
o_quant_config = quant_config
if quant_config and hasattr(quant_config, "ignore") and quant_config.ignore:
if add_prefix("q_proj", prefix) in quant_config.ignore:
qkv_quant_config = None
if add_prefix("o_proj", prefix) in quant_config.ignore:
o_quant_config = None
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size, hidden_size=hidden_size,
head_size=self.head_dim, head_size=self.head_dim,
total_num_heads=self.total_num_heads, total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads, total_num_kv_heads=self.total_num_kv_heads,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=qkv_quant_config,
prefix=add_prefix("qkv_proj", prefix), prefix=add_prefix("qkv_proj", prefix),
tp_rank=attn_tp_rank, tp_rank=attn_tp_rank,
tp_size=attn_tp_size, tp_size=attn_tp_size,
...@@ -257,7 +266,7 @@ class Llama4Attention(nn.Module): ...@@ -257,7 +266,7 @@ class Llama4Attention(nn.Module):
input_size=self.total_num_heads * self.head_dim, input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size, output_size=hidden_size,
bias=bias_o_proj, bias=bias_o_proj,
quant_config=quant_config, quant_config=o_quant_config,
prefix=add_prefix("o_proj", prefix), prefix=add_prefix("o_proj", prefix),
tp_rank=attn_tp_rank, tp_rank=attn_tp_rank,
tp_size=attn_tp_size, tp_size=attn_tp_size,
......
import json as json_lib import json as json_lib
import logging import logging
import math
import os import os
from collections.abc import Iterable from collections.abc import Iterable
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import Llama4Config from transformers import Llama4Config, Llama4VisionConfig
from transformers.models.llama4.modeling_llama4 import ( from transformers.models.llama4.modeling_llama4 import (
Llama4MultiModalProjector, Llama4MultiModalProjector,
Llama4VisionModel, vision_apply_rotary_emb,
) )
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.quantization import QuantizationConfig
...@@ -26,10 +33,10 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -26,10 +33,10 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import is_cpu
from sglang.srt.utils import add_prefix, is_cpu
_is_cpu = is_cpu() _is_cpu = is_cpu()
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
...@@ -39,6 +46,376 @@ from sglang.srt.utils import add_prefix ...@@ -39,6 +46,376 @@ from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Llama4VisionMLP(nn.Module):
def __init__(
self,
input_size: int,
intermediate_size: int,
output_size: int,
bias: bool,
output_activation: bool,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
cls_fc1 = ReplicatedLinear if use_data_parallel else ColumnParallelLinear
self.fc1 = cls_fc1(
input_size=input_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear
self.fc2 = cls_fc2(
input_size=intermediate_size,
output_size=output_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
self.activation_fn = nn.GELU()
self.output_activation = output_activation
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
if self.output_activation:
return self.activation_fn(hidden_states)
return hidden_states
def pixel_shuffle(input_tensor, shuffle_ratio):
# input_tensor: [batch_size, num_patches, channels]
batch_size, num_patches, channels = input_tensor.shape
patch_size = int(math.sqrt(num_patches))
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
batch_size, height, width, channels = input_tensor.size()
reshaped_tensor = input_tensor.view(
batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
)
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
reshaped_tensor = reshaped_tensor.view(
batch_size,
int(height * shuffle_ratio),
int(width * shuffle_ratio),
int(channels / (shuffle_ratio**2)),
)
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
return output_tensor
class Llama4VisionPixelShuffleMLP(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
self.mlp = Llama4VisionMLP(
input_size=config.intermediate_size,
intermediate_size=config.projector_input_dim,
output_size=config.projector_output_dim,
bias=config.multi_modal_projector_bias,
output_activation=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
return self.mlp(encoded_patches)
def apply_position_embedding(q, k, freqs_ci, shape):
# [batch_size_times_num_tiles, num_channels]
input_shape = shape[:2]
# [batch_size_times_num_tiles, num_channels, num_heads, head_dim]
hidden_shape = (*input_shape, *q.shape[-2:])
q = q.view(hidden_shape)
k = k.view(hidden_shape)
q, k = vision_apply_rotary_emb(q, k, freqs_ci)
return q, k
class Llama4VisionEncoderLayer(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.intermediate_size = config.intermediate_size
self.self_attn = VisionAttention(
self.hidden_size,
self.num_attention_heads,
self.hidden_size,
use_qkv_parallel=True,
# vision_model is explicitly ignored in Maverick-17B-128E-Instruct-FP8
quant_config=None,
dropout=0.0,
qkv_backend="sdpa",
softmax_in_single_precision=False,
flatten_batch=False,
prefix=add_prefix("self_attn", prefix),
qkv_bias=True,
customized_position_embedding_applier=apply_position_embedding,
)
self.mlp = Llama4VisionMLP(
input_size=config.hidden_size,
intermediate_size=config.intermediate_size,
output_size=config.hidden_size,
bias=True,
output_activation=False,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
self.input_layernorm = nn.LayerNorm(config.hidden_size)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
def forward(
self,
hidden_state: torch.Tensor,
freqs_ci: torch.Tensor,
):
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state = self.self_attn(hidden_state, position_embeddings=freqs_ci)
hidden_state = residual + hidden_state
# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
hidden_state = residual + hidden_state
outputs = hidden_state
return outputs
class Llama4VisionEncoder(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
Llama4VisionEncoderLayer(
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel,
)
for layer_idx in range(config.num_hidden_layers)
]
)
def forward(
self,
hidden_states: torch.Tensor,
freqs_ci: torch.Tensor, # TODO: move this to an attribute instead of keeping it around
) -> torch.Tensor:
r"""
Args:
hidden_states (`torch.FloatTensor` of shape
`(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to
directly pass an embedded representation. This is useful if you
want more control over how to convert `input_ids` indices into
associated vectors than the model's internal embedding
lookup matrix.
"""
for encoder_layer in self.layers:
layer_outputs = encoder_layer(hidden_states, freqs_ci=freqs_ci)
hidden_states = layer_outputs
return hidden_states
class Llama4UnfoldConvolution(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
kernel_size = config.patch_size
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
params = {
"input_size": config.num_channels * kernel_size[0] * kernel_size[1],
"output_size": config.hidden_size,
"bias": False,
"quant_config": quant_config,
"prefix": f"{prefix}.linear",
}
if use_data_parallel:
cls = ReplicatedLinear
else:
cls = ColumnParallelLinear
params["gather_output"] = True
self.linear = cls(**params)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.unfold(hidden_states)
hidden_states = hidden_states.permute(0, 2, 1)
hidden_states, _ = self.linear(hidden_states)
return hidden_states
class Llama4VisionRotaryEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
idx = config.image_size // config.patch_size
img_idx = torch.arange(idx**2, dtype=torch.int32).reshape(idx**2, 1)
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
img_idx[-1, -1] = -2 # ID_CLS_TOKEN
frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x
frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y
freq_dim = config.hidden_size // config.num_attention_heads // 2
rope_freq = 1.0 / (
config.rope_theta
** (torch.arange(0, freq_dim, 2)[: (freq_dim // 2)].float() / freq_dim)
)
freqs_x = (
(frequencies_x + 1)[..., None] * rope_freq[None, None, :]
).repeat_interleave(2, dim=-1)
freqs_y = (
(frequencies_y + 1)[..., None] * rope_freq[None, None, :]
).repeat_interleave(2, dim=-1)
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
freq_cis = torch.view_as_complex(
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
)
self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
def forward(self, hidden_states):
return self.freqs_ci.to(hidden_states.device)
class Llama4VisionModel(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.image_size = config.image_size
self.patch_size = config.patch_size
self.hidden_size = config.hidden_size
self.num_channels = config.num_channels
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
self.scale = config.hidden_size**-0.5
self.patch_embedding = Llama4UnfoldConvolution(
config,
quant_config=quant_config,
prefix=f"{prefix}.patch_embedding",
)
self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
self.positional_embedding_vlm = nn.Parameter(
self.scale * torch.randn(self.num_patches, self.hidden_size)
)
self.rotary_embedding = Llama4VisionRotaryEmbedding(config)
# layer norms
self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5)
self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)
# encoders
self.model = Llama4VisionEncoder(
config,
quant_config=quant_config,
prefix=f"{prefix}.model",
)
self.vision_adapter = Llama4VisionPixelShuffleMLP(
config,
quant_config,
prefix=f"{prefix}.vision_adapter",
)
def forward(
self,
pixel_values: torch.Tensor,
) -> torch.Tensor:
# Patch embedding
hidden_state = self.patch_embedding(pixel_values)
num_tiles, num_patches, hidden_dim = hidden_state.shape
# Add cls token
class_embedding = self.class_embedding.expand(
hidden_state.shape[0], 1, hidden_state.shape[-1]
)
hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
num_patches += 1
# Position embeddings
hidden_state = hidden_state.reshape(
num_tiles,
1,
num_patches,
hidden_dim,
)
positional_embedding = self.positional_embedding_vlm.to(
dtype=hidden_state.dtype, device=hidden_state.device
)
hidden_state = hidden_state + positional_embedding
hidden_state = self.layernorm_pre(hidden_state)
hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)
freqs_ci = self.rotary_embedding(pixel_values)
# Apply encoder
hidden_state = self.model(hidden_state, freqs_ci=freqs_ci)
hidden_state = self.layernorm_post(hidden_state)
# Remove CLS token output
hidden_state = hidden_state[:, :-1, :]
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings
hidden_state = self.vision_adapter(hidden_state)
return hidden_state
class Llama4ForConditionalGeneration(nn.Module): class Llama4ForConditionalGeneration(nn.Module):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
...@@ -60,7 +437,8 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -60,7 +437,8 @@ class Llama4ForConditionalGeneration(nn.Module):
if not self.has_vision_weights: if not self.has_vision_weights:
logger.warning( logger.warning(
"No vision weights found in checkpoint. Model will run in text-only mode. " "No vision weights found in checkpoint. Model will run in text-only mode. "
"Multimodal capabilities (image processing) will be unavailable." "Multimodal capabilities (vision understanding) will be unavailable. "
"Please not that this warning might be inaccurate if the weights haven't been fully downloaded"
) )
self.has_vision = ( self.has_vision = (
...@@ -68,7 +446,12 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -68,7 +446,12 @@ class Llama4ForConditionalGeneration(nn.Module):
) )
if self.has_vision: if self.has_vision:
self.vision_model = Llama4VisionModel(config.vision_config) self.vision_model = Llama4VisionModel(
config.vision_config,
quant_config=quant_config,
prefix=add_prefix("vision_model", prefix),
)
self.multi_modal_projector = Llama4MultiModalProjector(config) self.multi_modal_projector = Llama4MultiModalProjector(config)
else: else:
self.vision_model = None self.vision_model = None
...@@ -112,7 +495,6 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -112,7 +495,6 @@ class Llama4ForConditionalGeneration(nn.Module):
filename="model.safetensors.index.json", filename="model.safetensors.index.json",
cache_dir=None, cache_dir=None,
) )
if index_file_path and os.path.exists(index_file_path): if index_file_path and os.path.exists(index_file_path):
return self._check_vision_weights_in_index(index_file_path) return self._check_vision_weights_in_index(index_file_path)
...@@ -120,7 +502,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -120,7 +502,7 @@ class Llama4ForConditionalGeneration(nn.Module):
# If we can't access the cache, fall back to config-based detection # If we can't access the cache, fall back to config-based detection
pass pass
# Fallback assume text-only # Fallback, assume text-only
return False return False
def _check_vision_weights_in_index(self, index_file: str) -> bool: def _check_vision_weights_in_index(self, index_file: str) -> bool:
...@@ -131,7 +513,6 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -131,7 +513,6 @@ class Llama4ForConditionalGeneration(nn.Module):
vision_patterns = ["vision_model", "vision_tower", "multi_modal_projector"] vision_patterns = ["vision_model", "vision_tower", "multi_modal_projector"]
weight_names = index_data.get("weight_map", {}).keys() weight_names = index_data.get("weight_map", {}).keys()
return any( return any(
pattern in weight_name pattern in weight_name
for weight_name in weight_names for weight_name in weight_names
...@@ -150,17 +531,17 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -150,17 +531,17 @@ class Llama4ForConditionalGeneration(nn.Module):
# For text-only models, return None or raise an error # For text-only models, return None or raise an error
if not self.has_vision or self.vision_model is None: if not self.has_vision or self.vision_model is None:
raise ValueError("Vision model not available for text-only checkpoint") raise ValueError("Vision model not available for text-only checkpoint")
pixel_values = ( pixel_values = (
torch.concat([item.feature for item in items]) torch.concat([item.feature for item in items])
.to(next(self.vision_model.parameters()).device) .to(next(self.vision_model.parameters()).device)
.type(next(self.vision_model.parameters()).dtype) .type(next(self.vision_model.parameters()).dtype)
) )
image_features = self.vision_model(pixel_values)
image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
image_features = image_outputs.last_hidden_state
vision_flat = image_features.view(-1, image_features.size(-1)) vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat) projected_vision_flat = self.multi_modal_projector(vision_flat)
return projected_vision_flat return projected_vision_flat
def forward( def forward(
...@@ -246,31 +627,47 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -246,31 +627,47 @@ class Llama4ForConditionalGeneration(nn.Module):
num_experts=num_experts, num_experts=num_experts,
) )
loaded_params = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if self._should_skip_weight(name): if self._should_skip_weight(name):
continue continue
name = self._transform_weight_name(name) name = self._transform_weight_name(name)
if "vision" not in name: if "vision" in name:
name = name.replace(".self_attn.o_proj", ".self_attn.proj")
else:
name, loaded_weight = self.permute_qk_weight_for_rotary( name, loaded_weight = self.permute_qk_weight_for_rotary(
name, loaded_weight name, loaded_weight
) )
if self._handle_scale_remapping(name, params_dict): if self._handle_scale_remapping(name, params_dict):
loaded_params.add(name)
continue continue
if self._handle_stacked_params( if self._handle_stacked_params(
name, loaded_weight, stacked_params_mapping, params_dict name, loaded_weight, stacked_params_mapping, params_dict, loaded_params
): ):
continue continue
if self._handle_expert_weights( if self._handle_expert_weights(
name, loaded_weight, expert_params_mapping, params_dict, num_experts name,
loaded_weight,
expert_params_mapping,
params_dict,
num_experts,
loaded_params,
): ):
continue continue
loaded_params.add(name)
self._handle_default_weight(name, loaded_weight, params_dict) self._handle_default_weight(name, loaded_weight, params_dict)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
logger.warning(
f"Some weights are not initialized from checkpoints {unloaded_params}"
)
def _should_skip_weight(self, name: str) -> bool: def _should_skip_weight(self, name: str) -> bool:
"""Check if we should skip loading this weight.""" """Check if we should skip loading this weight."""
...@@ -301,11 +698,13 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -301,11 +698,13 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
stacked_params_mapping: list, stacked_params_mapping: list,
params_dict: dict, params_dict: dict,
loaded_params: set,
) -> bool: ) -> bool:
"""Handle stacked parameter loading. Returns True if handled.""" """Handle stacked parameter loading. Returns True if handled."""
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name in name and "vision" not in name: if weight_name in name:
transformed_name = name.replace(weight_name, param_name) transformed_name = name.replace(weight_name, param_name)
loaded_params.add(transformed_name)
param = params_dict[transformed_name] param = params_dict[transformed_name]
param.weight_loader(param, loaded_weight, shard_id) param.weight_loader(param, loaded_weight, shard_id)
return True return True
...@@ -318,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -318,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module):
expert_params_mapping: list, expert_params_mapping: list,
params_dict: dict, params_dict: dict,
num_experts: int, num_experts: int,
loaded_params: set,
) -> bool: ) -> bool:
"""Handle expert weight loading for MoE (Mixture of Experts) layers. """Handle expert weight loading for MoE (Mixture of Experts) layers.
...@@ -336,16 +736,16 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -336,16 +736,16 @@ class Llama4ForConditionalGeneration(nn.Module):
if "experts.gate_up_proj" not in name and "experts.down_proj" not in name: if "experts.gate_up_proj" not in name and "experts.down_proj" not in name:
return self._handle_other_expert_params( return self._handle_other_expert_params(
name, loaded_weight, expert_params_mapping, params_dict name, loaded_weight, expert_params_mapping, params_dict, loaded_params
) )
if "scale" in name: if "scale" in name:
return self._handle_expert_scale_params( return self._handle_expert_scale_params(
name, loaded_weight, params_dict, num_experts name, loaded_weight, params_dict, num_experts, loaded_params
) )
else: else:
return self._handle_expert_weight_params( return self._handle_expert_weight_params(
name, loaded_weight, params_dict, num_experts name, loaded_weight, params_dict, num_experts, loaded_params
) )
def _handle_other_expert_params( def _handle_other_expert_params(
...@@ -354,6 +754,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -354,6 +754,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
expert_params_mapping: list, expert_params_mapping: list,
params_dict: dict, params_dict: dict,
loaded_params: set,
) -> bool: ) -> bool:
"""Handle expert parameters that are not gate_up_proj or down_proj weights. """Handle expert parameters that are not gate_up_proj or down_proj weights.
...@@ -362,6 +763,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -362,6 +763,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight: The weight tensor to be loaded loaded_weight: The weight tensor to be loaded
expert_params_mapping: List of tuples mapping checkpoint names to model parameters expert_params_mapping: List of tuples mapping checkpoint names to model parameters
params_dict: Dictionary of model parameters params_dict: Dictionary of model parameters
loaded_params: Set of loaded parameter names
Returns: Returns:
bool: True if parameter was found and handled, False otherwise bool: True if parameter was found and handled, False otherwise
...@@ -373,6 +775,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -373,6 +775,7 @@ class Llama4ForConditionalGeneration(nn.Module):
param.weight_loader( param.weight_loader(
param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id
) )
loaded_params.add(transformed_name)
return True return True
return False return False
...@@ -411,6 +814,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -411,6 +814,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
params_dict: dict, params_dict: dict,
num_experts: int, num_experts: int,
loaded_params: set,
) -> bool: ) -> bool:
"""Handle quantization scale parameters for expert weights. """Handle quantization scale parameters for expert weights.
...@@ -419,6 +823,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -419,6 +823,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight: Scale tensor to be loaded loaded_weight: Scale tensor to be loaded
params_dict: Dictionary of model parameters params_dict: Dictionary of model parameters
num_experts: Total number of experts for broadcast operations num_experts: Total number of experts for broadcast operations
loaded_params: Set of loaded parameter names
Returns: Returns:
bool: True (always handles scale parameters) bool: True (always handles scale parameters)
...@@ -447,6 +852,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -447,6 +852,7 @@ class Llama4ForConditionalGeneration(nn.Module):
# Load the same scale for all experts # Load the same scale for all experts
for expert_id in range(num_experts): for expert_id in range(num_experts):
param.data[expert_id] = loaded_weight param.data[expert_id] = loaded_weight
loaded_params.add(transformed_name)
return True return True
...@@ -456,6 +862,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -456,6 +862,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
params_dict: dict, params_dict: dict,
num_experts: int, num_experts: int,
loaded_params: set,
) -> bool: ) -> bool:
"""Handle actual weight tensors for expert layers (gate_up_proj and down_proj). """Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
...@@ -464,6 +871,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -464,6 +871,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight: Weight tensor(s) to be loaded loaded_weight: Weight tensor(s) to be loaded
params_dict: Dictionary of model parameters params_dict: Dictionary of model parameters
num_experts: Total number of experts for tensor distribution num_experts: Total number of experts for tensor distribution
loaded_params: Set of loaded parameter names
Returns: Returns:
bool: True (always handles weight parameters) bool: True (always handles weight parameters)
...@@ -486,6 +894,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -486,6 +894,7 @@ class Llama4ForConditionalGeneration(nn.Module):
param = params_dict[param_name] param = params_dict[param_name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
loaded_params.add(param_name)
# Handle the case where loaded_weight might be a single tensor for all experts # Handle the case where loaded_weight might be a single tensor for all experts
if weight_chunk.dim() == 2: if weight_chunk.dim() == 2:
......
...@@ -12,7 +12,6 @@ import torch ...@@ -12,7 +12,6 @@ import torch
from PIL import Image from PIL import Image
from transformers import BaseImageProcessorFast from transformers import BaseImageProcessorFast
from sglang.srt.managers.mm_utils import TransportProxyTensor
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import load_audio, load_image, load_video, logger from sglang.srt.utils import load_audio, load_image, load_video, logger
...@@ -218,8 +217,10 @@ class BaseMultimodalProcessor(ABC): ...@@ -218,8 +217,10 @@ class BaseMultimodalProcessor(ABC):
kwargs["audio"] = audios kwargs["audio"] = audios
processor = self._processor processor = self._processor
if hasattr(processor, "image_processor") and isinstance( if (
processor.image_processor, BaseImageProcessorFast hasattr(processor, "image_processor")
and isinstance(processor.image_processor, BaseImageProcessorFast)
and not self.server_args.disable_fast_image_processor
): ):
kwargs["device"] = "cuda" kwargs["device"] = "cuda"
result = processor.__call__( result = processor.__call__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment