cache.py 796 Bytes
Newer Older
1
2
import torch

3
from typing import Dict, Optional, TypeVar
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
4

5
from text_generation_server.models.types import Batch
6

7
8
B = TypeVar("B", bound=Batch)

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9
10
11

class Cache:
    def __init__(self):
12
        self.cache: Dict[int, B] = {}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
13

14
    def pop(self, batch_id: int) -> Optional[B]:
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
15
16
        return self.cache.pop(batch_id, None)

17
    def set(self, entry: B):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
18
19
20
        if entry is not None:
            self.cache[entry.batch_id] = entry

Olivier Dehaene's avatar
Olivier Dehaene committed
21
    def delete(self, batch_id: int):
22
23
24
        batch = self.pop(batch_id)
        if batch is not None:
            del batch
25
26
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
27
28

    def clear(self):
29
30
31
        keys = list(self.cache.keys())
        for k in keys:
            self.delete(k)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
32
33
34

    def __len__(self):
        return len(self.cache.keys())