Unverified Commit 94cde109 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Llama3.2 vision model support (#1551)

parent 00611286
...@@ -8,16 +8,12 @@ version = "0.3.4" ...@@ -8,16 +8,12 @@ version = "0.3.4"
description = "SGLang is yet another fast serving framework for large language models and vision language models." description = "SGLang is yet another fast serving framework for large language models and vision language models."
readme = "README.md" readme = "README.md"
requires-python = ">=3.8" requires-python = ">=3.8"
license = {file = "LICENSE"} license = { file = "LICENSE" }
classifiers = [ classifiers = [
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License", "License :: OSI Approved :: Apache Software License",
] ]
dependencies = [ dependencies = ["requests", "tqdm", "numpy"]
"requests",
"tqdm",
"numpy",
]
[project.optional-dependencies] [project.optional-dependencies]
runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
...@@ -32,7 +28,14 @@ srt_xpu = ["sglang[runtime_common]"] ...@@ -32,7 +28,14 @@ srt_xpu = ["sglang[runtime_common]"]
openai = ["openai>=1.0", "tiktoken"] openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"] anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"] litellm = ["litellm>=1.0.0"]
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"] test = [
"jsonlines",
"matplotlib",
"pandas",
"sentence_transformers",
"accelerate",
"peft",
]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
dev = ["sglang[all]", "sglang[test]"] dev = ["sglang[all]", "sglang[test]"]
...@@ -43,7 +46,23 @@ dev_xpu = ["sglang[all_xpu]", "sglang[test]"] ...@@ -43,7 +46,23 @@ dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues" "Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"] exclude = [
"assets*",
"benchmark*",
"docs*",
"dist*",
"playground*",
"scripts*",
"tests*",
]
[tool.wheel] [tool.wheel]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"] exclude = [
"assets*",
"benchmark*",
"docs*",
"dist*",
"playground*",
"scripts*",
"tests*",
]
...@@ -227,8 +227,9 @@ def extend(reqs, model_runner): ...@@ -227,8 +227,9 @@ def extend(reqs, model_runner):
req_to_token_pool=model_runner.req_to_token_pool, req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool, token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None, tree_cache=None,
model_config=model_runner.model_config,
) )
batch.prepare_for_extend(model_runner.model_config.vocab_size) batch.prepare_for_extend()
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch) logits_output = model_runner.forward(forward_batch)
......
...@@ -229,6 +229,7 @@ register_chat_template( ...@@ -229,6 +229,7 @@ register_chat_template(
), ),
}, },
stop_str=("<|eot_id|>",), stop_str=("<|eot_id|>",),
image_token="<|image|>",
) )
) )
......
...@@ -89,6 +89,8 @@ class ModelConfig: ...@@ -89,6 +89,8 @@ class ModelConfig:
self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size self.vocab_size = self.hf_text_config.vocab_size
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int: def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads.""" """Returns the total number of KV heads."""
......
...@@ -509,6 +509,19 @@ register_conv_template( ...@@ -509,6 +509,19 @@ register_conv_template(
) )
) )
register_conv_template(
Conversation(
name="llama_3_vision",
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA3,
sep="",
stop_str=["<|end_of_text|>", "<|eot_id|>"],
image_token="<|image|>",
)
)
register_conv_template( register_conv_template(
Conversation( Conversation(
name="llava_llama_3", name="llava_llama_3",
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional
import torch import torch
from torch import nn from torch import nn
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -19,7 +21,11 @@ class AttentionBackend(ABC): ...@@ -19,7 +21,11 @@ class AttentionBackend(ABC):
raise NotImplementedError() raise NotImplementedError()
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor] = None,
): ):
"""Init the metadata for a forward pass for capturing a cuda graph.""" """Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
...@@ -30,6 +36,7 @@ class AttentionBackend(ABC): ...@@ -30,6 +36,7 @@ class AttentionBackend(ABC):
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor] = None,
): ):
"""Init the metadata for a forward pass for replying a cuda graph.""" """Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
...@@ -43,7 +50,7 @@ class AttentionBackend(ABC): ...@@ -43,7 +50,7 @@ class AttentionBackend(ABC):
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
layer: nn.Module, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
"""Run forward on an attention layer.""" """Run forward on an attention layer."""
...@@ -57,7 +64,7 @@ class AttentionBackend(ABC): ...@@ -57,7 +64,7 @@ class AttentionBackend(ABC):
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
layer: nn.Module, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
"""Run a forward for decode.""" """Run a forward for decode."""
...@@ -68,7 +75,7 @@ class AttentionBackend(ABC): ...@@ -68,7 +75,7 @@ class AttentionBackend(ABC):
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
layer: nn.Module, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
"""Run a forward for extend.""" """Run a forward for extend."""
......
...@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict ...@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -134,8 +135,13 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -134,8 +135,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
) )
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens=None,
): ):
# NOTE: encoder_lens expected to be zeros or None
self.forward_metadata = ( self.forward_metadata = (
self.cuda_graph_start_loc, self.cuda_graph_start_loc,
self.cuda_graph_attn_logits, self.cuda_graph_attn_logits,
...@@ -149,14 +155,18 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -149,14 +155,18 @@ class DoubleSparseAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens=None,
): ):
# NOTE: encoder_lens expected to be zeros or None
self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 1 return 1
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# TODO: reuse the buffer across layers # TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim: if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
...@@ -172,7 +182,7 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -172,7 +182,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
) )
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label layer, forward_batch.out_cache_loc, k, v, k_label
) )
( (
...@@ -201,7 +211,9 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -201,7 +211,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
) )
return o return o
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the # During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly. # output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
...@@ -231,7 +243,7 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -231,7 +243,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
) )
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label layer, forward_batch.out_cache_loc, k, v, k_label
) )
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
......
...@@ -11,7 +11,6 @@ from enum import Enum, auto ...@@ -11,7 +11,6 @@ from enum import Enum, auto
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
import torch.nn as nn
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -21,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch ...@@ -21,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
if is_flashinfer_available(): if is_flashinfer_available():
...@@ -56,13 +56,13 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -56,13 +56,13 @@ class FlashInferAttnBackend(AttentionBackend):
assert not ( assert not (
model_runner.sliding_window_size is not None model_runner.sliding_window_size is not None
and model_runner.has_cross_attention and model_runner.model_config.is_encoder_decoder
), "Sliding window and cross attention are not supported together" ), "Sliding window and cross attention are not supported together"
if model_runner.sliding_window_size is not None: if model_runner.sliding_window_size is not None:
self.num_wrappers = 2 self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
elif model_runner.has_cross_attention: elif model_runner.model_config.is_encoder_decoder:
self.num_wrappers = 2 self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
else: else:
...@@ -128,6 +128,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -128,6 +128,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
forward_batch.seq_lens_sum, forward_batch.seq_lens_sum,
decode_wrappers=None,
encoder_lens=forward_batch.encoder_lens,
) )
self.forward_metadata = (self.decode_wrappers,) self.forward_metadata = (self.decode_wrappers,)
else: else:
...@@ -144,13 +146,11 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -144,13 +146,11 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
prefix_lens, prefix_lens,
use_ragged, use_ragged=use_ragged,
encoder_lens=forward_batch.encoder_lens,
) )
self.forward_metadata = ( self.forward_metadata = (use_ragged, extend_no_prefix)
use_ragged,
extend_no_prefix,
)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int):
cuda_graph_kv_indices = torch.zeros( cuda_graph_kv_indices = torch.zeros(
...@@ -163,7 +163,11 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -163,7 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
] ]
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: torch.Tensor = None,
): ):
decode_wrappers = [] decode_wrappers = []
for i in range(self.num_wrappers): for i in range(self.num_wrappers):
...@@ -181,7 +185,11 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -181,7 +185,11 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum = seq_lens.sum().item() seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update( self.indices_updater_decode.update(
req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers req_pool_indices,
seq_lens,
seq_lens_sum,
decode_wrappers=decode_wrappers,
encoder_lens=encoder_lens,
) )
self.cuda_graph_metadata[bs] = decode_wrappers self.cuda_graph_metadata[bs] = decode_wrappers
self.forward_metadata = (decode_wrappers,) self.forward_metadata = (decode_wrappers,)
...@@ -192,34 +200,42 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -192,34 +200,42 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: torch.Tensor = None,
): ):
self.indices_updater_decode.update( self.indices_updater_decode.update(
req_pool_indices[:bs], req_pool_indices[:bs],
seq_lens[:bs], seq_lens[:bs],
seq_lens_sum, seq_lens_sum,
self.cuda_graph_metadata[bs], decode_wrappers=self.cuda_graph_metadata[bs],
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
) )
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 0 return 0
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
prefill_wrapper_paged = self.prefill_wrappers_paged[ prefill_wrapper_paged = self.prefill_wrappers_paged[
self._get_wrapper_idx(layer) self._get_wrapper_idx(layer)
] ]
use_ragged, extend_no_prefix = self.forward_metadata use_ragged, extend_no_prefix = self.forward_metadata
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not use_ragged: if not use_ragged:
if k is not None: if k is not None:
assert v is not None assert v is not None
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
layer.layer_id, forward_batch.out_cache_loc, k, v
)
o = prefill_wrapper_paged.forward( o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=True, causal=not layer.is_cross_attention,
sm_scale=layer.scaling, sm_scale=layer.scaling,
window_left=layer.sliding_window_size, window_left=layer.sliding_window_size,
logits_soft_cap=layer.logit_cap, logits_soft_cap=layer.logit_cap,
...@@ -247,20 +263,23 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -247,20 +263,23 @@ class FlashInferAttnBackend(AttentionBackend):
o, _ = merge_state(o1, s1, o2, s2) o, _ = merge_state(o1, s1, o2, s2)
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
layer.layer_id, forward_batch.out_cache_loc, k, v
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)] decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if k is not None: if k is not None:
assert v is not None assert v is not None
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
layer.layer_id, forward_batch.out_cache_loc, k, v
)
o = decode_wrapper.forward( o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
...@@ -271,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -271,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def _get_wrapper_idx(self, layer: nn.Module): def _get_wrapper_idx(self, layer: RadixAttention):
if self.num_wrappers == 1: if self.num_wrappers == 1:
return 0 return 0
...@@ -298,6 +317,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -298,6 +317,8 @@ class FlashInferIndicesUpdaterDecode:
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1) self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
self.sliding_window_size = model_runner.sliding_window_size self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
# Buffers and wrappers # Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len self.kv_last_page_len = attn_backend.kv_last_page_len
...@@ -305,20 +326,27 @@ class FlashInferIndicesUpdaterDecode: ...@@ -305,20 +326,27 @@ class FlashInferIndicesUpdaterDecode:
self.decode_wrappers = attn_backend.decode_wrappers self.decode_wrappers = attn_backend.decode_wrappers
# Dispatch # Dispatch
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window self.update = self.update_sliding_window
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
self.update = self.update_cross_attention self.update = self.update_cross_attention
else: else:
assert attn_backend.num_wrappers == 1 assert self.attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper self.update = self.update_single_wrapper
def update(
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
):
# Keep the signature for type checking, will be initialized during runtime
raise NotImplementedError()
def update_single_wrapper( def update_single_wrapper(
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers=None, decode_wrappers=None,
encoder_lens=None,
): ):
decode_wrappers = decode_wrappers or self.decode_wrappers decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward( self.call_begin_forward(
...@@ -336,6 +364,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -336,6 +364,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers=None, decode_wrappers=None,
encoder_lens=None,
): ):
decode_wrappers = decode_wrappers or self.decode_wrappers decode_wrappers = decode_wrappers or self.decode_wrappers
...@@ -363,8 +392,35 @@ class FlashInferIndicesUpdaterDecode: ...@@ -363,8 +392,35 @@ class FlashInferIndicesUpdaterDecode:
kv_start_idx_tmp, kv_start_idx_tmp,
) )
def update_cross_attention(self): def update_cross_attention(
raise NotImplementedError() self,
req_pool_indices,
seq_lens,
seq_lens_sum,
decode_wrappers=None,
encoder_lens=None,
):
decode_wrappers = decode_wrappers or self.decode_wrappers
for wrapper_id in range(2):
if wrapper_id == 0:
# Normal attention
paged_kernel_lens = seq_lens
kv_start_idx = encoder_lens
else:
# Cross attention
paged_kernel_lens = encoder_lens
kv_start_idx = torch.zeros_like(encoder_lens)
seq_lens_sum = encoder_lens.sum().item()
self.call_begin_forward(
decode_wrappers[wrapper_id],
req_pool_indices,
paged_kernel_lens,
seq_lens_sum,
self.kv_indptr[wrapper_id],
kv_start_idx,
)
def call_begin_forward( def call_begin_forward(
self, self,
...@@ -421,6 +477,8 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -421,6 +477,8 @@ class FlashInferIndicesUpdaterPrefill:
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1) self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
self.sliding_window_size = model_runner.sliding_window_size self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
# Buffers and wrappers # Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len self.kv_last_page_len = attn_backend.kv_last_page_len
...@@ -430,16 +488,20 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -430,16 +488,20 @@ class FlashInferIndicesUpdaterPrefill:
self.wrappers_paged = attn_backend.prefill_wrappers_paged self.wrappers_paged = attn_backend.prefill_wrappers_paged
# Dispatch # Dispatch
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window self.update = self.update_sliding_window
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
self.update = self.update_cross_attention self.update = self.update_cross_attention
else: else:
assert attn_backend.num_wrappers == 1 assert self.attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper self.update = self.update_single_wrapper
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
# Keep the signature for type checking, will be initialized during runtime
raise NotImplementedError()
def update_single_wrapper( def update_single_wrapper(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
): ):
if use_ragged: if use_ragged:
paged_kernel_lens = prefix_lens paged_kernel_lens = prefix_lens
...@@ -460,7 +522,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -460,7 +522,7 @@ class FlashInferIndicesUpdaterPrefill:
) )
def update_sliding_window( def update_sliding_window(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -487,8 +549,31 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -487,8 +549,31 @@ class FlashInferIndicesUpdaterPrefill:
use_ragged, use_ragged,
) )
def update_cross_attention(self): def update_cross_attention(
raise NotImplementedError() self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
):
for wrapper_id in range(2):
if wrapper_id == 0:
# normal attention
paged_kernel_lens = seq_lens
kv_start_idx = encoder_lens
else:
# cross attention
paged_kernel_lens = encoder_lens
kv_start_idx = torch.zeros_like(encoder_lens)
self.call_begin_forward(
self.wrapper_ragged,
self.wrappers_paged[wrapper_id],
req_pool_indices,
paged_kernel_lens,
seq_lens,
prefix_lens,
kv_start_idx,
self.kv_indptr[wrapper_id],
self.qo_indptr[wrapper_id],
use_ragged,
)
def call_begin_forward( def call_begin_forward(
self, self,
......
...@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict ...@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -81,8 +82,13 @@ class TritonAttnBackend(AttentionBackend): ...@@ -81,8 +82,13 @@ class TritonAttnBackend(AttentionBackend):
) )
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens=None,
): ):
# NOTE: encoder_lens expected to be zeros or None
self.forward_metadata = ( self.forward_metadata = (
self.cuda_graph_start_loc, self.cuda_graph_start_loc,
self.cuda_graph_attn_logits, self.cuda_graph_attn_logits,
...@@ -96,14 +102,18 @@ class TritonAttnBackend(AttentionBackend): ...@@ -96,14 +102,18 @@ class TritonAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens=None,
): ):
# NOTE: encoder_lens expected to be zeros or None
self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 1 return 1
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# TODO: reuse the buffer across layers # TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim: if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
...@@ -111,7 +121,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -111,7 +121,7 @@ class TritonAttnBackend(AttentionBackend):
o = torch.empty_like(q) o = torch.empty_like(q)
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v layer, forward_batch.out_cache_loc, k, v
) )
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
...@@ -133,7 +143,9 @@ class TritonAttnBackend(AttentionBackend): ...@@ -133,7 +143,9 @@ class TritonAttnBackend(AttentionBackend):
) )
return o return o
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the # During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly. # output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
...@@ -147,7 +159,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -147,7 +159,7 @@ class TritonAttnBackend(AttentionBackend):
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v layer, forward_batch.out_cache_loc, k, v
) )
self.decode_attention_fwd( self.decode_attention_fwd(
......
...@@ -33,26 +33,32 @@ def init_global_processor(server_args: ServerArgs): ...@@ -33,26 +33,32 @@ def init_global_processor(server_args: ServerArgs):
class BaseImageProcessor(ABC): class BaseImageProcessor(ABC):
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
)
@abstractmethod @abstractmethod
async def process_images_async(self, image_data, **kwargs): async def process_images_async(self, image_data, input_text, **kwargs):
pass pass
class DummyImageProcessor(BaseImageProcessor): class DummyImageProcessor(BaseImageProcessor):
def __init__(self):
pass
async def process_images_async(self, *args, **kwargs): async def process_images_async(self, *args, **kwargs):
return None return None
class LlavaImageProcessor(BaseImageProcessor): class LlavaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor): def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config super().__init__(hf_config, server_args, _processor)
self._image_processor = _image_processor
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
)
@staticmethod @staticmethod
def _process_single_image_task( def _process_single_image_task(
...@@ -119,7 +125,7 @@ class LlavaImageProcessor(BaseImageProcessor): ...@@ -119,7 +125,7 @@ class LlavaImageProcessor(BaseImageProcessor):
) )
async def process_images_async( async def process_images_async(
self, image_data: List[Union[str, bytes]], request_obj self, image_data: List[Union[str, bytes]], input_text, request_obj
): ):
if not image_data: if not image_data:
return None return None
...@@ -177,6 +183,54 @@ class LlavaImageProcessor(BaseImageProcessor): ...@@ -177,6 +183,54 @@ class LlavaImageProcessor(BaseImageProcessor):
} }
class MllamaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return global_processor(images, input_text, return_tensors="pt")
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MllamaImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(images, input_text, return_tensors="pt")
return image_inputs
async def process_images_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list):
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text)
image_inputs["image_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
return image_inputs
class Qwen2VLImageProcessor(BaseImageProcessor): class Qwen2VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor): def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config self.hf_config = hf_config
...@@ -237,7 +291,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor): ...@@ -237,7 +291,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
return self._process_single_image_task(image_data) return self._process_single_image_task(image_data)
async def process_images_async( async def process_images_async(
self, image_data: List[Union[str, bytes]], request_obj self, image_data: List[Union[str, bytes]], input_text, request_obj
): ):
if not image_data: if not image_data:
return None return None
...@@ -292,12 +346,14 @@ class Qwen2VLImageProcessor(BaseImageProcessor): ...@@ -292,12 +346,14 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
def get_image_processor( def get_image_processor(
hf_config, server_args: ServerArgs, _image_processor hf_config, server_args: ServerArgs, processor
) -> BaseImageProcessor: ) -> BaseImageProcessor:
if "Qwen2VLForConditionalGeneration" in hf_config.architectures: if "MllamaForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, _image_processor) return MllamaImageProcessor(hf_config, server_args, processor)
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
else: else:
return LlavaImageProcessor(hf_config, server_args, _image_processor) return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
def get_dummy_image_processor(): def get_dummy_image_processor():
......
...@@ -36,6 +36,7 @@ from typing import List, Optional, Tuple, Union ...@@ -36,6 +36,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained import RegexGuide from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
...@@ -121,11 +122,12 @@ class ImageInputs: ...@@ -121,11 +122,12 @@ class ImageInputs:
"""The image related inputs.""" """The image related inputs."""
pixel_values: torch.Tensor pixel_values: torch.Tensor
image_hash: int image_hashes: Optional[list] = None
image_sizes: Optional[list] = None image_sizes: Optional[list] = None
image_offsets: Optional[list] = None image_offsets: Optional[list] = None
pad_values: Optional[list] = None pad_values: Optional[list] = None
modalities: Optional[list] = None modalities: Optional[list] = None
num_image_tokens: Optional[int] = None
image_embeds: Optional[List[torch.Tensor]] = None image_embeds: Optional[List[torch.Tensor]] = None
aspect_ratio_ids: Optional[List[torch.Tensor]] = None aspect_ratio_ids: Optional[List[torch.Tensor]] = None
...@@ -138,19 +140,27 @@ class ImageInputs: ...@@ -138,19 +140,27 @@ class ImageInputs:
# Use image hash as fake token_ids, which is then used for prefix matching # Use image hash as fake token_ids, which is then used for prefix matching
ret = ImageInputs( ret = ImageInputs(
pixel_values=obj["pixel_values"], pixel_values=obj["pixel_values"],
image_hash=hash(tuple(obj["image_hashes"])), image_hashes=hash(tuple(obj["image_hashes"])),
image_grid_thws=obj.get("image_grid_thws"),
) )
image_hash = ret.image_hash image_hash = ret.image_hashes
ret.pad_values = [ ret.pad_values = [
(image_hash) % vocab_size, (image_hash) % vocab_size,
(image_hash >> 16) % vocab_size, (image_hash >> 16) % vocab_size,
(image_hash >> 32) % vocab_size, (image_hash >> 32) % vocab_size,
(image_hash >> 64) % vocab_size, (image_hash >> 64) % vocab_size,
] ]
ret.image_sizes = obj["image_sizes"]
# Only when pixel values is not None we have modalities optional_args = [
ret.modalities = obj["modalities"] or ["image"] "image_sizes",
"modalities",
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
]
for arg in optional_args:
if arg in obj:
setattr(ret, arg, obj[arg])
return ret return ret
...@@ -416,6 +426,10 @@ class ScheduleBatch: ...@@ -416,6 +426,10 @@ class ScheduleBatch:
req_to_token_pool: ReqToTokenPool = None req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool: BaseTokenToKVPool = None token_to_kv_pool: BaseTokenToKVPool = None
tree_cache: BasePrefixCache = None tree_cache: BasePrefixCache = None
# For utility
model_config: ModelConfig = None
forward_mode: ForwardMode = None forward_mode: ForwardMode = None
sampling_info: SamplingBatchInfo = None sampling_info: SamplingBatchInfo = None
...@@ -440,6 +454,12 @@ class ScheduleBatch: ...@@ -440,6 +454,12 @@ class ScheduleBatch:
extend_num_tokens: int = None extend_num_tokens: int = None
decoding_reqs: List[Req] = None decoding_reqs: List[Req] = None
# For encoder-decoder
encoder_cached: Optional[List[bool]] = None
encoder_lens: Optional[torch.Tensor] = None
encoder_lens_cpu: Optional[List[int]] = None
encoder_out_cache_loc: Optional[torch.Tensor] = None
# Stream # Stream
has_stream: bool = False has_stream: bool = False
...@@ -450,12 +470,20 @@ class ScheduleBatch: ...@@ -450,12 +470,20 @@ class ScheduleBatch:
device: str = "cuda" device: str = "cuda"
@classmethod @classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): def init_new(
cls,
reqs,
req_to_token_pool,
token_to_kv_pool,
tree_cache,
model_config,
):
return cls( return cls(
reqs=reqs, reqs=reqs,
req_to_token_pool=req_to_token_pool, req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool, token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache, tree_cache=tree_cache,
model_config=model_config,
return_logprob=any(req.return_logprob for req in reqs), return_logprob=any(req.return_logprob for req in reqs),
has_stream=any(req.stream for req in reqs), has_stream=any(req.stream for req in reqs),
has_regex=any(req.regex_fsm for req in reqs), has_regex=any(req.regex_fsm for req in reqs),
...@@ -493,7 +521,78 @@ class ScheduleBatch: ...@@ -493,7 +521,78 @@ class ScheduleBatch:
return out_cache_loc return out_cache_loc
def prepare_for_extend(self, vocab_size: int): def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = []
self.encoder_cached = []
for req in self.reqs:
im = req.image_inputs
if im is None or im.num_image_tokens is None:
# No image input
self.encoder_lens_cpu.append(0)
self.encoder_cached.append(True)
else:
self.encoder_lens_cpu.append(im.num_image_tokens)
self.encoder_cached.append(
self.forward_mode.is_decode()
or len(req.prefix_indices) >= im.num_image_tokens
)
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
self.device, non_blocking=True
)
# Strip encoder infos
pt = 0
decoder_out_cache_loc = []
encoder_out_cache_loc = []
for i, req in enumerate(self.reqs):
encoder_len = self.encoder_lens_cpu[i]
seq_lens[i] -= encoder_len
if len(req.prefix_indices) < encoder_len:
# NOTE: the encoder part should considered as a whole
assert len(req.prefix_indices) == 0
input_ids[i] = input_ids[i][encoder_len:]
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
decoder_out_cache_loc.append(
self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
)
self.extend_lens[i] -= encoder_len
self.extend_num_tokens -= encoder_len
else:
decoder_out_cache_loc.append(
self.out_cache_loc[pt : pt + req.extend_input_len]
)
self.prefix_lens[i] -= encoder_len
pt += req.extend_input_len
# Reassign
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
if not decoder_out_cache_loc:
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
else:
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
if not encoder_out_cache_loc:
self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
else:
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
assert len(self.out_cache_loc) == self.extend_num_tokens
def prepare_for_extend(self):
self.forward_mode = ForwardMode.EXTEND self.forward_mode = ForwardMode.EXTEND
bs = len(self.reqs) bs = len(self.reqs)
...@@ -561,8 +660,13 @@ class ScheduleBatch: ...@@ -561,8 +660,13 @@ class ScheduleBatch:
self.extend_lens = [r.extend_input_len for r in reqs] self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
if self.model_config.is_encoder_decoder:
self.prepare_encoder_info_extend(input_ids, seq_lens)
self.sampling_info = SamplingBatchInfo.from_schedule_batch( self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self, vocab_size, global_server_args_dict["disable_penalizer"] self,
self.model_config.vocab_size,
global_server_args_dict["disable_penalizer"],
) )
def mix_with_running(self, running_batch: "ScheduleBatch"): def mix_with_running(self, running_batch: "ScheduleBatch"):
...@@ -752,6 +856,10 @@ class ScheduleBatch: ...@@ -752,6 +856,10 @@ class ScheduleBatch:
return jump_forward_reqs return jump_forward_reqs
def prepare_encoder_info_decode(self):
# Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs)
def prepare_for_decode(self, enable_overlap: bool = False): def prepare_for_decode(self, enable_overlap: bool = False):
self.forward_mode = ForwardMode.DECODE self.forward_mode = ForwardMode.DECODE
...@@ -766,16 +874,22 @@ class ScheduleBatch: ...@@ -766,16 +874,22 @@ class ScheduleBatch:
bs = len(self.reqs) bs = len(self.reqs)
self.out_cache_loc = self.alloc_token_slots(bs) self.out_cache_loc = self.alloc_token_slots(bs)
if self.model_config.is_encoder_decoder:
locs = self.encoder_lens + self.seq_lens
self.prepare_encoder_info_decode()
else:
locs = self.seq_lens
if enable_overlap: if enable_overlap:
# Do not use in-place operations in the overlap mode # Do not use in-place operations in the overlap mode
self.req_to_token_pool.write( self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc (self.req_pool_indices, locs), self.out_cache_loc
) )
self.seq_lens = self.seq_lens + 1 self.seq_lens = self.seq_lens + 1
else: else:
# A faster in-place version # A faster in-place version
self.req_to_token_pool.write( self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc (self.req_pool_indices, locs), self.out_cache_loc
) )
self.seq_lens.add_(1) self.seq_lens.add_(1)
self.seq_lens_sum += bs self.seq_lens_sum += bs
...@@ -802,6 +916,10 @@ class ScheduleBatch: ...@@ -802,6 +916,10 @@ class ScheduleBatch:
# No need to filter # No need to filter
return return
if self.model_config.is_encoder_decoder:
self.encoder_lens = self.encoder_lens[keep_indices]
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
self.reqs = [self.reqs[i] for i in keep_indices] self.reqs = [self.reqs[i] for i in keep_indices]
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to( new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
self.device, non_blocking=True self.device, non_blocking=True
...@@ -828,6 +946,11 @@ class ScheduleBatch: ...@@ -828,6 +946,11 @@ class ScheduleBatch:
# needs to be called with pre-merged Batch.reqs. # needs to be called with pre-merged Batch.reqs.
self.sampling_info.merge_batch(other.sampling_info) self.sampling_info.merge_batch(other.sampling_info)
# Encoder-decoder infos
if self.model_config.is_encoder_decoder:
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
self.req_pool_indices = torch.concat( self.req_pool_indices = torch.concat(
[self.req_pool_indices, other.req_pool_indices] [self.req_pool_indices, other.req_pool_indices]
) )
...@@ -850,14 +973,11 @@ class ScheduleBatch: ...@@ -850,14 +973,11 @@ class ScheduleBatch:
def get_model_worker_batch(self): def get_model_worker_batch(self):
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = ( extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
image_inputs
) = None
else: else:
extend_seq_lens = self.extend_lens extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens extend_logprob_start_lens = self.extend_logprob_start_lens
image_inputs = [r.image_inputs for r in self.reqs]
if self.has_regex: if self.has_regex:
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs] self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
...@@ -887,7 +1007,11 @@ class ScheduleBatch: ...@@ -887,7 +1007,11 @@ class ScheduleBatch:
extend_seq_lens=extend_seq_lens, extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens, extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens, extend_logprob_start_lens=extend_logprob_start_lens,
image_inputs=image_inputs, image_inputs=[r.image_inputs for r in self.reqs],
encoder_cached=self.encoder_cached,
encoder_lens=self.encoder_lens,
encoder_lens_cpu=self.encoder_lens_cpu,
encoder_out_cache_loc=self.encoder_out_cache_loc,
lora_paths=[req.lora_path for req in self.reqs], lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info, sampling_info=self.sampling_info,
mrope_positions_delta=mrope_positions_delta, mrope_positions_delta=mrope_positions_delta,
...@@ -897,6 +1021,7 @@ class ScheduleBatch: ...@@ -897,6 +1021,7 @@ class ScheduleBatch:
# Only contain fields that will be used by process_batch_result # Only contain fields that will be used by process_batch_result
return ScheduleBatch( return ScheduleBatch(
reqs=self.reqs, reqs=self.reqs,
model_config=self.model_config,
forward_mode=self.forward_mode, forward_mode=self.forward_mode,
out_cache_loc=self.out_cache_loc, out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob, return_logprob=self.return_logprob,
...@@ -944,6 +1069,12 @@ class ModelWorkerBatch: ...@@ -944,6 +1069,12 @@ class ModelWorkerBatch:
# For multimodal # For multimodal
image_inputs: Optional[List[ImageInputs]] image_inputs: Optional[List[ImageInputs]]
# For encoder-decoder
encoder_cached: Optional[List[bool]]
encoder_lens: Optional[torch.Tensor]
encoder_lens_cpu: Optional[List[int]]
encoder_out_cache_loc: Optional[torch.Tensor]
# For LoRA # For LoRA
lora_paths: Optional[List[str]] lora_paths: Optional[List[str]]
......
...@@ -662,8 +662,9 @@ class Scheduler: ...@@ -662,8 +662,9 @@ class Scheduler:
self.req_to_token_pool, self.req_to_token_pool,
self.token_to_kv_pool, self.token_to_kv_pool,
self.tree_cache, self.tree_cache,
self.model_config,
) )
new_batch.prepare_for_extend(self.model_config.vocab_size) new_batch.prepare_for_extend()
# Mixed-style chunked prefill # Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None: if self.is_mixed_chunk and self.running_batch is not None:
......
...@@ -122,7 +122,7 @@ class TokenizerManager: ...@@ -122,7 +122,7 @@ class TokenizerManager:
# We want to parallelize the image pre-processing so we create an executor for it # We want to parallelize the image pre-processing so we create an executor for it
self.image_processor = get_image_processor( self.image_processor = get_image_processor(
self.hf_config, server_args, self.processor.image_processor self.hf_config, server_args, self.processor
) )
else: else:
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
...@@ -191,8 +191,10 @@ class TokenizerManager: ...@@ -191,8 +191,10 @@ class TokenizerManager:
sampling_params = self._get_sampling_params(obj.sampling_params) sampling_params = self._get_sampling_params(obj.sampling_params)
if self.is_generation: if self.is_generation:
image_inputs = await self.image_processor.process_images_async( image_inputs = await self.image_processor.process_images_async(
obj.image_data, obj obj.image_data, input_text or input_ids, obj
) )
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num top_logprobs_num = obj.top_logprobs_num
...@@ -217,8 +219,10 @@ class TokenizerManager: ...@@ -217,8 +219,10 @@ class TokenizerManager:
sampling_params = self._get_sampling_params(obj.sampling_params[index]) sampling_params = self._get_sampling_params(obj.sampling_params[index])
if self.is_generation: if self.is_generation:
image_inputs = await self.image_processor.process_images_async( image_inputs = await self.image_processor.process_images_async(
obj.image_data[index], obj obj.image_data[index], input_text or input_ids, obj
) )
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob[index] return_logprob = obj.return_logprob[index]
logprob_start_len = obj.logprob_start_len[index] logprob_start_len = obj.logprob_start_len[index]
top_logprobs_num = obj.top_logprobs_num[index] top_logprobs_num = obj.top_logprobs_num[index]
...@@ -263,8 +267,10 @@ class TokenizerManager: ...@@ -263,8 +267,10 @@ class TokenizerManager:
sampling_params = SamplingParams(**obj.sampling_params[0]) sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0 sampling_params.max_new_tokens = 0
image_inputs = await self.image_processor.process_images_async( image_inputs = await self.image_processor.process_images_async(
obj.image_data[0], obj obj.image_data[0], input_text or input_ids, obj
) )
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob[0] return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0] logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0] top_logprobs_num = obj.top_logprobs_num[0]
......
...@@ -26,6 +26,8 @@ from typing import List, Tuple, Union ...@@ -26,6 +26,8 @@ from typing import List, Tuple, Union
import torch import torch
from sglang.srt.layers.radix_attention import RadixAttention
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -41,13 +43,17 @@ class ReqToTokenPool: ...@@ -41,13 +43,17 @@ class ReqToTokenPool:
) )
self.free_slots = list(range(size)) self.free_slots = list(range(size))
self.write_records = [] self.write_records = []
self.use_records = use_records
if use_records: if self.use_records:
# records all write operations
self.write = self.write_with_records self.write = self.write_with_records
else: else:
self.write = self.write_without_records self.write = self.write_without_records
def write(self, indices, values):
# Keep the signature for type checking, will be initialized during runtime
raise NotImplementedError()
def available_size(self): def available_size(self):
return len(self.free_slots) return len(self.free_slots)
...@@ -154,7 +160,7 @@ class BaseTokenToKVPool: ...@@ -154,7 +160,7 @@ class BaseTokenToKVPool:
def set_kv_buffer( def set_kv_buffer(
self, self,
layer_id: int, layer: RadixAttention,
loc: torch.Tensor, loc: torch.Tensor,
cache_k: torch.Tensor, cache_k: torch.Tensor,
cache_v: torch.Tensor, cache_v: torch.Tensor,
...@@ -209,11 +215,12 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -209,11 +215,12 @@ class MHATokenToKVPool(BaseTokenToKVPool):
def set_kv_buffer( def set_kv_buffer(
self, self,
layer_id: int, layer: RadixAttention,
loc: torch.Tensor, loc: torch.Tensor,
cache_k: torch.Tensor, cache_k: torch.Tensor,
cache_v: torch.Tensor, cache_v: torch.Tensor,
): ):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype: if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype) cache_k = cache_k.to(self.dtype)
if cache_v.dtype != self.dtype: if cache_v.dtype != self.dtype:
...@@ -265,11 +272,12 @@ class MLATokenToKVPool(BaseTokenToKVPool): ...@@ -265,11 +272,12 @@ class MLATokenToKVPool(BaseTokenToKVPool):
def set_kv_buffer( def set_kv_buffer(
self, self,
layer_id: int, layer: RadixAttention,
loc: torch.Tensor, loc: torch.Tensor,
cache_k: torch.Tensor, cache_k: torch.Tensor,
cache_v: torch.Tensor, cache_v: torch.Tensor,
): ):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype: if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype) cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
...@@ -324,13 +332,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool): ...@@ -324,13 +332,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
def set_kv_buffer( def set_kv_buffer(
self, self,
layer_id: int, layer: RadixAttention,
loc: torch.Tensor, loc: torch.Tensor,
cache_k: torch.Tensor, cache_k: torch.Tensor,
cache_v: torch.Tensor, cache_v: torch.Tensor,
cache_label: torch.Tensor, cache_label: torch.Tensor,
): ):
# NOTE(Andy): ignore the dtype check # NOTE(Andy): ignore the dtype check
layer_id = layer.layer_id
self.k_buffer[layer_id][loc] = cache_k self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v self.v_buffer[layer_id][loc] = cache_v
self.label_buffer[layer_id][loc] = cache_label self.label_buffer[layer_id][loc] = cache_label
...@@ -105,6 +105,7 @@ class CudaGraphRunner: ...@@ -105,6 +105,7 @@ class CudaGraphRunner:
self.graph_memory_pool = None self.graph_memory_pool = None
self.use_torch_compile = model_runner.server_args.enable_torch_compile self.use_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
# Batch sizes to capture # Batch sizes to capture
if self.model_runner.server_args.disable_cuda_graph_padding: if self.model_runner.server_args.disable_cuda_graph_padding:
...@@ -132,6 +133,9 @@ class CudaGraphRunner: ...@@ -132,6 +133,9 @@ class CudaGraphRunner:
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
) )
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0
if self.use_torch_compile: if self.use_torch_compile:
set_torch_compile_config() set_torch_compile_config()
...@@ -144,9 +148,18 @@ class CudaGraphRunner: ...@@ -144,9 +148,18 @@ class CudaGraphRunner:
) )
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32) self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
if self.is_encoder_decoder:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
self.encoder_lens = torch.full(
(self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32
)
else:
self.encoder_lens = None
# Capture # Capture
try: try:
self.capture() with self.model_capture_mode():
self.capture()
except RuntimeError as e: except RuntimeError as e:
raise Exception( raise Exception(
f"Capture cuda graph failed: {e}\n" f"Capture cuda graph failed: {e}\n"
...@@ -157,11 +170,32 @@ class CudaGraphRunner: ...@@ -157,11 +170,32 @@ class CudaGraphRunner:
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
) )
def can_run(self, batch_size: int): @contextmanager
if self.disable_padding: def model_capture_mode(self):
return batch_size in self.graphs if hasattr(self.model_runner.model, "capture_mode"):
else: self.model_runner.model.capture_mode = True
return batch_size <= self.max_bs
yield
if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = False
def can_run(self, forward_batch: ForwardBatch):
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
# because the full_text_row_masked_out_mask tensor will always be ones
is_encoder_lens_supported = (
torch.all(forward_batch.encoder_lens > 0)
if self.is_encoder_decoder
else True
)
return is_bs_supported and is_encoder_lens_supported
def capture(self): def capture(self):
with graph_capture() as graph_capture_context: with graph_capture() as graph_capture_context:
...@@ -188,11 +222,19 @@ class CudaGraphRunner: ...@@ -188,11 +222,19 @@ class CudaGraphRunner:
req_pool_indices = self.req_pool_indices[:bs] req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs] seq_lens = self.seq_lens[:bs]
out_cache_loc = self.out_cache_loc[:bs] out_cache_loc = self.out_cache_loc[:bs]
if self.is_encoder_decoder:
encoder_lens = self.encoder_lens[:bs]
else:
encoder_lens = None
seq_lens_sum = seq_lens.sum().item() seq_lens_sum = seq_lens.sum().item()
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs, req_pool_indices, seq_lens bs,
req_pool_indices,
seq_lens,
encoder_lens,
) )
# Run and capture # Run and capture
...@@ -208,6 +250,7 @@ class CudaGraphRunner: ...@@ -208,6 +250,7 @@ class CudaGraphRunner:
attn_backend=self.model_runner.attn_backend, attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens_sum, seq_lens_sum=seq_lens_sum,
encoder_lens=encoder_lens,
return_logprob=False, return_logprob=False,
top_logprobs_nums=[0] * bs, top_logprobs_nums=[0] * bs,
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
...@@ -251,6 +294,8 @@ class CudaGraphRunner: ...@@ -251,6 +294,8 @@ class CudaGraphRunner:
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc) self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
...@@ -258,6 +303,7 @@ class CudaGraphRunner: ...@@ -258,6 +303,7 @@ class CudaGraphRunner:
self.req_pool_indices, self.req_pool_indices,
self.seq_lens, self.seq_lens,
forward_batch.seq_lens_sum, forward_batch.seq_lens_sum,
self.encoder_lens,
) )
# Replay # Replay
......
...@@ -108,6 +108,12 @@ class ForwardBatch: ...@@ -108,6 +108,12 @@ class ForwardBatch:
# For multimodal # For multimodal
image_inputs: Optional[List[ImageInputs]] = None image_inputs: Optional[List[ImageInputs]] = None
# Encoder-decoder
encoder_cached: Optional[List[bool]] = None
encoder_lens: Optional[torch.Tensor] = None
encoder_lens_cpu: Optional[List[int]] = None
encoder_out_cache_loc: Optional[torch.Tensor] = None
# For LoRA # For LoRA
lora_paths: Optional[List[str]] = None lora_paths: Optional[List[str]] = None
...@@ -194,6 +200,11 @@ class ForwardBatch: ...@@ -194,6 +200,11 @@ class ForwardBatch:
req_pool_indices=batch.req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens, seq_lens=batch.seq_lens,
out_cache_loc=batch.out_cache_loc, out_cache_loc=batch.out_cache_loc,
image_inputs=batch.image_inputs,
encoder_cached=batch.encoder_cached,
encoder_lens=batch.encoder_lens,
encoder_lens_cpu=batch.encoder_lens_cpu,
encoder_out_cache_loc=batch.encoder_out_cache_loc,
seq_lens_sum=batch.seq_lens_sum, seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
...@@ -212,11 +223,11 @@ class ForwardBatch: ...@@ -212,11 +223,11 @@ class ForwardBatch:
], ],
axis=0, axis=0,
) )
ret.image_inputs = batch.image_inputs
ret.extend_num_tokens = batch.extend_num_tokens ret.extend_num_tokens = batch.extend_num_tokens
ret.extend_seq_lens = torch.tensor( ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32 batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True) ).to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor( ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32 batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True) ).to(device, non_blocking=True)
......
...@@ -270,7 +270,6 @@ class ModelRunner: ...@@ -270,7 +270,6 @@ class ModelRunner:
if hasattr(self.model, "get_attention_sliding_window_size") if hasattr(self.model, "get_attention_sliding_window_size")
else None else None
) )
self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
self.is_generation = is_generation_model( self.is_generation = is_generation_model(
self.model_config.hf_config.architectures, self.server_args.is_embedding self.model_config.hf_config.architectures, self.server_args.is_embedding
) )
...@@ -510,7 +509,7 @@ class ModelRunner: ...@@ -510,7 +509,7 @@ class ModelRunner:
"Window attention is not supported in the triton attention backend. " "Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`." "Please use `--attention-backend flashinfer`."
) )
assert not self.has_cross_attention, ( assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. " "Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`." "Please use `--attention-backend flashinfer`."
) )
...@@ -558,9 +557,7 @@ class ModelRunner: ...@@ -558,9 +557,7 @@ class ModelRunner:
self.cuda_graph_runner = CudaGraphRunner(self) self.cuda_graph_runner = CudaGraphRunner(self)
def forward_decode(self, forward_batch: ForwardBatch): def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run( if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
forward_batch.batch_size
):
return self.cuda_graph_runner.replay(forward_batch) return self.cuda_graph_runner.replay(forward_batch)
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64) forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
......
This diff is collapsed.
...@@ -605,7 +605,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -605,7 +605,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
] ]
positions = forward_batch.mrope_positions positions = forward_batch.mrope_positions
if image_inputs is None or len(image_inputs) == 0: if (
forward_batch.forward_mode.is_decode()
or image_inputs is None
or len(image_inputs) == 0
):
inputs_embeds = self.model.embed_tokens(input_ids) inputs_embeds = self.model.embed_tokens(input_ids)
else: else:
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
......
...@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures): ...@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures):
or "LlavaQwenForCausalLM" in model_architectures or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures or "Qwen2VLForConditionalGeneration" in model_architectures
): ):
return True return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment