"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fc7a867ae5cb6dfc6a7394d9571f5a0cd3bd05da"
Unverified Commit da1ffed6 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add output_ids into ScheduleBatch (#1659)

parent 48761171
...@@ -232,17 +232,18 @@ def extend(reqs, model_runner): ...@@ -232,17 +232,18 @@ def extend(reqs, model_runner):
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch) logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist() next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits, batch return next_token_ids, logits_output.next_token_logits, batch
@torch.inference_mode() @torch.inference_mode()
def decode(input_token_ids, batch, model_runner): def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids) batch.output_ids = input_token_ids
batch.prepare_for_decode()
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch) logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist() next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits return next_token_ids, logits_output.next_token_logits
...@@ -252,6 +253,7 @@ def correctness_test( ...@@ -252,6 +253,7 @@ def correctness_test(
bench_args, bench_args,
tp_rank, tp_rank,
): ):
configure_logger(server_args, prefix=f" TP{tp_rank}")
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model # Load the model
...@@ -279,8 +281,9 @@ def correctness_test( ...@@ -279,8 +281,9 @@ def correctness_test(
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))] output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
for _ in range(bench_args.output_len[0] - 1): for _ in range(bench_args.output_len[0] - 1):
next_token_ids, _ = decode(next_token_ids, batch, model_runner) next_token_ids, _ = decode(next_token_ids, batch, model_runner)
next_token_ids_list = next_token_ids.tolist()
for i in range(len(reqs)): for i in range(len(reqs)):
output_ids[i].append(next_token_ids[i]) output_ids[i].append(next_token_ids_list[i])
# Print # Print
for i in range(len(reqs)): for i in range(len(reqs)):
......
...@@ -410,6 +410,8 @@ class ScheduleBatch: ...@@ -410,6 +410,8 @@ class ScheduleBatch:
seq_lens: torch.Tensor = None seq_lens: torch.Tensor = None
out_cache_loc: torch.Tensor = None out_cache_loc: torch.Tensor = None
output_ids: torch.Tensor = None
# For processing logprobs # For processing logprobs
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None top_logprobs_nums: Optional[List[int]] = None
...@@ -720,19 +722,12 @@ class ScheduleBatch: ...@@ -720,19 +722,12 @@ class ScheduleBatch:
return jump_forward_reqs return jump_forward_reqs
def prepare_for_decode(self, input_ids=None): def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE self.forward_mode = ForwardMode.DECODE
if input_ids is None: self.input_ids = self.output_ids
input_ids = [
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
for r in self.reqs
]
self.input_ids = torch.tensor(
input_ids, dtype=torch.int32, device=self.seq_lens.device
)
self.seq_lens.add_(1) self.seq_lens.add_(1)
self.output_ids = None
# Alloc mem # Alloc mem
bs = len(self.reqs) bs = len(self.reqs)
...@@ -759,6 +754,7 @@ class ScheduleBatch: ...@@ -759,6 +754,7 @@ class ScheduleBatch:
self.req_pool_indices = self.req_pool_indices[new_indices] self.req_pool_indices = self.req_pool_indices[new_indices]
self.seq_lens = self.seq_lens[new_indices] self.seq_lens = self.seq_lens[new_indices]
self.out_cache_loc = None self.out_cache_loc = None
self.output_ids = self.output_ids[new_indices]
self.return_logprob = any(req.return_logprob for req in self.reqs) self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob: if self.return_logprob:
self.top_logprobs_nums = [ self.top_logprobs_nums = [
...@@ -783,6 +779,8 @@ class ScheduleBatch: ...@@ -783,6 +779,8 @@ class ScheduleBatch:
) )
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
self.out_cache_loc = None self.out_cache_loc = None
if self.output_ids is not None:
self.output_ids = torch.concat([self.output_ids, other.output_ids])
if self.return_logprob and other.return_logprob: if self.return_logprob and other.return_logprob:
self.top_logprobs_nums.extend(other.top_logprobs_nums) self.top_logprobs_nums.extend(other.top_logprobs_nums)
elif self.return_logprob: elif self.return_logprob:
...@@ -838,7 +836,9 @@ class ScheduleBatch: ...@@ -838,7 +836,9 @@ class ScheduleBatch:
token_to_kv_pool=self.token_to_kv_pool, token_to_kv_pool=self.token_to_kv_pool,
tree_cache=self.tree_cache, tree_cache=self.tree_cache,
forward_mode=self.forward_mode, forward_mode=self.forward_mode,
output_token_ids=self.output_token_ids, output_ids=self.output_ids,
sampling_info=self.sampling_info,
decoding_reqs=self.decoding_reqs,
) )
def __str__(self): def __str__(self):
......
...@@ -247,7 +247,7 @@ class Scheduler: ...@@ -247,7 +247,7 @@ class Scheduler:
) )
@torch.inference_mode() @torch.inference_mode()
def event_loop(self): def event_loop_normal(self):
self.last_batch = None self.last_batch = None
while True: while True:
...@@ -411,9 +411,10 @@ class Scheduler: ...@@ -411,9 +411,10 @@ class Scheduler:
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic) throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.last_stats_tic = time.time() self.last_stats_tic = time.time()
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
logger.info( logger.info(
f"Decode batch. " f"Decode batch. "
f"#running-req: {len(self.running_batch.reqs)}, " f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, " f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, " f"gen throughput (token/s): {throughput:.2f}, "
...@@ -659,6 +660,7 @@ class Scheduler: ...@@ -659,6 +660,7 @@ class Scheduler:
) )
else: else:
next_token_ids = torch.full((batch.batch_size(),), 0) next_token_ids = torch.full((batch.batch_size(),), 0)
batch.output_ids = next_token_ids
ret = logits_output, next_token_ids ret = logits_output, next_token_ids
else: # embedding or reward model else: # embedding or reward model
assert batch.extend_num_tokens != 0 assert batch.extend_num_tokens != 0
...@@ -753,7 +755,7 @@ class Scheduler: ...@@ -753,7 +755,7 @@ class Scheduler:
# Inflight request would get a new req idx # Inflight request would get a new req idx
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
self.handle_finished_requests(batch) self.stream_output(batch)
def process_batch_result_decode(self, batch: ScheduleBatch, result): def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result logits_output, next_token_ids = result
...@@ -793,7 +795,7 @@ class Scheduler: ...@@ -793,7 +795,7 @@ class Scheduler:
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
self.handle_finished_requests(batch) self.stream_output(batch)
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
...@@ -872,7 +874,7 @@ class Scheduler: ...@@ -872,7 +874,7 @@ class Scheduler:
return num_input_logprobs return num_input_logprobs
def handle_finished_requests(self, batch: ScheduleBatch): def stream_output(self, batch: ScheduleBatch):
output_rids = [] output_rids = []
output_meta_info = [] output_meta_info = []
output_finished_reason: List[BaseFinishReason] = [] output_finished_reason: List[BaseFinishReason] = []
...@@ -949,6 +951,9 @@ class Scheduler: ...@@ -949,6 +951,9 @@ class Scheduler:
} }
output_meta_info.append(meta_info) output_meta_info.append(meta_info)
# Remove finished reqs: update batch tensors
batch.filter_batch(unfinished_indices)
# Send to detokenizer # Send to detokenizer
if output_rids: if output_rids:
if self.is_generation: if self.is_generation:
...@@ -976,9 +981,6 @@ class Scheduler: ...@@ -976,9 +981,6 @@ class Scheduler:
) )
) )
# Remove finished reqs: update batch tensors
batch.filter_batch(unfinished_indices)
def flush_cache(self): def flush_cache(self):
if len(self.waiting_queue) == 0 and ( if len(self.waiting_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0 self.running_batch is None or len(self.running_batch.reqs) == 0
...@@ -1060,7 +1062,7 @@ def run_scheduler_process( ...@@ -1060,7 +1062,7 @@ def run_scheduler_process(
try: try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank) scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
pipe_writer.send("ready") pipe_writer.send("ready")
scheduler.event_loop() scheduler.event_loop_normal()
except Exception: except Exception:
msg = get_exception_traceback() msg = get_exception_traceback()
logger.error(msg) logger.error(msg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment