Unverified Commit da1ffed6 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add output_ids into ScheduleBatch (#1659)

parent 48761171
......@@ -232,17 +232,18 @@ def extend(reqs, model_runner):
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits, batch
@torch.inference_mode()
def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids)
batch.output_ids = input_token_ids
batch.prepare_for_decode()
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits
......@@ -252,6 +253,7 @@ def correctness_test(
bench_args,
tp_rank,
):
configure_logger(server_args, prefix=f" TP{tp_rank}")
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model
......@@ -279,8 +281,9 @@ def correctness_test(
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
for _ in range(bench_args.output_len[0] - 1):
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
next_token_ids_list = next_token_ids.tolist()
for i in range(len(reqs)):
output_ids[i].append(next_token_ids[i])
output_ids[i].append(next_token_ids_list[i])
# Print
for i in range(len(reqs)):
......
......@@ -410,6 +410,8 @@ class ScheduleBatch:
seq_lens: torch.Tensor = None
out_cache_loc: torch.Tensor = None
output_ids: torch.Tensor = None
# For processing logprobs
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
......@@ -720,19 +722,12 @@ class ScheduleBatch:
return jump_forward_reqs
def prepare_for_decode(self, input_ids=None):
def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE
if input_ids is None:
input_ids = [
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
for r in self.reqs
]
self.input_ids = torch.tensor(
input_ids, dtype=torch.int32, device=self.seq_lens.device
)
self.input_ids = self.output_ids
self.seq_lens.add_(1)
self.output_ids = None
# Alloc mem
bs = len(self.reqs)
......@@ -759,6 +754,7 @@ class ScheduleBatch:
self.req_pool_indices = self.req_pool_indices[new_indices]
self.seq_lens = self.seq_lens[new_indices]
self.out_cache_loc = None
self.output_ids = self.output_ids[new_indices]
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob:
self.top_logprobs_nums = [
......@@ -783,6 +779,8 @@ class ScheduleBatch:
)
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
self.out_cache_loc = None
if self.output_ids is not None:
self.output_ids = torch.concat([self.output_ids, other.output_ids])
if self.return_logprob and other.return_logprob:
self.top_logprobs_nums.extend(other.top_logprobs_nums)
elif self.return_logprob:
......@@ -838,7 +836,9 @@ class ScheduleBatch:
token_to_kv_pool=self.token_to_kv_pool,
tree_cache=self.tree_cache,
forward_mode=self.forward_mode,
output_token_ids=self.output_token_ids,
output_ids=self.output_ids,
sampling_info=self.sampling_info,
decoding_reqs=self.decoding_reqs,
)
def __str__(self):
......
......@@ -247,7 +247,7 @@ class Scheduler:
)
@torch.inference_mode()
def event_loop(self):
def event_loop_normal(self):
self.last_batch = None
while True:
......@@ -411,9 +411,10 @@ class Scheduler:
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
logger.info(
f"Decode batch. "
f"#running-req: {len(self.running_batch.reqs)}, "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
......@@ -659,6 +660,7 @@ class Scheduler:
)
else:
next_token_ids = torch.full((batch.batch_size(),), 0)
batch.output_ids = next_token_ids
ret = logits_output, next_token_ids
else: # embedding or reward model
assert batch.extend_num_tokens != 0
......@@ -753,7 +755,7 @@ class Scheduler:
# Inflight request would get a new req idx
self.req_to_token_pool.free(req.req_pool_idx)
self.handle_finished_requests(batch)
self.stream_output(batch)
def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result
......@@ -793,7 +795,7 @@ class Scheduler:
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
self.handle_finished_requests(batch)
self.stream_output(batch)
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
......@@ -872,7 +874,7 @@ class Scheduler:
return num_input_logprobs
def handle_finished_requests(self, batch: ScheduleBatch):
def stream_output(self, batch: ScheduleBatch):
output_rids = []
output_meta_info = []
output_finished_reason: List[BaseFinishReason] = []
......@@ -949,6 +951,9 @@ class Scheduler:
}
output_meta_info.append(meta_info)
# Remove finished reqs: update batch tensors
batch.filter_batch(unfinished_indices)
# Send to detokenizer
if output_rids:
if self.is_generation:
......@@ -976,9 +981,6 @@ class Scheduler:
)
)
# Remove finished reqs: update batch tensors
batch.filter_batch(unfinished_indices)
def flush_cache(self):
if len(self.waiting_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0
......@@ -1060,7 +1062,7 @@ def run_scheduler_process(
try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
pipe_writer.send("ready")
scheduler.event_loop()
scheduler.event_loop_normal()
except Exception:
msg = get_exception_traceback()
logger.error(msg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment