cache.py 2.41 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
2
import torch

Casper's avatar
Casper committed
3

Casper Hansen's avatar
Casper Hansen committed
4
class WindowedCache:
Casper's avatar
Casper committed
5
    def __init__(self, cache_v_shape, cache_k_shape, max_seq_len, device):
Casper Hansen's avatar
Casper Hansen committed
6
        """
Casper's avatar
Casper committed
7
8
        The window size is the same as the max_seq_len. The window will
        automatically roll once max_seq_len is exceeded.
Casper Hansen's avatar
Casper Hansen committed
9
10
11
12
13
        """
        # [batch_size, n_kv_heads, max_seq_len, head_dim]
        self.v = torch.zeros(cache_v_shape).to(device).half()
        # [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
        self.k = torch.zeros(cache_k_shape).to(device).half()
Casper's avatar
Casper committed
14
        self.max_seq_len = max_seq_len
Casper's avatar
Casper committed
15

Casper Hansen's avatar
Casper Hansen committed
16
    def get_kv(self, batch_size, start_pos, seqlen, head_dim):
Casper's avatar
Casper committed
17
18
19
        """
        Gets the key-value store in correct shapes.
        """
Casper's avatar
Casper committed
20
21
22
23
24
25
26
27
        xv = (
            self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous()
        )
        xk = (
            self.k[:batch_size, :, :, : start_pos + seqlen, :]
            .transpose(2, 3)
            .contiguous()
        )
Casper Hansen's avatar
Casper Hansen committed
28
29
30
        xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous()

        return xv, xk
Casper's avatar
Casper committed
31

Casper Hansen's avatar
Casper Hansen committed
32
    def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen):
Casper's avatar
Casper committed
33
34
35
        """
        Updates the values in the key-value store.
        """
Casper Hansen's avatar
Casper Hansen committed
36
37
38
        self.v[:batch_size, :, start_pos : start_pos + seqlen, :] = values_store
        self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store

Casper's avatar
Casper committed
39
40
41
42
43
44
45
46
    def roll_kv_n_steps(self, start_pos, n=100):
        """
        Roll cache n to the left.
        """
        n = min(n, self.max_seq_len)
        # Roll cache to the left
        self.v = torch.roll(self.v, shifts=-n, dims=2)
        self.k = torch.roll(self.k, shifts=-n, dims=3)
Casper Hansen's avatar
Casper Hansen committed
47

Casper's avatar
Casper committed
48
49
50
        # Zero out the new part
        self.v[:, :, -n:, :] = 0
        self.k[:, :, :, -n:, :] = 0
Casper's avatar
Casper committed
51

Casper's avatar
Casper committed
52
        return start_pos - n
Casper's avatar
Casper committed
53

Casper Hansen's avatar
Casper Hansen committed
54
55
56
    def to(self, device):
        self.k = self.k.to(device)
        self.v = self.v.to(device)
Casper's avatar
Casper committed
57

Casper's avatar
Casper committed
58
59
    def increase_batch_size(self, to_bsz):
        """Dynamically allocate new kv when batch size changes."""
Casper's avatar
Casper committed
60
61
62
63
64
65
        self.v = torch.zeros(
            to_bsz, *self.v.shape[1:], dtype=self.v.dtype, device=self.v.device
        )
        self.k = torch.zeros(
            to_bsz, *self.k.shape[1:], dtype=self.k.dtype, device=self.k.device
        )
Casper's avatar
Casper committed
66
67
68
69

    def decrease_batch_size(self, to_bsz):
        """Dynamically remove part of cache if batch size changes."""
        self.v = self.v[:to_bsz, :, :, :]
Casper's avatar
Casper committed
70
        self.k = self.k[:to_bsz, :, :, :, :]