Unverified Commit 24f3e151 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Improve style (#1666)

parent 6790240c
...@@ -203,6 +203,7 @@ class Req: ...@@ -203,6 +203,7 @@ class Req:
self.prefix_indices = [] self.prefix_indices = []
self.extend_input_len = 0 self.extend_input_len = 0
self.last_node = None self.last_node = None
self.is_inflight_req = 0
# Logprobs (arguments) # Logprobs (arguments)
self.return_logprob = False self.return_logprob = False
......
...@@ -45,7 +45,7 @@ class SchedulePolicy: ...@@ -45,7 +45,7 @@ class SchedulePolicy:
def calc_priority(self, waiting_queue: List[Req]): def calc_priority(self, waiting_queue: List[Req]):
# Compute matched prefix length # Compute matched prefix length
prefix_computed = False prefix_computed = False
if self.policy in ["lpm", "dfs-weight"]: if self.policy == "lpm" or self.policy == "dfs-weight":
for r in waiting_queue: for r in waiting_queue:
# NOTE: the prefix_indices must always be aligned with last_node # NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix( r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
......
...@@ -194,7 +194,7 @@ class Scheduler: ...@@ -194,7 +194,7 @@ class Scheduler:
# Init running status # Init running status
self.waiting_queue: List[Req] = [] self.waiting_queue: List[Req] = []
self.running_batch: ScheduleBatch = None self.running_batch: Optional[ScheduleBatch] = None
self.decode_forward_ct = 0 self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval self.stream_interval = server_args.stream_interval
self.num_generated_tokens = 0 self.num_generated_tokens = 0
...@@ -273,6 +273,9 @@ class Scheduler: ...@@ -273,6 +273,9 @@ class Scheduler:
break break
result = self.run_batch(batch) result = self.run_batch(batch)
self.process_batch_result(batch, result) self.process_batch_result(batch, result)
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
self.last_batch = batch self.last_batch = batch
...@@ -468,8 +471,6 @@ class Scheduler: ...@@ -468,8 +471,6 @@ class Scheduler:
# Check memory # Check memory
if self.running_batch is None: if self.running_batch is None:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
return return
# Run decode # Run decode
...@@ -489,9 +490,7 @@ class Scheduler: ...@@ -489,9 +490,7 @@ class Scheduler:
) and self.current_inflight_req is None: ) and self.current_inflight_req is None:
return None return None
running_bs = ( running_bs = len(self.running_batch.reqs) if self.running_batch else 0
len(self.running_batch.reqs) if self.running_batch is not None else 0
)
if running_bs >= self.max_running_requests: if running_bs >= self.max_running_requests:
self.batch_is_full = True self.batch_is_full = True
return None return None
...@@ -512,7 +511,7 @@ class Scheduler: ...@@ -512,7 +511,7 @@ class Scheduler:
) )
has_inflight = self.current_inflight_req is not None has_inflight = self.current_inflight_req is not None
if self.current_inflight_req is not None: if has_inflight:
self.current_inflight_req.init_next_round_input( self.current_inflight_req.init_next_round_input(
None if prefix_computed else self.tree_cache None if prefix_computed else self.tree_cache
) )
...@@ -520,7 +519,7 @@ class Scheduler: ...@@ -520,7 +519,7 @@ class Scheduler:
self.current_inflight_req self.current_inflight_req
) )
if self.lora_paths is not None: if self.lora_paths:
lora_set = ( lora_set = (
set([req.lora_path for req in self.running_batch.reqs]) set([req.lora_path for req in self.running_batch.reqs])
if self.running_batch is not None if self.running_batch is not None
...@@ -529,7 +528,7 @@ class Scheduler: ...@@ -529,7 +528,7 @@ class Scheduler:
for req in self.waiting_queue: for req in self.waiting_queue:
if ( if (
self.lora_paths is not None self.lora_paths
and len( and len(
lora_set lora_set
| set([req.lora_path for req in adder.can_run_list]) | set([req.lora_path for req in adder.can_run_list])
...@@ -551,16 +550,20 @@ class Scheduler: ...@@ -551,16 +550,20 @@ class Scheduler:
self.batch_is_full = True self.batch_is_full = True
break break
# Update waiting queue
can_run_list = adder.can_run_list can_run_list = adder.can_run_list
if len(can_run_list) == 0:
return None
self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list)
]
if adder.new_inflight_req is not None: if adder.new_inflight_req is not None:
assert self.current_inflight_req is None assert self.current_inflight_req is None
self.current_inflight_req = adder.new_inflight_req self.current_inflight_req = adder.new_inflight_req
if len(can_run_list) == 0: if self.current_inflight_req:
return None self.current_inflight_req.is_inflight_req += 1
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
# Print stats # Print stats
if self.tp_rank == 0: if self.tp_rank == 0:
...@@ -613,13 +616,13 @@ class Scheduler: ...@@ -613,13 +616,13 @@ class Scheduler:
new_batch.prepare_for_extend(self.model_config.vocab_size) new_batch.prepare_for_extend(self.model_config.vocab_size)
# Mixed-style chunked prefill # Mixed-style chunked prefill
decoding_reqs = []
if self.is_mixed_chunk and self.running_batch is not None: if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.prepare_for_decode() self.running_batch.prepare_for_decode()
new_batch.mix_with_running(self.running_batch) new_batch.mix_with_running(self.running_batch)
decoding_reqs = self.running_batch.reqs new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch = None self.running_batch = None
new_batch.decoding_reqs = decoding_reqs else:
new_batch.decoding_reqs = None
return new_batch return new_batch
...@@ -738,12 +741,12 @@ class Scheduler: ...@@ -738,12 +741,12 @@ class Scheduler:
if req.finished(): if req.finished():
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
elif req not in batch.decoding_reqs: elif not batch.decoding_reqs or req not in batch.decoding_reqs:
# To reduce overhead, only cache prefill reqs
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
if req is self.current_inflight_req: if req.is_inflight_req > 0:
# Inflight request would get a new req idx # Inflight request would get a new req idx
req.is_inflight_req -= 1
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
if req.return_logprob: if req.return_logprob:
...@@ -768,8 +771,9 @@ class Scheduler: ...@@ -768,8 +771,9 @@ class Scheduler:
else: else:
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
if req is self.current_inflight_req: if req.is_inflight_req > 0:
# Inflight request would get a new req idx # Inflight request would get a new req idx
req.is_inflight_req -= 1
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
self.stream_output(batch) self.stream_output(batch)
...@@ -906,13 +910,11 @@ class Scheduler: ...@@ -906,13 +910,11 @@ class Scheduler:
else: # embedding or reward model else: # embedding or reward model
output_embeddings = [] output_embeddings = []
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
for req in batch.reqs: for req in batch.reqs:
if req.finished() or ( if req.finished() or (
req.stream req.stream and (is_stream_iter or len(req.output_ids) == 1)
and (
self.decode_forward_ct % self.stream_interval == 0
or len(req.output_ids) == 1
)
): ):
output_rids.append(req.rid) output_rids.append(req.rid)
output_finished_reason.append(req.finished_reason) output_finished_reason.append(req.finished_reason)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment