train_cv.py 12.5 KB
Newer Older
1
2
3
import argparse
import time

4
5
6
7
8
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
9
10
11
12
import tqdm
from torch.utils.data import DataLoader

import dgl
13
14
15
import dgl.function as fn
import dgl.nn.pytorch as dglnn
from dgl.data import RedditDataset
16

17
18
19
20
21
22
23
24
25

class SAGEConvWithCV(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super().__init__()
        self.W = nn.Linear(in_feats * 2, out_feats)
        self.activation = activation
        self.reset_parameters()

    def reset_parameters(self):
26
        gain = nn.init.calculate_gain("relu")
27
28
29
30
31
32
33
34
        nn.init.xavier_uniform_(self.W.weight, gain=gain)
        nn.init.constant_(self.W.bias, 0)

    def forward(self, block, H, HBar=None):
        if self.training:
            with block.local_scope():
                H_src, H_dst = H
                HBar_src, agg_HBar_dst = HBar
35
36
37
38
39
40
41
42
                block.dstdata["agg_hbar"] = agg_HBar_dst
                block.srcdata["hdelta"] = H_src - HBar_src
                block.update_all(
                    fn.copy_u("hdelta", "m"), fn.mean("m", "hdelta_new")
                )
                h_neigh = (
                    block.dstdata["agg_hbar"] + block.dstdata["hdelta_new"]
                )
43
44
45
46
47
48
49
                h = self.W(th.cat([H_dst, h_neigh], 1))
                if self.activation is not None:
                    h = self.activation(h)
                return h
        else:
            with block.local_scope():
                H_src, H_dst = H
50
51
52
                block.srcdata["h"] = H_src
                block.update_all(fn.copy_u("h", "m"), fn.mean("m", "h_new"))
                h_neigh = block.dstdata["h_new"]
53
54
55
56
57
                h = self.W(th.cat([H_dst, h_neigh], 1))
                if self.activation is not None:
                    h = self.activation(h)
                return h

58

59
class SAGE(nn.Module):
60
    def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation):
61
62
63
64
65
66
67
68
69
70
71
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(SAGEConvWithCV(in_feats, n_hidden, activation))
        for i in range(1, n_layers - 1):
            self.layers.append(SAGEConvWithCV(n_hidden, n_hidden, activation))
        self.layers.append(SAGEConvWithCV(n_hidden, n_classes, None))

    def forward(self, blocks):
72
        h = blocks[0].srcdata["features"]
73
74
75
76
77
78
        updates = []
        for layer, block in zip(self.layers, blocks):
            # We need to first copy the representation of nodes on the RHS from the
            # appropriate nodes on the LHS.
            # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
            # would be (num_nodes_RHS, D)
79
80
81
            h_dst = h[: block.number_of_dst_nodes()]
            hbar_src = block.srcdata["hist"]
            agg_hbar_dst = block.dstdata["agg_hist"]
82
83
84
            # Then we compute the updated representation on the RHS.
            # The shape of h now becomes (num_nodes_RHS, D)
            h = layer(block, (h, h_dst), (hbar_src, agg_hbar_dst))
85
            block.dstdata["h_new"] = h
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        return h

    def inference(self, g, x, batch_size, device):
        """
        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
        g : the entire graph.
        x : the input of entire node set.

        The inference code is written in a fashion that it could handle any number of nodes and
        layers.
        """
        # During inference with sampling, multi-layer blocks are very inefficient because
        # lots of computations in the first few layers are repeated.
        # Therefore, we compute the representation of all nodes layer by layer.  The nodes
        # on each layer are of course splitted in batches.
        # TODO: can we standardize this?
        nodes = th.arange(g.number_of_nodes())
        ys = []
        for l, layer in enumerate(self.layers):
105
106
107
108
            y = th.zeros(
                g.number_of_nodes(),
                self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
            )
109
110
111
112

            for start in tqdm.trange(0, len(nodes), batch_size):
                end = start + batch_size
                batch_nodes = nodes[start:end]
113
114
115
                block = dgl.to_block(
                    dgl.in_subgraph(g, batch_nodes), batch_nodes
                )
116
                block = block.int().to(device)
117
118
119
                induced_nodes = block.srcdata[dgl.NID]

                h = x[induced_nodes].to(device)
120
                h_dst = h[: block.number_of_dst_nodes()]
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
                h = layer(block, (h, h_dst))

                y[start:end] = h.cpu()

            ys.append(y)
            x = y
        return y, ys


class NeighborSampler(object):
    def __init__(self, g, fanouts):
        self.g = g
        self.fanouts = fanouts

    def sample_blocks(self, seeds):
        seeds = th.LongTensor(seeds)
        blocks = []
        hist_blocks = []
        for fanout in self.fanouts:
            # For each seed node, sample ``fanout`` neighbors.
            frontier = dgl.sampling.sample_neighbors(self.g, seeds, fanout)
            hist_frontier = dgl.in_subgraph(self.g, seeds)
            # Then we compact the frontier into a bipartite graph for message passing.
            block = dgl.to_block(frontier, seeds)
            hist_block = dgl.to_block(hist_frontier, seeds)
            # Obtain the seed nodes for next layer.
            seeds = block.srcdata[dgl.NID]

            blocks.insert(0, block)
            hist_blocks.insert(0, hist_block)
        return blocks, hist_blocks

153

154
155
156
157
158
159
def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

160

161
162
163
164
165
166
167
168
169
170
171
172
def evaluate(model, g, labels, val_mask, batch_size, device):
    """
    Evaluate the model on the validation set specified by ``val_mask``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
173
        inputs = g.ndata["features"]
174
175
176
177
        pred, _ = model.inference(g, inputs, batch_size, device)
    model.train()
    return compute_acc(pred[val_mask], labels[val_mask])

178
179
180
181

def load_subtensor(
    g, labels, blocks, hist_blocks, dev_id, aggregation_on_device=False
):
182
183
184
    """
    Copys features and labels of a set of nodes onto GPU.
    """
185
186
187
188
    blocks[0].srcdata["features"] = g.ndata["features"][
        blocks[0].srcdata[dgl.NID]
    ]
    blocks[-1].dstdata["label"] = labels[blocks[-1].dstdata[dgl.NID]]
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
189
    ret_blocks = []
190
    ret_hist_blocks = []
191
    for i, (block, hist_block) in enumerate(zip(blocks, hist_blocks)):
192
193
        hist_col = "features" if i == 0 else "hist_%d" % i
        block.srcdata["hist"] = g.ndata[hist_col][block.srcdata[dgl.NID]]
194
195

        # Aggregate history
196
197
198
        hist_block.srcdata["hist"] = g.ndata[hist_col][
            hist_block.srcdata[dgl.NID]
        ]
199
        if aggregation_on_device:
200
            hist_block = hist_block.to(dev_id)
201
        hist_block.update_all(fn.copy_u("hist", "m"), fn.mean("m", "agg_hist"))
202

203
        block = block.int().to(dev_id)
204
        if not aggregation_on_device:
205
            hist_block = hist_block.to(dev_id)
206
        block.dstdata["agg_hist"] = hist_block.dstdata["agg_hist"]
207
208
209
        ret_blocks.append(block)
        ret_hist_blocks.append(hist_block)
    return ret_blocks, ret_hist_blocks
210

211

212
213
def init_history(g, model, dev_id):
    with th.no_grad():
214
        history = model.inference(g, g.ndata["features"], 1000, dev_id)[1]
215
216
        for layer in range(args.num_layers + 1):
            if layer > 0:
217
218
219
                hist_col = "hist_%d" % layer
                g.ndata["hist_%d" % layer] = history[layer - 1]

220
221
222
223

def update_history(g, blocks):
    with th.no_grad():
        for i, block in enumerate(blocks):
224
            ids = block.dstdata[dgl.NID].cpu()
225
            hist_col = "hist_%d" % (i + 1)
226

227
            h_new = block.dstdata["h_new"].cpu()
228
229
            g.ndata[hist_col][ids] = h_new

230

231
232
233
234
235
236
237
def run(args, dev_id, data):
    dropout = 0.2

    th.cuda.set_device(dev_id)

    # Unpack data
    train_mask, val_mask, in_feats, labels, n_classes, g = data
238
239
    train_nid = train_mask.nonzero().squeeze()
    val_nid = val_mask.nonzero().squeeze()
240
241

    # Create sampler
242
    sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(",")])
243
244
245
246
247
248
249
250

    # Create PyTorch DataLoader for constructing blocks
    dataloader = DataLoader(
        dataset=train_nid.numpy(),
        batch_size=args.batch_size,
        collate_fn=sampler.sample_blocks,
        shuffle=True,
        drop_last=False,
251
252
        num_workers=args.num_workers_per_gpu,
    )
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

    # Define model
    model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu)

    # Move the model to GPU and define optimizer
    model = model.to(dev_id)
    loss_fcn = nn.CrossEntropyLoss()
    loss_fcn = loss_fcn.to(dev_id)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Compute history tensor and their aggregation before training on CPU
    model.eval()
    init_history(g, model, dev_id)
    model.train()

    # Training loop
    avg = 0
    iter_tput = []
    for epoch in range(args.num_epochs):
        tic = time.time()
        model.train()
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
274
        tic_step = time.time()
275
276
277
278
279
280
        for step, (blocks, hist_blocks) in enumerate(dataloader):
            # The nodes for input lies at the LHS side of the first block.
            # The nodes for output lies at the RHS side of the last block.
            input_nodes = blocks[0].srcdata[dgl.NID]
            seeds = blocks[-1].dstdata[dgl.NID]

281
282
283
            blocks, hist_blocks = load_subtensor(
                g, labels, blocks, hist_blocks, dev_id, True
            )
284
285
286
287
288
289

            # forward
            batch_pred = model(blocks)
            # update history
            update_history(g, blocks)
            # compute loss
290
            batch_labels = blocks[-1].dstdata["label"]
291
292
293
294
295
296
297
298
            loss = loss_fcn(batch_pred, batch_labels)
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            iter_tput.append(len(seeds) / (time.time() - tic_step))
            if step % args.log_every == 0:
                acc = compute_acc(batch_pred, batch_labels)
299
300
301
302
303
304
305
306
307
                print(
                    "Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}".format(
                        epoch,
                        step,
                        loss.item(),
                        acc.item(),
                        np.mean(iter_tput[3:]),
                    )
                )
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
308
            tic_step = time.time()
309
        toc = time.time()
310
        print("Epoch Time(s): {:.4f}".format(toc - tic))
311
312
313
314
        if epoch >= 5:
            avg += toc - tic
        if epoch % args.eval_every == 0 and epoch != 0:
            model.eval()
315
316
317
318
319
320
            eval_acc = evaluate(
                model, g, labels, val_nid, args.val_batch_size, dev_id
            )
            print("Eval Acc {:.4f}".format(eval_acc))

    print("Avg epoch time: {}".format(avg / (epoch - 4)))
321
322


323
if __name__ == "__main__":
324
    argparser = argparse.ArgumentParser("multi-gpu training")
325
326
327
328
329
330
331
332
333
334
335
    argparser.add_argument("--gpu", type=str, default="0")
    argparser.add_argument("--num-epochs", type=int, default=20)
    argparser.add_argument("--num-hidden", type=int, default=16)
    argparser.add_argument("--num-layers", type=int, default=2)
    argparser.add_argument("--fan-out", type=str, default="1,1")
    argparser.add_argument("--batch-size", type=int, default=1000)
    argparser.add_argument("--val-batch-size", type=int, default=1000)
    argparser.add_argument("--log-every", type=int, default=20)
    argparser.add_argument("--eval-every", type=int, default=5)
    argparser.add_argument("--lr", type=float, default=0.003)
    argparser.add_argument("--num-workers-per-gpu", type=int, default=0)
336
337
338
339
    args = argparser.parse_args()

    # load reddit data
    data = RedditDataset(self_loop=True)
Xiangkun Hu's avatar
Xiangkun Hu committed
340
341
    n_classes = data.num_classes
    g = data[0]
342
    features = g.ndata["feat"]
343
    in_feats = features.shape[1]
344
345
346
347
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    g.ndata["features"] = features
348
    g.create_formats_()
349
350
351
352
    # Pack data
    data = train_mask, val_mask, in_feats, labels, n_classes, g

    run(args, int(args.gpu), data)