Commit abe30c35 authored by lizhigong's avatar lizhigong
Browse files

修复extern报错、精度问题,支持radix cache和chunk prefill

parent 46da9556
......@@ -432,11 +432,18 @@ class DCUMLABackend(AttentionBackend):
layer: "RadixAttention",
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
sinks=None,
):
if (
if save_kv_cache:
return self.forward_decode(q,k,v,layer,forward_batch, save_kv_cache)
if ((
forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND)
):
# flash_attn不支持fp8,fp8无法正常执行extend
if not self.skip_prefill:
......@@ -444,7 +451,7 @@ class DCUMLABackend(AttentionBackend):
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# )
return self.flashattn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, sinks
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope, sinks
)
else:
raise RuntimeError("skip prefill but use forward_extend")
......
......@@ -1618,7 +1618,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum()
self.seq_lens_sum = self.seq_lens.sum().item()
self.output_ids = self.output_ids[keep_indices_device]
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob:
......
......@@ -174,6 +174,7 @@ MLA_ATTENTION_BACKENDS = [
CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [
"flashinfer",
"fa3",
"dcu_mla",
"fa4",
"flashmla",
"cutlass_mla",
......@@ -2238,7 +2239,6 @@ class ModelRunner:
and self.graph_runner
and self.graph_runner.can_run(forward_batch)
)
if can_run_graph:
ret = self.graph_runner.replay(
forward_batch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment