Unverified Commit 769bf11c authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix the race condition in overlap mode (#1712)

parent 3db43d1b
...@@ -405,9 +405,9 @@ class ScheduleBatch: ...@@ -405,9 +405,9 @@ class ScheduleBatch:
# Request, memory pool, and cache # Request, memory pool, and cache
reqs: List[Req] reqs: List[Req]
req_to_token_pool: ReqToTokenPool req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool: BaseTokenToKVPool token_to_kv_pool: BaseTokenToKVPool = None
tree_cache: BasePrefixCache tree_cache: BasePrefixCache = None
forward_mode: ForwardMode = None forward_mode: ForwardMode = None
sampling_info: SamplingBatchInfo = None sampling_info: SamplingBatchInfo = None
...@@ -874,12 +874,9 @@ class ScheduleBatch: ...@@ -874,12 +874,9 @@ class ScheduleBatch:
def copy(self): def copy(self):
return ScheduleBatch( return ScheduleBatch(
reqs=self.reqs, reqs=self.reqs,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool,
tree_cache=self.tree_cache,
forward_mode=self.forward_mode, forward_mode=self.forward_mode,
output_ids=self.output_ids, out_cache_loc=self.out_cache_loc,
sampling_info=self.sampling_info, return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs, decoding_reqs=self.decoding_reqs,
) )
...@@ -929,7 +926,7 @@ class ModelWorkerBatch: ...@@ -929,7 +926,7 @@ class ModelWorkerBatch:
forward_mode=self.forward_mode, forward_mode=self.forward_mode,
input_ids=self.input_ids.clone(), input_ids=self.input_ids.clone(),
req_pool_indices=self.req_pool_indices, req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens, seq_lens=self.seq_lens.clone(),
out_cache_loc=self.out_cache_loc, out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob, return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums, top_logprobs_nums=self.top_logprobs_nums,
......
...@@ -261,12 +261,7 @@ class Scheduler: ...@@ -261,12 +261,7 @@ class Scheduler:
self.resolve_next_token_ids = ( self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid) lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
) )
self.cache_finished_req = self.tree_cache.cache_finished_req
def cache_finished_req(req):
free_delta = int(self.running_batch and req in self.cur_batch.reqs)
self.tree_cache.cache_finished_req(req, free_delta=free_delta)
self.cache_finished_req = cache_finished_req
else: else:
self.forward_batch_generation = self.tp_worker.forward_batch_generation self.forward_batch_generation = self.tp_worker.forward_batch_generation
self.resolve_next_token_ids = lambda bid, x: x.tolist() self.resolve_next_token_ids = lambda bid, x: x.tolist()
...@@ -798,7 +793,6 @@ class Scheduler: ...@@ -798,7 +793,6 @@ class Scheduler:
i, req, logprob_pt, next_token_ids, logits_output i, req, logprob_pt, next_token_ids, logits_output
) )
else: # embedding or reward model else: # embedding or reward model
assert batch.extend_num_tokens != 0
embeddings, bid = result embeddings, bid = result
embeddings = embeddings.tolist() embeddings = embeddings.tolist()
...@@ -838,6 +832,7 @@ class Scheduler: ...@@ -838,6 +832,7 @@ class Scheduler:
# Check finish condition # Check finish condition
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if self.server_args.enable_overlap_schedule and req.finished(): if self.server_args.enable_overlap_schedule and req.finished():
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
continue continue
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
......
...@@ -149,14 +149,12 @@ class TpModelWorker: ...@@ -149,14 +149,12 @@ class TpModelWorker:
) )
# Resolve future tokens in the input # Resolve future tokens in the input
# logger.info(f"raw input {model_worker_batch.input_ids=}")
tic2 = time.time() tic2 = time.time()
resolved_input_ids = model_worker_batch.input_ids resolved_input_ids = model_worker_batch.input_ids
future_mask = resolved_input_ids < 0 future_mask = resolved_input_ids < 0
resolved_input_ids[future_mask] = self.future_token_ids_map[ resolved_input_ids[future_mask] = self.future_token_ids_map[
-resolved_input_ids[future_mask] -resolved_input_ids[future_mask]
] ]
# logger.info(f"resolved input {model_worker_batch.input_ids=}")
# Run forward # Run forward
logits_output, next_token_ids = self.forward_batch_generation( logits_output, next_token_ids = self.forward_batch_generation(
...@@ -215,12 +213,13 @@ class TpModelWorker: ...@@ -215,12 +213,13 @@ class TpModelWorker:
self.future_logits_output_ct += 1 self.future_logits_output_ct += 1
bs = len(model_worker_batch.seq_lens) bs = len(model_worker_batch.seq_lens)
future_next_token_ids = -torch.arange( with torch.cuda.stream(self.forward_stream):
self.future_token_ids_ct + 1, future_next_token_ids = -torch.arange(
self.future_token_ids_ct + 1 + bs, self.future_token_ids_ct + 1,
dtype=torch.int32, self.future_token_ids_ct + 1 + bs,
device=self.device, dtype=torch.int32,
) device=self.device,
)
self.future_token_ids_ct = ( self.future_token_ids_ct = (
self.future_token_ids_ct + bs self.future_token_ids_ct + bs
) % self.future_token_ids_limit ) % self.future_token_ids_limit
......
...@@ -38,16 +38,14 @@ class ChunkCache(BasePrefixCache): ...@@ -38,16 +38,14 @@ class ChunkCache(BasePrefixCache):
max_prefix_len = len(key) max_prefix_len = len(key)
return entry.value[:max_prefix_len], entry return entry.value[:max_prefix_len], entry
def cache_finished_req( def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0
):
if token_ids is None: if token_ids is None:
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1 token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
else: else:
token_id_len = len(token_ids) token_id_len = len(token_ids)
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : token_id_len + free_delta req.req_pool_idx, :token_id_len
] ]
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool.free(kv_indices) self.token_to_kv_pool.free(kv_indices)
......
...@@ -97,9 +97,7 @@ class RadixCache(BasePrefixCache): ...@@ -97,9 +97,7 @@ class RadixCache(BasePrefixCache):
value = [x for x in key] value = [x for x in key]
return self._insert_helper(self.root_node, key, value) return self._insert_helper(self.root_node, key, value)
def cache_finished_req( def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0
):
"""Cache request when it finishes.""" """Cache request when it finishes."""
if self.disable: if self.disable:
if token_ids is None: if token_ids is None:
...@@ -108,7 +106,7 @@ class RadixCache(BasePrefixCache): ...@@ -108,7 +106,7 @@ class RadixCache(BasePrefixCache):
token_ids_len = len(token_ids) token_ids_len = len(token_ids)
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : token_ids_len + free_delta req.req_pool_idx, :token_ids_len
] ]
self.token_to_kv_pool.free(kv_indices) self.token_to_kv_pool.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
...@@ -123,12 +121,6 @@ class RadixCache(BasePrefixCache): ...@@ -123,12 +121,6 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone()) new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len]) self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
if free_delta:
self.token_to_kv_pool.free(
self.req_to_token_pool.req_to_token[
req.req_pool_idx, len(token_ids) : len(token_ids) + 1
]
)
# Remove req slot release the cache lock # Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
......
...@@ -542,6 +542,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid): ...@@ -542,6 +542,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
kill_child_process(pid, including_parent=False) kill_child_process(pid, including_parent=False)
return return
# logger.info(f"{res.json()=}")
logger.info("The server is fired up and ready to roll!") logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send("ready") pipe_finish_writer.send("ready")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment