Unverified Commit 603b3446 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix FA3 swa spec verify topk>1 (#9658)

parent b6c14ec0
...@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Optional, Union ...@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Optional, Union
import numpy as np import numpy as np
import torch import torch
import triton
import triton.language as tl
from sglang.srt.configs.model_config import AttentionArch from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...@@ -64,6 +66,9 @@ class FlashAttentionMetadata: ...@@ -64,6 +66,9 @@ class FlashAttentionMetadata:
local_attn_metadata: Optional[LocalAttentionMetadata] = None local_attn_metadata: Optional[LocalAttentionMetadata] = None
# For sliding window attention topk>1 spec decoding
swa_spec_metadata: Optional[FlashAttentionMetadata] = None
# Copied from: # Copied from:
# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py # https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
...@@ -340,6 +345,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -340,6 +345,13 @@ class FlashAttentionBackend(AttentionBackend):
else None else None
) )
# For each layer, the sliding_window_size can be different. This is only used for preparing SWA metadata.
# We use `layer.sliding_window_size` to decide whether to use SWA for each layer.
self.sliding_window_size = model_runner.sliding_window_size
self.has_swa = (
self.sliding_window_size is not None and self.sliding_window_size > -1
)
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata hence all layers in the forward pass can reuse it.""" """Initialize forward metadata hence all layers in the forward pass can reuse it."""
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()
...@@ -556,6 +568,12 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -556,6 +568,12 @@ class FlashAttentionBackend(AttentionBackend):
(1, 0), (1, 0),
) )
self.forward_metadata_spec_decode_expand = metadata_expand self.forward_metadata_spec_decode_expand = metadata_expand
if self.has_swa:
self._init_sliding_window_attn_spec_metadata(
metadata, metadata_expand
)
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed(): elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
...@@ -657,11 +675,10 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -657,11 +675,10 @@ class FlashAttentionBackend(AttentionBackend):
# Calculate window size (can be moved to metadata if layer properties don't change) # Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1 # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
# here is two side inclusive # here is two side inclusive
window_size = ( is_swa = (
(layer.sliding_window_size, 0) layer.sliding_window_size is not None and layer.sliding_window_size > -1
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
else (-1, -1)
) )
window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
k_descale, v_descale = None, None k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None, # has corresponding quantization method so that layer.k_scale is not None,
...@@ -684,8 +701,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -684,8 +701,13 @@ class FlashAttentionBackend(AttentionBackend):
) )
# We do cascade attention for Target Verify with topk > 1 # We do cascade attention for Target Verify with topk > 1
# We don't use cascade attention for Sliding Window Attention:
# - Different window sizes should be passed in for each q in the first stage of cascade attention, but FA3 interface doesn't support pass in a list of window sizes.
# - The overhead of duplicated computation of the common prefix part is small for sliding window layers (seq_len <= window_size), so we can just expand it.
use_cascade_attn = ( use_cascade_attn = (
forward_batch.forward_mode.is_target_verify() and self.topk > 1 forward_batch.forward_mode.is_target_verify()
and self.topk > 1
and not is_swa
) )
# For fa3 interface version compatibility, we put new fields into conditional keyword args # For fa3 interface version compatibility, we put new fields into conditional keyword args
...@@ -700,13 +722,18 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -700,13 +722,18 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_q = local_metadata.local_query_start_loc cu_seqlens_q = local_metadata.local_query_start_loc
cache_seqlens = local_metadata.local_seqused_k cache_seqlens = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len elif is_swa and metadata.swa_spec_metadata is not None:
swa_spec_metadata = metadata.swa_spec_metadata
page_table = swa_spec_metadata.page_table
cu_seqlens_q = swa_spec_metadata.cu_seqlens_q
cache_seqlens = swa_spec_metadata.cache_seqlens_int32
max_seqlen_q = swa_spec_metadata.max_seq_len_q
cu_seqlens_k = swa_spec_metadata.cu_seqlens_k
else: else:
page_table = metadata.page_table page_table = metadata.page_table
cu_seqlens_q = metadata.cu_seqlens_q cu_seqlens_q = metadata.cu_seqlens_q
cache_seqlens = metadata.cache_seqlens_int32 cache_seqlens = metadata.cache_seqlens_int32
max_seqlen_q = metadata.max_seq_len_q max_seqlen_q = metadata.max_seq_len_q
max_seqlen_k = metadata.max_seq_len_k
cu_seqlens_k = metadata.cu_seqlens_k cu_seqlens_k = metadata.cu_seqlens_k
# Use Flash Attention for prefill # Use Flash Attention for prefill
...@@ -1377,6 +1404,32 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1377,6 +1404,32 @@ class FlashAttentionBackend(AttentionBackend):
), ),
} }
if self.has_swa:
self.target_verify_metadata_topk_swa = {
"cache_seqlens": torch.zeros(
max_bs * self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs * self.speculative_num_draft_tokens + 1,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.speculative_num_draft_tokens + 1,
dtype=torch.int32,
device=self.device,
),
"page_table": torch.zeros(
max_bs * self.speculative_num_draft_tokens,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
}
self.encoder_metadata = { self.encoder_metadata = {
"encoder_page_table": torch.zeros( "encoder_page_table": torch.zeros(
max_bs, max_bs,
...@@ -1564,6 +1617,28 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1564,6 +1617,28 @@ class FlashAttentionBackend(AttentionBackend):
self.target_verify_metadata_topk_normal[bs] = metadata self.target_verify_metadata_topk_normal[bs] = metadata
self.target_verify_metadata_topk_expand[bs] = metadata_expand self.target_verify_metadata_topk_expand[bs] = metadata_expand
if self.has_swa:
metadata_swa = FlashAttentionMetadata()
metadata_swa.cache_seqlens_int32 = (
self.target_verify_metadata_topk_swa["cache_seqlens"][
: bs * self.speculative_num_draft_tokens
]
)
metadata_swa.max_seq_len_q = 1
metadata_swa.cu_seqlens_q = self.target_verify_metadata_topk_swa[
"cu_seqlens_q"
][: bs * self.speculative_num_draft_tokens + 1]
metadata_swa.cu_seqlens_k = self.target_verify_metadata_topk_swa[
"cu_seqlens_k"
][: bs * self.speculative_num_draft_tokens + 1]
metadata_swa.page_table = self.target_verify_metadata_topk_swa[
"page_table"
][: bs * self.speculative_num_draft_tokens]
self.target_verify_metadata_topk_swa[bs] = metadata_swa
metadata.swa_spec_metadata = metadata_swa
elif forward_mode.is_draft_extend(): elif forward_mode.is_draft_extend():
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][ metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
:bs :bs
...@@ -1804,6 +1879,12 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1804,6 +1879,12 @@ class FlashAttentionBackend(AttentionBackend):
) )
) )
if self.has_swa:
metadata_swa = self.target_verify_metadata_topk_swa[bs]
self._init_sliding_window_attn_spec_metadata(
metadata, metadata_expand, metadata_swa
)
elif forward_mode.is_draft_extend(): elif forward_mode.is_draft_extend():
metadata = self.draft_extend_metadata[bs] metadata = self.draft_extend_metadata[bs]
metadata.cache_seqlens_int32.copy_(seq_lens) metadata.cache_seqlens_int32.copy_(seq_lens)
...@@ -2039,6 +2120,159 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -2039,6 +2120,159 @@ class FlashAttentionBackend(AttentionBackend):
lam.local_max_query_len = int(seqlens_q_local_np.max()) lam.local_max_query_len = int(seqlens_q_local_np.max())
lam.local_max_seq_len = int(seqlens_k_local_np.max()) lam.local_max_seq_len = int(seqlens_k_local_np.max())
def _init_sliding_window_attn_spec_metadata(
self,
metadata: FlashAttentionMetadata,
metadata_expand: FlashAttentionMetadata,
metadata_swa: Optional[FlashAttentionMetadata] = None,
):
# TODO: support page_size > 1 for swa spec
assert (
self.page_size == 1
), "FlashAttention backend doesn't support topk > 1 speculative decoding with page size > 1 sliding window attention"
cache_seqlens_int32 = (
metadata.cache_seqlens_int32.repeat_interleave(
self.speculative_num_draft_tokens
)
+ metadata_expand.cache_seqlens_int32
)
cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0)
)
bs = cache_seqlens_int32.shape[0]
page_table = (
metadata.page_table.new_zeros(
(bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1])
)
if metadata_swa is None
else metadata_swa.page_table
)
prepare_swa_spec_page_table_triton(
page_table,
metadata.page_table,
metadata_expand.page_table,
metadata.cache_seqlens_int32,
metadata_expand.cache_seqlens_int32,
self.speculative_num_draft_tokens,
)
if metadata_swa is None:
metadata_swa = FlashAttentionMetadata()
metadata_swa.max_seq_len_q = 1
metadata_swa.cu_seqlens_q = metadata_expand.cu_seqlens_q
metadata_swa.cache_seqlens_int32 = cache_seqlens_int32
metadata_swa.cu_seqlens_k = cu_seqlens_k
metadata_swa.page_table = page_table
else:
metadata_swa.cache_seqlens_int32.copy_(cache_seqlens_int32)
metadata_swa.cu_seqlens_k.copy_(cu_seqlens_k)
metadata.swa_spec_metadata = metadata_swa
@triton.jit
def _prepare_swa_spec_page_table_kernel(
dst_ptr,
src_a_ptr,
src_b_ptr,
seq_len_a_ptr,
seq_len_b_ptr,
dst_stride_m,
dst_stride_n,
a_stride_m,
a_stride_n,
b_stride_m,
b_stride_n,
LEN_A: tl.constexpr,
LEN_B: tl.constexpr,
REPEAT_STEP: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
idx_a = pid_m // REPEAT_STEP
idx_b = pid_m
seq_len_a = tl.load(seq_len_a_ptr + idx_a)
seq_len_b = tl.load(seq_len_b_ptr + idx_b)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
total_len = seq_len_a + seq_len_b
if pid_n * BLOCK_N >= total_len:
return
mask = offs_n < total_len
dst = dst_ptr + pid_m * dst_stride_m + offs_n * dst_stride_n
if (pid_n + 1) * BLOCK_N < seq_len_a:
a_ptr = src_a_ptr + idx_a * a_stride_m + offs_n * a_stride_n
a_mask = mask & (offs_n < LEN_A)
val = tl.load(a_ptr, mask=a_mask, other=0)
tl.store(dst, val, mask=mask)
elif pid_n * BLOCK_N >= seq_len_a:
offs_b = offs_n - seq_len_a
b_ptr = src_b_ptr + idx_b * b_stride_m + offs_b * b_stride_n
b_mask = mask & (offs_b < LEN_B)
val = tl.load(b_ptr, mask=b_mask, other=0)
tl.store(dst, val, mask=mask)
else:
# mixed part
a_offs = offs_n
a_mask = (a_offs < seq_len_a) & (a_offs < LEN_A)
a_ptr = src_a_ptr + idx_a * a_stride_m + a_offs * a_stride_n
a_val = tl.load(a_ptr, mask=a_mask, other=0)
b_offs = offs_n - seq_len_a
b_mask = (b_offs >= 0) & (b_offs < seq_len_b) & (b_offs < LEN_B)
b_ptr = src_b_ptr + idx_b * b_stride_m + b_offs * b_stride_n
b_val = tl.load(b_ptr, mask=b_mask, other=0)
result = tl.where(offs_n < seq_len_a, a_val, b_val)
tl.store(dst, result, mask=mask)
def prepare_swa_spec_page_table_triton(
page_table_dst: torch.Tensor,
page_table_a: torch.Tensor,
page_table_b: torch.Tensor, # expand page table
seq_len_a: torch.Tensor,
seq_len_b: torch.Tensor, # expand seq lens
speculative_num_draft_tokens: int,
):
# concat page_table and expand page_table by kv seq length
bs = seq_len_a.numel()
bs_expand = seq_len_b.numel()
assert bs_expand == bs * speculative_num_draft_tokens
LEN_A = page_table_a.shape[1]
LEN_B = page_table_b.shape[1]
LEN_OUT = LEN_A + LEN_B
REPEAT_STEP = speculative_num_draft_tokens
BLOCK_N = 256
grid = (bs_expand, triton.cdiv(LEN_OUT, BLOCK_N))
_prepare_swa_spec_page_table_kernel[grid](
page_table_dst,
page_table_a,
page_table_b,
seq_len_a,
seq_len_b,
page_table_dst.stride(0),
page_table_dst.stride(1),
page_table_a.stride(0),
page_table_a.stride(1),
page_table_b.stride(0),
page_table_b.stride(1),
LEN_A=LEN_A,
LEN_B=LEN_B,
REPEAT_STEP=REPEAT_STEP,
BLOCK_N=BLOCK_N,
num_warps=4,
)
class FlashAttentionMultiStepBackend: class FlashAttentionMultiStepBackend:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment