history.py 2.11 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
from typing import Optional

import torch
from torch import Tensor


class History(torch.nn.Module):
rusty1s's avatar
rusty1s committed
8
    r"""A historical embedding storage module."""
rusty1s's avatar
rusty1s committed
9
    def __init__(self, num_embeddings: int, embedding_dim: int, device=None):
rusty1s's avatar
rusty1s committed
10
        super().__init__()
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim

        pin_memory = device is None or str(device) == 'cpu'
        self.emb = torch.empty(num_embeddings, embedding_dim, device=device,
                               pin_memory=pin_memory)

        self._device = torch.device('cpu')

        self.reset_parameters()

    def reset_parameters(self):
        self.emb.fill_(0)

    def _apply(self, fn):
rusty1s's avatar
rusty1s committed
27
        # Set the `_device` of the module without transfering `self.emb`.
rusty1s's avatar
rusty1s committed
28
29
30
31
        self._device = fn(torch.zeros(1)).device
        return self

    @torch.no_grad()
rusty1s's avatar
rusty1s committed
32
    def pull(self, n_id: Optional[Tensor] = None) -> Tensor:
rusty1s's avatar
rusty1s committed
33
        out = self.emb
rusty1s's avatar
rusty1s committed
34
35
36
        if n_id is not None:
            assert n_id.device == self.emb.device
            out = out.index_select(0, n_id)
rusty1s's avatar
rusty1s committed
37
38
39
        return out.to(device=self._device)

    @torch.no_grad()
rusty1s's avatar
rusty1s committed
40
    def push(self, x, n_id: Optional[Tensor] = None,
rusty1s's avatar
rusty1s committed
41
42
             offset: Optional[Tensor] = None, count: Optional[Tensor] = None):

rusty1s's avatar
rusty1s committed
43
        if n_id is None and x.size(0) != self.num_embeddings:
rusty1s's avatar
rusty1s committed
44
45
            raise ValueError

rusty1s's avatar
rusty1s committed
46
        elif n_id is None and x.size(0) == self.num_embeddings:
rusty1s's avatar
rusty1s committed
47
48
            self.emb.copy_(x)

rusty1s's avatar
rusty1s committed
49
50
51
        elif offset is None or count is None:
            assert n_id.device == self.emb.device
            self.emb[n_id] = x.to(self.emb.device)
rusty1s's avatar
rusty1s committed
52

rusty1s's avatar
rusty1s committed
53
54
        else:  # Push in chunks:
            src_o = 0
rusty1s's avatar
rusty1s committed
55
            x = x.to(self.emb.device)
rusty1s's avatar
rusty1s committed
56
57
58
            for dst_o, c, in zip(offset.tolist(), count.tolist()):
                self.emb[dst_o:dst_o + c] = x[src_o:src_o + c]
                src_o += c
rusty1s's avatar
rusty1s committed
59
60
61
62
63
64
65
66
67

    def forward(self, *args, **kwargs):
        """"""
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.num_embeddings}, '
                f'{self.embedding_dim}, emb_device={self.emb.device}, '
                f'device={self._device})')