Unverified Commit 948b01a0 authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

[Refactor] Remove Hicache Load & Write threads (#10127)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent cdc56ef6
......@@ -18,7 +18,7 @@ import math
import threading
import time
from queue import Empty, Full, PriorityQueue, Queue
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
import torch
......@@ -43,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
logger = logging.getLogger(__name__)
class LayerLoadingEvent:
def __init__(self, num_layers: int):
self._num_layers = num_layers
self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
self.start_event = torch.cuda.Event() # start event on controller stream
def complete(self, layer_index: int):
assert 0 <= layer_index < self._num_layers
self.load_events[layer_index].record()
def wait(self, layer_index: int):
torch.cuda.current_stream().wait_event(self.load_events[layer_index])
@property
def finish_event(self):
return self.load_events[-1]
class LayerDoneCounter:
def __init__(self, num_layers):
def __init__(self, num_layers: int):
self.num_layers = num_layers
# extra producer and consumer counters for overlap mode
self.num_counters = 3
self.counters = [num_layers] * self.num_counters
self.conditions = [threading.Condition() for _ in range(self.num_counters)]
self.producer_index = 0
self.consumer_index = 0
def next_producer(self):
return (self.producer_index + 1) % self.num_counters
self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
self.producer_index = -1
self.consumer_index = -1
def update_producer(self):
self.producer_index = self.next_producer()
self.producer_index = (self.producer_index + 1) % self.num_counters
assert self.events[
self.producer_index
].finish_event.query(), (
"Producer finish event should be ready before being reused."
)
return self.producer_index
def set_consumer(self, index):
def set_consumer(self, index: int):
self.consumer_index = index
def increment(self):
with self.conditions[self.producer_index]:
self.counters[self.producer_index] += 1
self.conditions[self.producer_index].notify_all()
def wait_until(self, threshold):
with self.conditions[self.consumer_index]:
while self.counters[self.consumer_index] <= threshold:
self.conditions[self.consumer_index].wait()
def wait_until(self, threshold: int):
if self.consumer_index < 0:
return
self.events[self.consumer_index].wait(threshold)
def reset(self):
with self.conditions[self.producer_index]:
self.counters[self.producer_index] = 0
self.producer_index = -1
self.consumer_index = -1
class CacheOperation:
......@@ -99,36 +113,30 @@ class CacheOperation:
# default priority is the order of creation
self.priority = priority if priority is not None else self.id
def merge(self, other: "CacheOperation") -> None:
# multiple operations can be merged into a single operation for batch processing
self.host_indices = torch.cat([self.host_indices, other.host_indices])
self.device_indices = torch.cat([self.device_indices, other.device_indices])
self.priority = min(self.priority, other.priority)
self.node_ids.extend(other.node_ids)
def split(self, factor) -> List["CacheOperation"]:
# split an operation into smaller operations to reduce the size of intermediate buffers
if factor <= 1:
return [self]
chunk_size = math.ceil(len(self.host_indices) / factor)
split_ops = []
for i in range(0, len(self.host_indices), chunk_size):
split_ops.append(
CacheOperation(
host_indices=self.host_indices[i : i + chunk_size],
device_indices=self.device_indices[i : i + chunk_size],
node_id=0,
)
)
# Inherit the node_ids on the final chunk
if split_ops:
split_ops[-1].node_ids = self.node_ids
@staticmethod
def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
assert len(ops) > 0
if len(ops) == 1:
return ops[0]
host_indices = torch.cat([op.host_indices for op in ops])
device_indices = torch.cat([op.device_indices for op in ops])
node_ids = []
priority = min(op.priority for op in ops)
for op in ops:
node_ids.extend(op.node_ids)
merged_op = CacheOperation(host_indices, device_indices, -1, priority)
merged_op.node_ids = node_ids
return merged_op
def __lt__(self, other: CacheOperation):
return self.priority < other.priority
return split_ops
def __lt__(self, other: "CacheOperation"):
return self.priority < other.priority
class HiCacheAck(NamedTuple):
start_event: torch.cuda.Event
finish_event: torch.cuda.Event
node_ids: List[int]
class TransferBuffer:
......@@ -236,7 +244,7 @@ class HiCacheController:
mem_pool_host: HostKVCache,
page_size: int,
tp_group: torch.distributed.ProcessGroup,
load_cache_event: threading.Event = None,
load_cache_event: threading.Event,
write_policy: str = "write_through_selective",
io_backend: str = "",
storage_backend: Optional[str] = None,
......@@ -340,8 +348,9 @@ class HiCacheController:
self.page_set_func = self._3fs_zero_copy_page_set
self.batch_exists_func = self._3fs_zero_copy_batch_exists
self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
self.device = self.mem_pool_device.device
self.layer_num = self.mem_pool_device.layer_num
self.layer_done_counter = LayerDoneCounter(self.layer_num)
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
if write_policy not in [
......@@ -351,11 +360,11 @@ class HiCacheController:
]:
raise ValueError(f"Invalid write policy: {write_policy}")
self.write_queue = PriorityQueue()
self.load_queue = PriorityQueue()
self.ack_write_queue = Queue()
self.ack_load_queue = Queue()
# self.write_queue = PriorityQueue[CacheOperation]()
self.load_queue: List[CacheOperation] = []
self.write_queue: List[CacheOperation] = []
self.ack_load_queue: List[HiCacheAck] = []
self.ack_write_queue: List[HiCacheAck] = []
self.stop_event = threading.Event()
self.write_buffer = TransferBuffer(self.stop_event)
......@@ -366,16 +375,6 @@ class HiCacheController:
self.write_stream = torch.cuda.Stream()
self.load_stream = torch.cuda.Stream()
self.write_thread = threading.Thread(
target=self.write_thread_func_direct, daemon=True
)
self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True
)
self.write_thread.start()
self.load_thread.start()
if self.enable_storage:
self.prefetch_thread = threading.Thread(
target=self.prefetch_thread_func, daemon=True
......@@ -432,15 +431,13 @@ class HiCacheController:
def reset(self):
self.stop_event.set()
self.write_thread.join()
self.load_thread.join()
self.write_queue.queue.clear()
self.load_queue.queue.clear()
self.write_queue.clear()
self.load_queue.clear()
self.write_buffer.clear()
self.load_buffer.clear()
self.ack_write_queue.queue.clear()
self.ack_load_queue.queue.clear()
self.ack_write_queue.clear()
self.ack_load_queue.clear()
if self.enable_storage:
self.prefetch_thread.join()
self.backup_thread.join()
......@@ -449,15 +446,7 @@ class HiCacheController:
self.prefetch_revoke_queue.queue.clear()
self.ack_backup_queue.queue.clear()
self.write_thread = threading.Thread(
target=self.write_thread_func_direct, daemon=True
)
self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True
)
self.stop_event.clear()
self.write_thread.start()
self.load_thread.start()
if self.enable_storage:
self.prefetch_thread = threading.Thread(
......@@ -473,7 +462,7 @@ class HiCacheController:
self,
device_indices: torch.Tensor,
priority: Optional[int] = None,
node_id: int = 0,
node_id: int = -1,
) -> Optional[torch.Tensor]:
"""
Back up KV caches from device memory to host memory.
......@@ -482,17 +471,46 @@ class HiCacheController:
if host_indices is None:
return None
self.mem_pool_host.protect_write(host_indices)
torch.cuda.current_stream().synchronize()
self.write_queue.put(
self.write_queue.append(
CacheOperation(host_indices, device_indices, node_id, priority)
)
self.start_writing()
return host_indices
def start_writing(self) -> None:
if len(self.write_queue) == 0:
return
op = CacheOperation.merge_ops(self.write_queue)
host_indices, device_indices = self.move_indices(op)
self.write_queue.clear()
start_event = torch.cuda.Event()
finish_event = torch.cuda.Event()
start_event.record()
with torch.cuda.stream(self.write_stream):
start_event.wait(self.write_stream)
self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_device, host_indices, device_indices, self.io_backend
)
self.mem_pool_host.complete_io(op.host_indices)
finish_event.record()
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# still alive when the write stream is executing.
if host_indices.is_cuda:
host_indices.record_stream(self.write_stream)
if device_indices.is_cuda:
device_indices.record_stream(self.write_stream)
self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
def load(
self,
host_indices: torch.Tensor,
priority: Optional[int] = None,
node_id: int = 0,
node_id: int = -1,
) -> Optional[torch.Tensor]:
"""
Load KV caches from host memory to device memory.
......@@ -501,17 +519,18 @@ class HiCacheController:
if device_indices is None:
return None
self.mem_pool_host.protect_load(host_indices)
# to ensure the device indices are ready before accessed by another CUDA stream
torch.cuda.current_stream().synchronize()
self.load_queue.put(
self.load_queue.append(
CacheOperation(host_indices, device_indices, node_id, priority)
)
return device_indices
def move_indices(self, host_indices, device_indices):
def move_indices(self, op: CacheOperation):
host_indices, device_indices = op.host_indices, op.device_indices
# move indices to GPU if using kernels, to host if using direct indexing
if self.io_backend == "kernel":
return host_indices.to(self.mem_pool_device.device), device_indices
if not host_indices.is_cuda:
host_indices = host_indices.to(self.device, non_blocking=True)
return host_indices, device_indices
elif self.io_backend == "direct":
device_indices = device_indices.cpu()
host_indices, idx = host_indices.sort()
......@@ -519,58 +538,20 @@ class HiCacheController:
else:
raise ValueError(f"Unsupported io backend")
def write_thread_func_direct(self):
"""
Directly write through KV caches to host memory without buffering.
"""
torch.cuda.set_stream(self.write_stream)
while not self.stop_event.is_set():
try:
operation = self.write_queue.get(block=True, timeout=1)
host_indices, device_indices = self.move_indices(
operation.host_indices, operation.device_indices
)
self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_device, host_indices, device_indices, self.io_backend
)
self.write_stream.synchronize()
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
if node_id != 0:
self.ack_write_queue.put(node_id)
except Empty:
continue
except Exception as e:
logger.error(e)
def start_loading(self) -> int:
if len(self.load_queue) == 0:
return -1
def load_thread_func_layer_by_layer(self):
"""
Load KV caches from host memory to device memory layer by layer.
"""
torch.cuda.set_stream(self.load_stream)
while not self.stop_event.is_set():
self.load_cache_event.wait(timeout=1)
if not self.load_cache_event.is_set():
continue
self.load_cache_event.clear()
self.layer_done_counter.update_producer()
batch_operation = None
while self.load_queue.qsize() > 0:
op = self.load_queue.get(block=True)
if batch_operation is None:
batch_operation = op
else:
batch_operation.merge(op)
if batch_operation is None:
continue
producer_id = self.layer_done_counter.update_producer()
op = CacheOperation.merge_ops(self.load_queue)
host_indices, device_indices = self.move_indices(op)
self.load_queue.clear()
producer_event = self.layer_done_counter.events[producer_id]
producer_event.start_event.record()
# start layer-wise KV cache transfer from CPU to GPU
self.layer_done_counter.reset()
host_indices, device_indices = self.move_indices(
batch_operation.host_indices, batch_operation.device_indices
)
for i in range(self.mem_pool_host.layer_num):
with torch.cuda.stream(self.load_stream):
producer_event.start_event.wait(self.load_stream)
for i in range(self.layer_num):
self.mem_pool_host.load_to_device_per_layer(
self.mem_pool_device,
host_indices,
......@@ -578,13 +559,24 @@ class HiCacheController:
i,
self.io_backend,
)
self.load_stream.synchronize()
self.layer_done_counter.increment()
self.mem_pool_host.complete_io(batch_operation.host_indices)
for node_id in batch_operation.node_ids:
if node_id != 0:
self.ack_load_queue.put(node_id)
producer_event.complete(i)
self.mem_pool_host.complete_io(op.host_indices)
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# still alive when the load stream is executing.
if host_indices.is_cuda:
host_indices.record_stream(self.load_stream)
if device_indices.is_cuda:
device_indices.record_stream(self.load_stream)
self.ack_load_queue.append(
HiCacheAck(
start_event=producer_event.start_event,
finish_event=producer_event.finish_event,
node_ids=op.node_ids,
)
)
return producer_id
def evict_device(
self, device_indices: torch.Tensor, host_indices: torch.Tensor
......
......@@ -911,7 +911,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
is_prefill_only: bool = False
# hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index: int = 0
hicache_consumer_index: int = -1
@classmethod
def init_new(
......@@ -1897,7 +1897,7 @@ class ModelWorkerBatch:
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = 0
hicache_consumer_index: int = -1
# Overlap event
launch_done: Optional[threading.Event] = None
......
......@@ -1807,10 +1807,6 @@ class Scheduler(
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
# update the consumer index of hicache to the running batch
self.tp_worker.set_hicache_consumer(
model_worker_batch.hicache_consumer_index
)
if self.pp_group.is_last_rank:
logits_output, next_token_ids, can_run_cuda_graph = (
self.tp_worker.forward_batch_generation(model_worker_batch)
......
......@@ -12,10 +12,11 @@
# limitations under the License.
# ==============================================================================
"""A tensor parallel worker."""
from __future__ import annotations
import logging
import threading
from typing import Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch
......@@ -45,6 +46,9 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
logger = logging.getLogger(__name__)
......@@ -167,10 +171,10 @@ class TpModelWorker:
self.hicache_layer_transfer_counter = None
def register_hicache_layer_transfer_counter(self, counter):
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
self.hicache_layer_transfer_counter = counter
def set_hicache_consumer(self, consumer_index):
def set_hicache_consumer(self, consumer_index: int):
if self.hicache_layer_transfer_counter is not None:
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
......@@ -230,6 +234,9 @@ class TpModelWorker:
) -> Tuple[
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
]:
# update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
pp_proxy_tensors = None
......
......@@ -12,13 +12,14 @@
# limitations under the License.
# ==============================================================================
"""A tensor parallel worker."""
from __future__ import annotations
import dataclasses
import logging
import signal
import threading
from queue import Queue
from typing import Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
import psutil
import torch
......@@ -38,6 +39,9 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import DynamicGradMode, get_compiler_backend
from sglang.utils import get_exception_traceback
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
logger = logging.getLogger(__name__)
......@@ -79,7 +83,7 @@ class TpModelWorkerClient:
)
# Launch threads
self.input_queue = Queue()
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
self.output_queue = Queue()
self.forward_stream = torch.get_device_module(self.device).Stream()
self.forward_thread = threading.Thread(
......@@ -93,13 +97,9 @@ class TpModelWorkerClient:
self.hicache_layer_transfer_counter = None
def register_hicache_layer_transfer_counter(self, counter):
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
self.hicache_layer_transfer_counter = counter
def set_hicache_consumer(self, consumer_index):
if self.hicache_layer_transfer_counter is not None:
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
def get_worker_info(self):
return self.worker.get_worker_info()
......@@ -147,7 +147,7 @@ class TpModelWorkerClient:
@DynamicGradMode()
def forward_thread_func_(self):
batch_pt = 0
batch_lists = [None] * 2
batch_lists: List = [None] * 2
while True:
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
......@@ -169,8 +169,6 @@ class TpModelWorkerClient:
input_ids = model_worker_batch.input_ids
resolve_future_token_ids(input_ids, self.future_token_ids_map)
# update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
# Run forward
logits_output, next_token_ids, can_run_cuda_graph = (
self.worker.forward_batch_generation(
......
......@@ -201,41 +201,57 @@ class HiRadixCache(RadixCache):
if write_back:
# blocking till all write back complete
while len(self.ongoing_write_through) > 0:
ack_id = self.cache_controller.ack_write_queue.get()
del self.ongoing_write_through[ack_id]
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
finish_event.synchronize()
for ack_id in ack_list:
del self.ongoing_write_through[ack_id]
self.cache_controller.ack_write_queue.clear()
assert len(self.ongoing_write_through) == 0
return
queue_size = torch.tensor(
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
)
# NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
if len(self.ongoing_write_through) == 0:
return
finish_count = 0
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
if not finish_event.query():
break
finish_count += 1
queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to radix cache
# synchronize TP workers to make the same update to radix cache
torch.distributed.all_reduce(
queue_size,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
for _ in range(queue_size.item()):
ack_id = self.cache_controller.ack_write_queue.get()
backuped_node = self.ongoing_write_through[ack_id]
self.dec_lock_ref(backuped_node)
del self.ongoing_write_through[ack_id]
if self.enable_storage:
self.write_backup_storage(backuped_node)
finish_count = int(queue_size.item())
while finish_count > 0:
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
finish_event.synchronize()
for ack_id in ack_list:
backuped_node = self.ongoing_write_through.pop(ack_id)
self.dec_lock_ref(backuped_node)
if self.enable_storage:
self.write_backup_storage(backuped_node)
finish_count -= 1
def loading_check(self):
while not self.cache_controller.ack_load_queue.empty():
try:
ack_id = self.cache_controller.ack_load_queue.get_nowait()
start_node, end_node = self.ongoing_load_back[ack_id]
self.dec_lock_ref(end_node)
while end_node != start_node:
assert end_node.loading
end_node.loading = False
end_node = end_node.parent
# clear the reference
del self.ongoing_load_back[ack_id]
except Exception:
finish_count = 0
for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
if not finish_event.query():
# the KV cache loading is still ongoing
break
finish_count += 1
# no need to sync across TP workers as batch forwarding is synced
for ack_id in ack_list:
end_node = self.ongoing_load_back.pop(ack_id)
self.dec_lock_ref(end_node)
# ACK until all events are processed
del self.cache_controller.ack_load_queue[:finish_count]
def evictable_size(self):
return self.evictable_size_
......@@ -360,12 +376,11 @@ class HiRadixCache(RadixCache):
# no sufficient GPU memory to load back KV caches
return None
self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
self.ongoing_load_back[last_hit_node.id] = last_hit_node
offset = 0
for node in nodes_to_load:
node.value = device_indices[offset : offset + len(node.host_value)]
offset += len(node.host_value)
node.loading = True
self.evictable_size_ += len(device_indices)
self.inc_lock_ref(last_hit_node)
......@@ -394,10 +409,12 @@ class HiRadixCache(RadixCache):
last_node,
)
def ready_to_load_host_cache(self):
producer_index = self.cache_controller.layer_done_counter.next_producer()
self.load_cache_event.set()
return producer_index
def ready_to_load_host_cache(self) -> int:
"""
Notify the cache controller to start the KV cache loading.
Return the consumer index for the schedule batch manager to track.
"""
return self.cache_controller.start_loading()
def check_hicache_events(self):
self.writing_check()
......@@ -702,7 +719,6 @@ class HiRadixCache(RadixCache):
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
new_node.loading = child.loading
new_node.hit_count = child.hit_count
# split value and host value if exists
......
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import annotations
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
"""
......@@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache.
import abc
import logging
from contextlib import nullcontext
from typing import Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -38,6 +40,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024
......@@ -175,7 +180,7 @@ class KVCache(abc.ABC):
) -> None:
raise NotImplementedError()
def register_layer_transfer_counter(self, layer_transfer_counter):
def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
self.layer_transfer_counter = layer_transfer_counter
def get_cpu_copy(self, indices):
......
......@@ -3,6 +3,7 @@ import logging
import threading
from enum import IntEnum
from functools import wraps
from typing import Optional
import psutil
import torch
......@@ -169,7 +170,7 @@ class HostKVCache(abc.ABC):
return len(self.free_slots)
@synchronized()
def alloc(self, need_size: int) -> torch.Tensor:
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
assert (
need_size % self.page_size == 0
), "The requested size should be a multiple of the page size."
......
......@@ -53,8 +53,6 @@ class TreeNode:
self.last_access_time = time.monotonic()
self.hit_count = 0
# indicating the node is loading KV cache from host
self.loading = False
# indicating the node is locked to protect from eviction
# incremented when the node is referenced by a storage operation
self.host_ref_counter = 0
......
......@@ -60,8 +60,6 @@ class TreeNode:
self.last_access_time = time.monotonic()
self.hit_count = 0
# indicating the node is loading KV cache from host
self.loading = False
# store the host indices of KV cache
self.host_value = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment