"docs/source/en/using-diffusers/controlling_generation.md" did not exist on "597908971392310d9bc56b9c7502ce4448156e24"
Unverified Commit 10b544ae authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Hierarchical Caching Refactoring and Fixing TP issue (#4082)

parent 01090e8a
...@@ -30,6 +30,26 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -30,6 +30,26 @@ from sglang.srt.mem_cache.memory_pool import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LayerDoneCounter:
def __init__(self, num_layers):
self.counter = num_layers
self.condition = threading.Condition()
def increment(self):
with self.condition:
self.counter += 1
self.condition.notify_all()
def wait_until(self, threshold):
with self.condition:
while self.counter <= threshold:
self.condition.wait()
def reset(self):
with self.condition:
self.counter = 0
class CacheOperation: class CacheOperation:
counter = 0 counter = 0
...@@ -132,6 +152,7 @@ class HiCacheController: ...@@ -132,6 +152,7 @@ class HiCacheController:
self, self,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
mem_pool_host: MHATokenToKVPoolHost, mem_pool_host: MHATokenToKVPoolHost,
load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective", write_policy: str = "write_through_selective",
): ):
self.mem_pool_device_allocator = token_to_kv_pool_allocator self.mem_pool_device_allocator = token_to_kv_pool_allocator
...@@ -139,6 +160,10 @@ class HiCacheController: ...@@ -139,6 +160,10 @@ class HiCacheController:
self.mem_pool_host = mem_pool_host self.mem_pool_host = mem_pool_host
self.write_policy = write_policy self.write_policy = write_policy
self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
if write_policy not in [ if write_policy not in [
"write_through", "write_through",
"write_through_selective", "write_through_selective",
...@@ -165,7 +190,7 @@ class HiCacheController: ...@@ -165,7 +190,7 @@ class HiCacheController:
target=self.write_thread_func_buffer, daemon=True target=self.write_thread_func_buffer, daemon=True
) )
self.load_thread = threading.Thread( self.load_thread = threading.Thread(
target=self.load_thread_func_buffer, daemon=True target=self.load_thread_func_layer_by_layer, daemon=True
) )
self.write_thread.start() self.write_thread.start()
self.load_thread.start() self.load_thread.start()
...@@ -186,7 +211,7 @@ class HiCacheController: ...@@ -186,7 +211,7 @@ class HiCacheController:
target=self.write_thread_func_buffer, daemon=True target=self.write_thread_func_buffer, daemon=True
) )
self.load_thread = threading.Thread( self.load_thread = threading.Thread(
target=self.load_thread_func_buffer, daemon=True target=self.load_thread_func_layer_by_layer, daemon=True
) )
self.stop_event.clear() self.stop_event.clear()
self.write_thread.start() self.write_thread.start()
...@@ -273,6 +298,42 @@ class HiCacheController: ...@@ -273,6 +298,42 @@ class HiCacheController:
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
def load_thread_func_layer_by_layer(self):
"""
Load KV caches from host memory to device memory layer by layer.
"""
with torch.cuda.stream(self.load_stream):
while not self.stop_event.is_set():
self.load_cache_event.wait(timeout=1)
if not self.load_cache_event.is_set():
continue
self.load_cache_event.clear()
batch_operation = None
while self.load_queue.qsize() > 0:
op = self.load_queue.get(block=True)
if batch_operation is None:
batch_operation = op
else:
batch_operation.merge(op)
if batch_operation is None:
continue
self.layer_done_counter.reset()
for i in range(self.mem_pool_host.layer_num):
flat_data = self.mem_pool_host.get_flat_data_by_layer(
batch_operation.host_indices, i
)
self.mem_pool_device.transfer_per_layer(
batch_operation.device_indices, flat_data, i
)
self.layer_done_counter.increment()
self.mem_pool_host.complete_io(batch_operation.host_indices)
for node_id in batch_operation.node_ids:
if node_id != 0:
self.ack_load_queue.put(node_id)
def write_aux_func(self, no_wait=False): def write_aux_func(self, no_wait=False):
""" """
Auxiliary function to prepare the buffer for write operations. Auxiliary function to prepare the buffer for write operations.
......
...@@ -315,6 +315,7 @@ class Req: ...@@ -315,6 +315,7 @@ class Req:
# The relative logprob_start_len in an extend batch # The relative logprob_start_len in an extend batch
self.extend_logprob_start_len = 0 self.extend_logprob_start_len = 0
self.last_node = None self.last_node = None
self.last_node_global = None
# Whether or not if it is chunked. It increments whenever # Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is # it is chunked, and decrement whenever chunked request is
...@@ -389,10 +390,21 @@ class Req: ...@@ -389,10 +390,21 @@ class Req:
# Whether request reached finished condition # Whether request reached finished condition
return self.finished_reason is not None return self.finished_reason is not None
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): def init_next_round_input(
self,
tree_cache: Optional[BasePrefixCache] = None,
enable_hierarchical_cache=False,
):
self.fill_ids = self.origin_input_ids + self.output_ids self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None: if tree_cache is not None:
# tree cache is None if the prefix is not computed with tree cache. # tree cache is None if the prefix is not computed with tree cache.
if enable_hierarchical_cache:
self.prefix_indices, self.last_node, self.last_node_global = (
tree_cache.match_prefix(
key=self.adjust_max_prefix_ids(), include_evicted=True
)
)
else:
self.prefix_indices, self.last_node = tree_cache.match_prefix( self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids() rid=self.rid, key=self.adjust_max_prefix_ids()
) )
......
...@@ -73,9 +73,15 @@ class CacheAgnosticPolicy(Enum): ...@@ -73,9 +73,15 @@ class CacheAgnosticPolicy(Enum):
class SchedulePolicy: class SchedulePolicy:
Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy] Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy]
def __init__(self, policy: str, tree_cache: BasePrefixCache): def __init__(
self,
policy: str,
tree_cache: BasePrefixCache,
enable_hierarchical_cache: bool = False,
):
self.policy = self._validate_and_adjust_policy(policy, tree_cache) self.policy = self._validate_and_adjust_policy(policy, tree_cache)
self.tree_cache = tree_cache self.tree_cache = tree_cache
self.enable_hierarchical_cache = enable_hierarchical_cache
# It is used to find the matching prefix for in-batch prefix caching. # It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache( self.waiting_queue_radix_tree = RadixCache(
...@@ -149,6 +155,11 @@ class SchedulePolicy: ...@@ -149,6 +155,11 @@ class SchedulePolicy:
prefix_ids = r.adjust_max_prefix_ids() prefix_ids = r.adjust_max_prefix_ids()
# NOTE: the prefix_indices must always be aligned with last_node # NOTE: the prefix_indices must always be aligned with last_node
if self.enable_hierarchical_cache:
r.prefix_indices, r.last_node, r.last_node_global = (
self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True)
)
else:
r.prefix_indices, r.last_node = self.tree_cache.match_prefix( r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=prefix_ids rid=r.rid, key=prefix_ids
) )
...@@ -428,7 +439,9 @@ class PrefillAdder: ...@@ -428,7 +439,9 @@ class PrefillAdder:
return self.budget_state() return self.budget_state()
def add_one_req(self, req: Req, has_chunked_req: bool): def add_one_req(
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
):
if req.sampling_params.ignore_eos and self.tree_cache.disable: if req.sampling_params.ignore_eos and self.tree_cache.disable:
return self.add_one_req_ignore_eos(req, has_chunked_req) return self.add_one_req_ignore_eos(req, has_chunked_req)
...@@ -448,6 +461,18 @@ class PrefillAdder: ...@@ -448,6 +461,18 @@ class PrefillAdder:
if total_tokens > self.rem_total_tokens: if total_tokens > self.rem_total_tokens:
return AddReqResult.NO_TOKEN return AddReqResult.NO_TOKEN
if (
enable_hierarchical_cache
and req.last_node_global is not None
and req.last_node_global.evicted
):
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
req.last_node_global, req.prefix_indices
)
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
input_tokens = req.extend_input_len
prefix_len = len(req.prefix_indices)
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens: if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
# Non-chunked prefill # Non-chunked prefill
self.can_run_list.append(req) self.can_run_list.append(req)
......
...@@ -265,12 +265,10 @@ class Scheduler: ...@@ -265,12 +265,10 @@ class Scheduler:
f"context_len={self.model_config.context_len}" f"context_len={self.model_config.context_len}"
) )
# Init memory pool and cache
self.init_memory_pool_and_cache() self.init_memory_pool_and_cache()
# Init running status # Init running status
self.waiting_queue: List[Req] = [] self.waiting_queue: List[Req] = []
self.staging_reqs = {}
# The running decoding batch for continuous batching # The running decoding batch for continuous batching
self.running_batch: Optional[ScheduleBatch] = None self.running_batch: Optional[ScheduleBatch] = None
# The current forward batch # The current forward batch
...@@ -308,7 +306,9 @@ class Scheduler: ...@@ -308,7 +306,9 @@ class Scheduler:
self.grammar_backend = None self.grammar_backend = None
# Init schedule policy and new token estimation # Init schedule policy and new token estimation
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache) self.policy = SchedulePolicy(
self.schedule_policy, self.tree_cache, self.enable_hierarchical_cache
)
assert ( assert (
server_args.schedule_conservativeness >= 0 server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness" ), "Invalid schedule_conservativeness"
...@@ -431,6 +431,7 @@ class Scheduler: ...@@ -431,6 +431,7 @@ class Scheduler:
self.tree_cache = HiRadixCache( self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
) )
else: else:
self.tree_cache = RadixCache( self.tree_cache = RadixCache(
...@@ -1005,6 +1006,11 @@ class Scheduler: ...@@ -1005,6 +1006,11 @@ class Scheduler:
self.batch_is_full = True self.batch_is_full = True
return None return None
if self.enable_hierarchical_cache:
# check for completion of hierarchical cache activities to release memory
self.tree_cache.writing_check()
self.tree_cache.loading_check()
# Get priority queue # Get priority queue
prefix_computed = self.policy.calc_priority(self.waiting_queue) prefix_computed = self.policy.calc_priority(self.waiting_queue)
...@@ -1048,32 +1054,14 @@ class Scheduler: ...@@ -1048,32 +1054,14 @@ class Scheduler:
self.batch_is_full = True self.batch_is_full = True
break break
req.init_next_round_input(None if prefix_computed else self.tree_cache) req.init_next_round_input(
None if prefix_computed else self.tree_cache,
if self.enable_hierarchical_cache and req.last_node is not None: self.enable_hierarchical_cache,
if req.last_node.evicted: )
# loading KV cache for the request
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
req.last_node,
req.prefix_indices,
adder.rem_total_tokens,
)
if req.last_node.loading:
# to prevent frequent cache invalidation
if req.rid in self.staging_reqs:
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
self.tree_cache.inc_lock_ref(req.last_node)
self.staging_reqs[req.rid] = req.last_node
continue
elif req.last_node.loading:
if not self.tree_cache.loading_complete(req.last_node):
continue
if req.rid in self.staging_reqs:
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
del self.staging_reqs[req.rid]
res = adder.add_one_req(req, self.chunked_req) res = adder.add_one_req(
req, self.chunked_req, self.enable_hierarchical_cache
)
if res != AddReqResult.CONTINUE: if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN: if res == AddReqResult.NO_TOKEN:
if self.enable_hierarchical_cache: if self.enable_hierarchical_cache:
...@@ -1094,6 +1082,9 @@ class Scheduler: ...@@ -1094,6 +1082,9 @@ class Scheduler:
x for x in self.waiting_queue if x not in set(can_run_list) x for x in self.waiting_queue if x not in set(can_run_list)
] ]
if self.enable_hierarchical_cache:
self.tree_cache.read_to_load_cache()
if adder.new_chunked_req is not None: if adder.new_chunked_req is not None:
assert self.chunked_req is None assert self.chunked_req is None
self.chunked_req = adder.new_chunked_req self.chunked_req = adder.new_chunked_req
......
import heapq import heapq
import logging import logging
import threading
import time import time
from typing import List, Optional from typing import List, Optional
import torch import torch
from sglang.srt.managers.cache_controller import HiCacheController from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPoolHost, MHATokenToKVPoolHost,
ReqToTokenPool, ReqToTokenPool,
...@@ -22,12 +24,18 @@ class HiRadixCache(RadixCache): ...@@ -22,12 +24,18 @@ class HiRadixCache(RadixCache):
self, self,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
tp_cache_group: torch.distributed.ProcessGroup,
): ):
self.token_to_kv_pool_host = MHATokenToKVPoolHost( self.token_to_kv_pool_host = MHATokenToKVPoolHost(
token_to_kv_pool_allocator.get_kvcache() token_to_kv_pool_allocator.get_kvcache()
) )
self.tp_group = tp_cache_group
self.load_cache_event = threading.Event()
self.cache_controller = HiCacheController( self.cache_controller = HiCacheController(
token_to_kv_pool_allocator, self.token_to_kv_pool_host token_to_kv_pool_allocator,
self.token_to_kv_pool_host,
load_cache_event=self.load_cache_event,
) )
# record the nodes with ongoing write through # record the nodes with ongoing write through
...@@ -55,14 +63,12 @@ class HiRadixCache(RadixCache): ...@@ -55,14 +63,12 @@ class HiRadixCache(RadixCache):
def write_backup(self, node: TreeNode): def write_backup(self, node: TreeNode):
host_indices = self.cache_controller.write( host_indices = self.cache_controller.write(
device_indices=node.value, device_indices=node.value,
priority=-self.get_height(node),
node_id=node.id, node_id=node.id,
) )
if host_indices is None: if host_indices is None:
self.evict_host(len(node.value)) self.evict_host(len(node.value))
host_indices = self.cache_controller.write( host_indices = self.cache_controller.write(
device_indices=node.value, device_indices=node.value,
priority=-self.get_height(node),
node_id=node.id, node_id=node.id,
) )
if host_indices is not None: if host_indices is not None:
...@@ -83,14 +89,20 @@ class HiRadixCache(RadixCache): ...@@ -83,14 +89,20 @@ class HiRadixCache(RadixCache):
node.hit_count = 0 node.hit_count = 0
def writing_check(self): def writing_check(self):
while not self.cache_controller.ack_write_queue.empty(): queue_size = torch.tensor(
try: self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
ack_id = self.cache_controller.ack_write_queue.get_nowait() )
if torch.distributed.get_world_size(group=self.tp_group) > 1:
# synchrnoize TP workers to make the same update to radix cache
torch.distributed.all_reduce(
queue_size,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
for _ in range(queue_size.item()):
ack_id = self.cache_controller.ack_write_queue.get()
self.dec_lock_ref(self.ongoing_write_through[ack_id]) self.dec_lock_ref(self.ongoing_write_through[ack_id])
# clear the reference
del self.ongoing_write_through[ack_id] del self.ongoing_write_through[ack_id]
except Exception:
break
def loading_check(self): def loading_check(self):
while not self.cache_controller.ack_load_queue.empty(): while not self.cache_controller.ack_load_queue.empty():
...@@ -108,8 +120,6 @@ class HiRadixCache(RadixCache): ...@@ -108,8 +120,6 @@ class HiRadixCache(RadixCache):
break break
def evictable_size(self): def evictable_size(self):
self.writing_check()
self.loading_check()
return self.evictable_size_ return self.evictable_size_
def evict(self, num_tokens: int, evict_callback=None): def evict(self, num_tokens: int, evict_callback=None):
...@@ -242,10 +252,6 @@ class HiRadixCache(RadixCache): ...@@ -242,10 +252,6 @@ class HiRadixCache(RadixCache):
return device_indices return device_indices
def loading_complete(self, node: TreeNode):
self.loading_check()
return node.loading == False
def init_load_back( def init_load_back(
self, self,
last_node: TreeNode, last_node: TreeNode,
...@@ -272,6 +278,28 @@ class HiRadixCache(RadixCache): ...@@ -272,6 +278,28 @@ class HiRadixCache(RadixCache):
return last_node, prefix_indices return last_node, prefix_indices
def read_to_load_cache(self):
self.load_cache_event.set()
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
if self.disable:
return [], self.root_node
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.concat(value)
else:
value = torch.tensor([], dtype=torch.int32)
last_node_global = last_node
while last_node.evicted:
last_node = last_node.parent
if include_evicted:
return value, last_node, last_node_global
else:
return value, last_node
def _match_prefix_helper(self, node: TreeNode, key: List): def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time() node.last_access_time = time.time()
value = [] value = []
......
...@@ -206,6 +206,8 @@ class MHATokenToKVPool(KVCache): ...@@ -206,6 +206,8 @@ class MHATokenToKVPool(KVCache):
self.layer_num = layer_num self.layer_num = layer_num
self._create_buffers() self._create_buffers()
self.layer_transfer_counter = None
k_size, v_size = self.get_kv_size_bytes() k_size, v_size = self.get_kv_size_bytes()
logger.info( logger.info(
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB" f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
...@@ -267,12 +269,28 @@ class MHATokenToKVPool(KVCache): ...@@ -267,12 +269,28 @@ class MHATokenToKVPool(KVCache):
self.k_buffer[i][indices] = k_data[i] self.k_buffer[i][indices] = k_data[i]
self.v_buffer[i][indices] = v_data[i] self.v_buffer[i][indices] = v_data[i]
def register_layer_transfer_counter(self, layer_transfer_counter):
self.layer_transfer_counter = layer_transfer_counter
def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
self.k_buffer[layer_id][indices] = k_data
self.v_buffer[layer_id][indices] = v_data
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.k_buffer[layer_id].view(self.dtype) return self.k_buffer[layer_id].view(self.dtype)
return self.k_buffer[layer_id] return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int): def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.v_buffer[layer_id].view(self.dtype) return self.v_buffer[layer_id].view(self.dtype)
return self.v_buffer[layer_id] return self.v_buffer[layer_id]
...@@ -530,6 +548,9 @@ class MHATokenToKVPoolHost: ...@@ -530,6 +548,9 @@ class MHATokenToKVPoolHost:
def get_flat_data(self, indices): def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices] return self.kv_buffer[:, :, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id, indices]
def assign_flat_data(self, indices, flat_data): def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data self.kv_buffer[:, :, indices] = flat_data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment