cache.py 1.19 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch

from dataclasses import dataclass
from typing import Dict, Optional, List

from bloom_inference.pb import generate_pb2
from bloom_inference.utils import NextTokenChooser, StoppingCriteria


@dataclass
class CacheEntry:
    batch_id: int
    request_ids: List[int]
    input_ids: Dict[str, torch.Tensor]
    all_input_ids: List[torch.Tensor]
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

    def __len__(self):
        return len(self.request_ids)

    def to_pb(self):
        return generate_pb2.CacheEntry(
            id=self.batch_id,
            request_ids=self.request_ids,
            sequence_length=max(len(entry) for entry in self.all_input_ids),
        )


class Cache:
    def __init__(self):
        self.cache: Dict[str, CacheEntry] = {}

    def pop(self, batch_id: str) -> Optional[CacheEntry]:
        return self.cache.pop(batch_id, None)

    def set(self, entry: CacheEntry):
        if entry is not None:
            self.cache[entry.batch_id] = entry

    def delete(self, batch_id: str):
        del self.cache[batch_id]

    def clear(self):
        self.cache.clear()

    def __len__(self):
        return len(self.cache.keys())