Commit b2e1fc35 authored by Roger Wang's avatar Roger Wang Committed by khluu
Browse files

[Bugfix][Core] Fix CPU memory leak from Request reference cycle in prefix caching (#34183)


Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
(cherry picked from commit 8a5e0e2b)
parent 55a1baeb
......@@ -236,7 +236,7 @@ def test_prefix_caching_for_multi_turn():
req._all_token_ids = req.prompt_token_ids.copy()
req.all_token_ids = ConstantList(req._all_token_ids)
req.block_hashes = []
req.block_hashes = req.get_hash_new_full_blocks()
req.update_block_hashes()
# Schedule the next-turn requests.
for req in next_turn_requests:
......
......@@ -982,10 +982,8 @@ class Scheduler(SchedulerInterface):
session._all_token_ids.extend(update.prompt_token_ids or ())
session.prompt_token_ids.extend(update.prompt_token_ids or ())
# Update block hashes for the new tokens
# (mirrors Request.append_output_token_ids)
if session.get_hash_new_full_blocks is not None:
session.block_hashes.extend(session.get_hash_new_full_blocks())
# Update block hashes for the new tokens.
session.update_block_hashes()
session.num_prompt_tokens = len(session.prompt_token_ids)
session.arrival_time = update.arrival_time
session.sampling_params = update.sampling_params
......
......@@ -6,7 +6,6 @@ import time
from collections import deque
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any
import torch
......@@ -164,10 +163,11 @@ class Request:
self.num_external_computed_tokens = 0
self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
if block_hasher is not None:
self.get_hash_new_full_blocks = partial(block_hasher, self)
self.block_hashes = self.get_hash_new_full_blocks()
# Store the block hasher without binding self to avoid creating a
# reference cycle (Request -> partial -> Request) that prevents
# immediate garbage collection via reference counting.
self._block_hasher: Callable[[Request], list[BlockHash]] | None = block_hasher
self.update_block_hashes()
self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()
......@@ -212,8 +212,12 @@ class Request:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
if self.get_hash_new_full_blocks is not None:
self.block_hashes.extend(self.get_hash_new_full_blocks())
self.update_block_hashes()
def update_block_hashes(self) -> None:
"""Compute block hashes for any new full blocks and append them."""
if self._block_hasher is not None:
self.block_hashes.extend(self._block_hasher(self))
@property
def use_structured_output(self) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment