Unverified Commit 94cde109 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Llama3.2 vision model support (#1551)

parent 00611286
......@@ -8,16 +8,12 @@ version = "0.3.4"
description = "SGLang is yet another fast serving framework for large language models and vision language models."
readme = "README.md"
requires-python = ">=3.8"
license = {file = "LICENSE"}
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
"requests",
"tqdm",
"numpy",
]
dependencies = ["requests", "tqdm", "numpy"]
[project.optional-dependencies]
runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
......@@ -32,7 +28,14 @@ srt_xpu = ["sglang[runtime_common]"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"]
test = [
"jsonlines",
"matplotlib",
"pandas",
"sentence_transformers",
"accelerate",
"peft",
]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
dev = ["sglang[all]", "sglang[test]"]
......@@ -43,7 +46,23 @@ dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
[tool.setuptools.packages.find]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
exclude = [
"assets*",
"benchmark*",
"docs*",
"dist*",
"playground*",
"scripts*",
"tests*",
]
[tool.wheel]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
exclude = [
"assets*",
"benchmark*",
"docs*",
"dist*",
"playground*",
"scripts*",
"tests*",
]
......@@ -227,8 +227,9 @@ def extend(reqs, model_runner):
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None,
model_config=model_runner.model_config,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
batch.prepare_for_extend()
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)
......
......@@ -229,6 +229,7 @@ register_chat_template(
),
},
stop_str=("<|eot_id|>",),
image_token="<|image|>",
)
)
......
......@@ -89,6 +89,8 @@ class ModelConfig:
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
......
......@@ -509,6 +509,19 @@ register_conv_template(
)
)
register_conv_template(
Conversation(
name="llama_3_vision",
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA3,
sep="",
stop_str=["<|end_of_text|>", "<|eot_id|>"],
image_token="<|image|>",
)
)
register_conv_template(
Conversation(
name="llava_llama_3",
......
from abc import ABC, abstractmethod
from typing import Optional
import torch
from torch import nn
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -19,7 +21,11 @@ class AttentionBackend(ABC):
raise NotImplementedError()
def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor] = None,
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError()
......@@ -30,6 +36,7 @@ class AttentionBackend(ABC):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor] = None,
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError()
......@@ -43,7 +50,7 @@ class AttentionBackend(ABC):
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: nn.Module,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
"""Run forward on an attention layer."""
......@@ -57,7 +64,7 @@ class AttentionBackend(ABC):
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: nn.Module,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
"""Run a forward for decode."""
......@@ -68,7 +75,7 @@ class AttentionBackend(ABC):
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: nn.Module,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
"""Run a forward for extend."""
......
......@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
......@@ -134,8 +135,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens=None,
):
# NOTE: encoder_lens expected to be zeros or None
self.forward_metadata = (
self.cuda_graph_start_loc,
self.cuda_graph_attn_logits,
......@@ -149,14 +155,18 @@ class DoubleSparseAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens=None,
):
# NOTE: encoder_lens expected to be zeros or None
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
......@@ -172,7 +182,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
layer, forward_batch.out_cache_loc, k, v, k_label
)
(
......@@ -201,7 +211,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
return o
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
......@@ -231,7 +243,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
layer, forward_batch.out_cache_loc, k, v, k_label
)
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
......
......@@ -11,7 +11,6 @@ from enum import Enum, auto
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
import triton
import triton.language as tl
......@@ -21,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
if is_flashinfer_available():
......@@ -56,13 +56,13 @@ class FlashInferAttnBackend(AttentionBackend):
assert not (
model_runner.sliding_window_size is not None
and model_runner.has_cross_attention
and model_runner.model_config.is_encoder_decoder
), "Sliding window and cross attention are not supported together"
if model_runner.sliding_window_size is not None:
self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
elif model_runner.has_cross_attention:
elif model_runner.model_config.is_encoder_decoder:
self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
else:
......@@ -128,6 +128,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
decode_wrappers=None,
encoder_lens=forward_batch.encoder_lens,
)
self.forward_metadata = (self.decode_wrappers,)
else:
......@@ -144,13 +146,11 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.req_pool_indices,
forward_batch.seq_lens,
prefix_lens,
use_ragged,
use_ragged=use_ragged,
encoder_lens=forward_batch.encoder_lens,
)
self.forward_metadata = (
use_ragged,
extend_no_prefix,
)
self.forward_metadata = (use_ragged, extend_no_prefix)
def init_cuda_graph_state(self, max_bs: int):
cuda_graph_kv_indices = torch.zeros(
......@@ -163,7 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
]
def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: torch.Tensor = None,
):
decode_wrappers = []
for i in range(self.num_wrappers):
......@@ -181,7 +185,11 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update(
req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers
req_pool_indices,
seq_lens,
seq_lens_sum,
decode_wrappers=decode_wrappers,
encoder_lens=encoder_lens,
)
self.cuda_graph_metadata[bs] = decode_wrappers
self.forward_metadata = (decode_wrappers,)
......@@ -192,34 +200,42 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: torch.Tensor = None,
):
self.indices_updater_decode.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
self.cuda_graph_metadata[bs],
decode_wrappers=self.cuda_graph_metadata[bs],
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
)
def get_cuda_graph_seq_len_fill_value(self):
return 0
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
prefill_wrapper_paged = self.prefill_wrappers_paged[
self._get_wrapper_idx(layer)
]
use_ragged, extend_no_prefix = self.forward_metadata
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not use_ragged:
if k is not None:
assert v is not None
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=True,
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=layer.logit_cap,
......@@ -247,20 +263,23 @@ class FlashInferAttnBackend(AttentionBackend):
o, _ = merge_state(o1, s1, o2, s2)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if k is not None:
assert v is not None
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
......@@ -271,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def _get_wrapper_idx(self, layer: nn.Module):
def _get_wrapper_idx(self, layer: RadixAttention):
if self.num_wrappers == 1:
return 0
......@@ -298,6 +317,8 @@ class FlashInferIndicesUpdaterDecode:
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
......@@ -305,20 +326,27 @@ class FlashInferIndicesUpdaterDecode:
self.decode_wrappers = attn_backend.decode_wrappers
# Dispatch
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
self.update = self.update_cross_attention
else:
assert attn_backend.num_wrappers == 1
assert self.attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper
def update(
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
):
# Keep the signature for type checking, will be initialized during runtime
raise NotImplementedError()
def update_single_wrapper(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers=None,
encoder_lens=None,
):
decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward(
......@@ -336,6 +364,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers=None,
encoder_lens=None,
):
decode_wrappers = decode_wrappers or self.decode_wrappers
......@@ -363,8 +392,35 @@ class FlashInferIndicesUpdaterDecode:
kv_start_idx_tmp,
)
def update_cross_attention(self):
raise NotImplementedError()
def update_cross_attention(
self,
req_pool_indices,
seq_lens,
seq_lens_sum,
decode_wrappers=None,
encoder_lens=None,
):
decode_wrappers = decode_wrappers or self.decode_wrappers
for wrapper_id in range(2):
if wrapper_id == 0:
# Normal attention
paged_kernel_lens = seq_lens
kv_start_idx = encoder_lens
else:
# Cross attention
paged_kernel_lens = encoder_lens
kv_start_idx = torch.zeros_like(encoder_lens)
seq_lens_sum = encoder_lens.sum().item()
self.call_begin_forward(
decode_wrappers[wrapper_id],
req_pool_indices,
paged_kernel_lens,
seq_lens_sum,
self.kv_indptr[wrapper_id],
kv_start_idx,
)
def call_begin_forward(
self,
......@@ -421,6 +477,8 @@ class FlashInferIndicesUpdaterPrefill:
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
......@@ -430,16 +488,20 @@ class FlashInferIndicesUpdaterPrefill:
self.wrappers_paged = attn_backend.prefill_wrappers_paged
# Dispatch
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
self.update = self.update_cross_attention
else:
assert attn_backend.num_wrappers == 1
assert self.attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
# Keep the signature for type checking, will be initialized during runtime
raise NotImplementedError()
def update_single_wrapper(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
):
if use_ragged:
paged_kernel_lens = prefix_lens
......@@ -460,7 +522,7 @@ class FlashInferIndicesUpdaterPrefill:
)
def update_sliding_window(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
):
for wrapper_id in range(2):
if wrapper_id == 0:
......@@ -487,8 +549,31 @@ class FlashInferIndicesUpdaterPrefill:
use_ragged,
)
def update_cross_attention(self):
raise NotImplementedError()
def update_cross_attention(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
):
for wrapper_id in range(2):
if wrapper_id == 0:
# normal attention
paged_kernel_lens = seq_lens
kv_start_idx = encoder_lens
else:
# cross attention
paged_kernel_lens = encoder_lens
kv_start_idx = torch.zeros_like(encoder_lens)
self.call_begin_forward(
self.wrapper_ragged,
self.wrappers_paged[wrapper_id],
req_pool_indices,
paged_kernel_lens,
seq_lens,
prefix_lens,
kv_start_idx,
self.kv_indptr[wrapper_id],
self.qo_indptr[wrapper_id],
use_ragged,
)
def call_begin_forward(
self,
......
......@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
......@@ -81,8 +82,13 @@ class TritonAttnBackend(AttentionBackend):
)
def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens=None,
):
# NOTE: encoder_lens expected to be zeros or None
self.forward_metadata = (
self.cuda_graph_start_loc,
self.cuda_graph_attn_logits,
......@@ -96,14 +102,18 @@ class TritonAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens=None,
):
# NOTE: encoder_lens expected to be zeros or None
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
......@@ -111,7 +121,7 @@ class TritonAttnBackend(AttentionBackend):
o = torch.empty_like(q)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
layer, forward_batch.out_cache_loc, k, v
)
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
......@@ -133,7 +143,9 @@ class TritonAttnBackend(AttentionBackend):
)
return o
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
......@@ -147,7 +159,7 @@ class TritonAttnBackend(AttentionBackend):
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
layer, forward_batch.out_cache_loc, k, v
)
self.decode_attention_fwd(
......
......@@ -33,26 +33,32 @@ def init_global_processor(server_args: ServerArgs):
class BaseImageProcessor(ABC):
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
)
@abstractmethod
async def process_images_async(self, image_data, **kwargs):
async def process_images_async(self, image_data, input_text, **kwargs):
pass
class DummyImageProcessor(BaseImageProcessor):
def __init__(self):
pass
async def process_images_async(self, *args, **kwargs):
return None
class LlavaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config
self._image_processor = _image_processor
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
)
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(
......@@ -119,7 +125,7 @@ class LlavaImageProcessor(BaseImageProcessor):
)
async def process_images_async(
self, image_data: List[Union[str, bytes]], request_obj
self, image_data: List[Union[str, bytes]], input_text, request_obj
):
if not image_data:
return None
......@@ -177,6 +183,54 @@ class LlavaImageProcessor(BaseImageProcessor):
}
class MllamaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return global_processor(images, input_text, return_tensors="pt")
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MllamaImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(images, input_text, return_tensors="pt")
return image_inputs
async def process_images_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list):
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text)
image_inputs["image_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
return image_inputs
class Qwen2VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config
......@@ -237,7 +291,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
return self._process_single_image_task(image_data)
async def process_images_async(
self, image_data: List[Union[str, bytes]], request_obj
self, image_data: List[Union[str, bytes]], input_text, request_obj
):
if not image_data:
return None
......@@ -292,12 +346,14 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
def get_image_processor(
hf_config, server_args: ServerArgs, _image_processor
hf_config, server_args: ServerArgs, processor
) -> BaseImageProcessor:
if "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, _image_processor)
if "MllamaForConditionalGeneration" in hf_config.architectures:
return MllamaImageProcessor(hf_config, server_args, processor)
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
else:
return LlavaImageProcessor(hf_config, server_args, _image_processor)
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
def get_dummy_image_processor():
......
......@@ -36,6 +36,7 @@ from typing import List, Optional, Tuple, Union
import torch
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
......@@ -121,11 +122,12 @@ class ImageInputs:
"""The image related inputs."""
pixel_values: torch.Tensor
image_hash: int
image_hashes: Optional[list] = None
image_sizes: Optional[list] = None
image_offsets: Optional[list] = None
pad_values: Optional[list] = None
modalities: Optional[list] = None
num_image_tokens: Optional[int] = None
image_embeds: Optional[List[torch.Tensor]] = None
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
......@@ -138,19 +140,27 @@ class ImageInputs:
# Use image hash as fake token_ids, which is then used for prefix matching
ret = ImageInputs(
pixel_values=obj["pixel_values"],
image_hash=hash(tuple(obj["image_hashes"])),
image_grid_thws=obj.get("image_grid_thws"),
image_hashes=hash(tuple(obj["image_hashes"])),
)
image_hash = ret.image_hash
image_hash = ret.image_hashes
ret.pad_values = [
(image_hash) % vocab_size,
(image_hash >> 16) % vocab_size,
(image_hash >> 32) % vocab_size,
(image_hash >> 64) % vocab_size,
]
ret.image_sizes = obj["image_sizes"]
# Only when pixel values is not None we have modalities
ret.modalities = obj["modalities"] or ["image"]
optional_args = [
"image_sizes",
"modalities",
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
]
for arg in optional_args:
if arg in obj:
setattr(ret, arg, obj[arg])
return ret
......@@ -416,6 +426,10 @@ class ScheduleBatch:
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool: BaseTokenToKVPool = None
tree_cache: BasePrefixCache = None
# For utility
model_config: ModelConfig = None
forward_mode: ForwardMode = None
sampling_info: SamplingBatchInfo = None
......@@ -440,6 +454,12 @@ class ScheduleBatch:
extend_num_tokens: int = None
decoding_reqs: List[Req] = None
# For encoder-decoder
encoder_cached: Optional[List[bool]] = None
encoder_lens: Optional[torch.Tensor] = None
encoder_lens_cpu: Optional[List[int]] = None
encoder_out_cache_loc: Optional[torch.Tensor] = None
# Stream
has_stream: bool = False
......@@ -450,12 +470,20 @@ class ScheduleBatch:
device: str = "cuda"
@classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
def init_new(
cls,
reqs,
req_to_token_pool,
token_to_kv_pool,
tree_cache,
model_config,
):
return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache,
model_config=model_config,
return_logprob=any(req.return_logprob for req in reqs),
has_stream=any(req.stream for req in reqs),
has_regex=any(req.regex_fsm for req in reqs),
......@@ -493,7 +521,78 @@ class ScheduleBatch:
return out_cache_loc
def prepare_for_extend(self, vocab_size: int):
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = []
self.encoder_cached = []
for req in self.reqs:
im = req.image_inputs
if im is None or im.num_image_tokens is None:
# No image input
self.encoder_lens_cpu.append(0)
self.encoder_cached.append(True)
else:
self.encoder_lens_cpu.append(im.num_image_tokens)
self.encoder_cached.append(
self.forward_mode.is_decode()
or len(req.prefix_indices) >= im.num_image_tokens
)
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
self.device, non_blocking=True
)
# Strip encoder infos
pt = 0
decoder_out_cache_loc = []
encoder_out_cache_loc = []
for i, req in enumerate(self.reqs):
encoder_len = self.encoder_lens_cpu[i]
seq_lens[i] -= encoder_len
if len(req.prefix_indices) < encoder_len:
# NOTE: the encoder part should considered as a whole
assert len(req.prefix_indices) == 0
input_ids[i] = input_ids[i][encoder_len:]
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
decoder_out_cache_loc.append(
self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
)
self.extend_lens[i] -= encoder_len
self.extend_num_tokens -= encoder_len
else:
decoder_out_cache_loc.append(
self.out_cache_loc[pt : pt + req.extend_input_len]
)
self.prefix_lens[i] -= encoder_len
pt += req.extend_input_len
# Reassign
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
if not decoder_out_cache_loc:
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
else:
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
if not encoder_out_cache_loc:
self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
else:
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
assert len(self.out_cache_loc) == self.extend_num_tokens
def prepare_for_extend(self):
self.forward_mode = ForwardMode.EXTEND
bs = len(self.reqs)
......@@ -561,8 +660,13 @@ class ScheduleBatch:
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
if self.model_config.is_encoder_decoder:
self.prepare_encoder_info_extend(input_ids, seq_lens)
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self, vocab_size, global_server_args_dict["disable_penalizer"]
self,
self.model_config.vocab_size,
global_server_args_dict["disable_penalizer"],
)
def mix_with_running(self, running_batch: "ScheduleBatch"):
......@@ -752,6 +856,10 @@ class ScheduleBatch:
return jump_forward_reqs
def prepare_encoder_info_decode(self):
# Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs)
def prepare_for_decode(self, enable_overlap: bool = False):
self.forward_mode = ForwardMode.DECODE
......@@ -766,16 +874,22 @@ class ScheduleBatch:
bs = len(self.reqs)
self.out_cache_loc = self.alloc_token_slots(bs)
if self.model_config.is_encoder_decoder:
locs = self.encoder_lens + self.seq_lens
self.prepare_encoder_info_decode()
else:
locs = self.seq_lens
if enable_overlap:
# Do not use in-place operations in the overlap mode
self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
(self.req_pool_indices, locs), self.out_cache_loc
)
self.seq_lens = self.seq_lens + 1
else:
# A faster in-place version
self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
(self.req_pool_indices, locs), self.out_cache_loc
)
self.seq_lens.add_(1)
self.seq_lens_sum += bs
......@@ -802,6 +916,10 @@ class ScheduleBatch:
# No need to filter
return
if self.model_config.is_encoder_decoder:
self.encoder_lens = self.encoder_lens[keep_indices]
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
self.reqs = [self.reqs[i] for i in keep_indices]
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
self.device, non_blocking=True
......@@ -828,6 +946,11 @@ class ScheduleBatch:
# needs to be called with pre-merged Batch.reqs.
self.sampling_info.merge_batch(other.sampling_info)
# Encoder-decoder infos
if self.model_config.is_encoder_decoder:
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
self.req_pool_indices = torch.concat(
[self.req_pool_indices, other.req_pool_indices]
)
......@@ -850,14 +973,11 @@ class ScheduleBatch:
def get_model_worker_batch(self):
if self.forward_mode.is_decode():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
image_inputs
) = None
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
image_inputs = [r.image_inputs for r in self.reqs]
if self.has_regex:
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
......@@ -887,7 +1007,11 @@ class ScheduleBatch:
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens,
image_inputs=image_inputs,
image_inputs=[r.image_inputs for r in self.reqs],
encoder_cached=self.encoder_cached,
encoder_lens=self.encoder_lens,
encoder_lens_cpu=self.encoder_lens_cpu,
encoder_out_cache_loc=self.encoder_out_cache_loc,
lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info,
mrope_positions_delta=mrope_positions_delta,
......@@ -897,6 +1021,7 @@ class ScheduleBatch:
# Only contain fields that will be used by process_batch_result
return ScheduleBatch(
reqs=self.reqs,
model_config=self.model_config,
forward_mode=self.forward_mode,
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
......@@ -944,6 +1069,12 @@ class ModelWorkerBatch:
# For multimodal
image_inputs: Optional[List[ImageInputs]]
# For encoder-decoder
encoder_cached: Optional[List[bool]]
encoder_lens: Optional[torch.Tensor]
encoder_lens_cpu: Optional[List[int]]
encoder_out_cache_loc: Optional[torch.Tensor]
# For LoRA
lora_paths: Optional[List[str]]
......
......@@ -662,8 +662,9 @@ class Scheduler:
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
self.model_config,
)
new_batch.prepare_for_extend(self.model_config.vocab_size)
new_batch.prepare_for_extend()
# Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None:
......
......@@ -122,7 +122,7 @@ class TokenizerManager:
# We want to parallelize the image pre-processing so we create an executor for it
self.image_processor = get_image_processor(
self.hf_config, server_args, self.processor.image_processor
self.hf_config, server_args, self.processor
)
else:
self.tokenizer = get_tokenizer(
......@@ -191,8 +191,10 @@ class TokenizerManager:
sampling_params = self._get_sampling_params(obj.sampling_params)
if self.is_generation:
image_inputs = await self.image_processor.process_images_async(
obj.image_data, obj
obj.image_data, input_text or input_ids, obj
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
......@@ -217,8 +219,10 @@ class TokenizerManager:
sampling_params = self._get_sampling_params(obj.sampling_params[index])
if self.is_generation:
image_inputs = await self.image_processor.process_images_async(
obj.image_data[index], obj
obj.image_data[index], input_text or input_ids, obj
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob[index]
logprob_start_len = obj.logprob_start_len[index]
top_logprobs_num = obj.top_logprobs_num[index]
......@@ -263,8 +267,10 @@ class TokenizerManager:
sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0
image_inputs = await self.image_processor.process_images_async(
obj.image_data[0], obj
obj.image_data[0], input_text or input_ids, obj
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0]
......
......@@ -26,6 +26,8 @@ from typing import List, Tuple, Union
import torch
from sglang.srt.layers.radix_attention import RadixAttention
logger = logging.getLogger(__name__)
......@@ -41,13 +43,17 @@ class ReqToTokenPool:
)
self.free_slots = list(range(size))
self.write_records = []
self.use_records = use_records
if use_records:
# records all write operations
if self.use_records:
self.write = self.write_with_records
else:
self.write = self.write_without_records
def write(self, indices, values):
# Keep the signature for type checking, will be initialized during runtime
raise NotImplementedError()
def available_size(self):
return len(self.free_slots)
......@@ -154,7 +160,7 @@ class BaseTokenToKVPool:
def set_kv_buffer(
self,
layer_id: int,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
......@@ -209,11 +215,12 @@ class MHATokenToKVPool(BaseTokenToKVPool):
def set_kv_buffer(
self,
layer_id: int,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if cache_v.dtype != self.dtype:
......@@ -265,11 +272,12 @@ class MLATokenToKVPool(BaseTokenToKVPool):
def set_kv_buffer(
self,
layer_id: int,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
......@@ -324,13 +332,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
def set_kv_buffer(
self,
layer_id: int,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
cache_label: torch.Tensor,
):
# NOTE(Andy): ignore the dtype check
layer_id = layer.layer_id
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
self.label_buffer[layer_id][loc] = cache_label
......@@ -105,6 +105,7 @@ class CudaGraphRunner:
self.graph_memory_pool = None
self.use_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
# Batch sizes to capture
if self.model_runner.server_args.disable_cuda_graph_padding:
......@@ -132,6 +133,9 @@ class CudaGraphRunner:
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0
if self.use_torch_compile:
set_torch_compile_config()
......@@ -144,9 +148,18 @@ class CudaGraphRunner:
)
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
if self.is_encoder_decoder:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
self.encoder_lens = torch.full(
(self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32
)
else:
self.encoder_lens = None
# Capture
try:
self.capture()
with self.model_capture_mode():
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n"
......@@ -157,11 +170,32 @@ class CudaGraphRunner:
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)
def can_run(self, batch_size: int):
if self.disable_padding:
return batch_size in self.graphs
else:
return batch_size <= self.max_bs
@contextmanager
def model_capture_mode(self):
if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = True
yield
if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = False
def can_run(self, forward_batch: ForwardBatch):
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
# because the full_text_row_masked_out_mask tensor will always be ones
is_encoder_lens_supported = (
torch.all(forward_batch.encoder_lens > 0)
if self.is_encoder_decoder
else True
)
return is_bs_supported and is_encoder_lens_supported
def capture(self):
with graph_capture() as graph_capture_context:
......@@ -188,11 +222,19 @@ class CudaGraphRunner:
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
out_cache_loc = self.out_cache_loc[:bs]
if self.is_encoder_decoder:
encoder_lens = self.encoder_lens[:bs]
else:
encoder_lens = None
seq_lens_sum = seq_lens.sum().item()
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs, req_pool_indices, seq_lens
bs,
req_pool_indices,
seq_lens,
encoder_lens,
)
# Run and capture
......@@ -208,6 +250,7 @@ class CudaGraphRunner:
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens_sum,
encoder_lens=encoder_lens,
return_logprob=False,
top_logprobs_nums=[0] * bs,
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
......@@ -251,6 +294,8 @@ class CudaGraphRunner:
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
......@@ -258,6 +303,7 @@ class CudaGraphRunner:
self.req_pool_indices,
self.seq_lens,
forward_batch.seq_lens_sum,
self.encoder_lens,
)
# Replay
......
......@@ -108,6 +108,12 @@ class ForwardBatch:
# For multimodal
image_inputs: Optional[List[ImageInputs]] = None
# Encoder-decoder
encoder_cached: Optional[List[bool]] = None
encoder_lens: Optional[torch.Tensor] = None
encoder_lens_cpu: Optional[List[int]] = None
encoder_out_cache_loc: Optional[torch.Tensor] = None
# For LoRA
lora_paths: Optional[List[str]] = None
......@@ -194,6 +200,11 @@ class ForwardBatch:
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
out_cache_loc=batch.out_cache_loc,
image_inputs=batch.image_inputs,
encoder_cached=batch.encoder_cached,
encoder_lens=batch.encoder_lens,
encoder_lens_cpu=batch.encoder_lens_cpu,
encoder_out_cache_loc=batch.encoder_out_cache_loc,
seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
......@@ -212,11 +223,11 @@ class ForwardBatch:
],
axis=0,
)
ret.image_inputs = batch.image_inputs
ret.extend_num_tokens = batch.extend_num_tokens
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
......
......@@ -270,7 +270,6 @@ class ModelRunner:
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures, self.server_args.is_embedding
)
......@@ -510,7 +509,7 @@ class ModelRunner:
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert not self.has_cross_attention, (
assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
......@@ -558,9 +557,7 @@ class ModelRunner:
self.cuda_graph_runner = CudaGraphRunner(self)
def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
forward_batch.batch_size
):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
......
# Adapted from:
# https://github.com/vllm-project/vllm/blob/7193774b1ff8603ad5bf4598e5efba0d9a39b436/vllm/model_executor/models/mllama.py
"""PyTorch Mllama model."""
import math
from typing import Iterable, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers.models.mllama.configuration_mllama as config_mllama
import vllm.distributed.parallel_state as ps
from torch import nn
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
from transformers.models.mllama.modeling_mllama import (
_prepare_aspect_ratio_attention_mask,
)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
class ColumnParallelConv2dPatch(torch.nn.Module):
"""Conv2D Patching layer with model parallelism.
Column parallel over unfolded input.
Arguments:
in_channels: Input channels.
out_channels: Output channels.
kernel_size: Size of convolution kernel.
stride (default 1): Stride for convolution.
bias (default False): Use bias in Conv2d.
Input: (bsz, in_channels, width, height)
Output: (bsz, num_tokens, out_channels)
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
bias: bool = False,
) -> None:
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
self._linear = ColumnParallelLinear(
in_channels * kernel_size[0] * kernel_size[1],
out_channels,
bias=bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._unfold(x)
x = x.permute(0, 2, 1)
x, _ = self._linear(x)
return x
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig, is_gated: bool = True):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.is_gated = is_gated
self.embedding = nn.Embedding(
self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size
)
if is_gated:
self.gate = nn.Parameter(torch.zeros(1))
def forward(
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
) -> torch.Tensor:
embeddings = self.embedding(aspect_ratio_ids)
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
if self.is_gated:
embeddings = embeddings * self.gate.tanh()
hidden_state = hidden_state + embeddings
return hidden_state
class MllamaPrecomputedPositionEmbedding(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
self.hidden_size = config.hidden_size
self.scale = config.hidden_size**-0.5
self.gate = nn.Parameter(torch.zeros(1))
# position embedding
position_embedding = torch.randn(self.num_patches, self.hidden_size)
self.embedding = nn.Parameter(self.scale * position_embedding)
# tile position embedding
self.tile_embedding = nn.Embedding(
self.max_aspect_ratio_id + 1,
self.max_num_tiles * self.num_patches * self.hidden_size,
)
def forward(
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
) -> torch.Tensor:
# position embeddings
gated_position_embedding = (1 - self.gate.tanh()) * self.embedding
hidden_state = hidden_state + gated_position_embedding.view(
1, 1, self.num_patches, self.hidden_size
)
# precomputed tile position embeddings
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
batch_size = hidden_state.shape[0]
tile_position_embedding = tile_position_embedding.reshape(
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
)
gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
hidden_state = hidden_state + gated_tile_position_embedding
return hidden_state
class MllamaVisionSdpaAttention(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig):
super().__init__()
model_parallel_size = get_tensor_model_parallel_world_size()
self.embed_dim = config.hidden_size
self.num_heads = config.attention_heads
self.head_dim = config.hidden_size // config.attention_heads
self.num_local_heads = self.num_heads // model_parallel_size
self.q_size = self.num_local_heads * self.head_dim
self.kv_size = self.num_local_heads * self.head_dim
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=False,
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=False,
input_is_parallel=True,
)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_state)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(
q.shape[0], q.shape[1], self.num_local_heads, self.head_dim
).transpose(1, 2)
k = k.view(
k.shape[0], k.shape[1], self.num_local_heads, self.head_dim
).transpose(1, 2)
v = v.view(
v.shape[0], v.shape[1], self.num_local_heads, self.head_dim
).transpose(1, 2)
# TODO: remove padding in image encoder
attn_output = F.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask, dropout_p=0.0
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(
attn_output.shape[0], attn_output.shape[1], -1
)
output, _ = self.o_proj(attn_output)
return output
class MllamaVisionMLP(nn.Module):
def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class MllamaVisionEncoderLayer(nn.Module):
def __init__(
self, config: config_mllama.MllamaVisionConfig, is_gated: bool = False
):
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.attention_heads
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size
self.self_attn = MllamaVisionSdpaAttention(config)
self.mlp = MllamaVisionMLP(config)
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
self.post_attention_layernorm = nn.LayerNorm(
self.hidden_size, eps=config.norm_eps
)
# there used to be an if else here, no code path
if is_gated:
self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4)
self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
hidden_state = residual + gate_attn * hidden_state
# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
hidden_state = residual + gate_ffn * hidden_state
return hidden_state
class MllamaVisionEncoder(nn.Module):
def __init__(
self,
config: config_mllama.MllamaVisionConfig,
num_layers=32,
is_gated=False,
output_hidden_states=None,
):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)]
)
self.output_hidden_states = output_hidden_states or []
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutput]:
encoder_states = ()
for i, encoder_layer in enumerate(self.layers):
if i in self.output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
hidden_states = encoder_layer(
hidden_states,
attention_mask,
)
if len(self.layers) - 1 in self.output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
return hidden_states, encoder_states
class MllamaVisionModel(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig):
super().__init__()
self.image_size = config.image_size
self.patch_size = config.patch_size
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.in_channels = config.num_channels
self.intermediate_layers_indices = config.intermediate_layers_indices
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
self.scale = config.hidden_size**-0.5
self.patch_embedding = ColumnParallelConv2dPatch(
in_channels=config.num_channels,
out_channels=self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config)
self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
config, is_gated=True
)
self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
config, is_gated=True
)
# layer norms
self.layernorm_pre = nn.LayerNorm(self.hidden_size)
self.layernorm_post = nn.LayerNorm(self.hidden_size)
# encoders
self.transformer = MllamaVisionEncoder(
config,
config.num_hidden_layers,
is_gated=False,
output_hidden_states=config.intermediate_layers_indices,
)
self.global_transformer = MllamaVisionEncoder(
config, config.num_global_layers, is_gated=True
)
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def forward(
self,
pixel_values: torch.Tensor,
aspect_ratio_ids: torch.Tensor,
aspect_ratio_mask: torch.Tensor,
) -> torch.Tensor:
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = (
pixel_values.shape
)
pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
)
aspect_ratio_ids = aspect_ratio_ids.reshape(
batch_size * num_concurrent_media, -1
)
# patch embedding
patch_embeds = self.patch_embedding(
pixel_values.to(self.layernorm_pre.weight.dtype)
)
hidden_state = patch_embeds
hidden_state = ps.get_tp_group().all_gather(hidden_state)
# tile embeddings
_, num_patches, dim = hidden_state.shape
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media, num_tiles, -1, dim
)
hidden_state = self.pre_tile_positional_embedding(
hidden_state, aspect_ratio_ids
)
# apply cls token
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media * num_tiles, num_patches, dim
)
hidden_state = self.apply_class_embedding(hidden_state)
num_patches += 1
# apply position embeddings
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media, num_tiles, num_patches, dim
)
hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
# apply encoder
hidden_state = self.layernorm_pre(hidden_state)
# Compute the number of tokens to pad
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
# Compute padding tuple for pad function
padding = (
0,
0,
0,
num_padding_patches,
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
slice_index = -num_padding_patches if num_padding_patches > 0 else None
attention_mask = aspect_ratio_mask.reshape(
batch_size * num_concurrent_media, -1
)
attention_mask = _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask=attention_mask,
num_patches=self.num_patches,
target_length=hidden_state.shape[2],
dtype=self.layernorm_pre.weight.dtype,
)
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
output = self.transformer(
hidden_state,
attention_mask=attention_mask,
)
hidden_state, intermediate_hidden_states = output[0], output[1]
intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
# apply global encoder
hidden_state = self.layernorm_post(hidden_state)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim,
)
hidden_state = self.post_tile_positional_embedding(
hidden_state, aspect_ratio_ids
)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles * (num_patches + num_padding_patches),
dim,
)
hidden_state = self.global_transformer(
hidden_state, attention_mask=attention_mask
)[0]
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim,
)
hidden_state = hidden_state[:, :, :slice_index]
# adding intermediate layer outputs
hidden_state = hidden_state.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, dim
)
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
-1,
)
intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, -1
)
hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
return hidden_state
class MllamaTextRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class MllamaTextCrossAttention(nn.Module):
def __init__(
self,
config: Optional[config_mllama.MllamaTextConfig] = None,
layer_id: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.model_parallel_size = get_tensor_model_parallel_world_size()
self.num_heads = self.config.num_attention_heads
self.num_local_heads = self.num_heads // self.model_parallel_size
self.num_key_value_heads = self.config.num_key_value_heads
self.num_local_key_value_heads = (
self.num_key_value_heads // self.model_parallel_size
)
self.dropout = config.dropout
self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // self.num_heads
self.layer_id = layer_id
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_local_size = self.num_local_heads * self.head_dim
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.num_heads,
self.num_key_value_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
)
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead
self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.scaling = self.head_dim**-0.5
self.attn = RadixAttention(
self.num_local_heads,
self.head_dim,
self.scaling,
self.num_local_key_value_heads,
layer_id=layer_id,
is_cross_attention=True,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
cross_attention_states: Optional[torch.Tensor],
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv_dec, _ = self.qkv_proj(hidden_states)
q, _, _ = qkv_dec.split(
[self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1
)
if cross_attention_states is None:
k = None
v = None
else:
qkv_enc, _ = self.qkv_proj(cross_attention_states)
_, k, v = qkv_enc.split(
[self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1
)
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
k = self.k_norm(k)
q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q)
output = self.attn(q, k, v, forward_batch)
out, _ = self.o_proj(output)
return out
class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention
and feedforward."""
def __init__(
self,
config: config_mllama.MllamaTextConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig],
) -> None:
super().__init__()
self.layer_id = layer_id
self.cross_attn = MllamaTextCrossAttention(
config=config,
layer_id=layer_id,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1))
self.mlp = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1))
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
cross_attention_mask: torch.Tensor,
full_text_row_masked_out_mask: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.cross_attn(
hidden_states=hidden_states,
attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
forward_batch=forward_batch,
)
hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
return hidden_states
class MllamaTextModel(nn.Module):
config_class = config_mllama.MllamaTextConfig
base_model_prefix = "model"
def __init__(
self,
config: config_mllama.MllamaTextConfig,
quant_config: Optional[QuantizationConfig],
cache_config=None,
):
super().__init__()
self.padding_id = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size + 8, config.hidden_size
)
self.cross_attention_layers = config.cross_attention_layers
layers = []
for layer_id in range(config.num_hidden_layers):
if layer_id in self.cross_attention_layers:
layers.append(
MllamaCrossAttentionDecoderLayer(
config, layer_id, quant_config=quant_config
)
)
else:
# TODO: force LlamaDecoderLayer to config.attention_bias=False
layers.append(
LlamaDecoderLayer(
config, quant_config=quant_config, layer_id=layer_id
)
)
self.layers = nn.ModuleList(layers)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]],
forward_batch: ForwardBatch,
skip_cross_attention: bool,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
for _, decoder_layer in enumerate(self.layers):
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
if not skip_cross_attention:
hidden_states = decoder_layer(
hidden_states=hidden_states,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
forward_batch=forward_batch,
)
elif isinstance(decoder_layer, LlamaDecoderLayer):
hidden_states, residual = decoder_layer(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
residual=None,
)
hidden_states = hidden_states + residual
else:
raise ValueError(f"Unknown decoder layer type {type(decoder_layer)}")
hidden_states = self.norm(hidden_states)
return hidden_states
class MllamaForCausalLM(nn.Module):
config_class = config_mllama.MllamaTextConfig
base_model_prefix = "language_model"
_no_split_modules = [
"MllamaCrossAttentionDecoderLayer",
"MllamaSelfAttentionDecoderLayer",
]
def __init__(
self,
config: config_mllama.MllamaTextConfig,
quant_config: Optional[QuantizationConfig],
cache_config=None,
):
super().__init__()
self.vocab_size = config.vocab_size
self.model = MllamaTextModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
)
def forward(
self,
input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]],
forward_batch: ForwardBatch,
skip_cross_attention: bool,
) -> torch.Tensor:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
forward_batch=forward_batch,
skip_cross_attention=skip_cross_attention,
)
return hidden_states
class MllamaForConditionalGeneration(nn.Module):
def __init__(
self,
config: config_mllama.MllamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
):
super().__init__()
self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size
self.max_num_tiles = config.vision_config.max_num_tiles
self.vision_output_dim = config.vision_config.vision_output_dim
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
self.image_size = config.vision_config.image_size
self.vision_model = MllamaVisionModel(config.vision_config)
self.language_model = MllamaForCausalLM(
config.text_config,
cache_config=cache_config,
quant_config=quant_config,
)
self.multi_modal_projector = nn.Linear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
bias=True,
)
self.logits_processor = LogitsProcessor(config.text_config)
self.capture_mode = False
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
pixel_values = image_inputs.pixel_values
pad_values = image_inputs.pad_values
num_concurrent_media, num_tiles = pixel_values.shape[1:3]
num_patches = self.vision_model.num_patches
image_len = num_concurrent_media * num_tiles * num_patches
image_inputs.num_image_tokens = image_len
pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values))
return pad_ids[:image_len] + input_ids
def _batch_image_inputs(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode() or all(forward_batch.encoder_cached):
return None, None, None, None
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
max_num_images = max_num_tiles = bs = 0
for i, im in enumerate(forward_batch.image_inputs):
if not forward_batch.encoder_cached[i] and im is not None:
max_num_images = max(max_num_images, im.pixel_values.shape[1])
max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2])
bs += 1
if max_num_images * max_num_tiles * bs == 0:
return None, None, None, None
with forward_batch.out_cache_loc.device:
batched_images = torch.zeros(
bs,
max_num_images,
max_num_tiles,
3,
self.image_size,
self.image_size,
dtype=torch.float32,
)
batched_ar_ids = torch.ones(
bs, max_num_images, dtype=torch.int64, device="cuda"
)
batched_ar_mask = torch.zeros(
bs, max_num_images, max_num_tiles, dtype=torch.int64
)
i = 0
encoder_lens_need = []
for k, im in enumerate(forward_batch.image_inputs):
if forward_batch.encoder_cached[k] or im is None:
continue
encoder_lens_need.append(forward_batch.encoder_lens[k])
for j in range(im.pixel_values.shape[1]):
img = im.pixel_values[0, j]
num_tiles = img.shape[0]
batched_images[i, j, :num_tiles] = img
batched_ar_ids[i, j] = im.aspect_ratio_ids[0, j]
batched_ar_mask[i, j, :num_tiles] = im.aspect_ratio_mask[0, j]
i += 1
return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need
def flat_encoder_result(
self, cross_attention_states: torch.Tensor, encoder_lens_need: List[int]
):
# NOTE: not all encoders need computation, some are cached
head_dim = cross_attention_states.shape[-1]
total_encoder_len = sum(encoder_lens_need)
cross_attention_states_flat = torch.zeros(
total_encoder_len,
head_dim,
device=cross_attention_states.device,
dtype=cross_attention_states.dtype,
)
i = start_pos = 0
for encoder_len in encoder_lens_need:
if encoder_len == 0:
continue
end_pos = start_pos + encoder_len
cross_attention_states_flat[start_pos:end_pos] = cross_attention_states[i][
:encoder_len
]
i += 1
start_pos += encoder_len
return cross_attention_states_flat
def get_full_text_row_masked_out_mask(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode():
full_text_row_masked_out_mask = forward_batch.encoder_lens != 0
else:
full_text_row_masked_out_mask = torch.ones(
forward_batch.extend_seq_lens.sum(), dtype=torch.bool
)
start_pos = 0
for seq_len, encoder_len in zip(
forward_batch.seq_lens.tolist(), forward_batch.encoder_lens_cpu
):
if encoder_len == 0:
full_text_row_masked_out_mask[start_pos : start_pos + seq_len] = (
False
)
start_pos += encoder_len
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
forward_batch.seq_lens.device
)
return full_text_row_masked_out_mask.reshape(-1, 1)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> Union[Tuple, CausalLMOutputWithPast]:
batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
self._batch_image_inputs(forward_batch)
)
# TODO: support multi-image by this mask
cross_attention_mask = None
cross_attention_states = None
if self.capture_mode:
# NOTE: when doing cuda graph capture, we do not want to skip cross attention
# Make is a constant value to avoid cuda graph capture issue
skip_cross_attention = False
else:
# NOTE: we do not need image_inputs when prefill
assert len(forward_batch.encoder_lens) == len(forward_batch.seq_lens)
assert len(forward_batch.encoder_lens_cpu) == len(forward_batch.seq_lens)
skip_cross_attention = forward_batch.encoder_lens.max() == 0
if not skip_cross_attention:
full_text_row_masked_out_mask = self.get_full_text_row_masked_out_mask(
forward_batch
)
else:
full_text_row_masked_out_mask = None
if batched_images is not None:
# NOTE: llama's reference implementation runs vision model on CPU
cross_attention_states = self.vision_model(
batched_images, batched_ar_ids, batched_ar_mask
)
cross_attention_states = self.multi_modal_projector(cross_attention_states)
bs, _, _, _, image_token_dim = cross_attention_states.shape
cross_attention_states = cross_attention_states.view(
bs, -1, image_token_dim
)
cross_attention_states = self.flat_encoder_result(
cross_attention_states, encoder_lens_need
)
hidden_states = self.language_model(
input_ids=input_ids,
positions=positions,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
forward_batch=forward_batch,
skip_cross_attention=skip_cross_attention,
)
return self.logits_processor(
input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
updated_params = set()
for name, loaded_weight in weights:
if "patch_embedding.weight" in name:
name = name.replace(
"patch_embedding.weight", "patch_embedding._linear.weight"
)
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
updated_params.add(name)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = MllamaForConditionalGeneration
......@@ -605,7 +605,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
]
positions = forward_batch.mrope_positions
if image_inputs is None or len(image_inputs) == 0:
if (
forward_batch.forward_mode.is_decode()
or image_inputs is None
or len(image_inputs) == 0
):
inputs_embeds = self.model.embed_tokens(input_ids)
else:
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
......
......@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures):
or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
):
return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment