Unverified Commit e56685ac authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Upstreaming hicache bug fixes (#7267)

parent c26d7349
...@@ -239,7 +239,7 @@ class WorkloadGenerator: ...@@ -239,7 +239,7 @@ class WorkloadGenerator:
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
) )
self.candidate_inputs = [i[0] for i in self.candidate_inputs] self.candidate_inputs = [i.prompt for i in self.candidate_inputs]
init_requests = [ init_requests = [
(i, gen_payload(self.candidate_inputs[i], args.output_length)) (i, gen_payload(self.candidate_inputs[i], args.output_length))
......
...@@ -30,22 +30,37 @@ logger = logging.getLogger(__name__) ...@@ -30,22 +30,37 @@ logger = logging.getLogger(__name__)
class LayerDoneCounter: class LayerDoneCounter:
def __init__(self, num_layers): def __init__(self, num_layers):
self.counter = num_layers self.num_layers = num_layers
self.condition = threading.Condition() # extra producer and consumer counters for overlap mode
self.num_counters = 3
self.counters = [num_layers] * self.num_counters
self.conditions = [threading.Condition() for _ in range(self.num_counters)]
self.producer_index = 0
self.consumer_index = 0
def next_producer(self):
return (self.producer_index + 1) % self.num_counters
def update_producer(self):
self.producer_index = self.next_producer()
return self.producer_index
def set_consumer(self, index):
self.consumer_index = index
def increment(self): def increment(self):
with self.condition: with self.conditions[self.producer_index]:
self.counter += 1 self.counters[self.producer_index] += 1
self.condition.notify_all() self.conditions[self.producer_index].notify_all()
def wait_until(self, threshold): def wait_until(self, threshold):
with self.condition: with self.conditions[self.consumer_index]:
while self.counter <= threshold: while self.counters[self.consumer_index] <= threshold:
self.condition.wait() self.conditions[self.consumer_index].wait()
def reset(self): def reset(self):
with self.condition: with self.conditions[self.producer_index]:
self.counter = 0 self.counters[self.producer_index] = 0
class CacheOperation: class CacheOperation:
...@@ -296,7 +311,6 @@ class HiCacheController: ...@@ -296,7 +311,6 @@ class HiCacheController:
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try: try:
operation = self.load_queue.get(block=True, timeout=1) operation = self.load_queue.get(block=True, timeout=1)
# time.sleep(18e-6 * len(operation.host_indices))
operation.data = self.mem_pool_host.get_flat_data( operation.data = self.mem_pool_host.get_flat_data(
operation.host_indices operation.host_indices
) )
...@@ -320,6 +334,7 @@ class HiCacheController: ...@@ -320,6 +334,7 @@ class HiCacheController:
if not self.load_cache_event.is_set(): if not self.load_cache_event.is_set():
continue continue
self.load_cache_event.clear() self.load_cache_event.clear()
self.layer_done_counter.update_producer()
batch_operation = None batch_operation = None
while self.load_queue.qsize() > 0: while self.load_queue.qsize() > 0:
...@@ -331,6 +346,7 @@ class HiCacheController: ...@@ -331,6 +346,7 @@ class HiCacheController:
if batch_operation is None: if batch_operation is None:
continue continue
# start layer-wise KV cache transfer from CPU to GPU
self.layer_done_counter.reset() self.layer_done_counter.reset()
for i in range(self.mem_pool_host.layer_num): for i in range(self.mem_pool_host.layer_num):
if self.page_size == 1: if self.page_size == 1:
...@@ -466,6 +482,7 @@ class HiCacheController: ...@@ -466,6 +482,7 @@ class HiCacheController:
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
# todo (zhiqiang): double buffering to be deprecated
def write_thread_func_buffer(self): def write_thread_func_buffer(self):
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True) aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
aux_thread.start() aux_thread.start()
......
...@@ -659,14 +659,6 @@ class Req: ...@@ -659,14 +659,6 @@ class Req:
self.prefix_indices, self.last_node = tree_cache.match_prefix( self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids() rid=self.rid, key=self.adjust_max_prefix_ids()
) )
elif enable_hierarchical_cache:
# in case last_node is evicted during scheduling, we need to update the prefix_indices
while self.last_node.evicted:
self.prefix_indices = self.prefix_indices[
: -len(self.last_node.host_value)
]
self.last_node = self.last_node.parent
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self): def adjust_max_prefix_ids(self):
...@@ -909,6 +901,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -909,6 +901,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Whether to return hidden states # Whether to return hidden states
return_hidden_states: bool = False return_hidden_states: bool = False
# hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index: int = 0
@classmethod @classmethod
def init_new( def init_new(
cls, cls,
...@@ -1735,6 +1730,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1735,6 +1730,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
token_type_ids=self.token_type_ids, token_type_ids=self.token_type_ids,
spec_algorithm=self.spec_algorithm, spec_algorithm=self.spec_algorithm,
spec_info=self.spec_info, spec_info=self.spec_info,
hicache_consumer_index=self.hicache_consumer_index,
capture_hidden_mode=( capture_hidden_mode=(
CaptureHiddenMode.FULL CaptureHiddenMode.FULL
if self.return_hidden_states if self.return_hidden_states
...@@ -1839,6 +1835,7 @@ class ModelWorkerBatch: ...@@ -1839,6 +1835,7 @@ class ModelWorkerBatch:
# If set, the output of the batch contains the hidden states of the run. # If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
spec_num_draft_tokens: Optional[int] = None spec_num_draft_tokens: Optional[int] = None
hicache_consumer_index: int = 0
# Overlap event # Overlap event
launch_done: Optional[threading.Event] = None launch_done: Optional[threading.Event] = None
......
...@@ -565,6 +565,10 @@ class Scheduler( ...@@ -565,6 +565,10 @@ class Scheduler(
hicache_size=server_args.hicache_size, hicache_size=server_args.hicache_size,
hicache_write_policy=server_args.hicache_write_policy, hicache_write_policy=server_args.hicache_write_policy,
) )
self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter
)
else: else:
self.tree_cache = RadixCache( self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
...@@ -1514,8 +1518,13 @@ class Scheduler( ...@@ -1514,8 +1518,13 @@ class Scheduler(
self.running_batch.batch_is_full = True self.running_batch.batch_is_full = True
break break
# bypass prefix_computed if enable_hierarchical_cache
req.init_next_round_input( req.init_next_round_input(
None if prefix_computed else self.tree_cache, (
None
if (prefix_computed and not self.enable_hierarchical_cache)
else self.tree_cache
),
self.enable_hierarchical_cache, self.enable_hierarchical_cache,
) )
...@@ -1548,9 +1557,6 @@ class Scheduler( ...@@ -1548,9 +1557,6 @@ class Scheduler(
x for x in self.waiting_queue if x not in set(can_run_list) x for x in self.waiting_queue if x not in set(can_run_list)
] ]
if self.enable_hierarchical_cache:
self.tree_cache.ready_to_load_cache()
if adder.new_chunked_req is not None: if adder.new_chunked_req is not None:
assert self.chunked_req is None assert self.chunked_req is None
self.chunked_req = adder.new_chunked_req self.chunked_req = adder.new_chunked_req
...@@ -1574,6 +1580,10 @@ class Scheduler( ...@@ -1574,6 +1580,10 @@ class Scheduler(
self.server_args.enable_custom_logit_processor, self.server_args.enable_custom_logit_processor,
chunked_req=self.chunked_req, chunked_req=self.chunked_req,
) )
if self.enable_hierarchical_cache:
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
new_batch.hicache_consumer_index = self.tree_cache.ready_to_load_cache()
new_batch.prepare_for_extend() new_batch.prepare_for_extend()
# Mixed-style chunked prefill # Mixed-style chunked prefill
...@@ -1649,6 +1659,11 @@ class Scheduler( ...@@ -1649,6 +1659,11 @@ class Scheduler(
if self.is_generation: if self.is_generation:
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
# update the consumer index of hicache to the running batch
self.tp_worker.set_hicache_consumer(
model_worker_batch.hicache_consumer_index
)
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
logits_output, next_token_ids, can_run_cuda_graph = ( logits_output, next_token_ids, can_run_cuda_graph = (
self.tp_worker.forward_batch_generation(model_worker_batch) self.tp_worker.forward_batch_generation(model_worker_batch)
......
...@@ -147,6 +147,15 @@ class TpModelWorker: ...@@ -147,6 +147,15 @@ class TpModelWorker:
# A reference make this class has the same member as TpModelWorkerClient # A reference make this class has the same member as TpModelWorkerClient
self.worker = self self.worker = self
self.hicache_layer_transfer_counter = None
def register_hicache_layer_transfer_counter(self, counter):
self.hicache_layer_transfer_counter = counter
def set_hicache_consumer(self, consumer_index):
if self.hicache_layer_transfer_counter is not None:
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
def get_worker_info(self): def get_worker_info(self):
return ( return (
self.max_total_num_tokens, self.max_total_num_tokens,
......
...@@ -88,6 +88,15 @@ class TpModelWorkerClient: ...@@ -88,6 +88,15 @@ class TpModelWorkerClient:
if self.device == "cpu": if self.device == "cpu":
self.scheduler_stream.synchronize = lambda: None # No-op for CPU self.scheduler_stream.synchronize = lambda: None # No-op for CPU
self.hicache_layer_transfer_counter = None
def register_hicache_layer_transfer_counter(self, counter):
self.hicache_layer_transfer_counter = counter
def set_hicache_consumer(self, consumer_index):
if self.hicache_layer_transfer_counter is not None:
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
def get_worker_info(self): def get_worker_info(self):
return self.worker.get_worker_info() return self.worker.get_worker_info()
...@@ -146,6 +155,8 @@ class TpModelWorkerClient: ...@@ -146,6 +155,8 @@ class TpModelWorkerClient:
input_ids = model_worker_batch.input_ids input_ids = model_worker_batch.input_ids
resolve_future_token_ids(input_ids, self.future_token_ids_map) resolve_future_token_ids(input_ids, self.future_token_ids_map)
# update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
# Run forward # Run forward
logits_output, next_token_ids, can_run_cuda_graph = ( logits_output, next_token_ids, can_run_cuda_graph = (
self.worker.forward_batch_generation( self.worker.forward_batch_generation(
......
...@@ -307,7 +307,9 @@ class HiRadixCache(RadixCache): ...@@ -307,7 +307,9 @@ class HiRadixCache(RadixCache):
return last_node, prefix_indices return last_node, prefix_indices
def ready_to_load_cache(self): def ready_to_load_cache(self):
producer_index = self.cache_controller.layer_done_counter.next_producer()
self.load_cache_event.set() self.load_cache_event.set()
return producer_index
def match_prefix(self, key: List[int], include_evicted=False, **kwargs): def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
...@@ -372,6 +374,7 @@ class HiRadixCache(RadixCache): ...@@ -372,6 +374,7 @@ class HiRadixCache(RadixCache):
new_node.lock_ref = child.lock_ref new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len] new_node.key = child.key[:split_len]
new_node.loading = child.loading new_node.loading = child.loading
new_node.hit_count = child.hit_count
# split value and host value if exists # split value and host value if exists
if child.evicted: if child.evicted:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment