Unverified Commit 7de60345 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix the prefix indices (#1037)

parent d84c5e70
......@@ -43,11 +43,14 @@ class PolicyScheduler:
def calc_priority(self, waiting_queue: List[Req]):
# Compute matched prefix length
for r in waiting_queue:
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=r.adjust_max_prefix_ids()
)
prefix_computed = False
if self.policy in ["lpm", "dfs-weight"]:
for r in waiting_queue:
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=r.adjust_max_prefix_ids()
)
prefix_computed = True
if self.policy == "lpm":
# Longest Prefix Match
......@@ -80,6 +83,8 @@ class PolicyScheduler:
else:
raise ValueError(f"Unknown schedule_policy: {self.policy}")
return prefix_computed
def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
for child in cur_node.children.values():
self.calc_weight(child, node_to_weight)
......
......@@ -18,9 +18,8 @@ limitations under the License.
import logging
import warnings
from dataclasses import dataclass
from typing import List, Union
from typing import List, Optional, Union
import numpy as np
import torch
from flashinfer.sampling import top_k_top_p_sampling_from_probs
......@@ -28,9 +27,9 @@ import sglang.srt.sampling.penaltylib as penaltylib
from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixCache
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
......@@ -164,8 +163,12 @@ class Req:
def finished(self) -> bool:
return self.finished_reason is not None
def init_next_round_input(self):
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None:
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self):
......@@ -312,7 +315,7 @@ class ScheduleBatch:
reqs: List[Req]
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: BaseTokenToKVPool
tree_cache: RadixCache
tree_cache: BasePrefixCache
# Batched arguments to model runner
input_ids: torch.Tensor = None
......@@ -534,7 +537,7 @@ class ScheduleBatch:
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
req.prefix_indices = None
req.prefix_indices = []
req.last_node = None
req.extend_input_len = 0
......
......@@ -369,7 +369,7 @@ class ModelTpServer:
return None
# Get priority queue
self.scheduler.calc_priority(self.waiting_queue)
prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
adder = PrefillAdder(
self.tree_cache,
......@@ -383,13 +383,15 @@ class ModelTpServer:
has_inflight = self.current_inflight_req is not None
if self.current_inflight_req is not None:
self.current_inflight_req.init_next_round_input()
self.current_inflight_req.init_next_round_input(
None if prefix_computed else self.tree_cache
)
self.current_inflight_req = adder.add_inflight_req(
self.current_inflight_req
)
for req in self.waiting_queue:
req.init_next_round_input()
req.init_next_round_input(None if prefix_computed else self.tree_cache)
res = adder.add_one_req(req)
if (
not res
......
......@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
import heapq
import time
from collections import defaultdict
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, List, Optional
import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment