Unverified Commit 90532b76 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[Fix] Fix raw_bs bug when using flashinfer mla and eagle (#4557)

parent c0e9a36c
......@@ -52,6 +52,9 @@ class EAGLEDraftCudaGraphRunner:
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
0
].get_cuda_graph_seq_len_fill_value()
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
if self.enable_torch_compile:
set_torch_compile_config()
......@@ -210,6 +213,12 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
forward_batch.positions = self.positions[:num_tokens]
# Special handle for seq_len_cpu used when flashinfer mla is used
if (forward_batch.decode_seq_lens_cpu is not None) and (bs != raw_bs):
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:bs]
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch, bs
)
......@@ -224,5 +233,7 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.positions = self.positions[:raw_num_token]
forward_batch.seq_lens = self.seq_lens[:raw_bs]
forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
if forward_batch.decode_seq_lens_cpu is not None:
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment