train_dist.py 14.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import os
os.environ['DGLBACKEND']='pytorch'
from multiprocessing import Process
import argparse, time, math
import numpy as np
from functools import wraps
import tqdm

import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.data.utils import load_graphs
import dgl.function as fn
import dgl.nn.pytorch as dglnn
15
from dgl.distributed import DistDataLoader
16
17
18
19
20
21
22

import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
23
import socket
24

25
def load_subtensor(g, seeds, input_nodes, device, load_feat=True):
26
27
28
    """
    Copys features and labels of a set of nodes onto GPU.
    """
29
    batch_inputs = g.ndata['features'][input_nodes].to(device) if load_feat else None
30
31
32
    batch_labels = g.ndata['labels'][seeds].to(device)
    return batch_inputs, batch_labels

33
class NeighborSampler(object):
34
    def __init__(self, g, fanouts, sample_neighbors, device, load_feat=True):
35
36
37
        self.g = g
        self.fanouts = fanouts
        self.sample_neighbors = sample_neighbors
38
        self.device = device
39
        self.load_feat=load_feat
40
41
42
43
44
45
46
47
48
49
50
51
52

    def sample_blocks(self, seeds):
        seeds = th.LongTensor(np.asarray(seeds))
        blocks = []
        for fanout in self.fanouts:
            # For each seed node, sample ``fanout`` neighbors.
            frontier = self.sample_neighbors(self.g, seeds, fanout, replace=True)
            # Then we compact the frontier into a bipartite graph for message passing.
            block = dgl.to_block(frontier, seeds)
            # Obtain the seed nodes for next layer.
            seeds = block.srcdata[dgl.NID]

            blocks.insert(0, block)
53
54
55

        input_nodes = blocks[0].srcdata[dgl.NID]
        seeds = blocks[-1].dstdata[dgl.NID]
56
57
58
        batch_inputs, batch_labels = load_subtensor(self.g, seeds, input_nodes, "cpu", self.load_feat)
        if self.load_feat:
            blocks[0].srcdata['features'] = batch_inputs
59
        blocks[-1].dstdata['labels'] = batch_labels
60
        return blocks
61

62
class DistSAGE(nn.Module):
63
64
    def __init__(self, in_feats, n_hidden, n_classes, n_layers,
                 activation, dropout):
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    def inference(self, g, x, batch_size, device):
        """
        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
        g : the entire graph.
        x : the input of entire node set.

        The inference code is written in a fashion that it could handle any number of nodes and
        layers.
        """
        # During inference with sampling, multi-layer blocks are very inefficient because
        # lots of computations in the first few layers are repeated.
        # Therefore, we compute the representation of all nodes layer by layer.  The nodes
        # on each layer are of course splitted in batches.
        # TODO: can we standardize this?
        nodes = dgl.distributed.node_split(np.arange(g.number_of_nodes()),
                                           g.get_partition_book(), force_even=True)
102
        y = dgl.distributed.DistTensor((g.number_of_nodes(), self.n_hidden), th.float32, 'h',
103
104
105
                                       persistent=True)
        for l, layer in enumerate(self.layers):
            if l == len(self.layers) - 1:
106
                y = dgl.distributed.DistTensor((g.number_of_nodes(), self.n_classes),
107
108
                                               th.float32, 'h_last', persistent=True)

109
            sampler = NeighborSampler(g, [-1], dgl.distributed.sample_neighbors, device)
110
111
            print('|V|={}, eval batch size: {}'.format(g.number_of_nodes(), batch_size))
            # Create PyTorch DataLoader for constructing blocks
112
            dataloader = DistDataLoader(
113
114
115
116
                dataset=nodes,
                batch_size=batch_size,
                collate_fn=sampler.sample_blocks,
                shuffle=False,
117
                drop_last=False)
118
119

            for blocks in tqdm.tqdm(dataloader):
120
                block = blocks[0].to(device)
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
                input_nodes = block.srcdata[dgl.NID]
                output_nodes = block.dstdata[dgl.NID]
                h = x[input_nodes].to(device)
                h_dst = h[:block.number_of_dst_nodes()]
                h = layer(block, (h, h_dst))
                if l != len(self.layers) - 1:
                    h = self.activation(h)
                    h = self.dropout(h)

                y[output_nodes] = h.cpu()

            x = y
            g.barrier()
        return y

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
    labels = labels.long()
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
    """
    Evaluate the model on the validation set specified by ``val_nid``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_nid : the node Ids for validation.
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
        pred = model.inference(g, inputs, batch_size, device)
    model.train()
    return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid])

159

160
161
def run(args, device, data):
    # Unpack data
162
    train_nid, val_nid, test_nid, in_feats, n_classes, g = data
163
    shuffle = True
164
165
    # Create sampler
    sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')],
166
                              dgl.distributed.sample_neighbors, device)
167

168
169
    # Create DataLoader for constructing blocks
    dataloader = DistDataLoader(
170
171
172
        dataset=train_nid.numpy(),
        batch_size=args.batch_size,
        collate_fn=sampler.sample_blocks,
173
        shuffle=shuffle,
174
        drop_last=False)
175
176

    # Define model and optimizer
177
    model = DistSAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
178
    model = model.to(device)
179
    if not args.standalone:
180
181
182
        if args.num_gpus == -1:
            model = th.nn.parallel.DistributedDataParallel(model)
        else:
183
            model = th.nn.parallel.DistributedDataParallel(model, device_ids=[device], output_device=device)
184
185
186
187
188
189
190
    loss_fcn = nn.CrossEntropyLoss()
    loss_fcn = loss_fcn.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    iter_tput = []
    epoch = 0
191
    for epoch in range(args.num_epochs):
192
193
194
195
196
197
198
199
200
201
202
203
        tic = time.time()

        sample_time = 0
        forward_time = 0
        backward_time = 0
        update_time = 0
        num_seeds = 0
        num_inputs = 0
        start = time.time()
        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.
        step_time = []
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

        with model.join():
            for step, blocks in enumerate(dataloader):
                tic_step = time.time()
                sample_time += tic_step - start

                # The nodes for input lies at the LHS side of the first block.
                # The nodes for output lies at the RHS side of the last block.
                batch_inputs = blocks[0].srcdata['features']
                batch_labels = blocks[-1].dstdata['labels']
                batch_labels = batch_labels.long()

                num_seeds += len(blocks[-1].dstdata[dgl.NID])
                num_inputs += len(blocks[0].srcdata[dgl.NID])
                blocks = [block.to(device) for block in blocks]
                batch_labels = batch_labels.to(device)
                # Compute loss and prediction
                start = time.time()
                #print(g.rank(), blocks[0].device, model.module.layers[0].fc_neigh.weight.device, dev_id)
                batch_pred = model(blocks, batch_inputs)
                loss = loss_fcn(batch_pred, batch_labels)
                forward_end = time.time()
                optimizer.zero_grad()
                loss.backward()
                compute_end = time.time()
                forward_time += forward_end - start
                backward_time += compute_end - forward_end

                optimizer.step()
                update_time += time.time() - compute_end

                step_t = time.time() - tic_step
                step_time.append(step_t)
                iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t)
                if step % args.log_every == 0:
                    acc = compute_acc(batch_pred, batch_labels)
                    gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
                    print('Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB | time {:.3f} s'.format(
                        g.rank(), epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc, np.sum(step_time[-args.log_every:])))
                start = time.time()
244
245

        toc = time.time()
246
247
        print('Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}'.format(
            g.rank(), toc - tic, sample_time, forward_time, backward_time, update_time, num_seeds, num_inputs))
248
249
250
        epoch += 1


251
252
        if epoch % args.eval_every == 0 and epoch != 0:
            start = time.time()
253
254
255
256
            val_acc, test_acc = evaluate(model.module, g, g.ndata['features'],
                                         g.ndata['labels'], val_nid, test_nid, args.batch_size_eval, device)
            print('Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}'.format(g.rank(), val_acc, test_acc,
                                                                                  time.time() - start))
257
258

def main(args):
259
    print(socket.gethostname(), 'Initializing DGL dist')
260
    dgl.distributed.initialize(args.ip_config, net_type=args.net_type)
261
    if not args.standalone:
262
        print(socket.gethostname(), 'Initializing DGL process group')
263
        th.distributed.init_process_group(backend=args.backend)
264
    print(socket.gethostname(), 'Initializing DistGraph')
265
    g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
266
    print(socket.gethostname(), 'rank:', g.rank())
267

268
    pb = g.get_partition_book()
269
270
271
272
273
274
275
276
277
278
279
    if 'trainer_id' in g.ndata:
        train_nid = dgl.distributed.node_split(g.ndata['train_mask'], pb, force_even=True,
                                               node_trainer_ids=g.ndata['trainer_id'])
        val_nid = dgl.distributed.node_split(g.ndata['val_mask'], pb, force_even=True,
                                             node_trainer_ids=g.ndata['trainer_id'])
        test_nid = dgl.distributed.node_split(g.ndata['test_mask'], pb, force_even=True,
                                              node_trainer_ids=g.ndata['trainer_id'])
    else:
        train_nid = dgl.distributed.node_split(g.ndata['train_mask'], pb, force_even=True)
        val_nid = dgl.distributed.node_split(g.ndata['val_mask'], pb, force_even=True)
        test_nid = dgl.distributed.node_split(g.ndata['test_mask'], pb, force_even=True)
280
281
282
283
284
    local_nid = pb.partid2nids(pb.partid).detach().numpy()
    print('part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})'.format(
        g.rank(), len(train_nid), len(np.intersect1d(train_nid.numpy(), local_nid)),
        len(val_nid), len(np.intersect1d(val_nid.numpy(), local_nid)),
        len(test_nid), len(np.intersect1d(test_nid.numpy(), local_nid))))
285
    del local_nid
286
287
288
    if args.num_gpus == -1:
        device = th.device('cpu')
    else:
289
290
        dev_id = g.rank() % args.num_gpus
        device = th.device('cuda:'+str(dev_id))
291
292
293
294
295
    n_classes = args.n_classes
    if n_classes == -1:
        labels = g.ndata['labels'][np.arange(g.number_of_nodes())]
        n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
        del labels
296
    print('#labels:', n_classes)
297
298
299

    # Pack data
    in_feats = g.ndata['features'].shape[1]
300
    data = train_nid, val_nid, test_nid, in_feats, n_classes, g
301
302
303
304
305
306
    run(args, device, data)
    print("parent ends")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
    register_data_args(parser)
307
    parser.add_argument('--graph_name', type=str, help='graph name')
308
309
    parser.add_argument('--id', type=int, help='the partition id')
    parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
310
    parser.add_argument('--part_config', type=str, help='The path to the partition config file')
311
    parser.add_argument('--num_clients', type=int, help='The number of clients')
312
313
314
315
316
317
    parser.add_argument('--n_classes', type=int, default=-1,
                        help='The number of classes. If not specified, this'
                        ' value will be calculated via scaning all the labels'
                        ' in the dataset which probably causes memory burst.')
    parser.add_argument('--backend', type=str, default='gloo',
                        help='pytorch distributed backend')
318
    parser.add_argument('--num_gpus', type=int, default=-1,
319
                        help="the number of GPU device. Use -1 for CPU training")
320
321
322
323
324
325
326
327
    parser.add_argument('--num_epochs', type=int, default=20)
    parser.add_argument('--num_hidden', type=int, default=16)
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--fan_out', type=str, default='10,25')
    parser.add_argument('--batch_size', type=int, default=1000)
    parser.add_argument('--batch_size_eval', type=int, default=100000)
    parser.add_argument('--log_every', type=int, default=20)
    parser.add_argument('--eval_every', type=int, default=5)
328
329
330
    parser.add_argument('--lr', type=float, default=0.003)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--local_rank', type=int, help='get rank of the process')
331
    parser.add_argument('--standalone', action='store_true', help='run in the standalone mode')
332
333
    parser.add_argument('--pad-data', default=False, action='store_true',
                        help='Pad train nid to the same length across machine, to ensure num of batches to be the same.')
334
335
    parser.add_argument('--net_type', type=str, default='socket',
                        help="backend net type, 'socket' or 'tensorpipe'")
336
337
338
    args = parser.parse_args()

    print(args)
339
    main(args)