Unverified Commit f04c80dc authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

Add Llama4 support (#5092)


Co-authored-by: default avatarCheng Wan <cwan39@gatech.edu>
Co-authored-by: default avatarfzyzcjy <ch271828n@outlook.com>
Co-authored-by: default avatarispobock <ispobaoke@163.com>
parent d1bb1711
......@@ -35,6 +35,7 @@ class RadixAttention(nn.Module):
sliding_window_size: int = -1,
is_cross_attention: bool = False,
prefix: str = "",
use_irope: bool = False,
):
super().__init__()
self.tp_q_head_num = num_heads
......@@ -50,6 +51,7 @@ class RadixAttention(nn.Module):
self.is_cross_attention = is_cross_attention
self.k_scale = None
self.v_scale = None
self.use_irope = use_irope
def forward(
self,
......
......@@ -733,6 +733,69 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
return new_freqs
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
):
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base)
inv_freqs = inv_freqs[: (self.rotary_dim // 2)]
return inv_freqs
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
# self.max_position_embeddings here is number of image patches
# i.e. (image_size // patch_size) ** 2
num_patches = self.max_position_embeddings
img_idx = torch.arange(num_patches, dtype=torch.int32).reshape(num_patches, 1)
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN
num_patches_single_dim = int(math.sqrt(num_patches))
frequencies_x = img_idx % num_patches_single_dim
frequencies_y = img_idx // num_patches_single_dim
freqs_x = (
(frequencies_x + 1)[..., None] * inv_freq[None, None, :]
).repeat_interleave(2, dim=-1)
freqs_y = (
(frequencies_y + 1)[..., None] * inv_freq[None, None, :]
).repeat_interleave(2, dim=-1)
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
cache = torch.view_as_complex(
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
)
return cache
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
broadcast_shape = [
d if i == 1 or i == (query_.ndim - 1) else 1
for i, d in enumerate(query_.shape)
]
freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
return query_out.type_as(query), key_out.type_as(key)
class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
......
from typing import List, Mapping, Optional, Tuple, Union
import torch
from PIL import Image
from transformers import Llama4Processor
from transformers.image_utils import SizeDict
from transformers.models.llama4.image_processing_llama4 import (
find_supported_resolutions,
get_best_fit,
)
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
from sglang.srt.utils import load_image
class Mllama4ImageProcessor(BaseMultimodalProcessor):
models = [Llama4ForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.vision_config = hf_config.vision_config
self.text_config = hf_config.text_config
self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token
)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
max_req_input_len=None,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
# Process images and text using the base processor's load_mm_data method
processed_data = self.load_mm_data(
prompt=input_text,
multimodal_tokens=self.multimodal_tokens,
max_req_input_len=max_req_input_len or 4096,
image_data=image_data,
return_text=True,
)
# Process the images using the processor
processor = Llama4Processor.from_pretrained(
self.server_args.model_path, **kwargs
)
# Process the prompt and images
image_inputs = processor(
text=processed_data.input_text,
images=processed_data.images,
return_tensors="pt",
)
# Handle image resolutions and aspect ratios
if "pixel_values" in image_inputs:
image_processor = processor.image_processor
tokenizer = self._processor.tokenizer
# Calculate tile size and find supported resolutions
tile_size = self.vision_config.image_size
max_num_tiles = getattr(self.vision_config, "max_patches", 1)
possible_resolutions = find_supported_resolutions(
max_num_chunks=max_num_tiles,
patch_size=SizeDict(height=tile_size, width=tile_size),
)
# Find best fit for each image
best_fit_sizes = [
get_best_fit(
(image.size[1], image.size[0]), # (height, width)
torch.tensor(possible_resolutions),
resize_to_max_canvas=image_processor.resize_to_max_canvas,
)
for image in processed_data.images
]
# Calculate aspect ratios and patches per image
aspect_ratios = [
(image_size[0] // tile_size, image_size[1] // tile_size)
for image_size in best_fit_sizes
]
patches_per_image = [
1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
]
# Add to image_inputs
image_inputs["aspect_ratios"] = aspect_ratios
image_inputs["patches_per_image"] = torch.tensor(patches_per_image)
# Process embed_is_patch
vocab = tokenizer.get_vocab()
patch_id = vocab.get(processor.img_patch_token, -1)
image_end_id = vocab.get(processor.end_of_img_token, -1)
if patch_id != -1 and image_end_id != -1:
input_ids = image_inputs["input_ids"].view(-1)
# Remove BOS token if present
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
input_ids = input_ids[1:]
# Find image end indices and split input_ids
image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
if image_end_indices.size(0) > 0:
# Split at image boundaries
split_indices = (image_end_indices + 1)[:-1]
split_input_ids = torch.tensor_split(input_ids, split_indices)
split_input_ids = [x for x in split_input_ids if x.numel() > 0]
# Create embed_is_patch for each image
embed_is_patch = []
for per_image_input_ids in split_input_ids:
embed_is_patch.append(per_image_input_ids == patch_id)
image_inputs["embed_is_patch"] = embed_is_patch
# Convert to the format expected by SGLang
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
# Add metadata for image processing
image_inputs["mm_items"] = [
MultimodalDataItem(
pixel_values=image_inputs["pixel_values"],
modality=Modality.IMAGE,
# Add additional metadata needed for Llama4 vision processing
embed_is_patch=image_inputs.get("embed_is_patch", None),
aspect_ratios=image_inputs.get("aspect_ratios", None),
patches_per_image=image_inputs.get("patches_per_image", None),
)
]
return image_inputs
def get_patch_per_chunk(self):
"""Calculate patches per chunk based on vision config"""
image_size = self.vision_config.image_size
patch_size = self.vision_config.patch_size
assert (
image_size % patch_size == 0
), f"chunk size {image_size} should be multiple of patch_size {patch_size}"
ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2)))
return (image_size // patch_size) ** 2 // ds_ratio
......@@ -128,6 +128,7 @@ class ModelRunner:
self.model_config.attention_arch == AttentionArch.MLA
and not server_args.disable_mla
)
self.attention_chunk_size = model_config.attention_chunk_size
# Model-specific adjustment
self.model_specific_adjustment()
......
......@@ -63,6 +63,7 @@ class LlamaMLP(nn.Module):
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
......@@ -78,6 +79,7 @@ class LlamaMLP(nn.Module):
bias=False,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
reduce_results=reduce_results,
)
if hidden_act != "silu":
raise ValueError(
......@@ -281,7 +283,7 @@ class LlamaModel(nn.Module):
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: LlamaDecoderLayer(
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
),
prefix="model.layers",
)
......@@ -375,9 +377,7 @@ class LlamaForCausalLM(nn.Module):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = LlamaModel(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
# Llama 3.2 1B Instruct set tie_word_embeddings to True
# Llama 3.1 8B Instruct set tie_word_embeddings to False
if self.config.tie_word_embeddings:
......@@ -402,6 +402,14 @@ class LlamaForCausalLM(nn.Module):
self.capture_aux_hidden_states = False
def _init_model(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
return LlamaModel(config, quant_config=quant_config, prefix=prefix)
@torch.no_grad()
def forward(
self,
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from
# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/models/llama4.py
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import logging
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import Llama4TextConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
logger = logging.getLogger(__name__)
class Llama4MoE(nn.Module):
@torch.compile(dynamic=True, backend=get_compiler_backend())
@staticmethod
def custom_routing_function(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1)
router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
hidden_states.dtype
)
return (
router_scores_aK.view(-1).reshape(router_scores_aK.shape),
router_indices_aK.to(torch.int32),
)
def __init__(
self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok
intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear(
config.hidden_size,
config.num_local_experts,
bias=False,
quant_config=None,
prefix=add_prefix("router", prefix),
)
self.experts = FusedMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
custom_routing_function=Llama4MoE.custom_routing_function,
intermediate_size=intermediate_size_moe,
reduce_results=False,
renormalize=False,
quant_config=quant_config,
apply_router_weight_on_input=True,
prefix=add_prefix("experts", prefix),
)
self.shared_expert = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe,
hidden_act="silu",
quant_config=quant_config,
prefix=add_prefix("shared_expert", prefix),
reduce_results=False, # We need to do scatter before reduce
)
def forward(self, hidden_states):
# router_scores: [num_tokens, num_experts]
router_logits, _ = self.router(hidden_states)
shared_out = self.shared_expert(hidden_states)
routed_out = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
out_aD = routed_out + shared_out
if self.tp_size > 1:
out_aD = tensor_model_parallel_all_reduce(out_aD)
return out_aD
class Llama4Attention(nn.Module):
def __init__(
self,
config: Llama4TextConfig,
layer_id: int,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.layer_id = layer_id
self.hidden_size = hidden_size
self.use_rope = int((layer_id + 1) % 4 != 0)
self.use_qk_norm = config.use_qk_norm and self.use_rope
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.attn_temperature_tuning = config.attn_temperature_tuning
self.floor_scale = config.floor_scale
self.attn_scale = config.attn_scale
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.n_rep = self.num_heads // self.num_kv_heads
self.qk_norm = (
RMSNorm(
hidden_size=self.head_dim,
eps=config.rms_norm_eps,
)
if self.use_qk_norm
else None
)
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias_o_proj,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
)
is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type in ["llama", "llama4"]:
is_neox_style = False
self.rotary_emb = (
get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=int(rope_theta),
rope_scaling=rope_scaling if rope_scaling != "default" else None,
is_neox_style=is_neox_style,
)
if self.use_rope
else None
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
prefix=add_prefix("attn", prefix),
use_irope=self.use_rope,
)
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
return attn_scale.unsqueeze(-1)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k)
if self.qk_norm is not None:
# TODO: support float
q = q.reshape(-1, self.head_dim).contiguous().bfloat16()
k = k.reshape(-1, self.head_dim).contiguous().bfloat16()
q = self.qk_norm(q).to(q.dtype)
k = self.qk_norm(k).to(k.dtype)
q = q.reshape(-1, self.q_size)
k = k.reshape(-1, self.kv_size)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
# the inference-time temperature tuning function is customized to not affect short context
# while working at very long context
# https://arxiv.org/abs/2501.19399
if self.attn_temperature_tuning and not self.use_rope:
attn_scale = self._get_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
class Llama4DecoderLayer(nn.Module):
def __init__(
self,
config: Llama4TextConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.layer_id = layer_id
self.hidden_size = config.hidden_size
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
max_position_embeddings = config.max_position_embeddings
self.self_attn = Llama4Attention(
config=config,
layer_id=layer_id,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
bias_o_proj=False,
prefix=add_prefix("self_attn", prefix),
)
is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
if is_moe_layer:
self.feed_forward = Llama4MoE(
config=config,
quant_config=quant_config,
prefix=add_prefix("feed_forward", prefix),
)
else:
self.feed_forward = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size_mlp,
hidden_act="silu",
quant_config=quant_config,
prefix=add_prefix("feed_forward", prefix),
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
class Llama4Model(nn.Module):
def __init__(
self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("embed_tokens", prefix),
)
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Llama4DecoderLayer(
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
),
prefix="model.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.layers_to_capture = []
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
aux_hidden_states = []
for i in range(len(self.layers)):
if i in self.layers_to_capture:
aux_hidden_states.append(hidden_states + residual)
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
class Llama4ForCausalLM(LlamaForCausalLM):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(
self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(config, quant_config, prefix)
def _init_model(
self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
return Llama4Model(config, quant_config=quant_config, prefix=prefix)
EntryClass = [Llama4ForCausalLM]
# TODO: add Aapted from vllm/mllama4.py
from collections.abc import Iterable
from typing import Optional, Set, Tuple
import torch
from torch import nn
from transformers import Llama4Config
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class Llama4ForConditionalGeneration(nn.Module):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
}
def __init__(
self,
config: Llama4Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.quant_config = quant_config
# Initialize the language model
from sglang.srt.models.llama4 import Llama4ForCausalLM
self.language_model = Llama4ForCausalLM(
config.text_config,
quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
)
self.logits_processor = LogitsProcessor(config.text_config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
**kwargs: object,
) -> torch.Tensor:
return self.language_model(input_ids, positions, forward_batch)
def permute_qk_weight_for_rotary(
self,
name: str,
loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int):
attn_in = self.language_model.config.head_dim * n_heads
attn_out = self.language_model.config.hidden_size
return (
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
.transpose(1, 2)
.reshape(attn_in, attn_out)
)
modules = name.split(".")
# rotary embeds should be sliced
if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.language_model.config.num_key_value_heads
)
elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.language_model.config.num_attention_heads
)
return name, loaded_weight
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
(".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
(".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
]
params_dict = dict(self.named_parameters())
num_experts = self.config.text_config.num_local_experts
for name, loaded_weight in weights:
if name.startswith("vision_model") or name.startswith(
"multi_modal_projector"
):
continue
name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if ".experts" in name:
if ".gate_up_proj" in name:
name_list = [
name.replace(".experts.gate_up_proj", ".experts.w13_weight")
] * 2
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
shard_id_list = ["w1", "w3"]
else:
name_list = [
name.replace(".experts.down_proj", ".experts.w2_weight")
]
shard_id_list = ["w2"]
loaded_weight_list = [loaded_weight]
for name, loaded_weight, shard_id in zip(
name_list, loaded_weight_list, shard_id_list
):
param = params_dict[name]
weight_loader = param.weight_loader
for expert_id in range(num_experts):
weight_loader(
param,
loaded_weight[expert_id].T,
name,
shard_id=shard_id,
expert_id=expert_id,
)
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
EntryClass = Llama4ForConditionalGeneration
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment