loader.py 4.06 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from typing import NamedTuple, List, Tuple
rusty1s's avatar
rusty1s committed
2
3
4
5
6
7
8
9
10

import time

import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torch_sparse import SparseTensor
from torch_geometric.data import Data

rusty1s's avatar
rusty1s committed
11
relabel_fn = torch.ops.torch_geometric_autoscale.relabel_one_hop
rusty1s's avatar
rusty1s committed
12
13
14


class SubData(NamedTuple):
rusty1s's avatar
rusty1s committed
15
    data: Data
rusty1s's avatar
rusty1s committed
16
    batch_size: int
rusty1s's avatar
rusty1s committed
17
18
19
    n_id: Tensor  # The indices of mini-batched nodes
    offset: Tensor  # The offset of contiguous mini-batched nodes
    count: Tensor  # The number of contiguous mini-batched nodes
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26

    def to(self, *args, **kwargs):
        return SubData(self.data.to(*args, **kwargs), self.batch_size,
                       self.n_id, self.offset, self.count)


class SubgraphLoader(DataLoader):
rusty1s's avatar
rusty1s committed
27
    r"""A simple subgraph loader that, given a pre-partioned :obj:`data` object,
rusty1s's avatar
rusty1s committed
28
29
    generates subgraphs from mini-batches in :obj:`ptr` (including their 1-hop
    neighbors)."""
rusty1s's avatar
rusty1s committed
30
31
32
33
34
35
36
37
38
39
40
41
42
    def __init__(self, data: Data, ptr: Tensor, batch_size: int = 1,
                 bipartite: bool = True, log: bool = True, **kwargs):

        self.data = data
        self.ptr = ptr
        self.bipartite = bipartite
        self.log = log

        n_id = torch.arange(data.num_nodes)
        batches = n_id.split((ptr[1:] - ptr[:-1]).tolist())
        batches = [(i, batches[i]) for i in range(len(batches))]

        if batch_size > 1:
rusty1s's avatar
rusty1s committed
43
44
            super().__init__(batches, batch_size=batch_size,
                             collate_fn=self.compute_subgraph, **kwargs)
rusty1s's avatar
rusty1s committed
45
46
47
48
49
50
51
52
53
54
55
56
57

        else:  # If `batch_size=1`, we pre-process the subgraph generation:
            if log:
                t = time.perf_counter()
                print('Pre-processing subgraphs...', end=' ', flush=True)

            data_list = list(
                DataLoader(batches, collate_fn=self.compute_subgraph,
                           batch_size=batch_size, **kwargs))

            if log:
                print(f'Done! [{time.perf_counter() - t:.2f}s]')

rusty1s's avatar
rusty1s committed
58
59
            super().__init__(data_list, batch_size=batch_size,
                             collate_fn=lambda x: x[0], **kwargs)
rusty1s's avatar
rusty1s committed
60
61
62

    def compute_subgraph(self, batches: List[Tuple[int, Tensor]]) -> SubData:
        batch_ids, n_ids = zip(*batches)
rusty1s's avatar
rusty1s committed
63
        n_id = torch.cat(n_ids, dim=0)
rusty1s's avatar
rusty1s committed
64
        batch_id = torch.tensor(batch_ids)
rusty1s's avatar
rusty1s committed
65

rusty1s's avatar
rusty1s committed
66
67
68
        # We collect the in-mini-batch size (`batch_size`), the offset of each
        # partition in the mini-batch (`offset`), and the number of nodes in
        # each partition (`count`)
rusty1s's avatar
rusty1s committed
69
        batch_size = n_id.numel()
rusty1s's avatar
rusty1s committed
70
71
        offset = self.ptr[batch_id]
        count = self.ptr[batch_id.add_(1)].sub_(offset)
rusty1s's avatar
rusty1s committed
72

rusty1s's avatar
rusty1s committed
73
        rowptr, col, value = self.data.adj_t.csr()
rusty1s's avatar
rusty1s committed
74
        rowptr, col, value, n_id = relabel_fn(rowptr, col, value, n_id,
rusty1s's avatar
rusty1s committed
75
                                              self.bipartite)
rusty1s's avatar
rusty1s committed
76
77
78
79
80

        adj_t = SparseTensor(rowptr=rowptr, col=col, value=value,
                             sparse_sizes=(rowptr.numel() - 1, n_id.numel()),
                             is_sorted=True)

rusty1s's avatar
rusty1s committed
81
82
83
84
        data = self.data.__class__(adj_t=adj_t)
        for k, v in self.data:
            if isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:
                data[k] = v.index_select(0, n_id)
rusty1s's avatar
rusty1s committed
85

rusty1s's avatar
rusty1s committed
86
        return SubData(data, batch_size, n_id, offset, count)
rusty1s's avatar
rusty1s committed
87
88
89
90
91
92

    def __repr__(self):
        return f'{self.__class__.__name__}()'


class EvalSubgraphLoader(SubgraphLoader):
rusty1s's avatar
rusty1s committed
93
    r"""A simple subgraph loader that, given a pre-partioned :obj:`data` object,
rusty1s's avatar
rusty1s committed
94
95
    generates subgraphs from mini-batches in :obj:`ptr` (including their 1-hop
    neighbors).
rusty1s's avatar
rusty1s committed
96
97
98
99
100
101
    In contrast to :class:`SubgraphLoader`, this loader does not generate
    subgraphs from randomly sampled mini-batches, and should therefore only be
    used for evaluation.
    """
    def __init__(self, data: Data, ptr: Tensor, batch_size: int = 1,
                 bipartite: bool = True, log: bool = True, **kwargs):
rusty1s's avatar
rusty1s committed
102

rusty1s's avatar
rusty1s committed
103
104
        ptr = ptr[::batch_size]
        if int(ptr[-1]) != data.num_nodes:
rusty1s's avatar
rusty1s committed
105
            ptr = torch.cat([ptr, torch.tensor([data.num_nodes])], dim=0)
rusty1s's avatar
rusty1s committed
106

rusty1s's avatar
rusty1s committed
107
108
        super().__init__(data=data, ptr=ptr, batch_size=1, bipartite=bipartite,
                         log=log, shuffle=False, num_workers=0, **kwargs)