Unverified Commit ea93079b authored by Wenchen Lo's avatar Wenchen Lo Committed by GitHub
Browse files

model: adapt mllama4 to VisionAttention (#8512)


Co-authored-by: default avatarroot <mickjagger19@icloud.com>
parent 4bec99ec
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
"""Utilities for Huggingface Transformers.""" """Utilities for Huggingface Transformers."""
import contextlib import contextlib
import logging
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -45,7 +44,7 @@ from sglang.srt.configs import ( ...@@ -45,7 +44,7 @@ from sglang.srt.configs import (
) )
from sglang.srt.configs.internvl import InternVLChatConfig from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url, lru_cache_frozenset from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig, ChatGLMConfig.model_type: ChatGLMConfig,
...@@ -317,15 +316,31 @@ def get_processor( ...@@ -317,15 +316,31 @@ def get_processor(
if config.model_type not in {"llava", "clip"}: if config.model_type not in {"llava", "clip"}:
kwargs["use_fast"] = use_fast kwargs["use_fast"] = use_fast
try:
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
processor = AutoProcessor.from_pretrained( except ValueError as e:
tokenizer_name, error_message = str(e)
*args, if "does not have a slow version" in error_message:
trust_remote_code=trust_remote_code, logger.info(
revision=revision, f"Processor {tokenizer_name} does not have a slow version. Automatically use fast version"
**kwargs, )
) kwargs["use_fast"] = True
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
else:
raise e
tokenizer = get_tokenizer_from_processor(processor) tokenizer = get_tokenizer_from_processor(processor)
attach_additional_stop_token_ids(tokenizer) attach_additional_stop_token_ids(tokenizer)
......
...@@ -4,7 +4,7 @@ import dataclasses ...@@ -4,7 +4,7 @@ import dataclasses
import functools import functools
import math import math
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Any, Optional, Tuple, Union from typing import Any, Callable, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -308,6 +308,7 @@ class VisionFlash3Attention(nn.Module): ...@@ -308,6 +308,7 @@ class VisionFlash3Attention(nn.Module):
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device) cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen = seq_lens.max().item() max_seqlen = seq_lens.max().item()
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q, q,
k, k,
...@@ -358,6 +359,9 @@ class VisionAttention(nn.Module): ...@@ -358,6 +359,9 @@ class VisionAttention(nn.Module):
qkv_bias: bool = True, qkv_bias: bool = True,
qk_normalization: bool = False, qk_normalization: bool = False,
layer_norm_eps: float = 1e-06, layer_norm_eps: float = 1e-06,
customized_position_embedding_applier: Callable[
[torch.Tensor, torch.Tensor, Any, Any], Tuple[torch.Tensor, torch.Tensor]
] = None,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
...@@ -392,6 +396,7 @@ class VisionAttention(nn.Module): ...@@ -392,6 +396,7 @@ class VisionAttention(nn.Module):
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
) )
# priority: server_args > passed qkv_backend > sdpa
if global_server_args_dict["mm_attention_backend"] is None: if global_server_args_dict["mm_attention_backend"] is None:
if qkv_backend is None: if qkv_backend is None:
qkv_backend = "sdpa" qkv_backend = "sdpa"
...@@ -401,6 +406,9 @@ class VisionAttention(nn.Module): ...@@ -401,6 +406,9 @@ class VisionAttention(nn.Module):
print_info_once(f"Using {qkv_backend} as multimodal attention backend.") print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
self.customized_position_embedding_applier = (
customized_position_embedding_applier
)
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend]( self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
head_dim=self.head_size, head_dim=self.head_size,
num_heads=self.num_attention_heads_per_partition, num_heads=self.num_attention_heads_per_partition,
...@@ -473,13 +481,13 @@ class VisionAttention(nn.Module): ...@@ -473,13 +481,13 @@ class VisionAttention(nn.Module):
if x.dim() == 2: if x.dim() == 2:
x = x.unsqueeze(0) x = x.unsqueeze(0)
assert x.dim() == 3, x.shape assert x.dim() == 3, x.shape
bsz, s, _ = x.shape x_shape = x.shape
bsz, s, _ = x_shape
head = self.num_attention_heads_per_partition head = self.num_attention_heads_per_partition
kv_head = self.num_attention_kv_heads_per_partition kv_head = self.num_attention_kv_heads_per_partition
if self.use_qkv_parallel: if self.use_qkv_parallel:
# [b, s, embed_dim] --> [b, s, embed_dim] # [b, s, embed_dim] --> [b, s, embed_dim]
qkv, _ = self.qkv_proj(x) qkv, _ = self.qkv_proj(x)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# [b, s, embed_dim] --> [b * s, head, head_size] # [b, s, embed_dim] --> [b * s, head, head_size]
...@@ -508,16 +516,25 @@ class VisionAttention(nn.Module): ...@@ -508,16 +516,25 @@ class VisionAttention(nn.Module):
] ]
if position_embeddings is not None: if position_embeddings is not None:
cos, sin = position_embeddings
original_shape = q.shape original_shape = q.shape
# [total_tokens, head, head_size]
q = q.view(-1, head, self.head_size)
k = k.view(-1, head, self.head_size)
q, k = apply_rotary_pos_emb(q, k, cos, sin) if self.customized_position_embedding_applier is not None:
q, k = self.customized_position_embedding_applier(
q, k, position_embeddings, x_shape
)
q = q.view(original_shape)
k = k.view(original_shape)
else:
cos, sin = position_embeddings
# [total_tokens, head, head_size]
q = q.view(-1, head, self.head_size)
k = k.view(-1, head, self.head_size)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q = q.view(original_shape) q = q.view(original_shape)
k = k.view(original_shape) k = k.view(original_shape)
if q.dim() == 4: if q.dim() == 4:
# [b, s, head, head_size] --> [b * s, head, head_size] # [b, s, head, head_size] --> [b * s, head, head_size]
......
...@@ -70,7 +70,6 @@ from sglang.srt.managers.io_struct import ( ...@@ -70,7 +70,6 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut, BatchMultimodalOut,
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
BlockReqType,
CloseSessionReqInput, CloseSessionReqInput,
ConfigureLoggingReq, ConfigureLoggingReq,
EmbeddingReqInput, EmbeddingReqInput,
...@@ -202,13 +201,29 @@ class TokenizerManager: ...@@ -202,13 +201,29 @@ class TokenizerManager:
if self.model_config.is_multimodal: if self.model_config.is_multimodal:
import_processors() import_processors()
_processor = get_processor( try:
server_args.tokenizer_path, _processor = get_processor(
tokenizer_mode=server_args.tokenizer_mode, server_args.tokenizer_path,
trust_remote_code=server_args.trust_remote_code, tokenizer_mode=server_args.tokenizer_mode,
revision=server_args.revision, trust_remote_code=server_args.trust_remote_code,
use_fast=not server_args.disable_fast_image_processor, revision=server_args.revision,
) use_fast=not server_args.disable_fast_image_processor,
)
except ValueError as e:
error_message = str(e)
if "does not have a slow version" in error_message:
logger.info(
f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version"
)
_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
use_fast=True,
)
else:
raise e
transport_mode = _determine_tensor_transport_mode(self.server_args) transport_mode = _determine_tensor_transport_mode(self.server_args)
# We want to parallelize the image pre-processing so we create an executor for it # We want to parallelize the image pre-processing so we create an executor for it
......
...@@ -241,13 +241,22 @@ class Llama4Attention(nn.Module): ...@@ -241,13 +241,22 @@ class Llama4Attention(nn.Module):
if self.use_qk_norm if self.use_qk_norm
else None else None
) )
qkv_quant_config = quant_config
o_quant_config = quant_config
if quant_config and hasattr(quant_config, "ignore") and quant_config.ignore:
if add_prefix("q_proj", prefix) in quant_config.ignore:
qkv_quant_config = None
if add_prefix("o_proj", prefix) in quant_config.ignore:
o_quant_config = None
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size, hidden_size=hidden_size,
head_size=self.head_dim, head_size=self.head_dim,
total_num_heads=self.total_num_heads, total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads, total_num_kv_heads=self.total_num_kv_heads,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=qkv_quant_config,
prefix=add_prefix("qkv_proj", prefix), prefix=add_prefix("qkv_proj", prefix),
tp_rank=attn_tp_rank, tp_rank=attn_tp_rank,
tp_size=attn_tp_size, tp_size=attn_tp_size,
...@@ -257,7 +266,7 @@ class Llama4Attention(nn.Module): ...@@ -257,7 +266,7 @@ class Llama4Attention(nn.Module):
input_size=self.total_num_heads * self.head_dim, input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size, output_size=hidden_size,
bias=bias_o_proj, bias=bias_o_proj,
quant_config=quant_config, quant_config=o_quant_config,
prefix=add_prefix("o_proj", prefix), prefix=add_prefix("o_proj", prefix),
tp_rank=attn_tp_rank, tp_rank=attn_tp_rank,
tp_size=attn_tp_size, tp_size=attn_tp_size,
......
This diff is collapsed.
...@@ -12,7 +12,6 @@ import torch ...@@ -12,7 +12,6 @@ import torch
from PIL import Image from PIL import Image
from transformers import BaseImageProcessorFast from transformers import BaseImageProcessorFast
from sglang.srt.managers.mm_utils import TransportProxyTensor
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import load_audio, load_image, load_video, logger from sglang.srt.utils import load_audio, load_image, load_video, logger
...@@ -218,8 +217,10 @@ class BaseMultimodalProcessor(ABC): ...@@ -218,8 +217,10 @@ class BaseMultimodalProcessor(ABC):
kwargs["audio"] = audios kwargs["audio"] = audios
processor = self._processor processor = self._processor
if hasattr(processor, "image_processor") and isinstance( if (
processor.image_processor, BaseImageProcessorFast hasattr(processor, "image_processor")
and isinstance(processor.image_processor, BaseImageProcessorFast)
and not self.server_args.disable_fast_image_processor
): ):
kwargs["device"] = "cuda" kwargs["device"] = "cuda"
result = processor.__call__( result = processor.__call__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment