You need to sign in or sign up before continuing.
Unverified Commit ca8d02ab authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

FA3 Spec Decoding to support top k = 1 and add cuda graph support (#5050)


Co-authored-by: default avatarQingquan Song <ustcsqq@gmail.com>
Co-authored-by: default avatarChunan Zeng <zcnrex@gmail.com>
parent 3f287b85
...@@ -27,19 +27,42 @@ from sgl_kernel.flash_attn import flash_attn_with_kvcache ...@@ -27,19 +27,42 @@ from sgl_kernel.flash_attn import flash_attn_with_kvcache
@dataclass @dataclass
class FlashAttentionMetadata: class FlashAttentionMetadata:
"""Metadata for decode operations to avoid redundant computations.""" """Metadata to be init once in the model forward pass,
each layer's forward pass can reuse the metadata."""
# Cumulative sequence lengths for query
cu_seqlens_q: torch.Tensor = None cu_seqlens_q: torch.Tensor = None
# Cumulative sequence lengths for key
cu_seqlens_k: torch.Tensor = None cu_seqlens_k: torch.Tensor = None
# Maximum sequence length for query
max_seq_len_q: int = 0 max_seq_len_q: int = 0
# Maximum sequence length for key
max_seq_len_k: int = 0 max_seq_len_k: int = 0
# Window size (typically used by Gemma)
window_size: tuple = (-1, -1) window_size: tuple = (-1, -1)
# Page table, the index of KV Cache Tables/Blocks
page_table: torch.Tensor = None page_table: torch.Tensor = None
# Sequence lengths for the forward batch
cache_seqlens_int32: torch.Tensor = None cache_seqlens_int32: torch.Tensor = None
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
"""FlashAttention backend implementation.""" """FlashAttention backend implementation.
Note about the init:
- If no spec decoding
- FlashAttentionBackend will be init once when the server starts.
- If spec decoding
- FlashAttentionBackend will be init once for the target worker
- FlashAttentionMultiStepBackend will be once for the draft worker
- It will spawn num_steps FlashAttentionBackend for the draft worker
Note about CUDA Graph:
- We only support CUDA Graph for Decode (Normal Decode and Draft Decode) and Target Verify.
- We don't support CUDA Graph for Extend and Draft Extend.
- When server init, init_cuda_graph_state will be called first and then init_cuda_graph_capture will be called.
- For each forward batch, init_replay_cuda_graph will be called first and then replay the graph.
"""
def __init__( def __init__(
self, self,
...@@ -56,41 +79,42 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -56,41 +79,42 @@ class FlashAttentionBackend(AttentionBackend):
and model_runner.model_config.is_encoder_decoder and model_runner.model_config.is_encoder_decoder
), "Sliding window and cross attention are not supported together" ), "Sliding window and cross attention are not supported together"
# Initialize metadata
self.forward_metadata: FlashAttentionMetadata = None self.forward_metadata: FlashAttentionMetadata = None
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device self.device = model_runner.device
self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
self.target_verify_metadata = {}
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.page_size = model_runner.page_size self.page_size = model_runner.page_size
self.use_mla = ( self.use_mla = (
model_runner.model_config.attention_arch == AttentionArch.MLA model_runner.model_config.attention_arch == AttentionArch.MLA
) and (not global_server_args_dict["disable_mla"]) ) and (not global_server_args_dict["disable_mla"])
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
self.topk = topk
self.speculative_num_steps = speculative_num_steps # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
assert (
topk <= 1
), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
self.topk = 1
self.step_id = step_id self.step_id = step_id
self.speculative_num_steps = speculative_num_steps
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata to cache repetitive calculations.""" """Initialize forward metadata to cache repetitive calculations."""
# Create metadata based on forward mode
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()
# Get sequence information
seqlens_in_batch = forward_batch.seq_lens seqlens_in_batch = forward_batch.seq_lens
# Precompute int32 version of sequence lengths
batch_size = len(seqlens_in_batch) batch_size = len(seqlens_in_batch)
device = seqlens_in_batch.device device = seqlens_in_batch.device
if forward_batch.forward_mode.is_decode():
if forward_batch.forward_mode == ForwardMode.DECODE: # Skip Prefill or Draft Decode
if self.skip_prefill: # Note: Draft Decode will be ran on the Draft Worker
if forward_batch.spec_info is not None:
metadata.cu_seqlens_q = torch.arange( metadata.cu_seqlens_q = torch.arange(
0, batch_size * self.topk + 1, dtype=torch.int32, device=device 0, batch_size + 1, dtype=torch.int32, device=device
) )
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1) seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
metadata.cache_seqlens_int32 = ( metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
(seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32)
)
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum( torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
...@@ -103,86 +127,58 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -103,86 +127,58 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
metadata.page_table = metadata.page_table.repeat_interleave(
self.topk, dim=0
)
cache_loc = forward_batch.out_cache_loc.view( cache_loc = forward_batch.out_cache_loc.view(
self.speculative_num_steps, -1 self.speculative_num_steps, -1
).T ).T
# Calculate page table indices and cache location indices to update the page table.
batch_indices = torch.arange(
batch_size, device=device
).repeat_interleave(self.topk * (self.step_id + 1))
topk_indices = torch.arange(self.topk, device=device).repeat(
batch_size * (self.step_id + 1)
)
row_indices = batch_indices * self.topk + topk_indices
page_table_col_base_indices = seqlens_in_batch.unsqueeze(
1
) + torch.arange(self.step_id + 1, device=device)
page_table_col_indices = page_table_col_base_indices.view(-1).repeat(
self.topk
)
cache_loc_col_indices = torch.arange(
self.step_id + 1, device=device, dtype=torch.int32
).repeat(batch_size * self.topk)
metadata.page_table[row_indices, page_table_col_indices] = cache_loc[ for idx, single_seq_len in enumerate(seq_lens_with_decode):
row_indices, cache_loc_col_indices real_bsz_start_idx = idx
].to(torch.int32) real_bsz_end_idx = idx + 1
else: metadata.page_table[
real_bsz_start_idx:real_bsz_end_idx,
(single_seq_len - (self.step_id + 1)) : single_seq_len,
] = cache_loc[
real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
]
else: # Normal Decode without Spec Decoding
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
) )
# Precompute maximum sequence length
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
# Precompute page table
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
metadata.cu_seqlens_q = torch.arange( metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device 0, batch_size + 1, dtype=torch.int32, device=device
) )
elif forward_batch.forward_mode == ForwardMode.TARGET_VERIFY: elif forward_batch.forward_mode.is_target_verify():
# Note: Target Verify will be ran on the Target Worker
draft_token_num = forward_batch.spec_info.draft_token_num draft_token_num = forward_batch.spec_info.draft_token_num
metadata.cache_seqlens_int32 = (
metadata.cu_seqlens_q = torch.arange( forward_batch.seq_lens + draft_token_num
0, batch_size * draft_token_num + 1, dtype=torch.int32, device=device ).to(torch.int32)
metadata.max_seq_len_q = draft_token_num
metadata.max_seq_len_k = (
forward_batch.seq_lens_cpu.max().item() + draft_token_num
) )
metadata.cu_seqlens_q = torch.arange(
aug_seq_lens = (forward_batch.seq_lens + draft_token_num).to(torch.int32) 0,
metadata.cache_seqlens_int32 = aug_seq_lens.repeat_interleave( batch_size * draft_token_num + 1,
forward_batch.spec_info.draft_token_num draft_token_num,
dtype=torch.int32,
device=device,
) )
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32), torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
(1, 0), (1, 0),
) )
metadata.max_seq_len_k = (
forward_batch.seq_lens_cpu.max().item() + draft_token_num
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
].repeat_interleave(draft_token_num, dim=0) ]
aug_cum_len = torch.nn.functional.pad(
torch.cumsum(aug_seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
for idx, single_seq_len in enumerate(aug_seq_lens):
metadata.page_table[
idx * draft_token_num : (idx + 1) * draft_token_num, :single_seq_len
] *= forward_batch.spec_info.custom_mask[
aug_cum_len[idx]
* draft_token_num : aug_cum_len[idx + 1]
* draft_token_num
].view(
draft_token_num, -1
)
metadata.max_seq_len_q = 1 elif forward_batch.forward_mode.is_extend_or_draft_extend():
else: # Normal or Draft Extend (Both of them will be ran on the Target Worker)
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
...@@ -208,7 +204,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -208,7 +204,6 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_q = metadata.max_seq_len_k metadata.max_seq_len_q = metadata.max_seq_len_k
# Precompute strided indices # Precompute strided indices
# [0, page_size, 2 * page_size, ...]
if self.page_size > 1: if self.page_size > 1:
self.strided_indices = torch.arange( self.strided_indices = torch.arange(
0, metadata.page_table.shape[1], self.page_size, device=self.device 0, metadata.page_table.shape[1], self.page_size, device=self.device
...@@ -227,7 +222,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -227,7 +222,6 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
): ):
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
...@@ -262,7 +256,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -262,7 +256,7 @@ class FlashAttentionBackend(AttentionBackend):
page_table = metadata.page_table page_table = metadata.page_table
# # Use Flash Attention for prefill # Use Flash Attention for prefill
if not self.use_mla: if not self.use_mla:
# Do multi-head attention # Do multi-head attention
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
...@@ -368,7 +362,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -368,7 +362,6 @@ class FlashAttentionBackend(AttentionBackend):
if layer.sliding_window_size is not None if layer.sliding_window_size is not None
else (-1, -1) else (-1, -1)
) )
page_table = metadata.page_table page_table = metadata.page_table
if not self.use_mla: if not self.use_mla:
...@@ -437,7 +430,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -437,7 +430,6 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=layer.k_scale, k_descale=layer.k_scale,
v_descale=layer.v_scale, v_descale=layer.v_scale,
) )
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int):
...@@ -449,11 +441,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -449,11 +441,6 @@ class FlashAttentionBackend(AttentionBackend):
This creates fixed-size tensors that will be reused during CUDA graph replay This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations. to avoid memory allocations.
""" """
if self.speculative_num_steps > 0:
raise NotImplementedError(
"FlashAttentionBackend Spec Decoding does not support CUDA graph yet, stay tuned!"
)
self.decode_cuda_graph_metadata = { self.decode_cuda_graph_metadata = {
# Page table for token mapping (batch_size, max_context_len) # Page table for token mapping (batch_size, max_context_len)
"page_table": torch.zeros( "page_table": torch.zeros(
...@@ -462,6 +449,39 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -462,6 +449,39 @@ class FlashAttentionBackend(AttentionBackend):
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
), ),
"page_table_draft_decode": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.arange(
0, max_bs + 128, dtype=torch.int32, device=self.device
),
"cu_seqlens_k": torch.zeros(
max_bs + 128, dtype=torch.int32, device=self.device
),
}
self.target_verify_metadata = {
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.zeros(
max_bs + 128, dtype=torch.int32, device=self.device
),
"cu_seqlens_k": torch.zeros(
max_bs + 128, dtype=torch.int32, device=self.device
),
"max_seqlen_q": 0,
"strided_indices": torch.arange( "strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device 0, self.max_context_len, self.page_size, device=self.device
), ),
...@@ -479,27 +499,89 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -479,27 +499,89 @@ class FlashAttentionBackend(AttentionBackend):
): ):
"""Initialize forward metadata for capturing CUDA graph.""" """Initialize forward metadata for capturing CUDA graph."""
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()
# Get sequence information
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
batch_size = len(seq_lens)
device = seq_lens.device device = seq_lens.device
metadata.cu_seqlens_k = torch.nn.functional.pad( if forward_mode.is_decode():
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0) if spec_info is not None:
) # Draft Decode
# Precompute maximum sequence length metadata.cu_seqlens_q = torch.arange(
metadata.max_seq_len_k = seq_lens.max().item() 0, bs + 1, dtype=torch.int32, device=device
# Precompute page table )
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][ metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
req_pool_indices, : "cache_seqlens"
] ][:bs]
if forward_mode.is_cuda_graph():
# Precompute cumulative sequence lengths metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
metadata.cu_seqlens_q = torch.arange( : bs + 1
0, batch_size + 1, dtype=torch.int32, device=device ]
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
metadata.page_table = self.decode_cuda_graph_metadata[
"page_table_draft_decode"
][req_pool_indices, :]
else:
# Normal Decode
# Get sequence information
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
batch_size = len(seq_lens)
device = seq_lens.device
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
# Precompute maximum sequence length
metadata.max_seq_len_k = seq_lens.max().item()
# Precompute page table
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
req_pool_indices, :
]
# Precompute cumulative sequence lengths
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
self.decode_cuda_graph_metadata[bs] = metadata
elif forward_mode.is_target_verify():
draft_token_num = spec_info.draft_token_num
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
:bs
]
metadata.cache_seqlens_int32.copy_(
(seq_lens + draft_token_num).to(torch.int32)
) )
else:
raise ValueError("Do not support Prefill Mode cuda graph") metadata.max_seq_len_q = draft_token_num
self.decode_cuda_graph_metadata[bs] = metadata metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num
metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][
torch.arange(
0,
bs * draft_token_num + 1,
draft_token_num,
dtype=torch.int32,
device=device,
)
]
cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
cu_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
)
metadata.cu_seqlens_k = cu_k
metadata.page_table = self.target_verify_metadata["page_table"][
req_pool_indices, :
]
self.target_verify_metadata[bs] = metadata
self.forward_metadata = metadata self.forward_metadata = metadata
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
...@@ -512,28 +594,91 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -512,28 +594,91 @@ class FlashAttentionBackend(AttentionBackend):
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: torch.Tensor = None,
): ):
# """Initialize forward metadata for replaying CUDA graph.""" # """Initialize forward metadata for replaying CUDA graph."""
metadata = self.decode_cuda_graph_metadata[bs] device = seq_lens.device
seq_lens = seq_lens[:bs]
req_pool_indices = req_pool_indices[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
if forward_mode.is_decode():
metadata = self.decode_cuda_graph_metadata[bs]
if spec_info is not None:
# Draft Decode
max_len = seq_lens_cpu.max().item()
metadata.max_seq_len_k = max_len + (self.step_id + 1)
metadata.cache_seqlens_int32.copy_(
(seq_lens + (self.step_id + 1)).to(torch.int32)
)
# For CPU operations metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1)
max_len = seq_lens_cpu[:bs].max().item()
metadata.max_seq_len_k = max_len
# For GPU operations metadata.cu_seqlens_k.copy_(
seq_lens_in_batch = seq_lens[:bs] torch.nn.functional.pad(
metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32) torch.cumsum(
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0) ),
) (1, 0),
)
)
page_table = self.req_to_token[
req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
else:
# Normal Decode
max_len = seq_lens_cpu.max().item()
metadata.max_seq_len_k = max_len
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
:,
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
]
page_indices = page_indices[req_pool_indices] // self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0)
elif forward_mode.is_target_verify():
metadata = self.target_verify_metadata[bs]
draft_token_num = spec_info.draft_token_num
metadata.cu_seqlens_q.copy_(
torch.arange(
0,
bs * draft_token_num + 1,
draft_token_num,
dtype=torch.int32,
device=device,
)
)
metadata.cache_seqlens_int32.copy_(
(seq_lens + draft_token_num).to(torch.int32)
)
metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num
metadata.cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
)
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size
page_indices = self.req_to_token[
:, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
]
page_indices = page_indices[req_pool_indices[:bs]] // self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0)
self.forward_metadata = metadata self.forward_metadata = metadata
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
...@@ -555,7 +700,6 @@ class FlashAttentionMultiStepBackend: ...@@ -555,7 +700,6 @@ class FlashAttentionMultiStepBackend:
self.attn_backends.append( self.attn_backends.append(
FlashAttentionBackend( FlashAttentionBackend(
model_runner, model_runner,
skip_prefill=True,
topk=self.topk, topk=self.topk,
speculative_num_steps=self.speculative_num_steps, speculative_num_steps=self.speculative_num_steps,
step_id=i, step_id=i,
...@@ -570,7 +714,10 @@ class FlashAttentionMultiStepBackend: ...@@ -570,7 +714,10 @@ class FlashAttentionMultiStepBackend:
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs) self.attn_backends[i].init_cuda_graph_state(max_bs)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): def init_forward_metadata_capture_cuda_graph(
self,
forward_batch: ForwardBatch,
):
assert forward_batch.spec_info is not None assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput) assert isinstance(forward_batch.spec_info, EagleDraftInput)
...@@ -601,4 +748,5 @@ class FlashAttentionMultiStepBackend: ...@@ -601,4 +748,5 @@ class FlashAttentionMultiStepBackend:
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info, spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu, seq_lens_cpu=forward_batch.seq_lens_cpu,
out_cache_loc=forward_batch.out_cache_loc,
) )
...@@ -104,6 +104,9 @@ class ForwardMode(IntEnum): ...@@ -104,6 +104,9 @@ class ForwardMode(IntEnum):
or self == ForwardMode.IDLE or self == ForwardMode.IDLE
) )
def is_extend_or_draft_extend(self):
return self == ForwardMode.EXTEND or self == ForwardMode.DRAFT_EXTEND
def is_dummy_first(self): def is_dummy_first(self):
return self == ForwardMode.DUMMY_FIRST return self == ForwardMode.DUMMY_FIRST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment