Unverified Commit 59911195 authored by jacky.cheng's avatar jacky.cheng Committed by GitHub
Browse files

[Fix] Resolve performance drop in speculative decoding aiter backend (#11087)

parent 424591d5
......@@ -619,7 +619,11 @@ class AiterAttnBackend(AttentionBackend):
assert len(k.shape) == 3
assert len(v.shape) == 3
if forward_batch.forward_mode.is_extend():
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
if kv_indices.shape[0] == 0:
o = flash_attn_varlen_func(
q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment