chunk_cache.py 1.93 KB
Newer Older
1
2
from __future__ import annotations

3
"""Cache for chunked prefill, used when RadixCache is disabled."""
Lianmin Zheng's avatar
Lianmin Zheng committed
4
5

from typing import TYPE_CHECKING, Any, Callable, List, Tuple
6
7

import torch
8

9
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
10
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
11
12
13

if TYPE_CHECKING:
    from sglang.srt.managers.schedule_batch import Req
14
15
16


class ChunkCache(BasePrefixCache):
17
    def __init__(
18
19
20
        self,
        req_to_token_pool: ReqToTokenPool,
        token_to_kv_pool_allocator: TokenToKVPoolAllocator,
21
        page_size: int,
22
    ):
23
        self.req_to_token_pool = req_to_token_pool
24
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
25
        self.page_size = page_size
26
27

    def reset(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
28
        pass
29

30
31
32
33
34
35
    def match_prefix(self, **unused_kwargs) -> MatchResult:
        return MatchResult(
            device_indices=torch.empty((0,), dtype=torch.int64),
            last_device_node=None,
            last_host_node=None,
        )
36

Lianmin Zheng's avatar
Lianmin Zheng committed
37
    def cache_finished_req(self, req: Req):
38
        kv_indices = self.req_to_token_pool.req_to_token[
Byron Hsu's avatar
Byron Hsu committed
39
40
41
            req.req_pool_idx,
            # For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
            : len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
42
43
        ]
        self.req_to_token_pool.free(req.req_pool_idx)
44
        self.token_to_kv_pool_allocator.free(kv_indices)
45

46
    def cache_unfinished_req(self, req: Req):
47
        kv_indices = self.req_to_token_pool.req_to_token[
Lianmin Zheng's avatar
Lianmin Zheng committed
48
            req.req_pool_idx, : len(req.fill_ids)
49
50
        ]

Lianmin Zheng's avatar
Lianmin Zheng committed
51
        # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
Liangsheng Yin's avatar
Liangsheng Yin committed
52
        req.prefix_indices = kv_indices
53

Lianmin Zheng's avatar
Lianmin Zheng committed
54
    def evict(self, num_tokens: int):
55
56
        pass

Lianmin Zheng's avatar
Lianmin Zheng committed
57
    def inc_lock_ref(self, node: Any):
58
59
        return 0

Lianmin Zheng's avatar
Lianmin Zheng committed
60
    def dec_lock_ref(self, node: Any):
61
        return 0
62
63
64

    def pretty_print(self):
        return ""