Unverified Commit ab6f3487 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[PD] Change kv_load_failure_policy Default from "recompute" to "fail" (#34896)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 8dc8a99b
......@@ -197,8 +197,8 @@ For multi-host DP deployment, only need to provide the host/port of the head ins
The `kv_load_failure_policy` setting controls how the system handles failures when the decoder instance loads KV cache blocks from the prefiller instance:
- **fail** (recommended): Immediately fail the request with an error when KV load fails. This prevents performance degradation by avoiding recomputation of prefill work on the decode instance.
- **recompute** (default): Recompute failed blocks locally on the decode instance. This may cause performance _jitter_ on decode instances as the scheduled prefill will delay and interfere with other decodes. Furthermore, decode instances are typically configured with low-latency optimizations.
- **fail** (default): Immediately fail the request with an error when KV load fails. This prevents performance degradation by avoiding recomputation of prefill work on the decode instance.
- **recompute**: Recompute failed blocks locally on the decode instance. This may cause performance _jitter_ on decode instances as the scheduled prefill will delay and interfere with other decodes. Furthermore, decode instances are typically configured with low-latency optimizations.
!!! warning
Using `kv_load_failure_policy="recompute"` can lead to performance degradation in production deployments. When KV loads fail, the decode instance will execute prefill work with decode-optimized configurations, which is inefficient and defeats the purpose of disaggregated prefilling. This also increases tail latency for other ongoing decode requests.
......
......@@ -42,6 +42,7 @@ def main():
"async_load": args.async_load,
},
kv_connector_module_path="load_recovery_example_connector",
kv_load_failure_policy="recompute",
)
out_file = (
"async_decode_recovered_output.txt"
......
......@@ -30,7 +30,7 @@ def _make_get_num_new_matched_tokens(
@pytest.fixture
def scheduler():
vllm_config = create_vllm_config()
vllm_config = create_vllm_config(kv_load_failure_policy="recompute")
return create_scheduler(vllm_config)
......
......@@ -5,7 +5,7 @@ from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
from itertools import chain, count
from typing import Any
from typing import Any, Literal
import torch
......@@ -96,6 +96,7 @@ def create_vllm_config(
cache_dtype: str = "auto",
hf_overrides: dict[str, Any] | None = None,
attention_backend: str | None = None,
kv_load_failure_policy: Literal["recompute", "fail"] = "fail",
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
model_config = ModelConfig(
......@@ -125,6 +126,7 @@ def create_vllm_config(
kv_role="kv_both",
enable_permute_local_kv=enable_permute_local_kv,
kv_connector_extra_config=kv_connector_extra_config or {},
kv_load_failure_policy=kv_load_failure_policy,
)
attention_config = AttentionConfig(backend=attention_backend)
return VllmConfig(
......
......@@ -61,10 +61,10 @@ class KVTransferConfig:
enable_permute_local_kv: bool = False
"""Experiment feature flag to enable HND to NHD KV Transfer"""
kv_load_failure_policy: Literal["recompute", "fail"] = "recompute"
kv_load_failure_policy: Literal["recompute", "fail"] = "fail"
"""Policy for handling KV cache load failures.
'recompute': reschedule the request to recompute failed blocks (default)
'fail': immediately fail the request with an error finish reason"""
'recompute': reschedule the request to recompute failed blocks
'fail': immediately fail the request with an error finish reason (default)"""
def compute_hash(self) -> str:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment