chunk_cache.py 2 KB
Newer Older
1
2
from __future__ import annotations

3
"""Cache for chunked prefill, used when RadixCache is disabled."""
Lianmin Zheng's avatar
Lianmin Zheng committed
4
5

from typing import TYPE_CHECKING, Any, Callable, List, Tuple
6
7

import torch
8
9

from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
10
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
11
12
13

if TYPE_CHECKING:
    from sglang.srt.managers.schedule_batch import Req
14
15
16


class ChunkCacheEntry:
17
    def __init__(self, rid: str, value: torch.Tensor):
18
19
20
21
22
        self.rid = rid
        self.value = value


class ChunkCache(BasePrefixCache):
23
    def __init__(
24
25
26
        self,
        req_to_token_pool: ReqToTokenPool,
        token_to_kv_pool_allocator: TokenToKVPoolAllocator,
27
        page_size: int,
28
    ):
29
        self.req_to_token_pool = req_to_token_pool
30
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
31
        self.page_size = page_size
32
        self.disable = True
33
34

    def reset(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
35
        pass
36

Lianmin Zheng's avatar
Lianmin Zheng committed
37
38
    def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
        return [], None
39

Lianmin Zheng's avatar
Lianmin Zheng committed
40
    def cache_finished_req(self, req: Req):
41
        kv_indices = self.req_to_token_pool.req_to_token[
Byron Hsu's avatar
Byron Hsu committed
42
43
44
            req.req_pool_idx,
            # For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
            : len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
45
46
        ]
        self.req_to_token_pool.free(req.req_pool_idx)
47
        self.token_to_kv_pool_allocator.free(kv_indices)
48

49
    def cache_unfinished_req(self, req: Req):
50
        kv_indices = self.req_to_token_pool.req_to_token[
Lianmin Zheng's avatar
Lianmin Zheng committed
51
            req.req_pool_idx, : len(req.fill_ids)
52
53
        ]

Lianmin Zheng's avatar
Lianmin Zheng committed
54
        # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
Liangsheng Yin's avatar
Liangsheng Yin committed
55
        req.prefix_indices = kv_indices
56
57

    def insert(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
58
        raise NotImplementedError()
59

Lianmin Zheng's avatar
Lianmin Zheng committed
60
    def evict(self, num_tokens: int):
61
62
        pass

Lianmin Zheng's avatar
Lianmin Zheng committed
63
    def inc_lock_ref(self, node: Any):
64
65
        return 0

Lianmin Zheng's avatar
Lianmin Zheng committed
66
    def dec_lock_ref(self, node: Any):
67
        return 0
68
69
70

    def pretty_print(self):
        return ""