Unverified Commit 7ee6c259 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify the event loop and expose `--num-continuous-decode-steps` as an argument (#1652)

parent 9610fcd4
...@@ -19,7 +19,6 @@ class GlobalConfig: ...@@ -19,7 +19,6 @@ class GlobalConfig:
self.new_token_ratio_decay = 0.001 self.new_token_ratio_decay = 0.001
# Runtime constants: others # Runtime constants: others
self.num_continue_decode_steps = 10
self.retract_decode_steps = 20 self.retract_decode_steps = 20
self.flashinfer_workspace_size = os.environ.get( self.flashinfer_workspace_size = os.environ.get(
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024 "FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
......
...@@ -831,6 +831,22 @@ class ScheduleBatch: ...@@ -831,6 +831,22 @@ class ScheduleBatch:
sampling_info=self.sampling_info, sampling_info=self.sampling_info,
) )
def copy(self):
return ScheduleBatch(
reqs=self.reqs,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool,
tree_cache=self.tree_cache,
forward_mode=self.forward_mode,
output_token_ids=self.output_token_ids,
)
def __str__(self):
return (
f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
f"#req={(len(self.reqs))})"
)
@dataclass @dataclass
class ModelWorkerBatch: class ModelWorkerBatch:
......
...@@ -20,6 +20,7 @@ import logging ...@@ -20,6 +20,7 @@ import logging
import os import os
import time import time
import warnings import warnings
from types import SimpleNamespace
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
...@@ -106,7 +107,8 @@ class Scheduler: ...@@ -106,7 +107,8 @@ class Scheduler:
self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}") self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
else: else:
self.recv_from_tokenizer = self.send_to_detokenizer = None self.recv_from_tokenizer = None
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
# Init tokenizer # Init tokenizer
self.model_config = ModelConfig( self.model_config = ModelConfig(
...@@ -190,7 +192,6 @@ class Scheduler: ...@@ -190,7 +192,6 @@ class Scheduler:
# Init running status # Init running status
self.waiting_queue: List[Req] = [] self.waiting_queue: List[Req] = []
self.running_batch: ScheduleBatch = None self.running_batch: ScheduleBatch = None
self.out_pyobjs = []
self.decode_forward_ct = 0 self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval self.stream_interval = server_args.stream_interval
self.num_generated_tokens = 0 self.num_generated_tokens = 0
...@@ -247,13 +248,30 @@ class Scheduler: ...@@ -247,13 +248,30 @@ class Scheduler:
@torch.inference_mode() @torch.inference_mode()
def event_loop(self): def event_loop(self):
self.last_batch = None
while True: while True:
recv_reqs = self.recv_requests() recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
self.run_step() batch = self.get_next_batch_to_run()
if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
# Decode multiple steps to reduce the overhead
if batch.forward_mode.is_decode():
for _ in range(self.server_args.num_continuous_decode_steps - 1):
if not self.running_batch:
break
self.update_running_batch()
if not self.running_batch:
break
result = self.run_batch(batch)
self.process_batch_result(batch, result)
self.send_results() self.last_batch = batch
def recv_requests(self): def recv_requests(self):
if self.tp_rank == 0: if self.tp_rank == 0:
...@@ -286,7 +304,9 @@ class Scheduler: ...@@ -286,7 +304,9 @@ class Scheduler:
self.abort_request(recv_req) self.abort_request(recv_req)
elif isinstance(recv_req, UpdateWeightReqInput): elif isinstance(recv_req, UpdateWeightReqInput):
success, message = self.update_weights(recv_req) success, message = self.update_weights(recv_req)
self.out_pyobjs.append(UpdateWeightReqOutput(success, message)) self.send_to_detokenizer.send_pyobj(
UpdateWeightReqOutput(success, message)
)
elif isinstance(recv_req, ProfileReq): elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE: if recv_req == ProfileReq.START_PROFILE:
self.start_profile() self.start_profile()
...@@ -384,12 +404,6 @@ class Scheduler: ...@@ -384,12 +404,6 @@ class Scheduler:
self.waiting_queue.append(req) self.waiting_queue.append(req)
def send_results(self):
if self.tp_rank == 0:
for obj in self.out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
self.out_pyobjs = []
def print_decode_stats(self): def print_decode_stats(self):
num_used = self.max_total_num_tokens - ( num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
...@@ -427,41 +441,29 @@ class Scheduler: ...@@ -427,41 +441,29 @@ class Scheduler:
) )
exit(1) if crash_on_warning else None exit(1) if crash_on_warning else None
def run_step(self): def get_next_batch_to_run(self):
new_batch = self.get_new_batch_prefill() # Merge prefill to the running batch
if new_batch is not None: if (
# Run a new prefill batch self.last_batch
# replace run_batch with the uncommented line to use pytorch profiler and not self.last_batch.forward_mode.is_decode()
# result = pytorch_profile( and not self.last_batch.is_empty()
# "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs) ):
# ) if self.running_batch is None:
result = self.run_batch(new_batch) self.running_batch = self.last_batch
self.process_batch_result(new_batch, result)
else: else:
if self.running_batch is not None: self.running_batch.merge_batch(self.last_batch)
# Run a few decode batches continuously for reducing overhead
for _ in range(global_config.num_continue_decode_steps):
batch = self.get_new_batch_decode()
if batch: # Prefill first
# replace run_batch with the uncommented line to use pytorch profiler new_batch = self.get_new_batch_prefill()
# result = pytorch_profile( if new_batch is not None:
# "profile_decode_step", return new_batch
# self.run_batch,
# batch,
# data_size=len(batch.reqs),
# )
result = self.run_batch(batch)
self.process_batch_result(batch, result)
if self.running_batch.is_empty():
self.running_batch = None
if self.running_batch is None:
break
if self.out_pyobjs and self.running_batch.has_stream: # Run decode
break if self.running_batch is not None:
self.update_running_batch()
if not self.running_batch:
return None
return self.running_batch
else: else:
self.check_memory() self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio self.new_token_ratio = global_config.init_new_token_ratio
...@@ -607,7 +609,7 @@ class Scheduler: ...@@ -607,7 +609,7 @@ class Scheduler:
return new_batch return new_batch
def get_new_batch_decode(self) -> Optional[ScheduleBatch]: def update_running_batch(self):
batch = self.running_batch batch = self.running_batch
# Check if decode out of memory # Check if decode out of memory
...@@ -636,11 +638,11 @@ class Scheduler: ...@@ -636,11 +638,11 @@ class Scheduler:
if jump_forward_reqs: if jump_forward_reqs:
self.batch_is_full = False self.batch_is_full = False
if batch.is_empty(): if batch.is_empty():
return None self.running_batch = None
return
# Update batch tensors # Update batch tensors
batch.prepare_for_decode() batch.prepare_for_decode()
return batch
def run_batch(self, batch: ScheduleBatch): def run_batch(self, batch: ScheduleBatch):
if self.is_generation: if self.is_generation:
...@@ -657,16 +659,19 @@ class Scheduler: ...@@ -657,16 +659,19 @@ class Scheduler:
) )
else: else:
next_token_ids = torch.full((batch.batch_size(),), 0) next_token_ids = torch.full((batch.batch_size(),), 0)
return logits_output, next_token_ids ret = logits_output, next_token_ids
else: # embedding or reward model else: # embedding or reward model
assert batch.extend_num_tokens != 0 assert batch.extend_num_tokens != 0
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
return embeddings ret = embeddings
return ret
def process_batch_result(self, batch: ScheduleBatch, result): def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result) self.process_batch_result_decode(batch, result)
if batch.is_empty():
self.running_batch = None
else: else:
self.process_batch_result_prefill(batch, result) self.process_batch_result_prefill(batch, result)
...@@ -728,7 +733,7 @@ class Scheduler: ...@@ -728,7 +733,7 @@ class Scheduler:
) )
else: # embedding or reward model else: # embedding or reward model
assert batch.extend_num_tokens != 0 assert batch.extend_num_tokens != 0
embeddings = result embeddings = result.tolist()
# Check finish conditions # Check finish conditions
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
...@@ -750,12 +755,6 @@ class Scheduler: ...@@ -750,12 +755,6 @@ class Scheduler:
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
if not batch.is_empty():
if self.running_batch is None:
self.running_batch = batch
else:
self.running_batch.merge_batch(batch)
def process_batch_result_decode(self, batch: ScheduleBatch, result): def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result logits_output, next_token_ids = result
if batch.sampling_info.penalizer_orchestrator: if batch.sampling_info.penalizer_orchestrator:
...@@ -951,7 +950,7 @@ class Scheduler: ...@@ -951,7 +950,7 @@ class Scheduler:
# Send to detokenizer # Send to detokenizer
if output_rids: if output_rids:
if self.is_generation: if self.is_generation:
self.out_pyobjs.append( self.send_to_detokenizer.send_pyobj(
BatchTokenIDOut( BatchTokenIDOut(
output_rids, output_rids,
output_vids, output_vids,
...@@ -965,7 +964,7 @@ class Scheduler: ...@@ -965,7 +964,7 @@ class Scheduler:
) )
) )
else: # embedding or reward model else: # embedding or reward model
self.out_pyobjs.append( self.send_to_detokenizer.send_pyobj(
BatchEmbeddingOut( BatchEmbeddingOut(
output_rids, output_rids,
output_embeddings, output_embeddings,
......
...@@ -118,7 +118,7 @@ class TpModelWorker: ...@@ -118,7 +118,7 @@ class TpModelWorker:
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings.tolist() embeddings = logits_output.embeddings
return embeddings return embeddings
def update_weights(self, recv_req: UpdateWeightReqInput): def update_weights(self, recv_req: UpdateWeightReqInput):
......
...@@ -111,6 +111,7 @@ class ServerArgs: ...@@ -111,6 +111,7 @@ class ServerArgs:
torchao_config: str = "" torchao_config: str = ""
enable_p2p_check: bool = False enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False triton_attention_reduce_in_fp32: bool = False
num_continuous_decode_steps: int = 1
def __post_init__(self): def __post_init__(self):
# Set missing default values # Set missing default values
...@@ -559,6 +560,14 @@ class ServerArgs: ...@@ -559,6 +560,14 @@ class ServerArgs:
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels.", "This only affects Triton attention kernels.",
) )
parser.add_argument(
"--num-continuous-decode-steps",
type=int,
default=ServerArgs.num_continuous_decode_steps,
help="Run multiple continuous decoding steps to reduce scheduling overhead. "
"This can potentially increase throughput but may also increase time-to-first-token latency. "
"The default value is 1, meaning only run one decoding step at a time.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment