Unverified Commit 10b544ae authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Hierarchical Caching Refactoring and Fixing TP issue (#4082)

parent 01090e8a
......@@ -30,6 +30,26 @@ from sglang.srt.mem_cache.memory_pool import (
logger = logging.getLogger(__name__)
class LayerDoneCounter:
def __init__(self, num_layers):
self.counter = num_layers
self.condition = threading.Condition()
def increment(self):
with self.condition:
self.counter += 1
self.condition.notify_all()
def wait_until(self, threshold):
with self.condition:
while self.counter <= threshold:
self.condition.wait()
def reset(self):
with self.condition:
self.counter = 0
class CacheOperation:
counter = 0
......@@ -132,6 +152,7 @@ class HiCacheController:
self,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
mem_pool_host: MHATokenToKVPoolHost,
load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective",
):
self.mem_pool_device_allocator = token_to_kv_pool_allocator
......@@ -139,6 +160,10 @@ class HiCacheController:
self.mem_pool_host = mem_pool_host
self.write_policy = write_policy
self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
if write_policy not in [
"write_through",
"write_through_selective",
......@@ -165,7 +190,7 @@ class HiCacheController:
target=self.write_thread_func_buffer, daemon=True
)
self.load_thread = threading.Thread(
target=self.load_thread_func_buffer, daemon=True
target=self.load_thread_func_layer_by_layer, daemon=True
)
self.write_thread.start()
self.load_thread.start()
......@@ -186,7 +211,7 @@ class HiCacheController:
target=self.write_thread_func_buffer, daemon=True
)
self.load_thread = threading.Thread(
target=self.load_thread_func_buffer, daemon=True
target=self.load_thread_func_layer_by_layer, daemon=True
)
self.stop_event.clear()
self.write_thread.start()
......@@ -273,6 +298,42 @@ class HiCacheController:
except Exception as e:
logger.error(e)
def load_thread_func_layer_by_layer(self):
"""
Load KV caches from host memory to device memory layer by layer.
"""
with torch.cuda.stream(self.load_stream):
while not self.stop_event.is_set():
self.load_cache_event.wait(timeout=1)
if not self.load_cache_event.is_set():
continue
self.load_cache_event.clear()
batch_operation = None
while self.load_queue.qsize() > 0:
op = self.load_queue.get(block=True)
if batch_operation is None:
batch_operation = op
else:
batch_operation.merge(op)
if batch_operation is None:
continue
self.layer_done_counter.reset()
for i in range(self.mem_pool_host.layer_num):
flat_data = self.mem_pool_host.get_flat_data_by_layer(
batch_operation.host_indices, i
)
self.mem_pool_device.transfer_per_layer(
batch_operation.device_indices, flat_data, i
)
self.layer_done_counter.increment()
self.mem_pool_host.complete_io(batch_operation.host_indices)
for node_id in batch_operation.node_ids:
if node_id != 0:
self.ack_load_queue.put(node_id)
def write_aux_func(self, no_wait=False):
"""
Auxiliary function to prepare the buffer for write operations.
......
......@@ -315,6 +315,7 @@ class Req:
# The relative logprob_start_len in an extend batch
self.extend_logprob_start_len = 0
self.last_node = None
self.last_node_global = None
# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
......@@ -389,13 +390,24 @@ class Req:
# Whether request reached finished condition
return self.finished_reason is not None
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
def init_next_round_input(
self,
tree_cache: Optional[BasePrefixCache] = None,
enable_hierarchical_cache=False,
):
self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None:
# tree cache is None if the prefix is not computed with tree cache.
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
if enable_hierarchical_cache:
self.prefix_indices, self.last_node, self.last_node_global = (
tree_cache.match_prefix(
key=self.adjust_max_prefix_ids(), include_evicted=True
)
)
else:
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self):
......
......@@ -73,9 +73,15 @@ class CacheAgnosticPolicy(Enum):
class SchedulePolicy:
Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy]
def __init__(self, policy: str, tree_cache: BasePrefixCache):
def __init__(
self,
policy: str,
tree_cache: BasePrefixCache,
enable_hierarchical_cache: bool = False,
):
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
self.tree_cache = tree_cache
self.enable_hierarchical_cache = enable_hierarchical_cache
# It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache(
......@@ -149,9 +155,14 @@ class SchedulePolicy:
prefix_ids = r.adjust_max_prefix_ids()
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=prefix_ids
)
if self.enable_hierarchical_cache:
r.prefix_indices, r.last_node, r.last_node_global = (
self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True)
)
else:
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=prefix_ids
)
# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
......@@ -428,7 +439,9 @@ class PrefillAdder:
return self.budget_state()
def add_one_req(self, req: Req, has_chunked_req: bool):
def add_one_req(
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
):
if req.sampling_params.ignore_eos and self.tree_cache.disable:
return self.add_one_req_ignore_eos(req, has_chunked_req)
......@@ -448,6 +461,18 @@ class PrefillAdder:
if total_tokens > self.rem_total_tokens:
return AddReqResult.NO_TOKEN
if (
enable_hierarchical_cache
and req.last_node_global is not None
and req.last_node_global.evicted
):
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
req.last_node_global, req.prefix_indices
)
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
input_tokens = req.extend_input_len
prefix_len = len(req.prefix_indices)
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
# Non-chunked prefill
self.can_run_list.append(req)
......
......@@ -265,12 +265,10 @@ class Scheduler:
f"context_len={self.model_config.context_len}"
)
# Init memory pool and cache
self.init_memory_pool_and_cache()
# Init running status
self.waiting_queue: List[Req] = []
self.staging_reqs = {}
# The running decoding batch for continuous batching
self.running_batch: Optional[ScheduleBatch] = None
# The current forward batch
......@@ -308,7 +306,9 @@ class Scheduler:
self.grammar_backend = None
# Init schedule policy and new token estimation
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
self.policy = SchedulePolicy(
self.schedule_policy, self.tree_cache, self.enable_hierarchical_cache
)
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
......@@ -431,6 +431,7 @@ class Scheduler:
self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
)
else:
self.tree_cache = RadixCache(
......@@ -1005,6 +1006,11 @@ class Scheduler:
self.batch_is_full = True
return None
if self.enable_hierarchical_cache:
# check for completion of hierarchical cache activities to release memory
self.tree_cache.writing_check()
self.tree_cache.loading_check()
# Get priority queue
prefix_computed = self.policy.calc_priority(self.waiting_queue)
......@@ -1048,32 +1054,14 @@ class Scheduler:
self.batch_is_full = True
break
req.init_next_round_input(None if prefix_computed else self.tree_cache)
req.init_next_round_input(
None if prefix_computed else self.tree_cache,
self.enable_hierarchical_cache,
)
if self.enable_hierarchical_cache and req.last_node is not None:
if req.last_node.evicted:
# loading KV cache for the request
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
req.last_node,
req.prefix_indices,
adder.rem_total_tokens,
)
if req.last_node.loading:
# to prevent frequent cache invalidation
if req.rid in self.staging_reqs:
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
self.tree_cache.inc_lock_ref(req.last_node)
self.staging_reqs[req.rid] = req.last_node
continue
elif req.last_node.loading:
if not self.tree_cache.loading_complete(req.last_node):
continue
if req.rid in self.staging_reqs:
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
del self.staging_reqs[req.rid]
res = adder.add_one_req(req, self.chunked_req)
res = adder.add_one_req(
req, self.chunked_req, self.enable_hierarchical_cache
)
if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN:
if self.enable_hierarchical_cache:
......@@ -1094,6 +1082,9 @@ class Scheduler:
x for x in self.waiting_queue if x not in set(can_run_list)
]
if self.enable_hierarchical_cache:
self.tree_cache.read_to_load_cache()
if adder.new_chunked_req is not None:
assert self.chunked_req is None
self.chunked_req = adder.new_chunked_req
......
import heapq
import logging
import threading
import time
from typing import List, Optional
import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPoolHost,
ReqToTokenPool,
......@@ -22,12 +24,18 @@ class HiRadixCache(RadixCache):
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
tp_cache_group: torch.distributed.ProcessGroup,
):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
token_to_kv_pool_allocator.get_kvcache()
)
self.tp_group = tp_cache_group
self.load_cache_event = threading.Event()
self.cache_controller = HiCacheController(
token_to_kv_pool_allocator, self.token_to_kv_pool_host
token_to_kv_pool_allocator,
self.token_to_kv_pool_host,
load_cache_event=self.load_cache_event,
)
# record the nodes with ongoing write through
......@@ -55,14 +63,12 @@ class HiRadixCache(RadixCache):
def write_backup(self, node: TreeNode):
host_indices = self.cache_controller.write(
device_indices=node.value,
priority=-self.get_height(node),
node_id=node.id,
)
if host_indices is None:
self.evict_host(len(node.value))
host_indices = self.cache_controller.write(
device_indices=node.value,
priority=-self.get_height(node),
node_id=node.id,
)
if host_indices is not None:
......@@ -83,14 +89,20 @@ class HiRadixCache(RadixCache):
node.hit_count = 0
def writing_check(self):
while not self.cache_controller.ack_write_queue.empty():
try:
ack_id = self.cache_controller.ack_write_queue.get_nowait()
self.dec_lock_ref(self.ongoing_write_through[ack_id])
# clear the reference
del self.ongoing_write_through[ack_id]
except Exception:
break
queue_size = torch.tensor(
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
# synchrnoize TP workers to make the same update to radix cache
torch.distributed.all_reduce(
queue_size,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
for _ in range(queue_size.item()):
ack_id = self.cache_controller.ack_write_queue.get()
self.dec_lock_ref(self.ongoing_write_through[ack_id])
del self.ongoing_write_through[ack_id]
def loading_check(self):
while not self.cache_controller.ack_load_queue.empty():
......@@ -108,8 +120,6 @@ class HiRadixCache(RadixCache):
break
def evictable_size(self):
self.writing_check()
self.loading_check()
return self.evictable_size_
def evict(self, num_tokens: int, evict_callback=None):
......@@ -242,10 +252,6 @@ class HiRadixCache(RadixCache):
return device_indices
def loading_complete(self, node: TreeNode):
self.loading_check()
return node.loading == False
def init_load_back(
self,
last_node: TreeNode,
......@@ -272,6 +278,28 @@ class HiRadixCache(RadixCache):
return last_node, prefix_indices
def read_to_load_cache(self):
self.load_cache_event.set()
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
if self.disable:
return [], self.root_node
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.concat(value)
else:
value = torch.tensor([], dtype=torch.int32)
last_node_global = last_node
while last_node.evicted:
last_node = last_node.parent
if include_evicted:
return value, last_node, last_node_global
else:
return value, last_node
def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time()
value = []
......
......@@ -206,6 +206,8 @@ class MHATokenToKVPool(KVCache):
self.layer_num = layer_num
self._create_buffers()
self.layer_transfer_counter = None
k_size, v_size = self.get_kv_size_bytes()
logger.info(
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
......@@ -267,12 +269,28 @@ class MHATokenToKVPool(KVCache):
self.k_buffer[i][indices] = k_data[i]
self.v_buffer[i][indices] = v_data[i]
def register_layer_transfer_counter(self, layer_transfer_counter):
self.layer_transfer_counter = layer_transfer_counter
def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
self.k_buffer[layer_id][indices] = k_data
self.v_buffer[layer_id][indices] = v_data
def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
if self.store_dtype != self.dtype:
return self.k_buffer[layer_id].view(self.dtype)
return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
if self.store_dtype != self.dtype:
return self.v_buffer[layer_id].view(self.dtype)
return self.v_buffer[layer_id]
......@@ -530,6 +548,9 @@ class MHATokenToKVPoolHost:
def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment