Unverified Commit b39c266d authored by omerpaz95's avatar omerpaz95 Committed by GitHub
Browse files

[KV Offload] Offload all KV blocks when doing prefill in P/D (#40346)


Signed-off-by: default avataromerpaz95 <omerpaz95@gmail.com>
Signed-off-by: default avataromerpaz95 <73347585+omerpaz95@users.noreply.github.com>
Co-authored-by: default avatarOr Ozeri <or@ozery.com>
parent 9558f439
......@@ -478,3 +478,59 @@ class TestSlidingWindowLookup:
sched._sliding_window_lookup(to_keys([1, 2, 3, 4]), 2, _EMPTY_REQ_CTX)
is None
)
@pytest.mark.parametrize("async_scheduling", [True, False])
def test_do_remote_decode_stores_all_blocks(request_runner, async_scheduling: bool):
"""With do_remote_decode=True, after loading prefix blocks from CPU,
all blocks must be re-stored — not just the newly computed ones.
This supports P/D disaggregation where the prefill instance offloads the
complete KV cache so a remote decode node can consume it."""
offloaded_block_size = 12
gpu_block_size = 4
num_gpu_blocks = 100
runner = request_runner(
offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks,
async_scheduling=async_scheduling,
)
# Store 1 offloaded block (3 GPU blocks) via a normal request.
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys)
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(0, 1, 2),
)
# Reset GPU prefix cache so the next request must load from CPU.
runner.scheduler.reset_prefix_cache()
# New request with do_remote_decode=True and 2 offloaded blocks.
# The first offloaded block matches what we stored in CPU.
runner.new_request(
token_ids=[0] * offloaded_block_size * 2,
kv_transfer_params={"do_remote_decode": True},
)
runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys)
)
# Load the first offloaded block from CPU.
runner.run(
decoded_tokens=[0],
expected_loaded_gpu_block_indexes=(0, 1, 2),
)
# Store must include ALL 6 GPU blocks (both the loaded prefix and
# the newly computed block), not just the 3 new ones.
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(0, 1, 2, 3, 4, 5),
)
......@@ -270,7 +270,11 @@ class RequestRunner:
slot_mapping={},
)
def new_request(self, token_ids: list[int]):
def new_request(
self,
token_ids: list[int],
kv_transfer_params: dict | None = None,
):
self.req_id += 1
sampling_params = SamplingParams(max_tokens=1000)
......@@ -283,6 +287,8 @@ class RequestRunner:
pooling_params=None,
block_hasher=self._block_hasher,
)
if kv_transfer_params is not None:
req.kv_transfer_params = kv_transfer_params
self.scheduler.add_request(req)
......
......@@ -314,6 +314,9 @@ class OffloadingConnectorScheduler:
num_locally_computed_tokens = req_status.num_locally_computed_tokens
num_cached_tokens = num_locally_computed_tokens + num_external_tokens
params = req_status.req_context.kv_transfer_params
do_remote_decode = params is not None and params.get("do_remote_decode")
keys_to_load: list[OffloadKey] = []
dst_block_ids: list[int] = []
# per group
......@@ -360,6 +363,10 @@ class OffloadingConnectorScheduler:
group_sizes.append(num_pending_gpu_blocks)
block_indices.append(num_locally_computed_gpu_blocks)
if not do_remote_decode:
# For P/D prefill requests (do_remote_decode=True), we do
# NOT skip saving the hit prefix, as we need to stream the
# entire KV cache so a remote decode node can consume it.
group_state.next_stored_block_idx = num_blocks
src_spec = self.manager.prepare_load(keys_to_load, req_status.req_context)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment