Unverified Commit a7c3f74b authored by Chunan Zeng's avatar Chunan Zeng Committed by GitHub
Browse files

[FA3 Feature] Support multi modal Llama-3.2-11B-Vision-Instruct (#5103)

parent 5a144a8a
......@@ -86,8 +86,8 @@ def eval_mmmu(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = add_common_sglang_args_and_parse(parser)
EvalArgs.add_cli_args(parser)
args = add_common_sglang_args_and_parse(parser)
args = parser.parse_args()
eval_mmmu(args)
......@@ -42,6 +42,16 @@ class FlashAttentionMetadata:
# Page table, the index of KV Cache Tables/Blocks
page_table: torch.Tensor = None
# Encoder metadata
# Cumulative sequence lengths for encoder key
encoder_cu_seqlens_k: torch.Tensor = None
# Maximum sequence length for encoder key
encoder_max_seq_len_k: int = 0
# Sequence lengths for the forward batch
encoder_lens_int32: torch.Tensor = None
# Page table for the encoder
encoder_page_table: torch.Tensor = None
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
......@@ -435,6 +445,30 @@ class FlashAttentionBackend(AttentionBackend):
)
metadata.local_attn_metadata = local_metadata
# Encoder metadata for cross attention
if forward_batch.encoder_lens is not None:
assert (
forward_batch.encoder_lens.numel() == 1
), "Only encoder size 1 is supported for now"
metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
(1, 0),
)
metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
]
# Currently only support forward_batch.encoder_lens.numel() == 1
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices,
metadata.encoder_max_seq_len_k : (
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
),
]
# Convert the page table to a strided format which is needed by FA3 API
if self.page_size > 1:
self.strided_indices = torch.arange(
......@@ -486,6 +520,7 @@ class FlashAttentionBackend(AttentionBackend):
if layer.sliding_window_size is not None
else (-1, -1)
)
causal = not layer.is_cross_attention
# Check if we should use local attention
use_local_attn = (
......@@ -521,6 +556,12 @@ class FlashAttentionBackend(AttentionBackend):
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
if layer.is_cross_attention:
page_table = metadata.encoder_page_table
cache_seqlens = metadata.encoder_lens_int32
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)
o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
......@@ -531,7 +572,7 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
causal=causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
......@@ -614,6 +655,7 @@ class FlashAttentionBackend(AttentionBackend):
if layer.sliding_window_size is not None
else (-1, -1)
)
causal = not layer.is_cross_attention
if not self.use_mla:
# Do multi-head attention
......@@ -627,17 +669,27 @@ class FlashAttentionBackend(AttentionBackend):
)
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
if layer.is_cross_attention:
page_table = metadata.encoder_page_table
cache_seqlens = metadata.encoder_lens_int32
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)
else:
page_table = metadata.page_table
cache_seqlens = metadata.cache_seqlens_int32
cu_seqlens_k = metadata.cu_seqlens_k
o = flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=metadata.page_table,
cache_seqlens=metadata.cache_seqlens_int32,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=True,
causal=causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
......@@ -733,6 +785,21 @@ class FlashAttentionBackend(AttentionBackend):
),
}
self.encoder_metadata = {
"encoder_page_table": torch.zeros(
max_bs,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
"encoder_lens_int32": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"encoder_cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
}
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
......@@ -818,6 +885,19 @@ class FlashAttentionBackend(AttentionBackend):
self.target_verify_metadata[bs] = metadata
if encoder_lens is not None:
encoder_bs = encoder_lens.numel()
metadata.encoder_lens_int32 = self.encoder_metadata["encoder_lens_int32"][
:encoder_bs
]
metadata.encoder_cu_seqlens_k = self.encoder_metadata[
"encoder_cu_seqlens_k"
][: (encoder_bs + 1)]
metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
req_pool_indices, :
]
self.forward_metadata = metadata
def init_forward_metadata_replay_cuda_graph(
......@@ -903,6 +983,30 @@ class FlashAttentionBackend(AttentionBackend):
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
if encoder_lens is not None:
# Only support encoder size 1 for now
metadata.encoder_max_seq_len_k = encoder_lens[0]
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
metadata.encoder_cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
(1, 0),
)
)
metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]
)
# Update the regular page table
page_table = self.req_to_token[
req_pool_indices,
metadata.encoder_max_seq_len_k : (
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
),
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
self.forward_metadata = metadata
def get_cuda_graph_seq_len_fill_value(self):
......@@ -956,7 +1060,7 @@ class FlashAttentionMultiStepBackend:
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
encoder_lens=forward_batch.encoder_lens,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
......@@ -973,7 +1077,7 @@ class FlashAttentionMultiStepBackend:
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
encoder_lens=None,
encoder_lens=forward_batch.encoder_lens,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu,
......
......@@ -886,7 +886,7 @@ class ModelRunner:
"Please use `--attention-backend flashinfer`."
)
logger.warning(
"FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
"FlashAttention v3 Backend is in Beta. FP8 is not supported."
)
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment