Unverified Commit a2faf894 authored by b8zhong's avatar b8zhong Committed by GitHub
Browse files

[1/n] Enable DCA CUDA graph capture (#9537)

parent 7e61737d
...@@ -1537,7 +1537,7 @@ class DualChunkFlashAttentionBackend(AttentionBackend): ...@@ -1537,7 +1537,7 @@ class DualChunkFlashAttentionBackend(AttentionBackend):
query_inter, query_inter,
key_cache, key_cache,
value_cache, value_cache,
block_table[:, : decode_meta.max_seq_len_inter], block_table,
decode_meta.seq_lens_inter, decode_meta.seq_lens_inter,
softmax_scale, softmax_scale,
causal=False, causal=False,
......
...@@ -878,10 +878,9 @@ class ServerArgs: ...@@ -878,10 +878,9 @@ class ServerArgs:
if self.attention_backend == "dual_chunk_flash_attn": if self.attention_backend == "dual_chunk_flash_attn":
logger.warning( logger.warning(
"Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend" "Mixed chunk and radix cache are disabled when using dual-chunk flash attention backend"
) )
self.enable_mixed_chunk = False self.enable_mixed_chunk = False
self.disable_cuda_graph = True
self.disable_radix_cache = True self.disable_radix_cache = True
def _handle_page_size(self): def _handle_page_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment