Unverified Commit 7ee6c259 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify the event loop and expose `--num-continuous-decode-steps` as an argument (#1652)

parent 9610fcd4
......@@ -19,7 +19,6 @@ class GlobalConfig:
self.new_token_ratio_decay = 0.001
# Runtime constants: others
self.num_continue_decode_steps = 10
self.retract_decode_steps = 20
self.flashinfer_workspace_size = os.environ.get(
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
......
......@@ -831,6 +831,22 @@ class ScheduleBatch:
sampling_info=self.sampling_info,
)
def copy(self):
return ScheduleBatch(
reqs=self.reqs,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool,
tree_cache=self.tree_cache,
forward_mode=self.forward_mode,
output_token_ids=self.output_token_ids,
)
def __str__(self):
return (
f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
f"#req={(len(self.reqs))})"
)
@dataclass
class ModelWorkerBatch:
......
......@@ -20,6 +20,7 @@ import logging
import os
import time
import warnings
from types import SimpleNamespace
from typing import List, Optional, Union
import torch
......@@ -106,7 +107,8 @@ class Scheduler:
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
else:
self.recv_from_tokenizer = self.send_to_detokenizer = None
self.recv_from_tokenizer = None
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
# Init tokenizer
self.model_config = ModelConfig(
......@@ -190,7 +192,6 @@ class Scheduler:
# Init running status
self.waiting_queue: List[Req] = []
self.running_batch: ScheduleBatch = None
self.out_pyobjs = []
self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval
self.num_generated_tokens = 0
......@@ -247,13 +248,30 @@ class Scheduler:
@torch.inference_mode()
def event_loop(self):
self.last_batch = None
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
self.run_step()
batch = self.get_next_batch_to_run()
if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
self.send_results()
# Decode multiple steps to reduce the overhead
if batch.forward_mode.is_decode():
for _ in range(self.server_args.num_continuous_decode_steps - 1):
if not self.running_batch:
break
self.update_running_batch()
if not self.running_batch:
break
result = self.run_batch(batch)
self.process_batch_result(batch, result)
self.last_batch = batch
def recv_requests(self):
if self.tp_rank == 0:
......@@ -286,7 +304,9 @@ class Scheduler:
self.abort_request(recv_req)
elif isinstance(recv_req, UpdateWeightReqInput):
success, message = self.update_weights(recv_req)
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
self.send_to_detokenizer.send_pyobj(
UpdateWeightReqOutput(success, message)
)
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
......@@ -384,12 +404,6 @@ class Scheduler:
self.waiting_queue.append(req)
def send_results(self):
if self.tp_rank == 0:
for obj in self.out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
self.out_pyobjs = []
def print_decode_stats(self):
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
......@@ -427,44 +441,32 @@ class Scheduler:
)
exit(1) if crash_on_warning else None
def run_step(self):
def get_next_batch_to_run(self):
# Merge prefill to the running batch
if (
self.last_batch
and not self.last_batch.forward_mode.is_decode()
and not self.last_batch.is_empty()
):
if self.running_batch is None:
self.running_batch = self.last_batch
else:
self.running_batch.merge_batch(self.last_batch)
# Prefill first
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
# Run a new prefill batch
# replace run_batch with the uncommented line to use pytorch profiler
# result = pytorch_profile(
# "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
# )
result = self.run_batch(new_batch)
self.process_batch_result(new_batch, result)
else:
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(global_config.num_continue_decode_steps):
batch = self.get_new_batch_decode()
if batch:
# replace run_batch with the uncommented line to use pytorch profiler
# result = pytorch_profile(
# "profile_decode_step",
# self.run_batch,
# batch,
# data_size=len(batch.reqs),
# )
result = self.run_batch(batch)
self.process_batch_result(batch, result)
if self.running_batch.is_empty():
self.running_batch = None
return new_batch
if self.running_batch is None:
break
if self.out_pyobjs and self.running_batch.has_stream:
break
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
# Run decode
if self.running_batch is not None:
self.update_running_batch()
if not self.running_batch:
return None
return self.running_batch
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Handle the cases where prefill is not allowed
......@@ -607,7 +609,7 @@ class Scheduler:
return new_batch
def get_new_batch_decode(self) -> Optional[ScheduleBatch]:
def update_running_batch(self):
batch = self.running_batch
# Check if decode out of memory
......@@ -636,11 +638,11 @@ class Scheduler:
if jump_forward_reqs:
self.batch_is_full = False
if batch.is_empty():
return None
self.running_batch = None
return
# Update batch tensors
batch.prepare_for_decode()
return batch
def run_batch(self, batch: ScheduleBatch):
if self.is_generation:
......@@ -657,16 +659,19 @@ class Scheduler:
)
else:
next_token_ids = torch.full((batch.batch_size(),), 0)
return logits_output, next_token_ids
ret = logits_output, next_token_ids
else: # embedding or reward model
assert batch.extend_num_tokens != 0
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
return embeddings
ret = embeddings
return ret
def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
if batch.is_empty():
self.running_batch = None
else:
self.process_batch_result_prefill(batch, result)
......@@ -728,7 +733,7 @@ class Scheduler:
)
else: # embedding or reward model
assert batch.extend_num_tokens != 0
embeddings = result
embeddings = result.tolist()
# Check finish conditions
for i, req in enumerate(batch.reqs):
......@@ -750,12 +755,6 @@ class Scheduler:
self.handle_finished_requests(batch)
if not batch.is_empty():
if self.running_batch is None:
self.running_batch = batch
else:
self.running_batch.merge_batch(batch)
def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result
if batch.sampling_info.penalizer_orchestrator:
......@@ -951,7 +950,7 @@ class Scheduler:
# Send to detokenizer
if output_rids:
if self.is_generation:
self.out_pyobjs.append(
self.send_to_detokenizer.send_pyobj(
BatchTokenIDOut(
output_rids,
output_vids,
......@@ -965,7 +964,7 @@ class Scheduler:
)
)
else: # embedding or reward model
self.out_pyobjs.append(
self.send_to_detokenizer.send_pyobj(
BatchEmbeddingOut(
output_rids,
output_embeddings,
......
......@@ -118,7 +118,7 @@ class TpModelWorker:
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings.tolist()
embeddings = logits_output.embeddings
return embeddings
def update_weights(self, recv_req: UpdateWeightReqInput):
......
......@@ -111,6 +111,7 @@ class ServerArgs:
torchao_config: str = ""
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
num_continuous_decode_steps: int = 1
def __post_init__(self):
# Set missing default values
......@@ -559,6 +560,14 @@ class ServerArgs:
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels.",
)
parser.add_argument(
"--num-continuous-decode-steps",
type=int,
default=ServerArgs.num_continuous_decode_steps,
help="Run multiple continuous decoding steps to reduce scheduling overhead. "
"This can potentially increase throughput but may also increase time-to-first-token latency. "
"The default value is 1, meaning only run one decoding step at a time.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment