Unverified Commit e56685ac authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Upstreaming hicache bug fixes (#7267)

parent c26d7349
......@@ -239,7 +239,7 @@ class WorkloadGenerator:
tokenizer=self.tokenizer,
dataset_path=args.dataset_path,
)
self.candidate_inputs = [i[0] for i in self.candidate_inputs]
self.candidate_inputs = [i.prompt for i in self.candidate_inputs]
init_requests = [
(i, gen_payload(self.candidate_inputs[i], args.output_length))
......
......@@ -30,22 +30,37 @@ logger = logging.getLogger(__name__)
class LayerDoneCounter:
def __init__(self, num_layers):
self.counter = num_layers
self.condition = threading.Condition()
self.num_layers = num_layers
# extra producer and consumer counters for overlap mode
self.num_counters = 3
self.counters = [num_layers] * self.num_counters
self.conditions = [threading.Condition() for _ in range(self.num_counters)]
self.producer_index = 0
self.consumer_index = 0
def next_producer(self):
return (self.producer_index + 1) % self.num_counters
def update_producer(self):
self.producer_index = self.next_producer()
return self.producer_index
def set_consumer(self, index):
self.consumer_index = index
def increment(self):
with self.condition:
self.counter += 1
self.condition.notify_all()
with self.conditions[self.producer_index]:
self.counters[self.producer_index] += 1
self.conditions[self.producer_index].notify_all()
def wait_until(self, threshold):
with self.condition:
while self.counter <= threshold:
self.condition.wait()
with self.conditions[self.consumer_index]:
while self.counters[self.consumer_index] <= threshold:
self.conditions[self.consumer_index].wait()
def reset(self):
with self.condition:
self.counter = 0
with self.conditions[self.producer_index]:
self.counters[self.producer_index] = 0
class CacheOperation:
......@@ -296,7 +311,6 @@ class HiCacheController:
while not self.stop_event.is_set():
try:
operation = self.load_queue.get(block=True, timeout=1)
# time.sleep(18e-6 * len(operation.host_indices))
operation.data = self.mem_pool_host.get_flat_data(
operation.host_indices
)
......@@ -320,6 +334,7 @@ class HiCacheController:
if not self.load_cache_event.is_set():
continue
self.load_cache_event.clear()
self.layer_done_counter.update_producer()
batch_operation = None
while self.load_queue.qsize() > 0:
......@@ -331,6 +346,7 @@ class HiCacheController:
if batch_operation is None:
continue
# start layer-wise KV cache transfer from CPU to GPU
self.layer_done_counter.reset()
for i in range(self.mem_pool_host.layer_num):
if self.page_size == 1:
......@@ -466,6 +482,7 @@ class HiCacheController:
except Exception as e:
logger.error(e)
# todo (zhiqiang): double buffering to be deprecated
def write_thread_func_buffer(self):
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
aux_thread.start()
......
......@@ -659,14 +659,6 @@ class Req:
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
elif enable_hierarchical_cache:
# in case last_node is evicted during scheduling, we need to update the prefix_indices
while self.last_node.evicted:
self.prefix_indices = self.prefix_indices[
: -len(self.last_node.host_value)
]
self.last_node = self.last_node.parent
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self):
......@@ -909,6 +901,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Whether to return hidden states
return_hidden_states: bool = False
# hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index: int = 0
@classmethod
def init_new(
cls,
......@@ -1735,6 +1730,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
token_type_ids=self.token_type_ids,
spec_algorithm=self.spec_algorithm,
spec_info=self.spec_info,
hicache_consumer_index=self.hicache_consumer_index,
capture_hidden_mode=(
CaptureHiddenMode.FULL
if self.return_hidden_states
......@@ -1839,6 +1835,7 @@ class ModelWorkerBatch:
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
spec_num_draft_tokens: Optional[int] = None
hicache_consumer_index: int = 0
# Overlap event
launch_done: Optional[threading.Event] = None
......
......@@ -565,6 +565,10 @@ class Scheduler(
hicache_size=server_args.hicache_size,
hicache_write_policy=server_args.hicache_write_policy,
)
self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter
)
else:
self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool,
......@@ -1514,8 +1518,13 @@ class Scheduler(
self.running_batch.batch_is_full = True
break
# bypass prefix_computed if enable_hierarchical_cache
req.init_next_round_input(
None if prefix_computed else self.tree_cache,
(
None
if (prefix_computed and not self.enable_hierarchical_cache)
else self.tree_cache
),
self.enable_hierarchical_cache,
)
......@@ -1548,9 +1557,6 @@ class Scheduler(
x for x in self.waiting_queue if x not in set(can_run_list)
]
if self.enable_hierarchical_cache:
self.tree_cache.ready_to_load_cache()
if adder.new_chunked_req is not None:
assert self.chunked_req is None
self.chunked_req = adder.new_chunked_req
......@@ -1574,6 +1580,10 @@ class Scheduler(
self.server_args.enable_custom_logit_processor,
chunked_req=self.chunked_req,
)
if self.enable_hierarchical_cache:
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
new_batch.hicache_consumer_index = self.tree_cache.ready_to_load_cache()
new_batch.prepare_for_extend()
# Mixed-style chunked prefill
......@@ -1649,6 +1659,11 @@ class Scheduler(
if self.is_generation:
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
# update the consumer index of hicache to the running batch
self.tp_worker.set_hicache_consumer(
model_worker_batch.hicache_consumer_index
)
if self.pp_group.is_last_rank:
logits_output, next_token_ids, can_run_cuda_graph = (
self.tp_worker.forward_batch_generation(model_worker_batch)
......
......@@ -147,6 +147,15 @@ class TpModelWorker:
# A reference make this class has the same member as TpModelWorkerClient
self.worker = self
self.hicache_layer_transfer_counter = None
def register_hicache_layer_transfer_counter(self, counter):
self.hicache_layer_transfer_counter = counter
def set_hicache_consumer(self, consumer_index):
if self.hicache_layer_transfer_counter is not None:
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
def get_worker_info(self):
return (
self.max_total_num_tokens,
......
......@@ -88,6 +88,15 @@ class TpModelWorkerClient:
if self.device == "cpu":
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
self.hicache_layer_transfer_counter = None
def register_hicache_layer_transfer_counter(self, counter):
self.hicache_layer_transfer_counter = counter
def set_hicache_consumer(self, consumer_index):
if self.hicache_layer_transfer_counter is not None:
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
def get_worker_info(self):
return self.worker.get_worker_info()
......@@ -146,6 +155,8 @@ class TpModelWorkerClient:
input_ids = model_worker_batch.input_ids
resolve_future_token_ids(input_ids, self.future_token_ids_map)
# update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
# Run forward
logits_output, next_token_ids, can_run_cuda_graph = (
self.worker.forward_batch_generation(
......
......@@ -307,7 +307,9 @@ class HiRadixCache(RadixCache):
return last_node, prefix_indices
def ready_to_load_cache(self):
producer_index = self.cache_controller.layer_done_counter.next_producer()
self.load_cache_event.set()
return producer_index
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
......@@ -372,6 +374,7 @@ class HiRadixCache(RadixCache):
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
new_node.loading = child.loading
new_node.hit_count = child.hit_count
# split value and host value if exists
if child.evicted:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment