Unverified Commit 62757db6 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Reduce the overhead when cache is disabled (#1010)

parent 73fa2d49
......@@ -18,44 +18,40 @@ limitations under the License.
import random
from collections import defaultdict
from contextlib import contextmanager
from typing import List
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
class PolicyScheduler:
def __init__(
self,
policy,
max_running_seqs,
max_prefill_num_tokens,
max_total_num_tokens,
tree_cache,
):
if tree_cache.disable and policy == "lpm":
# LMP is meaningless when the tree cache is disabled.
def __init__(self, policy, tree_cache):
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
policy = "fcfs"
self.policy = policy
self.max_running_seqs = max_running_seqs
self.max_prefill_num_tokens = max_prefill_num_tokens
self.max_total_num_tokens = max_total_num_tokens
self.tree_cache = tree_cache
def get_priority_queue(self, waiting_queue):
def calc_priority(self, waiting_queue: List[Req]):
if self.policy in ["lpm", "dfs-weight"]:
# Compute matched prefix length
for r in waiting_queue:
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=r.adjust_max_prefix_ids()
)
if self.policy == "lpm":
# longest prefix match
# Longest Prefix Match
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
return waiting_queue
elif self.policy == "fcfs":
# first come first serve
return waiting_queue
pass
elif self.policy == "lof":
# longest output first
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
return waiting_queue
elif self.policy == "random":
random.shuffle(waiting_queue)
return waiting_queue
elif self.policy == "dfs-weight":
last_node_to_reqs = defaultdict(list)
for req in waiting_queue:
......@@ -66,12 +62,13 @@ class PolicyScheduler:
node_to_weight[node] = len(last_node_to_reqs[node])
self.calc_weight(self.tree_cache.root_node, node_to_weight)
q = []
waiting_queue.clear()
self.get_dfs_priority(
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
self.tree_cache.root_node,
node_to_weight,
last_node_to_reqs,
waiting_queue,
)
assert len(q) == len(waiting_queue)
return q
else:
raise ValueError(f"Unknown schedule_policy: {self.policy}")
......@@ -139,8 +136,6 @@ class PrefillAdder:
self.log_input_tokens += extend_input_len
def add_inflight_req(self, req: Req):
req.input_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = len(req.input_ids) - len(req.prefix_indices)
truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.input_ids = req.input_ids[: len(req.prefix_indices) + req.extend_input_len]
......
......@@ -164,7 +164,12 @@ class Req:
def finished(self) -> bool:
return self.finished_reason is not None
def init_next_round_input(self):
self.input_ids = self.origin_input_ids + self.output_ids
self.extend_input_len = len(self.input_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self):
self.input_ids = self.origin_input_ids + self.output_ids
input_len = len(self.input_ids)
max_prefix_len = input_len
......
......@@ -165,13 +165,7 @@ class ModelTpServer:
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = PolicyScheduler(
self.schedule_policy,
self.max_running_requests,
self.max_prefill_tokens,
self.max_total_num_tokens,
self.tree_cache,
)
self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache)
self.req_to_token_pool = self.model_runner.req_to_token_pool
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
......@@ -373,17 +367,8 @@ class ModelTpServer:
if running_bs >= self.max_running_requests:
return None
# Compute matched prefix length
for req in self.waiting_queue:
req.input_ids = req.origin_input_ids + req.output_ids
# NOTE: the prefix_indices must always be aligned with last_node
req.prefix_indices, req.last_node = self.tree_cache.match_prefix(
rid=req.rid, key=req.adjust_max_prefix_ids()
)
req.extend_input_len = len(req.input_ids) - len(req.prefix_indices)
# Get priority queue
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
self.scheduler.calc_priority(self.waiting_queue)
adder = PrefillAdder(
self.tree_cache,
......@@ -397,12 +382,13 @@ class ModelTpServer:
has_inflight = self.current_inflight_req is not None
if self.current_inflight_req is not None:
self.current_inflight_req.init_next_round_input()
self.current_inflight_req = adder.add_inflight_req(
self.current_inflight_req
)
for req in self.waiting_queue:
req.init_next_round_input()
res = adder.add_one_req(req)
if (
not res
......
......@@ -169,6 +169,9 @@ class RadixCache(BasePrefixCache):
heapq.heappush(leaves, x.parent)
def inc_lock_ref(self, node: TreeNode):
if self.disable:
return 0
delta = 0
while node != self.root_node:
if node.lock_ref == 0:
......@@ -179,6 +182,9 @@ class RadixCache(BasePrefixCache):
return delta
def dec_lock_ref(self, node: TreeNode):
if self.disable:
return 0
delta = 0
while node != self.root_node:
if node.lock_ref == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment