train_dist.py 13.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import os
os.environ['DGLBACKEND']='pytorch'
from multiprocessing import Process
import argparse, time, math
import numpy as np
from functools import wraps
import tqdm

import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.data.utils import load_graphs
import dgl.function as fn
import dgl.nn.pytorch as dglnn
15
from dgl.distributed import DistDataLoader
16
17
18
19
20
21
22
23

import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader

24
def load_subtensor(g, seeds, input_nodes, device, load_feat=True):
25
26
27
    """
    Copys features and labels of a set of nodes onto GPU.
    """
28
    batch_inputs = g.ndata['features'][input_nodes].to(device) if load_feat else None
29
30
31
    batch_labels = g.ndata['labels'][seeds].to(device)
    return batch_inputs, batch_labels

32
class NeighborSampler(object):
33
    def __init__(self, g, fanouts, sample_neighbors, device, load_feat=True):
34
35
36
        self.g = g
        self.fanouts = fanouts
        self.sample_neighbors = sample_neighbors
37
        self.device = device
38
        self.load_feat=load_feat
39
40
41
42
43
44
45
46
47
48
49
50
51

    def sample_blocks(self, seeds):
        seeds = th.LongTensor(np.asarray(seeds))
        blocks = []
        for fanout in self.fanouts:
            # For each seed node, sample ``fanout`` neighbors.
            frontier = self.sample_neighbors(self.g, seeds, fanout, replace=True)
            # Then we compact the frontier into a bipartite graph for message passing.
            block = dgl.to_block(frontier, seeds)
            # Obtain the seed nodes for next layer.
            seeds = block.srcdata[dgl.NID]

            blocks.insert(0, block)
52
53
54

        input_nodes = blocks[0].srcdata[dgl.NID]
        seeds = blocks[-1].dstdata[dgl.NID]
55
56
57
        batch_inputs, batch_labels = load_subtensor(self.g, seeds, input_nodes, "cpu", self.load_feat)
        if self.load_feat:
            blocks[0].srcdata['features'] = batch_inputs
58
        blocks[-1].dstdata['labels'] = batch_labels
59
        return blocks
60

61
class DistSAGE(nn.Module):
62
63
    def __init__(self, in_feats, n_hidden, n_classes, n_layers,
                 activation, dropout):
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

    def inference(self, g, x, batch_size, device):
        """
        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
        g : the entire graph.
        x : the input of entire node set.

        The inference code is written in a fashion that it could handle any number of nodes and
        layers.
        """
        # During inference with sampling, multi-layer blocks are very inefficient because
        # lots of computations in the first few layers are repeated.
        # Therefore, we compute the representation of all nodes layer by layer.  The nodes
        # on each layer are of course splitted in batches.
        # TODO: can we standardize this?
        nodes = dgl.distributed.node_split(np.arange(g.number_of_nodes()),
                                           g.get_partition_book(), force_even=True)
101
        y = dgl.distributed.DistTensor((g.number_of_nodes(), self.n_hidden), th.float32, 'h',
102
103
104
                                       persistent=True)
        for l, layer in enumerate(self.layers):
            if l == len(self.layers) - 1:
105
                y = dgl.distributed.DistTensor((g.number_of_nodes(), self.n_classes),
106
107
                                               th.float32, 'h_last', persistent=True)

108
            sampler = NeighborSampler(g, [-1], dgl.distributed.sample_neighbors, device)
109
110
            print('|V|={}, eval batch size: {}'.format(g.number_of_nodes(), batch_size))
            # Create PyTorch DataLoader for constructing blocks
111
            dataloader = DistDataLoader(
112
113
114
115
                dataset=nodes,
                batch_size=batch_size,
                collate_fn=sampler.sample_blocks,
                shuffle=False,
116
                drop_last=False)
117
118

            for blocks in tqdm.tqdm(dataloader):
119
                block = blocks[0].to(device)
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
                input_nodes = block.srcdata[dgl.NID]
                output_nodes = block.dstdata[dgl.NID]
                h = x[input_nodes].to(device)
                h_dst = h[:block.number_of_dst_nodes()]
                h = layer(block, (h, h_dst))
                if l != len(self.layers) - 1:
                    h = self.activation(h)
                    h = self.dropout(h)

                y[output_nodes] = h.cpu()

            x = y
            g.barrier()
        return y

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
    labels = labels.long()
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
    """
    Evaluate the model on the validation set specified by ``val_nid``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_nid : the node Ids for validation.
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
        pred = model.inference(g, inputs, batch_size, device)
    model.train()
    return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid])

158
159
def run(args, device, data):
    # Unpack data
160
    train_nid, val_nid, test_nid, in_feats, n_classes, g = data
161
162
    # Create sampler
    sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')],
163
                              dgl.distributed.sample_neighbors, device)
164

165
166
    # Create DataLoader for constructing blocks
    dataloader = DistDataLoader(
167
168
169
170
        dataset=train_nid.numpy(),
        batch_size=args.batch_size,
        collate_fn=sampler.sample_blocks,
        shuffle=True,
171
        drop_last=False)
172
173

    # Define model and optimizer
174
    model = DistSAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
175
    model = model.to(device)
176
    if not args.standalone:
177
178
179
180
181
        if args.num_gpus == -1:
            model = th.nn.parallel.DistributedDataParallel(model)
        else:
            dev_id = g.rank() % args.num_gpus
            model = th.nn.parallel.DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
182
183
184
185
186
187
188
189
190
    loss_fcn = nn.CrossEntropyLoss()
    loss_fcn = loss_fcn.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    train_size = th.sum(g.ndata['train_mask'][0:g.number_of_nodes()])

    # Training loop
    iter_tput = []
    epoch = 0
191
    for epoch in range(args.num_epochs):
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        tic = time.time()

        sample_time = 0
        forward_time = 0
        backward_time = 0
        update_time = 0
        num_seeds = 0
        num_inputs = 0
        start = time.time()
        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.
        step_time = []
        for step, blocks in enumerate(dataloader):
            tic_step = time.time()
            sample_time += tic_step - start

            # The nodes for input lies at the LHS side of the first block.
            # The nodes for output lies at the RHS side of the last block.
210
211
            batch_inputs = blocks[0].srcdata['features']
            batch_labels = blocks[-1].dstdata['labels']
212
            batch_labels = batch_labels.long()
213
214
215

            num_seeds += len(blocks[-1].dstdata[dgl.NID])
            num_inputs += len(blocks[0].srcdata[dgl.NID])
216
217
            blocks = [block.to(device) for block in blocks]
            batch_labels = batch_labels.to(device)
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
            # Compute loss and prediction
            start = time.time()
            batch_pred = model(blocks, batch_inputs)
            loss = loss_fcn(batch_pred, batch_labels)
            forward_end = time.time()
            optimizer.zero_grad()
            loss.backward()
            compute_end = time.time()
            forward_time += forward_end - start
            backward_time += compute_end - forward_end

            optimizer.step()
            update_time += time.time() - compute_end

            step_t = time.time() - tic_step
            step_time.append(step_t)
Qidong Su's avatar
Qidong Su committed
234
            iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t)
235
236
237
            if step % args.log_every == 0:
                acc = compute_acc(batch_pred, batch_labels)
                gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
maqy1995's avatar
maqy1995 committed
238
                print('Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB | time {:.3f} s'.format(
239
                    g.rank(), epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc, np.sum(step_time[-args.log_every:])))
240
241
242
            start = time.time()

        toc = time.time()
243
244
        print('Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}'.format(
            g.rank(), toc - tic, sample_time, forward_time, backward_time, update_time, num_seeds, num_inputs))
245
246
247
        epoch += 1


248
249
        if epoch % args.eval_every == 0 and epoch != 0:
            start = time.time()
250
251
252
253
            val_acc, test_acc = evaluate(model.module, g, g.ndata['features'],
                                         g.ndata['labels'], val_nid, test_nid, args.batch_size_eval, device)
            print('Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}'.format(g.rank(), val_acc, test_acc,
                                                                                  time.time() - start))
254
255

def main(args):
256
    dgl.distributed.initialize(args.ip_config)
257
258
    if not args.standalone:
        th.distributed.init_process_group(backend='gloo')
259
    g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
260
    print('rank:', g.rank())
261

262
263
264
265
266
267
268
269
270
    pb = g.get_partition_book()
    train_nid = dgl.distributed.node_split(g.ndata['train_mask'], pb, force_even=True)
    val_nid = dgl.distributed.node_split(g.ndata['val_mask'], pb, force_even=True)
    test_nid = dgl.distributed.node_split(g.ndata['test_mask'], pb, force_even=True)
    local_nid = pb.partid2nids(pb.partid).detach().numpy()
    print('part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})'.format(
        g.rank(), len(train_nid), len(np.intersect1d(train_nid.numpy(), local_nid)),
        len(val_nid), len(np.intersect1d(val_nid.numpy(), local_nid)),
        len(test_nid), len(np.intersect1d(test_nid.numpy(), local_nid))))
271
272
273
274
    if args.num_gpus == -1:
        device = th.device('cpu')
    else:
        device = th.device('cuda:'+str(g.rank() % args.num_gpus))
275
276
277
    labels = g.ndata['labels'][np.arange(g.number_of_nodes())]
    n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
    print('#labels:', n_classes)
278
279
280

    # Pack data
    in_feats = g.ndata['features'].shape[1]
281
    data = train_nid, val_nid, test_nid, in_feats, n_classes, g
282
283
284
285
286
287
    run(args, device, data)
    print("parent ends")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
    register_data_args(parser)
288
    parser.add_argument('--graph_name', type=str, help='graph name')
289
290
    parser.add_argument('--id', type=int, help='the partition id')
    parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
291
    parser.add_argument('--part_config', type=str, help='The path to the partition config file')
292
293
    parser.add_argument('--num_clients', type=int, help='The number of clients')
    parser.add_argument('--n_classes', type=int, help='the number of classes')
294
    parser.add_argument('--num_gpus', type=int, default=-1,
295
                        help="the number of GPU device. Use -1 for CPU training")
296
297
298
299
300
301
302
303
    parser.add_argument('--num_epochs', type=int, default=20)
    parser.add_argument('--num_hidden', type=int, default=16)
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--fan_out', type=str, default='10,25')
    parser.add_argument('--batch_size', type=int, default=1000)
    parser.add_argument('--batch_size_eval', type=int, default=100000)
    parser.add_argument('--log_every', type=int, default=20)
    parser.add_argument('--eval_every', type=int, default=5)
304
305
306
    parser.add_argument('--lr', type=float, default=0.003)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--local_rank', type=int, help='get rank of the process')
307
    parser.add_argument('--standalone', action='store_true', help='run in the standalone mode')
308
309
310
    args = parser.parse_args()

    print(args)
311
    main(args)