Commit a5156371 authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_dev_lzg' into 'v0.5.4_dev'

修复extern报错、精度问题,支持radix cache和chunk prefill

See merge request OpenDAS/sglang!10
parents 46da9556 abe30c35
...@@ -432,11 +432,18 @@ class DCUMLABackend(AttentionBackend): ...@@ -432,11 +432,18 @@ class DCUMLABackend(AttentionBackend):
layer: "RadixAttention", layer: "RadixAttention",
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True, save_kv_cache: bool = True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
sinks=None, sinks=None,
): ):
if ( if save_kv_cache:
return self.forward_decode(q,k,v,layer,forward_batch, save_kv_cache)
if ((
forward_batch.forward_mode == ForwardMode.EXTEND forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND)
): ):
# flash_attn不支持fp8,fp8无法正常执行extend # flash_attn不支持fp8,fp8无法正常执行extend
if not self.skip_prefill: if not self.skip_prefill:
...@@ -444,7 +451,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -444,7 +451,7 @@ class DCUMLABackend(AttentionBackend):
# q, k, v, layer, forward_batch, save_kv_cache, sinks # q, k, v, layer, forward_batch, save_kv_cache, sinks
# ) # )
return self.flashattn_backend.forward_extend( return self.flashattn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, sinks q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope, sinks
) )
else: else:
raise RuntimeError("skip prefill but use forward_extend") raise RuntimeError("skip prefill but use forward_extend")
......
...@@ -1618,7 +1618,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1618,7 +1618,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum() self.seq_lens_sum = self.seq_lens.sum().item()
self.output_ids = self.output_ids[keep_indices_device] self.output_ids = self.output_ids[keep_indices_device]
self.return_logprob = any(req.return_logprob for req in self.reqs) self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob: if self.return_logprob:
......
...@@ -174,6 +174,7 @@ MLA_ATTENTION_BACKENDS = [ ...@@ -174,6 +174,7 @@ MLA_ATTENTION_BACKENDS = [
CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [ CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [
"flashinfer", "flashinfer",
"fa3", "fa3",
"dcu_mla",
"fa4", "fa4",
"flashmla", "flashmla",
"cutlass_mla", "cutlass_mla",
...@@ -2238,7 +2239,6 @@ class ModelRunner: ...@@ -2238,7 +2239,6 @@ class ModelRunner:
and self.graph_runner and self.graph_runner
and self.graph_runner.can_run(forward_batch) and self.graph_runner.can_run(forward_batch)
) )
if can_run_graph: if can_run_graph:
ret = self.graph_runner.replay( ret = self.graph_runner.replay(
forward_batch, forward_batch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment