"scripts/ci/ci_install_rust.sh" did not exist on "a5e0defb5a560a6d42882008c1dd8a739002ab7d"
main.py 11.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import dgl
from functools import partial
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import dgl.nn.pytorch as dglnn
import time
import argparse
import tqdm
from ogb.nodeproppred import DglNodePropPredDataset

from sampler import ClusterIter, subgraph_collate_fn

class GAT(nn.Module):
    def __init__(self,
                 in_feats,
                 num_heads,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout=0.):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.num_heads = num_heads
        self.layers.append(dglnn.GATConv(in_feats,
                                         n_hidden,
                                         num_heads=num_heads,
                                         feat_drop=dropout,
                                         attn_drop=dropout,
                                         activation=activation,
                                         negative_slope=0.2))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.GATConv(n_hidden * num_heads,
                                             n_hidden,
                                             num_heads=num_heads,
                                             feat_drop=dropout,
                                             attn_drop=dropout,
                                             activation=activation,
                                             negative_slope=0.2))
        self.layers.append(dglnn.GATConv(n_hidden * num_heads,
                                         n_classes,
                                         num_heads=num_heads,
                                         feat_drop=dropout,
                                         attn_drop=dropout,
                                         activation=None,
                                         negative_slope=0.2))

    def forward(self, g, x):
        h = x
        for l, conv in enumerate(self.layers):
            h = conv(g, h)
            if l < len(self.layers) - 1:
                h = h.flatten(1)
        h = h.mean(1)
        return h.log_softmax(dim=-1)

    def inference(self, g, x, batch_size, device):
        """
        Inference with the GAT model on full neighbors (i.e. without neighbor sampling).
        g : the entire graph.
        x : the input of entire node set.
        The inference code is written in a fashion that it could handle any number of nodes and
        layers.
        """
        num_heads = self.num_heads
        for l, layer in enumerate(self.layers):
            if l < self.n_layers - 1:
75
                y = th.zeros(g.num_nodes(), self.n_hidden * num_heads if l != len(self.layers) - 1 else self.n_classes)
76
            else:
77
                y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
78
            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
79
            dataloader = dgl.dataloading.DataLoader(
80
                    g,
81
                    th.arange(g.num_nodes()),
82
83
84
85
86
87
88
                    sampler,
                    batch_size=batch_size,
                    shuffle=False,
                    drop_last=False,
                    num_workers=args.num_workers)

            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
89
                block = blocks[0].int().to(device)
90
91
                h = x[input_nodes].to(device)
                if l < self.n_layers - 1:
92
                   h = layer(block, h).flatten(1)
93
                else:
94
                    h = layer(block, h)
95
96
97
98
99
100
101
102
103
104
105
106
107
                    h = h.mean(1)
                    h = h.log_softmax(dim=-1)

                y[output_nodes] = h.cpu()
            x = y
        return y

def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

108
def evaluate(model, g, nfeat, labels, val_nid, test_nid, batch_size, device):
109
110
111
112
113
114
115
116
117
118
119
    """
    Evaluate the model on the validation set specified by ``val_mask``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
120
        pred = model.inference(g, nfeat, batch_size, device)
121
122
123
124
125
126
127
128
129
130
131
132
    model.train()
    return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid]), pred

def model_param_summary(model):
    """ Count the model parameters """
    cnt = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Total Params {}".format(cnt))

#### Entry point
def run(args, device, data):
    # Unpack data
    train_nid, val_nid, test_nid, in_feats, labels, n_classes, g, cluster_iterator = data
133
134
    labels = labels.to(device)
    nfeat = g.ndata.pop('feat').to(device)
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

    # Define model and optimizer
    model = GAT(in_feats, args.num_heads, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
    model_param_summary(model)
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)

    # Training loop
    avg = 0
    best_eval_acc = 0
    best_test_acc = 0
    for epoch in range(args.num_epochs):
        iter_load = 0
        iter_far = 0
        iter_back = 0
        tic = time.time()

        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.
        tic_start = time.time()
        for step, cluster in enumerate(cluster_iterator):
156
            mask = cluster.ndata.pop('train_mask')
157
158
            if mask.sum() == 0:
                continue
159
160
161
162
163
            cluster.edata.pop(dgl.EID)
            cluster = cluster.int().to(device)
            input_nodes = cluster.ndata[dgl.NID]
            batch_inputs = nfeat[input_nodes]
            batch_labels = labels[input_nodes]
164
165
166
            tic_step = time.time()

            # Compute loss and prediction
167
            batch_pred = model(cluster, batch_inputs)
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
            batch_pred = batch_pred[mask]
            batch_labels = batch_labels[mask]
            loss = nn.functional.nll_loss(batch_pred, batch_labels)
            optimizer.zero_grad()
            tic_far = time.time()
            loss.backward()
            optimizer.step()
            tic_back = time.time()
            iter_load += (tic_step - tic_start)
            iter_far += (tic_far - tic_step)
            iter_back += (tic_back - tic_far)

            if step % args.log_every == 0:
                acc = compute_acc(batch_pred, batch_labels)
                gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
maqy1995's avatar
maqy1995 committed
183
                print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | GPU {:.1f} MB'.format(
184
185
186
187
188
189
190
191
192
                    epoch, step, loss.item(), acc.item(), gpu_mem_alloc))
                tic_start = time.time()

        toc = time.time()
        print('Epoch Time(s): {:.4f} Load {:.4f} Forward {:.4f} Backward {:.4f}'.format(toc - tic, iter_load, iter_far, iter_back))
        if epoch >= 5:
            avg += toc - tic

        if epoch % args.eval_every == 0 and epoch != 0:
193
            eval_acc, test_acc, pred = evaluate(model, g, nfeat, labels, val_nid, test_nid, args.val_batch_size, device)
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
            model = model.to(device)
            if args.save_pred:
                np.savetxt(args.save_pred + '%02d' % epoch, pred.argmax(1).cpu().numpy(), '%d')
            print('Eval Acc {:.4f}'.format(eval_acc))
            if eval_acc > best_eval_acc:
                best_eval_acc = eval_acc
                best_test_acc = test_acc
            print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc))
    print('Avg epoch time: {}'.format(avg / (epoch - 4)))
    return best_test_acc

if __name__ == '__main__':
    argparser = argparse.ArgumentParser("multi-gpu training")
    argparser.add_argument('--gpu', type=int, default=0,
            help="GPU device ID. Use -1 for CPU training")
    argparser.add_argument('--num-epochs', type=int, default=20)
    argparser.add_argument('--num-hidden', type=int, default=128)
    argparser.add_argument('--num-layers', type=int, default=3)
    argparser.add_argument('--num-heads', type=int, default=8)
    argparser.add_argument('--batch-size', type=int, default=32)
    argparser.add_argument('--val-batch-size', type=int, default=2000)
    argparser.add_argument('--log-every', type=int, default=20)
    argparser.add_argument('--eval-every', type=int, default=1)
    argparser.add_argument('--lr', type=float, default=0.001)
    argparser.add_argument('--dropout', type=float, default=0.5)
    argparser.add_argument('--save-pred', type=str, default='')
    argparser.add_argument('--wd', type=float, default=0)
    argparser.add_argument('--num_partitions', type=int, default=15000)
    argparser.add_argument('--num-workers', type=int, default=0)
223
224
225
226
227
    argparser.add_argument('--data-cpu', action='store_true',
                           help="By default the script puts all node features and labels "
                                "on GPU when using it to save time for data copy. This may "
                                "be undesired if they cannot fit in GPU memory at once. "
                                "This flag disables that.")
228
229
230
231
232
233
234
    args = argparser.parse_args()

    if args.gpu >= 0:
        device = th.device('cuda:%d' % args.gpu)
    else:
        device = th.device('cpu')

235
    # load ogbn-products data
236
237
238
239
240
    data = DglNodePropPredDataset(name='ogbn-products')
    splitted_idx = data.get_idx_split()
    train_idx, val_idx, test_idx = splitted_idx['train'], splitted_idx['valid'], splitted_idx['test']
    graph, labels = data[0]
    labels = labels[:, 0]
241
    print('Total edges before adding self-loop {}'.format(graph.num_edges()))
242
243
    graph = dgl.remove_self_loop(graph)
    graph = dgl.add_self_loop(graph)
244
    print('Total edges after adding self-loop {}'.format(graph.num_edges()))
245
    num_nodes = train_idx.shape[0] + val_idx.shape[0] + test_idx.shape[0]
246
    assert num_nodes == graph.num_nodes()
247
248
249
250
251
252
253
254
255
256
    mask = th.zeros(num_nodes, dtype=th.bool)
    mask[train_idx] = True
    graph.ndata['train_mask'] = mask

    graph.in_degrees(0)
    graph.out_degrees(0)
    graph.find_edges(0)

    cluster_iter_data = ClusterIter(
            'ogbn-products', graph, args.num_partitions, args.batch_size)
257
258
259
    cluster_iterator = DataLoader(cluster_iter_data, batch_size=args.batch_size, shuffle=True,
                                  pin_memory=True, num_workers=4,
                                  collate_fn=partial(subgraph_collate_fn, graph))
260
261
262
263
264
265
266
267
268
269
270

    in_feats = graph.ndata['feat'].shape[1]
    n_classes = (labels.max() + 1).item()
    # Pack data
    data = train_idx, val_idx, test_idx, in_feats, labels, n_classes, graph, cluster_iterator

    # Run 10 times
    test_accs = []
    for i in range(10):
        test_accs.append(run(args, device, data))
        print('Average test accuracy:', np.mean(test_accs), '±', np.std(test_accs))