Unverified Commit 8a5e0e2b authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Bugfix][Core] Fix CPU memory leak from Request reference cycle in prefix caching (#34183)


Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
parent 4cde2e01
...@@ -236,7 +236,7 @@ def test_prefix_caching_for_multi_turn(): ...@@ -236,7 +236,7 @@ def test_prefix_caching_for_multi_turn():
req._all_token_ids = req.prompt_token_ids.copy() req._all_token_ids = req.prompt_token_ids.copy()
req.all_token_ids = ConstantList(req._all_token_ids) req.all_token_ids = ConstantList(req._all_token_ids)
req.block_hashes = [] req.block_hashes = []
req.block_hashes = req.get_hash_new_full_blocks() req.update_block_hashes()
# Schedule the next-turn requests. # Schedule the next-turn requests.
for req in next_turn_requests: for req in next_turn_requests:
......
...@@ -982,10 +982,8 @@ class Scheduler(SchedulerInterface): ...@@ -982,10 +982,8 @@ class Scheduler(SchedulerInterface):
session._all_token_ids.extend(update.prompt_token_ids or ()) session._all_token_ids.extend(update.prompt_token_ids or ())
session.prompt_token_ids.extend(update.prompt_token_ids or ()) session.prompt_token_ids.extend(update.prompt_token_ids or ())
# Update block hashes for the new tokens # Update block hashes for the new tokens.
# (mirrors Request.append_output_token_ids) session.update_block_hashes()
if session.get_hash_new_full_blocks is not None:
session.block_hashes.extend(session.get_hash_new_full_blocks())
session.num_prompt_tokens = len(session.prompt_token_ids) session.num_prompt_tokens = len(session.prompt_token_ids)
session.arrival_time = update.arrival_time session.arrival_time = update.arrival_time
session.sampling_params = update.sampling_params session.sampling_params = update.sampling_params
......
...@@ -6,7 +6,6 @@ import time ...@@ -6,7 +6,6 @@ import time
from collections import deque from collections import deque
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import torch import torch
...@@ -164,10 +163,11 @@ class Request: ...@@ -164,10 +163,11 @@ class Request:
self.num_external_computed_tokens = 0 self.num_external_computed_tokens = 0
self.block_hashes: list[BlockHash] = [] self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None # Store the block hasher without binding self to avoid creating a
if block_hasher is not None: # reference cycle (Request -> partial -> Request) that prevents
self.get_hash_new_full_blocks = partial(block_hasher, self) # immediate garbage collection via reference counting.
self.block_hashes = self.get_hash_new_full_blocks() self._block_hasher: Callable[[Request], list[BlockHash]] | None = block_hasher
self.update_block_hashes()
self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache() self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()
...@@ -212,8 +212,12 @@ class Request: ...@@ -212,8 +212,12 @@ class Request:
self._output_token_ids.extend(token_ids) self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids)
if self.get_hash_new_full_blocks is not None: self.update_block_hashes()
self.block_hashes.extend(self.get_hash_new_full_blocks())
def update_block_hashes(self) -> None:
"""Compute block hashes for any new full blocks and append them."""
if self._block_hasher is not None:
self.block_hashes.extend(self._block_hasher(self))
@property @property
def use_structured_output(self) -> bool: def use_structured_output(self) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment