"tests/python/vscode:/vscode.git/clone" did not exist on "1785acffa60b128b9f415712ace693fc998dea96"
Unverified Commit 93470a14 authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

Refactor and Optimize FA3 Code (#5090)


Co-authored-by: default avatarQingquan Song <ustcsqq@gmail.com>
parent db452760
from __future__ import annotations from __future__ import annotations
import numpy as np
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
"""
Support different attention backends.
Now there are three backends: FlashInfer, Triton and FlashAttention.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
import numpy as np
import torch import torch
from sglang.srt.configs.model_config import AttentionArch from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -30,22 +22,25 @@ from sgl_kernel.flash_attn import flash_attn_with_kvcache ...@@ -30,22 +22,25 @@ from sgl_kernel.flash_attn import flash_attn_with_kvcache
@dataclass @dataclass
class FlashAttentionMetadata: class FlashAttentionMetadata:
"""Metadata to be init once in the model forward pass, """Metadata to be init once in the model forward pass,
each layer's forward pass can reuse the metadata.""" each layer's forward pass can reuse the metadata.
# Cumulative sequence lengths for query For each init metadata function, we will try set up them in below order
cu_seqlens_q: torch.Tensor = None """
# Cumulative sequence lengths for key
cu_seqlens_k: torch.Tensor = None # Sequence lengths for the forward batch
cache_seqlens_int32: torch.Tensor = None
# Maximum sequence length for query # Maximum sequence length for query
max_seq_len_q: int = 0 max_seq_len_q: int = 0
# Maximum sequence length for key # Maximum sequence length for key
max_seq_len_k: int = 0 max_seq_len_k: int = 0
# Cumulative sequence lengths for query
cu_seqlens_q: torch.Tensor = None
# Cumulative sequence lengths for key
cu_seqlens_k: torch.Tensor = None
# Window size (typically used by Gemma) # Window size (typically used by Gemma)
window_size: tuple = (-1, -1) window_size: tuple = (-1, -1)
# Page table, the index of KV Cache Tables/Blocks # Page table, the index of KV Cache Tables/Blocks
page_table: torch.Tensor = None page_table: torch.Tensor = None
# Sequence lengths for the forward batch
cache_seqlens_int32: torch.Tensor = None
@dataclass @dataclass
class LocalAttentionMetadata: class LocalAttentionMetadata:
...@@ -270,9 +265,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -270,9 +265,9 @@ class FlashAttentionBackend(AttentionBackend):
self, self,
model_runner: ModelRunner, model_runner: ModelRunner,
skip_prefill: bool = False, skip_prefill: bool = False,
speculative_step_id=0,
topk=0, topk=0,
speculative_num_steps=0, speculative_num_steps=0,
step_id=0,
): ):
super().__init__() super().__init__()
...@@ -293,14 +288,12 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -293,14 +288,12 @@ class FlashAttentionBackend(AttentionBackend):
) and (not global_server_args_dict["disable_mla"]) ) and (not global_server_args_dict["disable_mla"])
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding self.topk = topk
assert (
topk <= 1
), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
self.topk = 1
self.step_id = step_id
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
self.speculative_num_draft_tokens = (
model_runner.server_args.speculative_num_draft_tokens
)
self.speculative_step_id = speculative_step_id
# Local attention settings # Local attention settings
self.attention_chunk_size = ( self.attention_chunk_size = (
...@@ -310,71 +303,59 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -310,71 +303,59 @@ class FlashAttentionBackend(AttentionBackend):
) )
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata to cache repetitive calculations.""" """Initialize forward metadata hence all layers in the forward pass can reuse it."""
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()
seqlens_in_batch = forward_batch.seq_lens seqlens_in_batch = forward_batch.seq_lens
batch_size = len(seqlens_in_batch) batch_size = len(seqlens_in_batch)
device = seqlens_in_batch.device device = seqlens_in_batch.device
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
# Skip Prefill or Draft Decode # Draft Decode
# Note: Draft Decode will be ran on the Draft Worker
if forward_batch.spec_info is not None: if forward_batch.spec_info is not None:
metadata.cache_seqlens_int32 = (
seqlens_in_batch + (self.speculative_step_id + 1)
).to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_q = torch.arange( metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device 0, batch_size + 1, dtype=torch.int32, device=device
) )
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum( torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
), ),
(1, 0), (1, 0),
) )
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
self.step_id + 1
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
cache_loc = forward_batch.out_cache_loc.view( else:
self.speculative_num_steps, -1 # Normal Decode
).T
for idx, single_seq_len in enumerate(seq_lens_with_decode):
real_bsz_start_idx = idx
real_bsz_end_idx = idx + 1
metadata.page_table[
real_bsz_start_idx:real_bsz_end_idx,
(single_seq_len - (self.step_id + 1)) : single_seq_len,
] = cache_loc[
real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
]
else: # Normal Decode without Spec Decoding
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
) )
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
elif forward_batch.forward_mode.is_target_verify(): elif forward_batch.forward_mode.is_target_verify():
# Note: Target Verify will be ran on the Target Worker
draft_token_num = forward_batch.spec_info.draft_token_num
metadata.cache_seqlens_int32 = ( metadata.cache_seqlens_int32 = (
forward_batch.seq_lens + draft_token_num forward_batch.seq_lens + self.speculative_num_draft_tokens
).to(torch.int32) ).to(torch.int32)
metadata.max_seq_len_q = draft_token_num metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = ( metadata.max_seq_len_k = (
forward_batch.seq_lens_cpu.max().item() + draft_token_num forward_batch.seq_lens_cpu.max().item()
+ self.speculative_num_draft_tokens
) )
metadata.cu_seqlens_q = torch.arange( metadata.cu_seqlens_q = torch.arange(
0, 0,
batch_size * draft_token_num + 1, batch_size * self.speculative_num_draft_tokens + 1,
draft_token_num, self.speculative_num_draft_tokens,
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
...@@ -387,31 +368,27 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -387,31 +368,27 @@ class FlashAttentionBackend(AttentionBackend):
] ]
elif forward_batch.forward_mode.is_extend_or_draft_extend(): elif forward_batch.forward_mode.is_extend_or_draft_extend():
# Normal or Draft Extend (Both of them will be ran on the Target Worker)
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
) )
# Precompute maximum sequence length
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
# Precompute page table
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
# Precompute cumulative sequence lengths
if ( if (
any(forward_batch.extend_prefix_lens_cpu) any(forward_batch.extend_prefix_lens_cpu)
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
): ):
extend_seq_lens = forward_batch.extend_seq_lens extend_seq_lens = forward_batch.extend_seq_lens
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
metadata.cu_seqlens_q = torch.nn.functional.pad( metadata.cu_seqlens_q = torch.nn.functional.pad(
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0) torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
) )
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
else: else:
metadata.cu_seqlens_q = metadata.cu_seqlens_k
metadata.max_seq_len_q = metadata.max_seq_len_k metadata.max_seq_len_q = metadata.max_seq_len_k
metadata.cu_seqlens_q = metadata.cu_seqlens_k
# Setup local attention if enabled # Setup local attention if enabled
if ( if (
...@@ -458,7 +435,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -458,7 +435,7 @@ class FlashAttentionBackend(AttentionBackend):
) )
metadata.local_attn_metadata = local_metadata metadata.local_attn_metadata = local_metadata
# Precompute strided indices # Convert the page table to a strided format which is needed by FA3 API
if self.page_size > 1: if self.page_size > 1:
self.strided_indices = torch.arange( self.strided_indices = torch.arange(
0, metadata.page_table.shape[1], self.page_size, device=self.device 0, metadata.page_table.shape[1], self.page_size, device=self.device
...@@ -498,7 +475,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -498,7 +475,7 @@ class FlashAttentionBackend(AttentionBackend):
v, v,
) )
# Use precomputed metadata # Use precomputed metadata across all layers
metadata = self.forward_metadata metadata = self.forward_metadata
# Calculate window size (can be moved to metadata if layer properties don't change) # Calculate window size (can be moved to metadata if layer properties don't change)
...@@ -606,8 +583,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -606,8 +583,6 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention using precomputed metadata."""
# Save KV cache if needed
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
...@@ -628,7 +603,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -628,7 +603,7 @@ class FlashAttentionBackend(AttentionBackend):
v, v,
) )
# Use precomputed metadata # Use precomputed metadata across all layers
metadata = self.forward_metadata metadata = self.forward_metadata
# Calculate window size (can be moved to metadata if layer properties don't change) # Calculate window size (can be moved to metadata if layer properties don't change)
...@@ -639,12 +614,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -639,12 +614,9 @@ class FlashAttentionBackend(AttentionBackend):
if layer.sliding_window_size is not None if layer.sliding_window_size is not None
else (-1, -1) else (-1, -1)
) )
page_table = metadata.page_table
if not self.use_mla: if not self.use_mla:
# Do multi-head attention # Do multi-head attention
# Get KV cache
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
key_cache, value_cache = kv_cache[0], kv_cache[1] key_cache, value_cache = kv_cache[0], kv_cache[1]
key_cache = key_cache.view( key_cache = key_cache.view(
...@@ -654,13 +626,12 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -654,13 +626,12 @@ class FlashAttentionBackend(AttentionBackend):
-1, self.page_size, layer.tp_v_head_num, layer.head_dim -1, self.page_size, layer.tp_v_head_num, layer.head_dim
) )
# Pre-reshape query tensor
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
o = flash_attn_with_kvcache( o = flash_attn_with_kvcache(
q=q_reshaped, q=q_reshaped,
k_cache=key_cache, k_cache=key_cache,
v_cache=value_cache, v_cache=value_cache,
page_table=page_table, page_table=metadata.page_table,
cache_seqlens=metadata.cache_seqlens_int32, cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k, cu_seqlens_k_new=metadata.cu_seqlens_k,
...@@ -696,7 +667,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -696,7 +667,7 @@ class FlashAttentionBackend(AttentionBackend):
k_cache=k_rope_cache, k_cache=k_rope_cache,
v_cache=c_kv_cache, v_cache=c_kv_cache,
qv=q_nope, qv=q_nope,
page_table=page_table, page_table=metadata.page_table,
cache_seqlens=metadata.cache_seqlens_int32, cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k, cu_seqlens_k_new=metadata.cu_seqlens_k,
...@@ -719,7 +690,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -719,7 +690,13 @@ class FlashAttentionBackend(AttentionBackend):
to avoid memory allocations. to avoid memory allocations.
""" """
self.decode_cuda_graph_metadata = { self.decode_cuda_graph_metadata = {
# Page table for token mapping (batch_size, max_context_len) "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.arange(
0, max_bs + 1, dtype=torch.int32, device=self.device
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros( "page_table": torch.zeros(
max_bs, max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size, (self.max_context_len + self.page_size - 1) // self.page_size,
...@@ -735,30 +712,22 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -735,30 +712,22 @@ class FlashAttentionBackend(AttentionBackend):
"strided_indices": torch.arange( "strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device 0, self.max_context_len, self.page_size, device=self.device
), ),
}
self.target_verify_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.arange( "cu_seqlens_q": torch.zeros(
0, max_bs + 128, dtype=torch.int32, device=self.device max_bs + 1, dtype=torch.int32, device=self.device
), ),
"cu_seqlens_k": torch.zeros( "cu_seqlens_k": torch.zeros(
max_bs + 128, dtype=torch.int32, device=self.device max_bs + 1, dtype=torch.int32, device=self.device
), ),
}
self.target_verify_metadata = {
"page_table": torch.zeros( "page_table": torch.zeros(
max_bs, max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size, (self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
), ),
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.zeros(
max_bs + 128, dtype=torch.int32, device=self.device
),
"cu_seqlens_k": torch.zeros(
max_bs + 128, dtype=torch.int32, device=self.device
),
"max_seqlen_q": 0,
"strided_indices": torch.arange( "strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device 0, self.max_context_len, self.page_size, device=self.device
), ),
...@@ -780,24 +749,21 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -780,24 +749,21 @@ class FlashAttentionBackend(AttentionBackend):
if forward_mode.is_decode(): if forward_mode.is_decode():
if spec_info is not None: if spec_info is not None:
# Draft Decode # Draft Decode
metadata.cu_seqlens_q = torch.arange(
0, bs + 1, dtype=torch.int32, device=device
)
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[ metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
"cache_seqlens" "cache_seqlens"
][:bs] ][:bs]
metadata.max_seq_len_k = seq_lens.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][ metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
: bs + 1 : bs + 1
] ]
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum( torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
), ),
(1, 0), (1, 0),
) )
metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
metadata.page_table = self.decode_cuda_graph_metadata[ metadata.page_table = self.decode_cuda_graph_metadata[
"page_table_draft_decode" "page_table_draft_decode"
][req_pool_indices, :] ][req_pool_indices, :]
...@@ -822,37 +788,30 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -822,37 +788,30 @@ class FlashAttentionBackend(AttentionBackend):
) )
self.decode_cuda_graph_metadata[bs] = metadata self.decode_cuda_graph_metadata[bs] = metadata
elif forward_mode.is_target_verify(): elif forward_mode.is_target_verify():
draft_token_num = spec_info.draft_token_num
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][ metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
:bs :bs
] ]
metadata.cache_seqlens_int32.copy_( metadata.cache_seqlens_int32.copy_(
(seq_lens + draft_token_num).to(torch.int32) (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
) )
metadata.max_seq_len_q = draft_token_num metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num metadata.max_seq_len_k = (
seq_lens.max().item() + self.speculative_num_draft_tokens
)
metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][ metadata.cu_seqlens_q = torch.arange(
torch.arange( 0,
0, bs * self.speculative_num_draft_tokens + 1,
bs * draft_token_num + 1, self.speculative_num_draft_tokens,
draft_token_num, dtype=torch.int32,
dtype=torch.int32, device=device,
device=device,
)
]
cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
cu_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
) )
metadata.cu_seqlens_k = cu_k
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
: (bs + 1)
]
metadata.page_table = self.target_verify_metadata["page_table"][ metadata.page_table = self.target_verify_metadata["page_table"][
req_pool_indices, : req_pool_indices, :
] ]
...@@ -874,24 +833,21 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -874,24 +833,21 @@ class FlashAttentionBackend(AttentionBackend):
out_cache_loc: torch.Tensor = None, out_cache_loc: torch.Tensor = None,
): ):
# """Initialize forward metadata for replaying CUDA graph.""" # """Initialize forward metadata for replaying CUDA graph."""
device = seq_lens.device
seq_lens = seq_lens[:bs] seq_lens = seq_lens[:bs]
req_pool_indices = req_pool_indices[:bs]
seq_lens_cpu = seq_lens_cpu[:bs] seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs]
if forward_mode.is_decode(): if forward_mode.is_decode():
metadata = self.decode_cuda_graph_metadata[bs] metadata = self.decode_cuda_graph_metadata[bs]
if spec_info is not None: if spec_info is not None:
# Draft Decode # Draft Decode
max_len = seq_lens_cpu.max().item()
metadata.max_seq_len_k = max_len + (self.step_id + 1)
metadata.cache_seqlens_int32.copy_( metadata.cache_seqlens_int32.copy_(
(seq_lens + (self.step_id + 1)).to(torch.int32) (seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
) )
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1) metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_k.copy_( metadata.cu_seqlens_k.copy_(
torch.nn.functional.pad( torch.nn.functional.pad(
torch.cumsum( torch.cumsum(
...@@ -929,22 +885,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -929,22 +885,13 @@ class FlashAttentionBackend(AttentionBackend):
elif forward_mode.is_target_verify(): elif forward_mode.is_target_verify():
metadata = self.target_verify_metadata[bs] metadata = self.target_verify_metadata[bs]
draft_token_num = spec_info.draft_token_num
metadata.cu_seqlens_q.copy_(
torch.arange(
0,
bs * draft_token_num + 1,
draft_token_num,
dtype=torch.int32,
device=device,
)
)
metadata.cache_seqlens_int32.copy_( metadata.cache_seqlens_int32.copy_(
(seq_lens + draft_token_num).to(torch.int32) (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
) )
metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num metadata.max_seq_len_k = (
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
)
metadata.cu_seqlens_k.copy_( metadata.cu_seqlens_k.copy_(
torch.nn.functional.pad( torch.nn.functional.pad(
torch.cumsum( torch.cumsum(
...@@ -972,14 +919,19 @@ class FlashAttentionMultiStepBackend: ...@@ -972,14 +919,19 @@ class FlashAttentionMultiStepBackend:
self.topk = topk self.topk = topk
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
assert (
self.topk == 1
), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
self.attn_backends = [] self.attn_backends = []
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends.append( self.attn_backends.append(
FlashAttentionBackend( FlashAttentionBackend(
model_runner, model_runner,
speculative_step_id=i,
topk=self.topk, topk=self.topk,
speculative_num_steps=self.speculative_num_steps, speculative_num_steps=self.speculative_num_steps,
step_id=i,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment