Unverified Commit 19bc77f0 authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

[Fix] Fix hicache backend (#8991)

parent 86497d99
...@@ -611,12 +611,7 @@ class Scheduler( ...@@ -611,12 +611,7 @@ class Scheduler(
hicache_ratio=server_args.hicache_ratio, hicache_ratio=server_args.hicache_ratio,
hicache_size=server_args.hicache_size, hicache_size=server_args.hicache_size,
hicache_write_policy=server_args.hicache_write_policy, hicache_write_policy=server_args.hicache_write_policy,
hicache_io_backend=( hicache_io_backend=server_args.hicache_io_backend,
"direct"
if server_args.attention_backend
== "fa3" # hot fix for incompatibility
else server_args.hicache_io_backend
),
hicache_mem_layout=server_args.hicache_mem_layout, hicache_mem_layout=server_args.hicache_mem_layout,
hicache_storage_backend=server_args.hicache_storage_backend, hicache_storage_backend=server_args.hicache_storage_backend,
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy, hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
......
...@@ -403,7 +403,6 @@ class ModelRunner: ...@@ -403,7 +403,6 @@ class ModelRunner:
is_hopper_with_cuda_12_3() is_hopper_with_cuda_12_3()
and is_no_spec_infer_or_topk_one(server_args) and is_no_spec_infer_or_topk_one(server_args)
and is_fa3_default_architecture(self.model_config.hf_config) and is_fa3_default_architecture(self.model_config.hf_config)
and (not server_args.enable_hierarchical_cache)
): ):
server_args.attention_backend = "fa3" server_args.attention_backend = "fa3"
elif _is_hip: elif _is_hip:
...@@ -416,9 +415,7 @@ class ModelRunner: ...@@ -416,9 +415,7 @@ class ModelRunner:
) )
else: else:
# MLA architecture # MLA architecture
if is_hopper_with_cuda_12_3() and ( if is_hopper_with_cuda_12_3():
not server_args.enable_hierarchical_cache
):
server_args.attention_backend = "fa3" server_args.attention_backend = "fa3"
elif is_sm100_supported(): elif is_sm100_supported():
server_args.attention_backend = "flashinfer" server_args.attention_backend = "flashinfer"
...@@ -506,6 +503,27 @@ class ModelRunner: ...@@ -506,6 +503,27 @@ class ModelRunner:
if self.model_config.context_len > 8192: if self.model_config.context_len > 8192:
self.mem_fraction_static *= 0.85 self.mem_fraction_static *= 0.85
if (
server_args.enable_hierarchical_cache
and server_args.hicache_io_backend == "kernel"
):
# fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
if server_args.decode_attention_backend is None:
if not self.use_mla_backend:
server_args.decode_attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton"
)
else:
server_args.decode_attention_backend = (
"flashinfer" if is_sm100_supported() else "triton"
)
elif server_args.decode_attention_backend == "fa3":
server_args.hicache_io_backend = "direct"
logger.warning(
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
)
def init_torch_distributed(self): def init_torch_distributed(self):
logger.info("Init torch distributed begin.") logger.info("Init torch distributed begin.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment