Unverified Commit ca8d02ab authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

FA3 Spec Decoding to support top k = 1 and add cuda graph support (#5050)


Co-authored-by: default avatarQingquan Song <ustcsqq@gmail.com>
Co-authored-by: default avatarChunan Zeng <zcnrex@gmail.com>
parent 3f287b85
...@@ -27,19 +27,42 @@ from sgl_kernel.flash_attn import flash_attn_with_kvcache ...@@ -27,19 +27,42 @@ from sgl_kernel.flash_attn import flash_attn_with_kvcache
@dataclass @dataclass
class FlashAttentionMetadata: class FlashAttentionMetadata:
"""Metadata for decode operations to avoid redundant computations.""" """Metadata to be init once in the model forward pass,
each layer's forward pass can reuse the metadata."""
# Cumulative sequence lengths for query
cu_seqlens_q: torch.Tensor = None cu_seqlens_q: torch.Tensor = None
# Cumulative sequence lengths for key
cu_seqlens_k: torch.Tensor = None cu_seqlens_k: torch.Tensor = None
# Maximum sequence length for query
max_seq_len_q: int = 0 max_seq_len_q: int = 0
# Maximum sequence length for key
max_seq_len_k: int = 0 max_seq_len_k: int = 0
# Window size (typically used by Gemma)
window_size: tuple = (-1, -1) window_size: tuple = (-1, -1)
# Page table, the index of KV Cache Tables/Blocks
page_table: torch.Tensor = None page_table: torch.Tensor = None
# Sequence lengths for the forward batch
cache_seqlens_int32: torch.Tensor = None cache_seqlens_int32: torch.Tensor = None
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
"""FlashAttention backend implementation.""" """FlashAttention backend implementation.
Note about the init:
- If no spec decoding
- FlashAttentionBackend will be init once when the server starts.
- If spec decoding
- FlashAttentionBackend will be init once for the target worker
- FlashAttentionMultiStepBackend will be once for the draft worker
- It will spawn num_steps FlashAttentionBackend for the draft worker
Note about CUDA Graph:
- We only support CUDA Graph for Decode (Normal Decode and Draft Decode) and Target Verify.
- We don't support CUDA Graph for Extend and Draft Extend.
- When server init, init_cuda_graph_state will be called first and then init_cuda_graph_capture will be called.
- For each forward batch, init_replay_cuda_graph will be called first and then replay the graph.
"""
def __init__( def __init__(
self, self,
...@@ -56,41 +79,42 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -56,41 +79,42 @@ class FlashAttentionBackend(AttentionBackend):
and model_runner.model_config.is_encoder_decoder and model_runner.model_config.is_encoder_decoder
), "Sliding window and cross attention are not supported together" ), "Sliding window and cross attention are not supported together"
# Initialize metadata
self.forward_metadata: FlashAttentionMetadata = None self.forward_metadata: FlashAttentionMetadata = None
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device self.device = model_runner.device
self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
self.target_verify_metadata = {}
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.page_size = model_runner.page_size self.page_size = model_runner.page_size
self.use_mla = ( self.use_mla = (
model_runner.model_config.attention_arch == AttentionArch.MLA model_runner.model_config.attention_arch == AttentionArch.MLA
) and (not global_server_args_dict["disable_mla"]) ) and (not global_server_args_dict["disable_mla"])
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
self.topk = topk
self.speculative_num_steps = speculative_num_steps # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
assert (
topk <= 1
), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
self.topk = 1
self.step_id = step_id self.step_id = step_id
self.speculative_num_steps = speculative_num_steps
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata to cache repetitive calculations.""" """Initialize forward metadata to cache repetitive calculations."""
# Create metadata based on forward mode
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()
# Get sequence information
seqlens_in_batch = forward_batch.seq_lens seqlens_in_batch = forward_batch.seq_lens
# Precompute int32 version of sequence lengths
batch_size = len(seqlens_in_batch) batch_size = len(seqlens_in_batch)
device = seqlens_in_batch.device device = seqlens_in_batch.device
if forward_batch.forward_mode.is_decode():
if forward_batch.forward_mode == ForwardMode.DECODE: # Skip Prefill or Draft Decode
if self.skip_prefill: # Note: Draft Decode will be ran on the Draft Worker
if forward_batch.spec_info is not None:
metadata.cu_seqlens_q = torch.arange( metadata.cu_seqlens_q = torch.arange(
0, batch_size * self.topk + 1, dtype=torch.int32, device=device 0, batch_size + 1, dtype=torch.int32, device=device
) )
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1) seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
metadata.cache_seqlens_int32 = ( metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
(seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32)
)
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum( torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
...@@ -103,86 +127,58 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -103,86 +127,58 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
metadata.page_table = metadata.page_table.repeat_interleave(
self.topk, dim=0
)
cache_loc = forward_batch.out_cache_loc.view( cache_loc = forward_batch.out_cache_loc.view(
self.speculative_num_steps, -1 self.speculative_num_steps, -1
).T ).T
# Calculate page table indices and cache location indices to update the page table.
batch_indices = torch.arange(
batch_size, device=device
).repeat_interleave(self.topk * (self.step_id + 1))
topk_indices = torch.arange(self.topk, device=device).repeat(
batch_size * (self.step_id + 1)
)
row_indices = batch_indices * self.topk + topk_indices
page_table_col_base_indices = seqlens_in_batch.unsqueeze(
1
) + torch.arange(self.step_id + 1, device=device)
page_table_col_indices = page_table_col_base_indices.view(-1).repeat(
self.topk
)
cache_loc_col_indices = torch.arange(
self.step_id + 1, device=device, dtype=torch.int32
).repeat(batch_size * self.topk)
metadata.page_table[row_indices, page_table_col_indices] = cache_loc[ for idx, single_seq_len in enumerate(seq_lens_with_decode):
row_indices, cache_loc_col_indices real_bsz_start_idx = idx
].to(torch.int32) real_bsz_end_idx = idx + 1
else: metadata.page_table[
real_bsz_start_idx:real_bsz_end_idx,
(single_seq_len - (self.step_id + 1)) : single_seq_len,
] = cache_loc[
real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
]
else: # Normal Decode without Spec Decoding
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
) )
# Precompute maximum sequence length
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
# Precompute page table
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
metadata.cu_seqlens_q = torch.arange( metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device 0, batch_size + 1, dtype=torch.int32, device=device
) )
elif forward_batch.forward_mode == ForwardMode.TARGET_VERIFY: elif forward_batch.forward_mode.is_target_verify():
# Note: Target Verify will be ran on the Target Worker
draft_token_num = forward_batch.spec_info.draft_token_num draft_token_num = forward_batch.spec_info.draft_token_num
metadata.cache_seqlens_int32 = (
metadata.cu_seqlens_q = torch.arange( forward_batch.seq_lens + draft_token_num
0, batch_size * draft_token_num + 1, dtype=torch.int32, device=device ).to(torch.int32)
metadata.max_seq_len_q = draft_token_num
metadata.max_seq_len_k = (
forward_batch.seq_lens_cpu.max().item() + draft_token_num
) )
metadata.cu_seqlens_q = torch.arange(
aug_seq_lens = (forward_batch.seq_lens + draft_token_num).to(torch.int32) 0,
metadata.cache_seqlens_int32 = aug_seq_lens.repeat_interleave( batch_size * draft_token_num + 1,
forward_batch.spec_info.draft_token_num draft_token_num,
dtype=torch.int32,
device=device,
) )
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32), torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
(1, 0), (1, 0),
) )
metadata.max_seq_len_k = (
forward_batch.seq_lens_cpu.max().item() + draft_token_num
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
].repeat_interleave(draft_token_num, dim=0) ]
aug_cum_len = torch.nn.functional.pad(
torch.cumsum(aug_seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
for idx, single_seq_len in enumerate(aug_seq_lens):
metadata.page_table[
idx * draft_token_num : (idx + 1) * draft_token_num, :single_seq_len
] *= forward_batch.spec_info.custom_mask[
aug_cum_len[idx]
* draft_token_num : aug_cum_len[idx + 1]
* draft_token_num
].view(
draft_token_num, -1
)
metadata.max_seq_len_q = 1 elif forward_batch.forward_mode.is_extend_or_draft_extend():
else: # Normal or Draft Extend (Both of them will be ran on the Target Worker)
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
...@@ -208,7 +204,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -208,7 +204,6 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_q = metadata.max_seq_len_k metadata.max_seq_len_q = metadata.max_seq_len_k
# Precompute strided indices # Precompute strided indices
# [0, page_size, 2 * page_size, ...]
if self.page_size > 1: if self.page_size > 1:
self.strided_indices = torch.arange( self.strided_indices = torch.arange(
0, metadata.page_table.shape[1], self.page_size, device=self.device 0, metadata.page_table.shape[1], self.page_size, device=self.device
...@@ -227,7 +222,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -227,7 +222,6 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
): ):
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
...@@ -262,7 +256,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -262,7 +256,7 @@ class FlashAttentionBackend(AttentionBackend):
page_table = metadata.page_table page_table = metadata.page_table
# # Use Flash Attention for prefill # Use Flash Attention for prefill
if not self.use_mla: if not self.use_mla:
# Do multi-head attention # Do multi-head attention
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
...@@ -368,7 +362,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -368,7 +362,6 @@ class FlashAttentionBackend(AttentionBackend):
if layer.sliding_window_size is not None if layer.sliding_window_size is not None
else (-1, -1) else (-1, -1)
) )
page_table = metadata.page_table page_table = metadata.page_table
if not self.use_mla: if not self.use_mla:
...@@ -437,7 +430,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -437,7 +430,6 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=layer.k_scale, k_descale=layer.k_scale,
v_descale=layer.v_scale, v_descale=layer.v_scale,
) )
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int):
...@@ -449,11 +441,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -449,11 +441,6 @@ class FlashAttentionBackend(AttentionBackend):
This creates fixed-size tensors that will be reused during CUDA graph replay This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations. to avoid memory allocations.
""" """
if self.speculative_num_steps > 0:
raise NotImplementedError(
"FlashAttentionBackend Spec Decoding does not support CUDA graph yet, stay tuned!"
)
self.decode_cuda_graph_metadata = { self.decode_cuda_graph_metadata = {
# Page table for token mapping (batch_size, max_context_len) # Page table for token mapping (batch_size, max_context_len)
"page_table": torch.zeros( "page_table": torch.zeros(
...@@ -462,6 +449,39 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -462,6 +449,39 @@ class FlashAttentionBackend(AttentionBackend):
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
), ),
"page_table_draft_decode": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.arange(
0, max_bs + 128, dtype=torch.int32, device=self.device
),
"cu_seqlens_k": torch.zeros(
max_bs + 128, dtype=torch.int32, device=self.device
),
}
self.target_verify_metadata = {
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.zeros(
max_bs + 128, dtype=torch.int32, device=self.device
),
"cu_seqlens_k": torch.zeros(
max_bs + 128, dtype=torch.int32, device=self.device
),
"max_seqlen_q": 0,
"strided_indices": torch.arange( "strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device 0, self.max_context_len, self.page_size, device=self.device
), ),
...@@ -479,27 +499,89 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -479,27 +499,89 @@ class FlashAttentionBackend(AttentionBackend):
): ):
"""Initialize forward metadata for capturing CUDA graph.""" """Initialize forward metadata for capturing CUDA graph."""
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()
# Get sequence information
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
batch_size = len(seq_lens)
device = seq_lens.device device = seq_lens.device
metadata.cu_seqlens_k = torch.nn.functional.pad( if forward_mode.is_decode():
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0) if spec_info is not None:
) # Draft Decode
# Precompute maximum sequence length metadata.cu_seqlens_q = torch.arange(
metadata.max_seq_len_k = seq_lens.max().item() 0, bs + 1, dtype=torch.int32, device=device
# Precompute page table )
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][ metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
req_pool_indices, : "cache_seqlens"
] ][:bs]
if forward_mode.is_cuda_graph():
# Precompute cumulative sequence lengths metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
metadata.cu_seqlens_q = torch.arange( : bs + 1
0, batch_size + 1, dtype=torch.int32, device=device ]
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
metadata.page_table = self.decode_cuda_graph_metadata[
"page_table_draft_decode"
][req_pool_indices, :]
else:
# Normal Decode
# Get sequence information
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
batch_size = len(seq_lens)
device = seq_lens.device
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
# Precompute maximum sequence length
metadata.max_seq_len_k = seq_lens.max().item()
# Precompute page table
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
req_pool_indices, :
]
# Precompute cumulative sequence lengths
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
self.decode_cuda_graph_metadata[bs] = metadata
elif forward_mode.is_target_verify():
draft_token_num = spec_info.draft_token_num
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
:bs
]
metadata.cache_seqlens_int32.copy_(
(seq_lens + draft_token_num).to(torch.int32)
) )
else:
raise ValueError("Do not support Prefill Mode cuda graph") metadata.max_seq_len_q = draft_token_num
self.decode_cuda_graph_metadata[bs] = metadata metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num
metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][
torch.arange(
0,
bs * draft_token_num + 1,
draft_token_num,
dtype=torch.int32,
device=device,
)
]
cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
cu_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
)
metadata.cu_seqlens_k = cu_k
metadata.page_table = self.target_verify_metadata["page_table"][
req_pool_indices, :
]
self.target_verify_metadata[bs] = metadata
self.forward_metadata = metadata self.forward_metadata = metadata
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
...@@ -512,28 +594,91 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -512,28 +594,91 @@ class FlashAttentionBackend(AttentionBackend):
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: torch.Tensor = None,
): ):
# """Initialize forward metadata for replaying CUDA graph.""" # """Initialize forward metadata for replaying CUDA graph."""
metadata = self.decode_cuda_graph_metadata[bs] device = seq_lens.device
seq_lens = seq_lens[:bs]
req_pool_indices = req_pool_indices[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
if forward_mode.is_decode():
metadata = self.decode_cuda_graph_metadata[bs]
if spec_info is not None:
# Draft Decode
max_len = seq_lens_cpu.max().item()
metadata.max_seq_len_k = max_len + (self.step_id + 1)
metadata.cache_seqlens_int32.copy_(
(seq_lens + (self.step_id + 1)).to(torch.int32)
)
# For CPU operations metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1)
max_len = seq_lens_cpu[:bs].max().item()
metadata.max_seq_len_k = max_len
# For GPU operations metadata.cu_seqlens_k.copy_(
seq_lens_in_batch = seq_lens[:bs] torch.nn.functional.pad(
metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32) torch.cumsum(
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0) ),
) (1, 0),
)
)
page_table = self.req_to_token[
req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
else:
# Normal Decode
max_len = seq_lens_cpu.max().item()
metadata.max_seq_len_k = max_len
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
:,
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
]
page_indices = page_indices[req_pool_indices] // self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0)
elif forward_mode.is_target_verify():
metadata = self.target_verify_metadata[bs]
draft_token_num = spec_info.draft_token_num
metadata.cu_seqlens_q.copy_(
torch.arange(
0,
bs * draft_token_num + 1,
draft_token_num,
dtype=torch.int32,
device=device,
)
)
metadata.cache_seqlens_int32.copy_(
(seq_lens + draft_token_num).to(torch.int32)
)
metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num
metadata.cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
)
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size
page_indices = self.req_to_token[
:, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
]
page_indices = page_indices[req_pool_indices[:bs]] // self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0)
self.forward_metadata = metadata self.forward_metadata = metadata
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
...@@ -555,7 +700,6 @@ class FlashAttentionMultiStepBackend: ...@@ -555,7 +700,6 @@ class FlashAttentionMultiStepBackend:
self.attn_backends.append( self.attn_backends.append(
FlashAttentionBackend( FlashAttentionBackend(
model_runner, model_runner,
skip_prefill=True,
topk=self.topk, topk=self.topk,
speculative_num_steps=self.speculative_num_steps, speculative_num_steps=self.speculative_num_steps,
step_id=i, step_id=i,
...@@ -570,7 +714,10 @@ class FlashAttentionMultiStepBackend: ...@@ -570,7 +714,10 @@ class FlashAttentionMultiStepBackend:
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs) self.attn_backends[i].init_cuda_graph_state(max_bs)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): def init_forward_metadata_capture_cuda_graph(
self,
forward_batch: ForwardBatch,
):
assert forward_batch.spec_info is not None assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput) assert isinstance(forward_batch.spec_info, EagleDraftInput)
...@@ -601,4 +748,5 @@ class FlashAttentionMultiStepBackend: ...@@ -601,4 +748,5 @@ class FlashAttentionMultiStepBackend:
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info, spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu, seq_lens_cpu=forward_batch.seq_lens_cpu,
out_cache_loc=forward_batch.out_cache_loc,
) )
...@@ -104,6 +104,9 @@ class ForwardMode(IntEnum): ...@@ -104,6 +104,9 @@ class ForwardMode(IntEnum):
or self == ForwardMode.IDLE or self == ForwardMode.IDLE
) )
def is_extend_or_draft_extend(self):
return self == ForwardMode.EXTEND or self == ForwardMode.DRAFT_EXTEND
def is_dummy_first(self): def is_dummy_first(self):
return self == ForwardMode.DUMMY_FIRST return self == ForwardMode.DUMMY_FIRST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment