"vscode:/vscode.git/clone" did not exist on "408eba247911f1942266fb2e7705f0e1db19a6ee"
train_sampling.py 9.91 KB
Newer Older
1
2
3
4
5
6
7
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
8
from torch.utils.data import DataLoader
9
10
11
12
13
14
15
16
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
import tqdm
17
import traceback
18

19
20
from load_graph import load_reddit, load_ogb

21
22
23
#### Neighbor sampler

class NeighborSampler(object):
24
    def __init__(self, g, fanouts, sample_neighbors):
25
26
        self.g = g
        self.fanouts = fanouts
27
        self.sample_neighbors = sample_neighbors
28
29

    def sample_blocks(self, seeds):
30
        seeds = th.LongTensor(np.asarray(seeds))
31
32
33
        blocks = []
        for fanout in self.fanouts:
            # For each seed node, sample ``fanout`` neighbors.
34
            frontier = self.sample_neighbors(self.g, seeds, fanout, replace=True)
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
            # Then we compact the frontier into a bipartite graph for message passing.
            block = dgl.to_block(frontier, seeds)
            # Obtain the seed nodes for next layer.
            seeds = block.srcdata[dgl.NID]

            blocks.insert(0, block)
        return blocks

class SAGE(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
56
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
57
        for i in range(1, n_layers - 1):
58
59
60
61
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation
62
63
64

    def forward(self, blocks, x):
        h = x
65
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
66
67
68
69
            # We need to first copy the representation of nodes on the RHS from the
            # appropriate nodes on the LHS.
            # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
            # would be (num_nodes_RHS, D)
70
            h_dst = h[:block.number_of_dst_nodes()]
71
72
73
            # Then we compute the updated representation on the RHS.
            # The shape of h now becomes (num_nodes_RHS, D)
            h = layer(block, (h, h_dst))
74
75
76
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
77
78
79
80
81
82
83
        return h

    def inference(self, g, x, batch_size, device):
        """
        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
        g : the entire graph.
        x : the input of entire node set.
84

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        The inference code is written in a fashion that it could handle any number of nodes and
        layers.
        """
        # During inference with sampling, multi-layer blocks are very inefficient because
        # lots of computations in the first few layers are repeated.
        # Therefore, we compute the representation of all nodes layer by layer.  The nodes
        # on each layer are of course splitted in batches.
        # TODO: can we standardize this?
        nodes = th.arange(g.number_of_nodes())
        for l, layer in enumerate(self.layers):
            y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)

            for start in tqdm.trange(0, len(nodes), batch_size):
                end = start + batch_size
                batch_nodes = nodes[start:end]
                block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
101
                input_nodes = block.srcdata[dgl.NID]
102

103
                h = x[input_nodes].to(device)
104
                h_dst = h[:block.number_of_dst_nodes()]
105
                h = layer(block, (h, h_dst))
106
107
108
                if l != len(self.layers) - 1:
                    h = self.activation(h)
                    h = self.dropout(h)
109
110
111
112
113
114

                y[start:end] = h.cpu()

            x = y
        return y

115
116
117
118
119
def prepare_mp(g):
    """
    Explicitly materialize the CSR, CSC and COO representation of the given graph
    so that they could be shared via copy-on-write to sampler workers and GPU
    trainers.
120

121
122
123
124
125
126
    This is a workaround before full shared memory support on heterogeneous graphs.
    """
    g.in_degree(0)
    g.out_degree(0)
    g.find_edges([0])

127
128
129
130
131
132
def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

133
def evaluate(model, g, inputs, labels, val_nid, batch_size, device):
134
    """
135
    Evaluate the model on the validation set specified by ``val_nid``.
136
137
138
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
139
    val_nid : the node Ids for validation.
140
141
142
143
144
145
146
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
        pred = model.inference(g, inputs, batch_size, device)
    model.train()
147
    return compute_acc(pred[val_nid], labels[val_nid])
148

149
def load_subtensor(g, seeds, input_nodes, device):
150
151
152
    """
    Copys features and labels of a set of nodes onto GPU.
    """
153
    batch_inputs = g.ndata['features'][input_nodes].to(device)
154
    batch_labels = g.ndata['labels'][seeds].to(device)
155
156
157
    return batch_inputs, batch_labels

#### Entry point
158
def run(args, device, data):
159
    # Unpack data
160
161
162
    train_mask, val_mask, in_feats, n_classes, g = data
    train_nid = th.nonzero(train_mask, as_tuple=True)[0]
    val_nid = th.nonzero(val_mask, as_tuple=True)[0]
163
164

    # Create sampler
165
166
    sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')],
                              dgl.sampling.sample_neighbors)
167
168
169
170
171
172
173
174

    # Create PyTorch DataLoader for constructing blocks
    dataloader = DataLoader(
        dataset=train_nid.numpy(),
        batch_size=args.batch_size,
        collate_fn=sampler.sample_blocks,
        shuffle=True,
        drop_last=False,
175
        num_workers=args.num_workers)
176
177

    # Define model and optimizer
178
179
    model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
    model = model.to(device)
180
    loss_fcn = nn.CrossEntropyLoss()
181
    loss_fcn = loss_fcn.to(device)
182
183
184
185
186
187
188
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    avg = 0
    iter_tput = []
    for epoch in range(args.num_epochs):
        tic = time.time()
189
190
191
192

        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.
        for step, blocks in enumerate(dataloader):
193
            tic_step = time.time()
194

195
196
197
198
199
            # The nodes for input lies at the LHS side of the first block.
            # The nodes for output lies at the RHS side of the last block.
            input_nodes = blocks[0].srcdata[dgl.NID]
            seeds = blocks[-1].dstdata[dgl.NID]

200
            # Load the input features as well as output labels
201
            batch_inputs, batch_labels = load_subtensor(g, seeds, input_nodes, device)
202
203
204
205
206
207
208
209

            # Compute loss and prediction
            batch_pred = model(blocks, batch_inputs)
            loss = loss_fcn(batch_pred, batch_labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

210
211
            iter_tput.append(len(seeds) / (time.time() - tic_step))
            if step % args.log_every == 0:
212
                acc = compute_acc(batch_pred, batch_labels)
213
                gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
214
                print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MiB'.format(
215
                    epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc))
216
217

        toc = time.time()
218
219
220
221
        print('Epoch Time(s): {:.4f}'.format(toc - tic))
        if epoch >= 5:
            avg += toc - tic
        if epoch % args.eval_every == 0 and epoch != 0:
222
            eval_acc = evaluate(model, g, g.ndata['features'], g.ndata['labels'], val_nid, args.batch_size, device)
223
224
225
            print('Eval Acc {:.4f}'.format(eval_acc))

    print('Avg epoch time: {}'.format(avg / (epoch - 4)))
226
227
228

if __name__ == '__main__':
    argparser = argparse.ArgumentParser("multi-gpu training")
229
230
    argparser.add_argument('--gpu', type=int, default=0,
        help="GPU device ID. Use -1 for CPU training")
231
    argparser.add_argument('--dataset', type=str, default='reddit')
232
233
234
    argparser.add_argument('--num-epochs', type=int, default=20)
    argparser.add_argument('--num-hidden', type=int, default=16)
    argparser.add_argument('--num-layers', type=int, default=2)
235
    argparser.add_argument('--fan-out', type=str, default='10,25')
236
237
238
239
    argparser.add_argument('--batch-size', type=int, default=1000)
    argparser.add_argument('--log-every', type=int, default=20)
    argparser.add_argument('--eval-every', type=int, default=5)
    argparser.add_argument('--lr', type=float, default=0.003)
240
241
242
    argparser.add_argument('--dropout', type=float, default=0.5)
    argparser.add_argument('--num-workers', type=int, default=0,
        help="Number of sampling processes. Use 0 for no extra process.")
243
244
    args = argparser.parse_args()
    
245
246
247
248
    if args.gpu >= 0:
        device = th.device('cuda:%d' % args.gpu)
    else:
        device = th.device('cpu')
249

250
251
252
253
254
255
256
257
258
259
    if args.dataset == 'reddit':
        g, n_classes = load_reddit()
    elif args.dataset == 'ogb-product':
        g, n_classes = load_ogb('ogbn-products')
    else:
        raise Exception('unknown dataset')
    g = dgl.as_heterograph(g)
    in_feats = g.ndata['features'].shape[1]
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
260
    prepare_mp(g)
261
    # Pack data
262
    data = train_mask, val_mask, in_feats, n_classes, g
263

264
    run(args, device, data)