model.py 3.21 KB
Newer Older
1
2
3
4
5
6
7
import dgl
import dgl.nn as dglnn
import sklearn.linear_model as lm
import sklearn.metrics as skm
import torch as th
import torch.functional as F
import torch.nn as nn
8
import tqdm
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49


class SAGE(nn.Module):
    def __init__(
        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
    ):
        super().__init__()
        self.init(in_feats, n_hidden, n_classes, n_layers, activation, dropout)

    def init(
        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
    ):
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        if n_layers > 1:
            self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
            for i in range(1, n_layers - 1):
                self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
            self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
        else:
            self.layers.append(dglnn.SAGEConv(in_feats, n_classes, "mean"))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(
                block,
                h,
                edge_weight=block.edata["edge_weights"]
                if "edge_weights" in block.edata
                else None,
            )
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h

50
    def inference(self, g, device, batch_size, use_uva, num_workers):
51
52
53
54
55
        # The difference between this inference function and the one in the official
        # example is that the intermediate results can also benefit from prefetching.
        g.ndata["h"] = g.ndata["features"]
        sampler = dgl.dataloading.MultiLayerFullNeighborSampler(
            1, prefetch_node_feats=["h"]
56
        )
57
        pin_memory = g.device != device and use_uva
58
59
60
61
62
63
64
65
        dataloader = dgl.dataloading.DataLoader(
            g,
            th.arange(g.num_nodes(), dtype=g.idtype, device=g.device),
            sampler,
            device=device,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
66
            use_uva=use_uva,
67
68
            num_workers=num_workers,
            persistent_workers=(num_workers > 0),
69
70
        )

71
        self.eval()
72

73
        for l, layer in enumerate(self.layers):
74
            y = th.empty(
75
76
                g.num_nodes(),
                self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
77
78
79
                dtype=g.ndata["h"].dtype,
                device=g.device,
                pin_memory=pin_memory,
80
81
82
83
            )
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                x = blocks[0].srcdata["h"]
                h = layer(blocks[0], x)
84
                if l < len(self.layers) - 1:
85
86
                    h = self.activation(h)
                    h = self.dropout(h)
87
88
89
90
                # by design, our output nodes are contiguous
                y[output_nodes[0].item() : output_nodes[-1].item() + 1] = h.to(
                    y.device
                )
91
92
            g.ndata["h"] = y
        return y