Unverified Commit 58d1082e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up event loop (#1586)

parent 4d086719
......@@ -228,20 +228,14 @@ class Scheduler:
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.batch_is_full = False
@torch.inference_mode()
def event_loop(self):
while True:
# Receive requests
if self.tp_rank == 0:
recv_reqs = self.recv_requests_from_zmq()
else:
recv_reqs = None
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
# Process requests
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
self.process_requests(recv_reqs)
# Forward
self.forward_step()
# Run one step
self.run_step()
# Send results
if self.tp_rank == 0:
......@@ -249,19 +243,23 @@ class Scheduler:
self.send_to_detokenizer.send_pyobj(obj)
self.out_pyobjs = []
def recv_requests_from_zmq(self):
recv_reqs = []
def recv_requests(self):
if self.tp_rank == 0:
recv_reqs = []
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_req)
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_req)
else:
recv_reqs = None
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
return recv_reqs
def process_requests(self, recv_reqs: List):
def process_input_requests(self, recv_reqs: List):
for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
......@@ -279,83 +277,6 @@ class Scheduler:
else:
raise ValueError(f"Invalid request: {recv_req}")
@torch.inference_mode()
def forward_step(self):
if (
self.batch_is_full or len(self.waiting_queue) == 0
) and self.current_inflight_req is None:
new_batch = None
else:
new_batch = self.get_new_prefill_batch()
if new_batch is not None:
# Run a new prefill batch
self.forward_prefill_batch(new_batch)
if not new_batch.is_empty():
if self.running_batch is None:
self.running_batch = new_batch
else:
self.running_batch.merge_batch(new_batch)
else:
# Run a decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(global_config.num_continue_decode_steps):
self.num_generated_tokens += len(self.running_batch.reqs)
self.forward_decode_batch(self.running_batch)
# Print stats
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
self.print_decode_stats()
if self.running_batch.is_empty():
self.running_batch = None
break
if self.out_pyobjs and self.running_batch.has_stream:
break
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
def print_decode_stats(self):
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
f"Decode batch. "
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
def check_memory(self):
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
if available_size != self.max_total_num_tokens:
warnings.warn(
"Warning: "
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
"KV cache pool leak detected!"
)
exit(1) if crash_on_warning else None
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
warnings.warn(
"Warning: "
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
f"total slots={self.req_to_token_pool.size}\n"
"Memory pool leak detected!"
)
exit(1) if crash_on_warning else None
def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
......@@ -445,7 +366,88 @@ class Scheduler:
self.waiting_queue.append(req)
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
def run_step(self):
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
# Run a new prefill batch
result = self.run_batch(new_batch)
self.process_batch_result(new_batch, result)
if not new_batch.is_empty():
if self.running_batch is None:
self.running_batch = new_batch
else:
self.running_batch.merge_batch(new_batch)
else:
# Run a decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(global_config.num_continue_decode_steps):
batch = self.get_new_batch_decode()
if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
# Print stats
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
self.print_decode_stats()
if self.running_batch.is_empty():
self.running_batch = None
break
if self.out_pyobjs and self.running_batch.has_stream:
break
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
def print_decode_stats(self):
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
f"Decode batch. "
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
def check_memory(self):
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
if available_size != self.max_total_num_tokens:
warnings.warn(
"Warning: "
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
"KV cache pool leak detected!"
)
exit(1) if crash_on_warning else None
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
warnings.warn(
"Warning: "
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
f"total slots={self.req_to_token_pool.size}\n"
"Memory pool leak detected!"
)
exit(1) if crash_on_warning else None
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Handle the cases where prefill is not allowed
if (
self.batch_is_full or len(self.waiting_queue) == 0
) and self.current_inflight_req is None:
return None
running_bs = (
len(self.running_batch.reqs) if self.running_batch is not None else 0
)
......@@ -456,8 +458,8 @@ class Scheduler:
# Get priority queue
prefix_computed = self.policy.calc_priority(self.waiting_queue)
# Prefill policy
num_mixed_running = running_bs if self.is_mixed_chunk else 0
adder = PrefillAdder(
self.tree_cache,
self.running_batch,
......@@ -517,6 +519,8 @@ class Scheduler:
if len(can_run_list) == 0:
return None
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
# Print stats
if self.tp_rank == 0:
if isinstance(self.tree_cache, RadixCache):
......@@ -544,7 +548,7 @@ class Scheduler:
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
)
else:
logger.info(
......@@ -555,41 +559,97 @@ class Scheduler:
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
)
# Return the new batch
# Create a new batch
new_batch = ScheduleBatch.init_new(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
)
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
return new_batch
def forward_prefill_batch(self, batch: ScheduleBatch):
# Build batch tensors
batch.prepare_for_extend(self.model_config.vocab_size)
new_batch.prepare_for_extend(self.model_config.vocab_size)
# Mixed-style chunked prefill
decoding_reqs = []
if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.prepare_for_decode()
batch.mix_with_running(self.running_batch)
new_batch.mix_with_running(self.running_batch)
decoding_reqs = self.running_batch.reqs
self.running_batch = None
new_batch.decoding_reqs = decoding_reqs
return new_batch
def get_new_batch_decode(self) -> Optional[ScheduleBatch]:
batch = self.running_batch
# Check if decode out of memory
if not batch.check_decode_mem():
old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio = batch.retract_decode()
self.new_token_ratio = new_token_ratio
logger.info(
"Decode out of memory happened. "
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
self.waiting_queue.extend(retracted_reqs)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay,
self.min_new_token_ratio,
)
# Check for jump-forward
if not self.disable_regex_jump_forward:
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self.waiting_queue.extend(jump_forward_reqs)
if batch.is_empty():
return None
# Update batch tensors
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
batch.prepare_for_decode()
return batch
def run_batch(self, batch: ScheduleBatch):
if self.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
else:
logits_output = None
if self.tokenizer is not None:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
else:
next_token_ids = [0] * len(batch.reqs)
return logits_output, next_token_ids
else: # embedding or reward model
assert batch.extend_num_tokens != 0
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
return embeddings
def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
else:
self.process_batch_result_prefill(batch, result)
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if self.is_generation:
logits_output, next_token_ids = result
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
if logits_output:
# Move logprobs to cpu
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
......@@ -607,16 +667,7 @@ class Scheduler:
logits_output.normalized_prompt_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
else:
if self.tokenizer is None:
next_token_ids = []
for req in batch.reqs:
next_token_ids.append(
next(iter(req.sampling_params.stop_token_ids))
)
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
next_token_ids = next_token_ids.tolist()
# Check finish conditions
logprob_pt = 0
......@@ -634,7 +685,7 @@ class Scheduler:
if req.finished():
self.tree_cache.cache_finished_req(req)
elif req not in decoding_reqs:
elif req not in batch.decoding_reqs:
# To reduce overhead, only cache prefill reqs
self.tree_cache.cache_unfinished_req(req)
......@@ -646,10 +697,9 @@ class Scheduler:
logprob_pt += self.add_logprob_return_values(
i, req, logprob_pt, next_token_ids, logits_output
)
else:
else: # embedding or reward model
assert batch.extend_num_tokens != 0
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
embeddings = result
# Check finish conditions
for i, req in enumerate(batch.reqs):
......@@ -671,6 +721,45 @@ class Scheduler:
self.handle_finished_requests(batch)
def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
self.num_generated_tokens += len(batch.reqs)
# Move logprobs to cpu
if logits_output.next_token_logprobs is not None:
next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
next_token_ids = next_token_ids.tolist()
# Check finish condition
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id)
req.check_finished()
if req.regex_fsm is not None:
req.regex_fsm_state = req.regex_fsm.get_next_state(
req.regex_fsm_state, next_token_id
)
if req.finished():
self.tree_cache.cache_finished_req(req)
if req.return_logprob:
req.output_token_logprobs.append(
(next_token_logprobs[i], next_token_id)
)
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
self.handle_finished_requests(batch)
def add_logprob_return_values(
self,
i: int,
......@@ -744,80 +833,6 @@ class Scheduler:
return num_input_logprobs
def forward_decode_batch(self, batch: ScheduleBatch):
# Check if decode out of memory
if not batch.check_decode_mem():
old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio = batch.retract_decode()
self.new_token_ratio = new_token_ratio
logger.info(
"Decode out of memory happened. "
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
self.waiting_queue.extend(retracted_reqs)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay,
self.min_new_token_ratio,
)
# Check for jump-forward
if not self.disable_regex_jump_forward:
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self.waiting_queue.extend(jump_forward_reqs)
if batch.is_empty():
return
# Update batch tensors
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
batch.prepare_for_decode()
# Forward and sample the next tokens
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
# Move logprobs to cpu
if logits_output.next_token_logprobs is not None:
next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
next_token_ids = next_token_ids.tolist()
# Check finish condition
has_finished = False
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id)
req.check_finished()
if req.regex_fsm is not None:
req.regex_fsm_state = req.regex_fsm.get_next_state(
req.regex_fsm_state, next_token_id
)
if req.finished():
self.tree_cache.cache_finished_req(req)
has_finished = True
if req.return_logprob:
req.output_token_logprobs.append(
(next_token_logprobs[i], next_token_id)
)
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
self.handle_finished_requests(batch)
def handle_finished_requests(self, batch: ScheduleBatch):
output_rids = []
output_meta_info = []
......@@ -829,7 +844,7 @@ class Scheduler:
output_read_offsets = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
else: # for embedding model
else: # embedding or reward model
output_embeddings = []
unfinished_indices = []
......@@ -886,7 +901,7 @@ class Scheduler:
req.normalized_prompt_logprob,
)
output_meta_info.append(meta_info)
else: # for embedding model
else: # embedding or reward model
output_embeddings.append(req.embedding)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
......@@ -909,7 +924,7 @@ class Scheduler:
output_finished_reason,
)
)
else: # for embedding model
else: # embedding or reward model
self.out_pyobjs.append(
BatchEmbeddingOut(
output_rids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment