"vscode:/vscode.git/clone" did not exist on "c399de396dbb464be0935f910703eff9f11667ad"
Unverified Commit 948b01a0 authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

[Refactor] Remove Hicache Load & Write threads (#10127)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent cdc56ef6
...@@ -18,7 +18,7 @@ import math ...@@ -18,7 +18,7 @@ import math
import threading import threading
import time import time
from queue import Empty, Full, PriorityQueue, Queue from queue import Empty, Full, PriorityQueue, Queue
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
import torch import torch
...@@ -43,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool ...@@ -43,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LayerLoadingEvent:
def __init__(self, num_layers: int):
self._num_layers = num_layers
self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
self.start_event = torch.cuda.Event() # start event on controller stream
def complete(self, layer_index: int):
assert 0 <= layer_index < self._num_layers
self.load_events[layer_index].record()
def wait(self, layer_index: int):
torch.cuda.current_stream().wait_event(self.load_events[layer_index])
@property
def finish_event(self):
return self.load_events[-1]
class LayerDoneCounter: class LayerDoneCounter:
def __init__(self, num_layers): def __init__(self, num_layers: int):
self.num_layers = num_layers self.num_layers = num_layers
# extra producer and consumer counters for overlap mode # extra producer and consumer counters for overlap mode
self.num_counters = 3 self.num_counters = 3
self.counters = [num_layers] * self.num_counters self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
self.conditions = [threading.Condition() for _ in range(self.num_counters)] self.producer_index = -1
self.producer_index = 0 self.consumer_index = -1
self.consumer_index = 0
def next_producer(self):
return (self.producer_index + 1) % self.num_counters
def update_producer(self): def update_producer(self):
self.producer_index = self.next_producer() self.producer_index = (self.producer_index + 1) % self.num_counters
assert self.events[
self.producer_index
].finish_event.query(), (
"Producer finish event should be ready before being reused."
)
return self.producer_index return self.producer_index
def set_consumer(self, index): def set_consumer(self, index: int):
self.consumer_index = index self.consumer_index = index
def increment(self): def wait_until(self, threshold: int):
with self.conditions[self.producer_index]: if self.consumer_index < 0:
self.counters[self.producer_index] += 1 return
self.conditions[self.producer_index].notify_all() self.events[self.consumer_index].wait(threshold)
def wait_until(self, threshold):
with self.conditions[self.consumer_index]:
while self.counters[self.consumer_index] <= threshold:
self.conditions[self.consumer_index].wait()
def reset(self): def reset(self):
with self.conditions[self.producer_index]: self.producer_index = -1
self.counters[self.producer_index] = 0 self.consumer_index = -1
class CacheOperation: class CacheOperation:
...@@ -99,36 +113,30 @@ class CacheOperation: ...@@ -99,36 +113,30 @@ class CacheOperation:
# default priority is the order of creation # default priority is the order of creation
self.priority = priority if priority is not None else self.id self.priority = priority if priority is not None else self.id
def merge(self, other: "CacheOperation") -> None: @staticmethod
# multiple operations can be merged into a single operation for batch processing def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
self.host_indices = torch.cat([self.host_indices, other.host_indices]) assert len(ops) > 0
self.device_indices = torch.cat([self.device_indices, other.device_indices]) if len(ops) == 1:
self.priority = min(self.priority, other.priority) return ops[0]
self.node_ids.extend(other.node_ids)
host_indices = torch.cat([op.host_indices for op in ops])
def split(self, factor) -> List["CacheOperation"]: device_indices = torch.cat([op.device_indices for op in ops])
# split an operation into smaller operations to reduce the size of intermediate buffers node_ids = []
if factor <= 1: priority = min(op.priority for op in ops)
return [self] for op in ops:
node_ids.extend(op.node_ids)
chunk_size = math.ceil(len(self.host_indices) / factor) merged_op = CacheOperation(host_indices, device_indices, -1, priority)
split_ops = [] merged_op.node_ids = node_ids
for i in range(0, len(self.host_indices), chunk_size): return merged_op
split_ops.append(
CacheOperation( def __lt__(self, other: CacheOperation):
host_indices=self.host_indices[i : i + chunk_size], return self.priority < other.priority
device_indices=self.device_indices[i : i + chunk_size],
node_id=0,
)
)
# Inherit the node_ids on the final chunk
if split_ops:
split_ops[-1].node_ids = self.node_ids
return split_ops
def __lt__(self, other: "CacheOperation"): class HiCacheAck(NamedTuple):
return self.priority < other.priority start_event: torch.cuda.Event
finish_event: torch.cuda.Event
node_ids: List[int]
class TransferBuffer: class TransferBuffer:
...@@ -236,7 +244,7 @@ class HiCacheController: ...@@ -236,7 +244,7 @@ class HiCacheController:
mem_pool_host: HostKVCache, mem_pool_host: HostKVCache,
page_size: int, page_size: int,
tp_group: torch.distributed.ProcessGroup, tp_group: torch.distributed.ProcessGroup,
load_cache_event: threading.Event = None, load_cache_event: threading.Event,
write_policy: str = "write_through_selective", write_policy: str = "write_through_selective",
io_backend: str = "", io_backend: str = "",
storage_backend: Optional[str] = None, storage_backend: Optional[str] = None,
...@@ -340,8 +348,9 @@ class HiCacheController: ...@@ -340,8 +348,9 @@ class HiCacheController:
self.page_set_func = self._3fs_zero_copy_page_set self.page_set_func = self._3fs_zero_copy_page_set
self.batch_exists_func = self._3fs_zero_copy_batch_exists self.batch_exists_func = self._3fs_zero_copy_batch_exists
self.load_cache_event = load_cache_event self.device = self.mem_pool_device.device
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) self.layer_num = self.mem_pool_device.layer_num
self.layer_done_counter = LayerDoneCounter(self.layer_num)
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter) self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
if write_policy not in [ if write_policy not in [
...@@ -351,11 +360,11 @@ class HiCacheController: ...@@ -351,11 +360,11 @@ class HiCacheController:
]: ]:
raise ValueError(f"Invalid write policy: {write_policy}") raise ValueError(f"Invalid write policy: {write_policy}")
self.write_queue = PriorityQueue() # self.write_queue = PriorityQueue[CacheOperation]()
self.load_queue = PriorityQueue() self.load_queue: List[CacheOperation] = []
self.write_queue: List[CacheOperation] = []
self.ack_write_queue = Queue() self.ack_load_queue: List[HiCacheAck] = []
self.ack_load_queue = Queue() self.ack_write_queue: List[HiCacheAck] = []
self.stop_event = threading.Event() self.stop_event = threading.Event()
self.write_buffer = TransferBuffer(self.stop_event) self.write_buffer = TransferBuffer(self.stop_event)
...@@ -366,16 +375,6 @@ class HiCacheController: ...@@ -366,16 +375,6 @@ class HiCacheController:
self.write_stream = torch.cuda.Stream() self.write_stream = torch.cuda.Stream()
self.load_stream = torch.cuda.Stream() self.load_stream = torch.cuda.Stream()
self.write_thread = threading.Thread(
target=self.write_thread_func_direct, daemon=True
)
self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True
)
self.write_thread.start()
self.load_thread.start()
if self.enable_storage: if self.enable_storage:
self.prefetch_thread = threading.Thread( self.prefetch_thread = threading.Thread(
target=self.prefetch_thread_func, daemon=True target=self.prefetch_thread_func, daemon=True
...@@ -432,15 +431,13 @@ class HiCacheController: ...@@ -432,15 +431,13 @@ class HiCacheController:
def reset(self): def reset(self):
self.stop_event.set() self.stop_event.set()
self.write_thread.join()
self.load_thread.join()
self.write_queue.queue.clear() self.write_queue.clear()
self.load_queue.queue.clear() self.load_queue.clear()
self.write_buffer.clear() self.write_buffer.clear()
self.load_buffer.clear() self.load_buffer.clear()
self.ack_write_queue.queue.clear() self.ack_write_queue.clear()
self.ack_load_queue.queue.clear() self.ack_load_queue.clear()
if self.enable_storage: if self.enable_storage:
self.prefetch_thread.join() self.prefetch_thread.join()
self.backup_thread.join() self.backup_thread.join()
...@@ -449,15 +446,7 @@ class HiCacheController: ...@@ -449,15 +446,7 @@ class HiCacheController:
self.prefetch_revoke_queue.queue.clear() self.prefetch_revoke_queue.queue.clear()
self.ack_backup_queue.queue.clear() self.ack_backup_queue.queue.clear()
self.write_thread = threading.Thread(
target=self.write_thread_func_direct, daemon=True
)
self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True
)
self.stop_event.clear() self.stop_event.clear()
self.write_thread.start()
self.load_thread.start()
if self.enable_storage: if self.enable_storage:
self.prefetch_thread = threading.Thread( self.prefetch_thread = threading.Thread(
...@@ -473,7 +462,7 @@ class HiCacheController: ...@@ -473,7 +462,7 @@ class HiCacheController:
self, self,
device_indices: torch.Tensor, device_indices: torch.Tensor,
priority: Optional[int] = None, priority: Optional[int] = None,
node_id: int = 0, node_id: int = -1,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
Back up KV caches from device memory to host memory. Back up KV caches from device memory to host memory.
...@@ -482,17 +471,46 @@ class HiCacheController: ...@@ -482,17 +471,46 @@ class HiCacheController:
if host_indices is None: if host_indices is None:
return None return None
self.mem_pool_host.protect_write(host_indices) self.mem_pool_host.protect_write(host_indices)
torch.cuda.current_stream().synchronize() self.write_queue.append(
self.write_queue.put(
CacheOperation(host_indices, device_indices, node_id, priority) CacheOperation(host_indices, device_indices, node_id, priority)
) )
self.start_writing()
return host_indices return host_indices
def start_writing(self) -> None:
if len(self.write_queue) == 0:
return
op = CacheOperation.merge_ops(self.write_queue)
host_indices, device_indices = self.move_indices(op)
self.write_queue.clear()
start_event = torch.cuda.Event()
finish_event = torch.cuda.Event()
start_event.record()
with torch.cuda.stream(self.write_stream):
start_event.wait(self.write_stream)
self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_device, host_indices, device_indices, self.io_backend
)
self.mem_pool_host.complete_io(op.host_indices)
finish_event.record()
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# still alive when the write stream is executing.
if host_indices.is_cuda:
host_indices.record_stream(self.write_stream)
if device_indices.is_cuda:
device_indices.record_stream(self.write_stream)
self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
def load( def load(
self, self,
host_indices: torch.Tensor, host_indices: torch.Tensor,
priority: Optional[int] = None, priority: Optional[int] = None,
node_id: int = 0, node_id: int = -1,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
Load KV caches from host memory to device memory. Load KV caches from host memory to device memory.
...@@ -501,17 +519,18 @@ class HiCacheController: ...@@ -501,17 +519,18 @@ class HiCacheController:
if device_indices is None: if device_indices is None:
return None return None
self.mem_pool_host.protect_load(host_indices) self.mem_pool_host.protect_load(host_indices)
# to ensure the device indices are ready before accessed by another CUDA stream self.load_queue.append(
torch.cuda.current_stream().synchronize()
self.load_queue.put(
CacheOperation(host_indices, device_indices, node_id, priority) CacheOperation(host_indices, device_indices, node_id, priority)
) )
return device_indices return device_indices
def move_indices(self, host_indices, device_indices): def move_indices(self, op: CacheOperation):
host_indices, device_indices = op.host_indices, op.device_indices
# move indices to GPU if using kernels, to host if using direct indexing # move indices to GPU if using kernels, to host if using direct indexing
if self.io_backend == "kernel": if self.io_backend == "kernel":
return host_indices.to(self.mem_pool_device.device), device_indices if not host_indices.is_cuda:
host_indices = host_indices.to(self.device, non_blocking=True)
return host_indices, device_indices
elif self.io_backend == "direct": elif self.io_backend == "direct":
device_indices = device_indices.cpu() device_indices = device_indices.cpu()
host_indices, idx = host_indices.sort() host_indices, idx = host_indices.sort()
...@@ -519,58 +538,20 @@ class HiCacheController: ...@@ -519,58 +538,20 @@ class HiCacheController:
else: else:
raise ValueError(f"Unsupported io backend") raise ValueError(f"Unsupported io backend")
def write_thread_func_direct(self): def start_loading(self) -> int:
""" if len(self.load_queue) == 0:
Directly write through KV caches to host memory without buffering. return -1
"""
torch.cuda.set_stream(self.write_stream)
while not self.stop_event.is_set():
try:
operation = self.write_queue.get(block=True, timeout=1)
host_indices, device_indices = self.move_indices(
operation.host_indices, operation.device_indices
)
self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_device, host_indices, device_indices, self.io_backend
)
self.write_stream.synchronize()
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
if node_id != 0:
self.ack_write_queue.put(node_id)
except Empty:
continue
except Exception as e:
logger.error(e)
def load_thread_func_layer_by_layer(self): producer_id = self.layer_done_counter.update_producer()
""" op = CacheOperation.merge_ops(self.load_queue)
Load KV caches from host memory to device memory layer by layer. host_indices, device_indices = self.move_indices(op)
""" self.load_queue.clear()
torch.cuda.set_stream(self.load_stream) producer_event = self.layer_done_counter.events[producer_id]
while not self.stop_event.is_set(): producer_event.start_event.record()
self.load_cache_event.wait(timeout=1)
if not self.load_cache_event.is_set():
continue
self.load_cache_event.clear()
self.layer_done_counter.update_producer()
batch_operation = None
while self.load_queue.qsize() > 0:
op = self.load_queue.get(block=True)
if batch_operation is None:
batch_operation = op
else:
batch_operation.merge(op)
if batch_operation is None:
continue
# start layer-wise KV cache transfer from CPU to GPU with torch.cuda.stream(self.load_stream):
self.layer_done_counter.reset() producer_event.start_event.wait(self.load_stream)
host_indices, device_indices = self.move_indices( for i in range(self.layer_num):
batch_operation.host_indices, batch_operation.device_indices
)
for i in range(self.mem_pool_host.layer_num):
self.mem_pool_host.load_to_device_per_layer( self.mem_pool_host.load_to_device_per_layer(
self.mem_pool_device, self.mem_pool_device,
host_indices, host_indices,
...@@ -578,13 +559,24 @@ class HiCacheController: ...@@ -578,13 +559,24 @@ class HiCacheController:
i, i,
self.io_backend, self.io_backend,
) )
self.load_stream.synchronize() producer_event.complete(i)
self.layer_done_counter.increment() self.mem_pool_host.complete_io(op.host_indices)
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# still alive when the load stream is executing.
if host_indices.is_cuda:
host_indices.record_stream(self.load_stream)
if device_indices.is_cuda:
device_indices.record_stream(self.load_stream)
self.mem_pool_host.complete_io(batch_operation.host_indices) self.ack_load_queue.append(
for node_id in batch_operation.node_ids: HiCacheAck(
if node_id != 0: start_event=producer_event.start_event,
self.ack_load_queue.put(node_id) finish_event=producer_event.finish_event,
node_ids=op.node_ids,
)
)
return producer_id
def evict_device( def evict_device(
self, device_indices: torch.Tensor, host_indices: torch.Tensor self, device_indices: torch.Tensor, host_indices: torch.Tensor
......
...@@ -911,7 +911,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -911,7 +911,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
is_prefill_only: bool = False is_prefill_only: bool = False
# hicache pointer for synchronizing data loading from CPU to GPU # hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index: int = 0 hicache_consumer_index: int = -1
@classmethod @classmethod
def init_new( def init_new(
...@@ -1897,7 +1897,7 @@ class ModelWorkerBatch: ...@@ -1897,7 +1897,7 @@ class ModelWorkerBatch:
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
# If set, the output of the batch contains the hidden states of the run. # If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = 0 hicache_consumer_index: int = -1
# Overlap event # Overlap event
launch_done: Optional[threading.Event] = None launch_done: Optional[threading.Event] = None
......
...@@ -1807,10 +1807,6 @@ class Scheduler( ...@@ -1807,10 +1807,6 @@ class Scheduler(
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
# update the consumer index of hicache to the running batch
self.tp_worker.set_hicache_consumer(
model_worker_batch.hicache_consumer_index
)
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
logits_output, next_token_ids, can_run_cuda_graph = ( logits_output, next_token_ids, can_run_cuda_graph = (
self.tp_worker.forward_batch_generation(model_worker_batch) self.tp_worker.forward_batch_generation(model_worker_batch)
......
...@@ -12,10 +12,11 @@ ...@@ -12,10 +12,11 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A tensor parallel worker.""" """A tensor parallel worker."""
from __future__ import annotations
import logging import logging
import threading import threading
from typing import Optional, Tuple, Union from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch import torch
...@@ -45,6 +46,9 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions ...@@ -45,6 +46,9 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -167,10 +171,10 @@ class TpModelWorker: ...@@ -167,10 +171,10 @@ class TpModelWorker:
self.hicache_layer_transfer_counter = None self.hicache_layer_transfer_counter = None
def register_hicache_layer_transfer_counter(self, counter): def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
self.hicache_layer_transfer_counter = counter self.hicache_layer_transfer_counter = counter
def set_hicache_consumer(self, consumer_index): def set_hicache_consumer(self, consumer_index: int):
if self.hicache_layer_transfer_counter is not None: if self.hicache_layer_transfer_counter is not None:
self.hicache_layer_transfer_counter.set_consumer(consumer_index) self.hicache_layer_transfer_counter.set_consumer(consumer_index)
...@@ -230,6 +234,9 @@ class TpModelWorker: ...@@ -230,6 +234,9 @@ class TpModelWorker:
) -> Tuple[ ) -> Tuple[
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
]: ]:
# update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
pp_proxy_tensors = None pp_proxy_tensors = None
......
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A tensor parallel worker.""" """A tensor parallel worker."""
from __future__ import annotations
import dataclasses import dataclasses
import logging import logging
import signal import signal
import threading import threading
from queue import Queue from queue import Queue
from typing import Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
import psutil import psutil
import torch import torch
...@@ -38,6 +39,9 @@ from sglang.srt.server_args import ServerArgs ...@@ -38,6 +39,9 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import DynamicGradMode, get_compiler_backend from sglang.srt.utils import DynamicGradMode, get_compiler_backend
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -79,7 +83,7 @@ class TpModelWorkerClient: ...@@ -79,7 +83,7 @@ class TpModelWorkerClient:
) )
# Launch threads # Launch threads
self.input_queue = Queue() self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
self.output_queue = Queue() self.output_queue = Queue()
self.forward_stream = torch.get_device_module(self.device).Stream() self.forward_stream = torch.get_device_module(self.device).Stream()
self.forward_thread = threading.Thread( self.forward_thread = threading.Thread(
...@@ -93,13 +97,9 @@ class TpModelWorkerClient: ...@@ -93,13 +97,9 @@ class TpModelWorkerClient:
self.hicache_layer_transfer_counter = None self.hicache_layer_transfer_counter = None
def register_hicache_layer_transfer_counter(self, counter): def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
self.hicache_layer_transfer_counter = counter self.hicache_layer_transfer_counter = counter
def set_hicache_consumer(self, consumer_index):
if self.hicache_layer_transfer_counter is not None:
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
def get_worker_info(self): def get_worker_info(self):
return self.worker.get_worker_info() return self.worker.get_worker_info()
...@@ -147,7 +147,7 @@ class TpModelWorkerClient: ...@@ -147,7 +147,7 @@ class TpModelWorkerClient:
@DynamicGradMode() @DynamicGradMode()
def forward_thread_func_(self): def forward_thread_func_(self):
batch_pt = 0 batch_pt = 0
batch_lists = [None] * 2 batch_lists: List = [None] * 2
while True: while True:
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get() model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
...@@ -169,8 +169,6 @@ class TpModelWorkerClient: ...@@ -169,8 +169,6 @@ class TpModelWorkerClient:
input_ids = model_worker_batch.input_ids input_ids = model_worker_batch.input_ids
resolve_future_token_ids(input_ids, self.future_token_ids_map) resolve_future_token_ids(input_ids, self.future_token_ids_map)
# update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
# Run forward # Run forward
logits_output, next_token_ids, can_run_cuda_graph = ( logits_output, next_token_ids, can_run_cuda_graph = (
self.worker.forward_batch_generation( self.worker.forward_batch_generation(
......
...@@ -201,41 +201,57 @@ class HiRadixCache(RadixCache): ...@@ -201,41 +201,57 @@ class HiRadixCache(RadixCache):
if write_back: if write_back:
# blocking till all write back complete # blocking till all write back complete
while len(self.ongoing_write_through) > 0: while len(self.ongoing_write_through) > 0:
ack_id = self.cache_controller.ack_write_queue.get() for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
finish_event.synchronize()
for ack_id in ack_list:
del self.ongoing_write_through[ack_id] del self.ongoing_write_through[ack_id]
self.cache_controller.ack_write_queue.clear()
assert len(self.ongoing_write_through) == 0
return return
queue_size = torch.tensor(
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int # NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
) if len(self.ongoing_write_through) == 0:
return
finish_count = 0
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
if not finish_event.query():
break
finish_count += 1
queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
if self.tp_world_size > 1: if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to radix cache # synchronize TP workers to make the same update to radix cache
torch.distributed.all_reduce( torch.distributed.all_reduce(
queue_size, queue_size,
op=torch.distributed.ReduceOp.MIN, op=torch.distributed.ReduceOp.MIN,
group=self.tp_group, group=self.tp_group,
) )
for _ in range(queue_size.item()):
ack_id = self.cache_controller.ack_write_queue.get() finish_count = int(queue_size.item())
backuped_node = self.ongoing_write_through[ack_id] while finish_count > 0:
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
finish_event.synchronize()
for ack_id in ack_list:
backuped_node = self.ongoing_write_through.pop(ack_id)
self.dec_lock_ref(backuped_node) self.dec_lock_ref(backuped_node)
del self.ongoing_write_through[ack_id]
if self.enable_storage: if self.enable_storage:
self.write_backup_storage(backuped_node) self.write_backup_storage(backuped_node)
finish_count -= 1
def loading_check(self): def loading_check(self):
while not self.cache_controller.ack_load_queue.empty(): finish_count = 0
try: for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
ack_id = self.cache_controller.ack_load_queue.get_nowait() if not finish_event.query():
start_node, end_node = self.ongoing_load_back[ack_id] # the KV cache loading is still ongoing
self.dec_lock_ref(end_node)
while end_node != start_node:
assert end_node.loading
end_node.loading = False
end_node = end_node.parent
# clear the reference
del self.ongoing_load_back[ack_id]
except Exception:
break break
finish_count += 1
# no need to sync across TP workers as batch forwarding is synced
for ack_id in ack_list:
end_node = self.ongoing_load_back.pop(ack_id)
self.dec_lock_ref(end_node)
# ACK until all events are processed
del self.cache_controller.ack_load_queue[:finish_count]
def evictable_size(self): def evictable_size(self):
return self.evictable_size_ return self.evictable_size_
...@@ -360,12 +376,11 @@ class HiRadixCache(RadixCache): ...@@ -360,12 +376,11 @@ class HiRadixCache(RadixCache):
# no sufficient GPU memory to load back KV caches # no sufficient GPU memory to load back KV caches
return None return None
self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node) self.ongoing_load_back[last_hit_node.id] = last_hit_node
offset = 0 offset = 0
for node in nodes_to_load: for node in nodes_to_load:
node.value = device_indices[offset : offset + len(node.host_value)] node.value = device_indices[offset : offset + len(node.host_value)]
offset += len(node.host_value) offset += len(node.host_value)
node.loading = True
self.evictable_size_ += len(device_indices) self.evictable_size_ += len(device_indices)
self.inc_lock_ref(last_hit_node) self.inc_lock_ref(last_hit_node)
...@@ -394,10 +409,12 @@ class HiRadixCache(RadixCache): ...@@ -394,10 +409,12 @@ class HiRadixCache(RadixCache):
last_node, last_node,
) )
def ready_to_load_host_cache(self): def ready_to_load_host_cache(self) -> int:
producer_index = self.cache_controller.layer_done_counter.next_producer() """
self.load_cache_event.set() Notify the cache controller to start the KV cache loading.
return producer_index Return the consumer index for the schedule batch manager to track.
"""
return self.cache_controller.start_loading()
def check_hicache_events(self): def check_hicache_events(self):
self.writing_check() self.writing_check()
...@@ -702,7 +719,6 @@ class HiRadixCache(RadixCache): ...@@ -702,7 +719,6 @@ class HiRadixCache(RadixCache):
new_node.parent = child.parent new_node.parent = child.parent
new_node.lock_ref = child.lock_ref new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len] new_node.key = child.key[:split_len]
new_node.loading = child.loading
new_node.hit_count = child.hit_count new_node.hit_count = child.hit_count
# split value and host value if exists # split value and host value if exists
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from __future__ import annotations
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
""" """
...@@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache. ...@@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache.
import abc import abc
import logging import logging
from contextlib import nullcontext from contextlib import nullcontext
from typing import Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -38,6 +40,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE ...@@ -38,6 +40,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2 from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024 GB = 1024 * 1024 * 1024
...@@ -175,7 +180,7 @@ class KVCache(abc.ABC): ...@@ -175,7 +180,7 @@ class KVCache(abc.ABC):
) -> None: ) -> None:
raise NotImplementedError() raise NotImplementedError()
def register_layer_transfer_counter(self, layer_transfer_counter): def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
self.layer_transfer_counter = layer_transfer_counter self.layer_transfer_counter = layer_transfer_counter
def get_cpu_copy(self, indices): def get_cpu_copy(self, indices):
......
...@@ -3,6 +3,7 @@ import logging ...@@ -3,6 +3,7 @@ import logging
import threading import threading
from enum import IntEnum from enum import IntEnum
from functools import wraps from functools import wraps
from typing import Optional
import psutil import psutil
import torch import torch
...@@ -169,7 +170,7 @@ class HostKVCache(abc.ABC): ...@@ -169,7 +170,7 @@ class HostKVCache(abc.ABC):
return len(self.free_slots) return len(self.free_slots)
@synchronized() @synchronized()
def alloc(self, need_size: int) -> torch.Tensor: def alloc(self, need_size: int) -> Optional[torch.Tensor]:
assert ( assert (
need_size % self.page_size == 0 need_size % self.page_size == 0
), "The requested size should be a multiple of the page size." ), "The requested size should be a multiple of the page size."
......
...@@ -53,8 +53,6 @@ class TreeNode: ...@@ -53,8 +53,6 @@ class TreeNode:
self.last_access_time = time.monotonic() self.last_access_time = time.monotonic()
self.hit_count = 0 self.hit_count = 0
# indicating the node is loading KV cache from host
self.loading = False
# indicating the node is locked to protect from eviction # indicating the node is locked to protect from eviction
# incremented when the node is referenced by a storage operation # incremented when the node is referenced by a storage operation
self.host_ref_counter = 0 self.host_ref_counter = 0
......
...@@ -60,8 +60,6 @@ class TreeNode: ...@@ -60,8 +60,6 @@ class TreeNode:
self.last_access_time = time.monotonic() self.last_access_time = time.monotonic()
self.hit_count = 0 self.hit_count = 0
# indicating the node is loading KV cache from host
self.loading = False
# store the host indices of KV cache # store the host indices of KV cache
self.host_value = None self.host_value = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment