Unverified Commit 24f3e151 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Improve style (#1666)

parent 6790240c
......@@ -203,6 +203,7 @@ class Req:
self.prefix_indices = []
self.extend_input_len = 0
self.last_node = None
self.is_inflight_req = 0
# Logprobs (arguments)
self.return_logprob = False
......
......@@ -45,7 +45,7 @@ class SchedulePolicy:
def calc_priority(self, waiting_queue: List[Req]):
# Compute matched prefix length
prefix_computed = False
if self.policy in ["lpm", "dfs-weight"]:
if self.policy == "lpm" or self.policy == "dfs-weight":
for r in waiting_queue:
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
......
......@@ -194,7 +194,7 @@ class Scheduler:
# Init running status
self.waiting_queue: List[Req] = []
self.running_batch: ScheduleBatch = None
self.running_batch: Optional[ScheduleBatch] = None
self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval
self.num_generated_tokens = 0
......@@ -273,6 +273,9 @@ class Scheduler:
break
result = self.run_batch(batch)
self.process_batch_result(batch, result)
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
self.last_batch = batch
......@@ -468,8 +471,6 @@ class Scheduler:
# Check memory
if self.running_batch is None:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
return
# Run decode
......@@ -489,9 +490,7 @@ class Scheduler:
) and self.current_inflight_req is None:
return None
running_bs = (
len(self.running_batch.reqs) if self.running_batch is not None else 0
)
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
if running_bs >= self.max_running_requests:
self.batch_is_full = True
return None
......@@ -512,7 +511,7 @@ class Scheduler:
)
has_inflight = self.current_inflight_req is not None
if self.current_inflight_req is not None:
if has_inflight:
self.current_inflight_req.init_next_round_input(
None if prefix_computed else self.tree_cache
)
......@@ -520,7 +519,7 @@ class Scheduler:
self.current_inflight_req
)
if self.lora_paths is not None:
if self.lora_paths:
lora_set = (
set([req.lora_path for req in self.running_batch.reqs])
if self.running_batch is not None
......@@ -529,7 +528,7 @@ class Scheduler:
for req in self.waiting_queue:
if (
self.lora_paths is not None
self.lora_paths
and len(
lora_set
| set([req.lora_path for req in adder.can_run_list])
......@@ -551,16 +550,20 @@ class Scheduler:
self.batch_is_full = True
break
# Update waiting queue
can_run_list = adder.can_run_list
if len(can_run_list) == 0:
return None
self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list)
]
if adder.new_inflight_req is not None:
assert self.current_inflight_req is None
self.current_inflight_req = adder.new_inflight_req
if len(can_run_list) == 0:
return None
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
if self.current_inflight_req:
self.current_inflight_req.is_inflight_req += 1
# Print stats
if self.tp_rank == 0:
......@@ -613,13 +616,13 @@ class Scheduler:
new_batch.prepare_for_extend(self.model_config.vocab_size)
# Mixed-style chunked prefill
decoding_reqs = []
if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.prepare_for_decode()
new_batch.mix_with_running(self.running_batch)
decoding_reqs = self.running_batch.reqs
new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch = None
new_batch.decoding_reqs = decoding_reqs
else:
new_batch.decoding_reqs = None
return new_batch
......@@ -738,12 +741,12 @@ class Scheduler:
if req.finished():
self.tree_cache.cache_finished_req(req)
elif req not in batch.decoding_reqs:
# To reduce overhead, only cache prefill reqs
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
self.tree_cache.cache_unfinished_req(req)
if req is self.current_inflight_req:
if req.is_inflight_req > 0:
# Inflight request would get a new req idx
req.is_inflight_req -= 1
self.req_to_token_pool.free(req.req_pool_idx)
if req.return_logprob:
......@@ -768,8 +771,9 @@ class Scheduler:
else:
self.tree_cache.cache_unfinished_req(req)
if req is self.current_inflight_req:
if req.is_inflight_req > 0:
# Inflight request would get a new req idx
req.is_inflight_req -= 1
self.req_to_token_pool.free(req.req_pool_idx)
self.stream_output(batch)
......@@ -906,13 +910,11 @@ class Scheduler:
else: # embedding or reward model
output_embeddings = []
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
for req in batch.reqs:
if req.finished() or (
req.stream
and (
self.decode_forward_ct % self.stream_interval == 0
or len(req.output_ids) == 1
)
req.stream and (is_stream_iter or len(req.output_ids) == 1)
):
output_rids.append(req.rid)
output_finished_reason.append(req.finished_reason)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment