loader.py 4.17 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from typing import NamedTuple, List, Tuple
rusty1s's avatar
rusty1s committed
2
3
4
5
6
7
8
9
10

import time

import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torch_sparse import SparseTensor
from torch_geometric.data import Data

rusty1s's avatar
rusty1s committed
11
relabel_fn = torch.ops.torch_geometric_autoscale.relabel_one_hop
rusty1s's avatar
rusty1s committed
12
13
14


class SubData(NamedTuple):
rusty1s's avatar
rusty1s committed
15
    data: Data
rusty1s's avatar
rusty1s committed
16
17
    batch_size: int
    n_id: Tensor
rusty1s's avatar
rusty1s committed
18
19
    offset: Tensor
    count: Tensor
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26

    def to(self, *args, **kwargs):
        return SubData(self.data.to(*args, **kwargs), self.batch_size,
                       self.n_id, self.offset, self.count)


class SubgraphLoader(DataLoader):
rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    r"""A simple subgraph loader that, given a pre-partioned :obj:`data` object,
    generates subgraphs (including its 1-hop neighbors) from mini-batches in
    :obj:`ptr`."""
    def __init__(self, data: Data, ptr: Tensor, batch_size: int = 1,
                 bipartite: bool = True, log: bool = True, **kwargs):

        self.data = data
        self.ptr = ptr
        self.bipartite = bipartite
        self.log = log

        n_id = torch.arange(data.num_nodes)
        batches = n_id.split((ptr[1:] - ptr[:-1]).tolist())
        batches = [(i, batches[i]) for i in range(len(batches))]

        if batch_size > 1:
            super(SubgraphLoader, self).__init__(
                batches,
                batch_size=batch_size,
                collate_fn=self.compute_subgraph,
                **kwargs,
            )

        else:  # If `batch_size=1`, we pre-process the subgraph generation:
            if log:
                t = time.perf_counter()
                print('Pre-processing subgraphs...', end=' ', flush=True)

            data_list = list(
                DataLoader(batches, collate_fn=self.compute_subgraph,
                           batch_size=batch_size, **kwargs))

            if log:
                print(f'Done! [{time.perf_counter() - t:.2f}s]')

            super(SubgraphLoader, self).__init__(
                data_list,
                batch_size=batch_size,
                collate_fn=lambda x: x[0],
                **kwargs,
            )

    def compute_subgraph(self, batches: List[Tuple[int, Tensor]]) -> SubData:
        batch_ids, n_ids = zip(*batches)
rusty1s's avatar
rusty1s committed
71
        n_id = torch.cat(n_ids, dim=0)
rusty1s's avatar
rusty1s committed
72
        batch_id = torch.tensor(batch_ids)
rusty1s's avatar
rusty1s committed
73

rusty1s's avatar
rusty1s committed
74
75
76
        # We collect the in-mini-batch size (`batch_size`), the offset of each
        # partition in the mini-batch (`offset`), and the number of nodes in
        # each partition (`count`)
rusty1s's avatar
rusty1s committed
77
        batch_size = n_id.numel()
rusty1s's avatar
rusty1s committed
78
79
        offset = self.ptr[batch_id]
        count = self.ptr[batch_id.add_(1)].sub_(offset)
rusty1s's avatar
rusty1s committed
80

rusty1s's avatar
rusty1s committed
81
        rowptr, col, value = self.data.adj_t.csr()
rusty1s's avatar
rusty1s committed
82
        rowptr, col, value, n_id = relabel_fn(rowptr, col, value, n_id,
rusty1s's avatar
rusty1s committed
83
                                              self.bipartite)
rusty1s's avatar
rusty1s committed
84
85
86
87
88

        adj_t = SparseTensor(rowptr=rowptr, col=col, value=value,
                             sparse_sizes=(rowptr.numel() - 1, n_id.numel()),
                             is_sorted=True)

rusty1s's avatar
rusty1s committed
89
90
91
92
        data = self.data.__class__(adj_t=adj_t)
        for k, v in self.data:
            if isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:
                data[k] = v.index_select(0, n_id)
rusty1s's avatar
rusty1s committed
93

rusty1s's avatar
rusty1s committed
94
        return SubData(data, batch_size, n_id, offset, count)
rusty1s's avatar
rusty1s committed
95
96
97
98
99
100

    def __repr__(self):
        return f'{self.__class__.__name__}()'


class EvalSubgraphLoader(SubgraphLoader):
rusty1s's avatar
rusty1s committed
101
102
103
104
105
106
107
108
109
    r"""A simple subgraph loader that, given a pre-partioned :obj:`data` object,
    generates subgraphs (including its 1-hop neighbors) from mini-batches in
    :obj:`ptr`.
    In contrast to :class:`SubgraphLoader`, this loader does not generate
    subgraphs from randomly sampled mini-batches, and should therefore only be
    used for evaluation.
    """
    def __init__(self, data: Data, ptr: Tensor, batch_size: int = 1,
                 bipartite: bool = True, log: bool = True, **kwargs):
rusty1s's avatar
rusty1s committed
110

rusty1s's avatar
rusty1s committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        ptr = ptr[::batch_size]
        if int(ptr[-1]) != data.num_nodes:
            ptr = torch.cat([ptr, torch.tensor(data.num_nodes)], dim=0)

        super(EvalSubgraphLoader, self).__init__(
            data=data,
            ptr=ptr,
            batch_size=1,
            bipartite=bipartite,
            log=log,
            shuffle=False,
            num_workers=0,
            **kwargs,
        )