"vscode:/vscode.git/clone" did not exist on "24c062aaa19f5626d03d058daf8afffa2dfd49f7"
Unverified Commit 0c1e8796 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Move filter_batch out of stream_output (#1663)

parent 869f1c02
...@@ -659,7 +659,7 @@ class ScheduleBatch: ...@@ -659,7 +659,7 @@ class ScheduleBatch:
def check_for_jump_forward(self, pad_input_ids_func): def check_for_jump_forward(self, pad_input_ids_func):
jump_forward_reqs = [] jump_forward_reqs = []
filter_indices = [i for i in range(len(self.reqs))] keep_indices = set(i for i in range(len(self.reqs)))
for i, req in enumerate(self.reqs): for i, req in enumerate(self.reqs):
if req.jump_forward_map is not None: if req.jump_forward_map is not None:
...@@ -719,9 +719,9 @@ class ScheduleBatch: ...@@ -719,9 +719,9 @@ class ScheduleBatch:
) )
jump_forward_reqs.append(req) jump_forward_reqs.append(req)
filter_indices.remove(i) keep_indices.remove(i)
self.filter_batch(filter_indices) self.filter_batch(keep_indices=list(keep_indices))
return jump_forward_reqs return jump_forward_reqs
...@@ -740,19 +740,31 @@ class ScheduleBatch: ...@@ -740,19 +740,31 @@ class ScheduleBatch:
self.req_pool_indices, self.seq_lens - 1 self.req_pool_indices, self.seq_lens - 1
] = self.out_cache_loc ] = self.out_cache_loc
def filter_batch(self, unfinished_indices: List[int]): def filter_batch(
if unfinished_indices is None or len(unfinished_indices) == 0: self,
current_inflight_req: Optional[Req] = None,
keep_indices: Optional[List[int]] = None,
):
if keep_indices is None:
keep_indices = [
i
for i in range(len(self.reqs))
if not self.reqs[i].finished()
and self.reqs[i] is not current_inflight_req
]
if keep_indices is None or len(keep_indices) == 0:
# Filter out all requests # Filter out all requests
self.reqs = [] self.reqs = []
return return
if len(unfinished_indices) == len(self.reqs): if len(keep_indices) == len(self.reqs):
# No need to filter # No need to filter
return return
self.reqs = [self.reqs[i] for i in unfinished_indices] self.reqs = [self.reqs[i] for i in keep_indices]
new_indices = torch.tensor( new_indices = torch.tensor(
unfinished_indices, dtype=torch.int32, device=self.seq_lens.device keep_indices, dtype=torch.int32, device=self.seq_lens.device
) )
self.req_pool_indices = self.req_pool_indices[new_indices] self.req_pool_indices = self.req_pool_indices[new_indices]
self.seq_lens = self.seq_lens[new_indices] self.seq_lens = self.seq_lens[new_indices]
...@@ -760,16 +772,14 @@ class ScheduleBatch: ...@@ -760,16 +772,14 @@ class ScheduleBatch:
self.output_ids = self.output_ids[new_indices] self.output_ids = self.output_ids[new_indices]
self.return_logprob = any(req.return_logprob for req in self.reqs) self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob: if self.return_logprob:
self.top_logprobs_nums = [ self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
self.top_logprobs_nums[i] for i in unfinished_indices
]
else: else:
self.top_logprobs_nums = None self.top_logprobs_nums = None
self.has_stream = any(req.stream for req in self.reqs) self.has_stream = any(req.stream for req in self.reqs)
self.has_regex = any(req.regex_fsm for req in self.reqs) self.has_regex = any(req.regex_fsm for req in self.reqs)
self.sampling_info.filter_batch(unfinished_indices, new_indices) self.sampling_info.filter_batch(keep_indices, new_indices)
def merge_batch(self, other: "ScheduleBatch"): def merge_batch(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
......
...@@ -446,31 +446,41 @@ class Scheduler: ...@@ -446,31 +446,41 @@ class Scheduler:
exit(1) if crash_on_warning else None exit(1) if crash_on_warning else None
def get_next_batch_to_run(self): def get_next_batch_to_run(self):
# Merge prefill to the running batch # Merge the prefill batch into the running batch
if ( if (
self.last_batch self.last_batch
and not self.last_batch.forward_mode.is_decode() and not self.last_batch.forward_mode.is_decode()
and not self.last_batch.is_empty() and not self.last_batch.is_empty()
): ):
if self.running_batch is None: if self.current_inflight_req:
self.running_batch = self.last_batch self.last_batch.filter_batch(self.current_inflight_req)
else: self.batch_is_full = False
self.running_batch.merge_batch(self.last_batch) if not self.last_batch.is_empty():
if self.running_batch is None:
self.running_batch = self.last_batch
else:
self.running_batch.merge_batch(self.last_batch)
# Prefill first # Prefill first
new_batch = self.get_new_batch_prefill() new_batch = self.get_new_batch_prefill()
if new_batch is not None: if new_batch is not None:
return new_batch return new_batch
# Run decode # Check memory
if self.running_batch is not None: if self.running_batch is None:
self.update_running_batch()
if not self.running_batch:
return None
return self.running_batch
else:
self.check_memory() self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio self.new_token_ratio = global_config.init_new_token_ratio
return
# Run decode
before_bs = self.running_batch.batch_size()
self.update_running_batch()
if not self.running_batch:
self.batch_is_full = False
return None
if before_bs != self.running_batch.batch_size():
self.batch_is_full = False
return self.running_batch
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Handle the cases where prefill is not allowed # Handle the cases where prefill is not allowed
...@@ -617,6 +627,11 @@ class Scheduler: ...@@ -617,6 +627,11 @@ class Scheduler:
global test_retract global test_retract
batch = self.running_batch batch = self.running_batch
batch.filter_batch()
if batch.is_empty():
self.running_batch = None
return
# Check if decode out of memory # Check if decode out of memory
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10): if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
old_ratio = self.new_token_ratio old_ratio = self.new_token_ratio
...@@ -640,8 +655,6 @@ class Scheduler: ...@@ -640,8 +655,6 @@ class Scheduler:
if not self.disable_regex_jump_forward: if not self.disable_regex_jump_forward:
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func) jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self.waiting_queue.extend(jump_forward_reqs) self.waiting_queue.extend(jump_forward_reqs)
if jump_forward_reqs:
self.batch_is_full = False
if batch.is_empty(): if batch.is_empty():
self.running_batch = None self.running_batch = None
return return
...@@ -892,14 +905,8 @@ class Scheduler: ...@@ -892,14 +905,8 @@ class Scheduler:
output_no_stop_trim = [] output_no_stop_trim = []
else: # embedding or reward model else: # embedding or reward model
output_embeddings = [] output_embeddings = []
unfinished_indices = []
for i, req in enumerate(batch.reqs):
if not req.finished() and req is not self.current_inflight_req:
unfinished_indices.append(i)
else:
self.batch_is_full = False
for req in batch.reqs:
if req.finished() or ( if req.finished() or (
req.stream req.stream
and ( and (
...@@ -955,9 +962,6 @@ class Scheduler: ...@@ -955,9 +962,6 @@ class Scheduler:
} }
output_meta_info.append(meta_info) output_meta_info.append(meta_info)
# Remove finished reqs: update batch tensors
batch.filter_batch(unfinished_indices)
# Send to detokenizer # Send to detokenizer
if output_rids: if output_rids:
if self.is_generation: if self.is_generation:
......
"""
python3 -m unittest test_json_constrained.TestJSONConstrained.test_json_generate
"""
import json import json
import unittest import unittest
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment