Unverified Commit 62757db6 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Reduce the overhead when cache is disabled (#1010)

parent 73fa2d49
...@@ -18,44 +18,40 @@ limitations under the License. ...@@ -18,44 +18,40 @@ limitations under the License.
import random import random
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import List
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
class PolicyScheduler: class PolicyScheduler:
def __init__( def __init__(self, policy, tree_cache):
self, if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
policy, # LPM and DFS-weight is meaningless when the tree cache is disabled.
max_running_seqs,
max_prefill_num_tokens,
max_total_num_tokens,
tree_cache,
):
if tree_cache.disable and policy == "lpm":
# LMP is meaningless when the tree cache is disabled.
policy = "fcfs" policy = "fcfs"
self.policy = policy self.policy = policy
self.max_running_seqs = max_running_seqs
self.max_prefill_num_tokens = max_prefill_num_tokens
self.max_total_num_tokens = max_total_num_tokens
self.tree_cache = tree_cache self.tree_cache = tree_cache
def get_priority_queue(self, waiting_queue): def calc_priority(self, waiting_queue: List[Req]):
if self.policy in ["lpm", "dfs-weight"]:
# Compute matched prefix length
for r in waiting_queue:
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=r.adjust_max_prefix_ids()
)
if self.policy == "lpm": if self.policy == "lpm":
# longest prefix match # Longest Prefix Match
waiting_queue.sort(key=lambda x: -len(x.prefix_indices)) waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
return waiting_queue
elif self.policy == "fcfs": elif self.policy == "fcfs":
# first come first serve # first come first serve
return waiting_queue pass
elif self.policy == "lof": elif self.policy == "lof":
# longest output first # longest output first
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
return waiting_queue
elif self.policy == "random": elif self.policy == "random":
random.shuffle(waiting_queue) random.shuffle(waiting_queue)
return waiting_queue
elif self.policy == "dfs-weight": elif self.policy == "dfs-weight":
last_node_to_reqs = defaultdict(list) last_node_to_reqs = defaultdict(list)
for req in waiting_queue: for req in waiting_queue:
...@@ -66,12 +62,13 @@ class PolicyScheduler: ...@@ -66,12 +62,13 @@ class PolicyScheduler:
node_to_weight[node] = len(last_node_to_reqs[node]) node_to_weight[node] = len(last_node_to_reqs[node])
self.calc_weight(self.tree_cache.root_node, node_to_weight) self.calc_weight(self.tree_cache.root_node, node_to_weight)
q = [] waiting_queue.clear()
self.get_dfs_priority( self.get_dfs_priority(
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q self.tree_cache.root_node,
node_to_weight,
last_node_to_reqs,
waiting_queue,
) )
assert len(q) == len(waiting_queue)
return q
else: else:
raise ValueError(f"Unknown schedule_policy: {self.policy}") raise ValueError(f"Unknown schedule_policy: {self.policy}")
...@@ -139,8 +136,6 @@ class PrefillAdder: ...@@ -139,8 +136,6 @@ class PrefillAdder:
self.log_input_tokens += extend_input_len self.log_input_tokens += extend_input_len
def add_inflight_req(self, req: Req): def add_inflight_req(self, req: Req):
req.input_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = len(req.input_ids) - len(req.prefix_indices)
truncated = req.extend_input_len > self.rem_chunk_tokens truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.input_ids = req.input_ids[: len(req.prefix_indices) + req.extend_input_len] req.input_ids = req.input_ids[: len(req.prefix_indices) + req.extend_input_len]
......
...@@ -164,7 +164,12 @@ class Req: ...@@ -164,7 +164,12 @@ class Req:
def finished(self) -> bool: def finished(self) -> bool:
return self.finished_reason is not None return self.finished_reason is not None
def init_next_round_input(self):
self.input_ids = self.origin_input_ids + self.output_ids
self.extend_input_len = len(self.input_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self): def adjust_max_prefix_ids(self):
self.input_ids = self.origin_input_ids + self.output_ids
input_len = len(self.input_ids) input_len = len(self.input_ids)
max_prefix_len = input_len max_prefix_len = input_len
......
...@@ -165,13 +165,7 @@ class ModelTpServer: ...@@ -165,13 +165,7 @@ class ModelTpServer:
disable=server_args.disable_radix_cache, disable=server_args.disable_radix_cache,
) )
self.tree_cache_metrics = {"total": 0, "hit": 0} self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = PolicyScheduler( self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache)
self.schedule_policy,
self.max_running_requests,
self.max_prefill_tokens,
self.max_total_num_tokens,
self.tree_cache,
)
self.req_to_token_pool = self.model_runner.req_to_token_pool self.req_to_token_pool = self.model_runner.req_to_token_pool
self.token_to_kv_pool = self.model_runner.token_to_kv_pool self.token_to_kv_pool = self.model_runner.token_to_kv_pool
...@@ -373,17 +367,8 @@ class ModelTpServer: ...@@ -373,17 +367,8 @@ class ModelTpServer:
if running_bs >= self.max_running_requests: if running_bs >= self.max_running_requests:
return None return None
# Compute matched prefix length
for req in self.waiting_queue:
req.input_ids = req.origin_input_ids + req.output_ids
# NOTE: the prefix_indices must always be aligned with last_node
req.prefix_indices, req.last_node = self.tree_cache.match_prefix(
rid=req.rid, key=req.adjust_max_prefix_ids()
)
req.extend_input_len = len(req.input_ids) - len(req.prefix_indices)
# Get priority queue # Get priority queue
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue) self.scheduler.calc_priority(self.waiting_queue)
adder = PrefillAdder( adder = PrefillAdder(
self.tree_cache, self.tree_cache,
...@@ -397,12 +382,13 @@ class ModelTpServer: ...@@ -397,12 +382,13 @@ class ModelTpServer:
has_inflight = self.current_inflight_req is not None has_inflight = self.current_inflight_req is not None
if self.current_inflight_req is not None: if self.current_inflight_req is not None:
self.current_inflight_req.init_next_round_input()
self.current_inflight_req = adder.add_inflight_req( self.current_inflight_req = adder.add_inflight_req(
self.current_inflight_req self.current_inflight_req
) )
for req in self.waiting_queue: for req in self.waiting_queue:
req.init_next_round_input()
res = adder.add_one_req(req) res = adder.add_one_req(req)
if ( if (
not res not res
......
...@@ -169,6 +169,9 @@ class RadixCache(BasePrefixCache): ...@@ -169,6 +169,9 @@ class RadixCache(BasePrefixCache):
heapq.heappush(leaves, x.parent) heapq.heappush(leaves, x.parent)
def inc_lock_ref(self, node: TreeNode): def inc_lock_ref(self, node: TreeNode):
if self.disable:
return 0
delta = 0 delta = 0
while node != self.root_node: while node != self.root_node:
if node.lock_ref == 0: if node.lock_ref == 0:
...@@ -179,6 +182,9 @@ class RadixCache(BasePrefixCache): ...@@ -179,6 +182,9 @@ class RadixCache(BasePrefixCache):
return delta return delta
def dec_lock_ref(self, node: TreeNode): def dec_lock_ref(self, node: TreeNode):
if self.disable:
return 0
delta = 0 delta = 0
while node != self.root_node: while node != self.root_node:
if node.lock_ref == 1: if node.lock_ref == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment