Unverified Commit 59911195 authored by jacky.cheng's avatar jacky.cheng Committed by GitHub
Browse files

[Fix] Resolve performance drop in speculative decoding aiter backend (#11087)

parent 424591d5
...@@ -619,7 +619,11 @@ class AiterAttnBackend(AttentionBackend): ...@@ -619,7 +619,11 @@ class AiterAttnBackend(AttentionBackend):
assert len(k.shape) == 3 assert len(k.shape) == 3
assert len(v.shape) == 3 assert len(v.shape) == 3
if forward_batch.forward_mode.is_extend(): if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
if kv_indices.shape[0] == 0: if kv_indices.shape[0] == 0:
o = flash_attn_varlen_func( o = flash_attn_varlen_func(
q, q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment