Unverified Commit 7ddf8e83 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[EAGLE] Fix draft kv cache layout for fa3 and topk > 1 (#7239)

parent 8321f8e4
...@@ -406,9 +406,10 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -406,9 +406,10 @@ class FlashAttentionBackend(AttentionBackend):
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
# shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
cache_loc = forward_batch.out_cache_loc.view( cache_loc = forward_batch.out_cache_loc.view(
self.speculative_num_steps, -1 -1, self.speculative_num_steps
).T.contiguous() )
metadata_expand.page_table = ( metadata_expand.page_table = (
cache_loc[:, :decode_length].contiguous().to(torch.int32) cache_loc[:, :decode_length].contiguous().to(torch.int32)
) )
...@@ -1636,9 +1637,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1636,9 +1637,8 @@ class FlashAttentionBackend(AttentionBackend):
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk) # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand = self.draft_decode_metadata_topk_expand[bs] metadata_expand = self.draft_decode_metadata_topk_expand[bs]
decode_length = self.speculative_step_id + 1 decode_length = self.speculative_step_id + 1
cache_loc = out_cache_loc.view( # shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
self.speculative_num_steps, -1 cache_loc = out_cache_loc.view(-1, self.speculative_num_steps)
).T.contiguous()
metadata_expand.page_table[: cache_loc.shape[0]].copy_( metadata_expand.page_table[: cache_loc.shape[0]].copy_(
cache_loc[:, :decode_length] cache_loc[:, :decode_length]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment