"server/vscode:/vscode.git/clone" did not exist on "d14eaacacab9ca3056a9d001d0ca2dc0a36edfde"
train_sampling_multi_gpu.py 10.5 KB
Newer Older
1
2
3
4
5
6
7
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
8
from torch.utils.data import DataLoader
9
10
11
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
Jinjing Zhou's avatar
Jinjing Zhou committed
12
import math
13
14
15
16
import argparse
from dgl.data import RedditDataset
from torch.nn.parallel import DistributedDataParallel
import tqdm
17
import traceback
18

19
from utils import thread_wrapped_func
20
from load_graph import load_reddit, inductive_split
21

22
23
24
25
26
27
28
29
30
31
32
33
34
class SAGE(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
35
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
36
        for i in range(1, n_layers - 1):
37
38
39
40
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation
41
42
43

    def forward(self, blocks, x):
        h = x
44
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
45
            h = layer(block, h)
46
47
48
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        return h

    def inference(self, g, x, batch_size, device):
        """
        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
        g : the entire graph.
        x : the input of entire node set.

        The inference code is written in a fashion that it could handle any number of nodes and
        layers.
        """
        # During inference with sampling, multi-layer blocks are very inefficient because
        # lots of computations in the first few layers are repeated.
        # Therefore, we compute the representation of all nodes layer by layer.  The nodes
        # on each layer are of course splitted in batches.
        # TODO: can we standardize this?
        nodes = th.arange(g.number_of_nodes())
        for l, layer in enumerate(self.layers):
            y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)

69
70
            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
            dataloader = dgl.dataloading.NodeDataLoader(
71
72
73
74
75
76
77
78
79
80
                g,
                th.arange(g.number_of_nodes()),
                sampler,
                batch_size=args.batch_size,
                shuffle=True,
                drop_last=False,
                num_workers=args.num_workers)

            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                block = blocks[0]
81

82
                block = block.int().to(device)
83
                h = x[input_nodes].to(device)
84
                h = layer(block, h)
85
86
87
                if l != len(self.layers) - 1:
                    h = self.activation(h)
                    h = self.dropout(h)
88

89
                y[output_nodes] = h.cpu()
90
91
92
93
94
95
96
97
98
99

            x = y
        return y

def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

100
def evaluate(model, g, inputs, labels, val_nid, batch_size, device):
101
    """
102
    Evaluate the model on the validation set specified by ``val_nid``.
103
104
105
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
106
    val_nid : A node ID tensor indicating which nodes do we actually compute the accuracy for.
107
108
109
110
111
112
113
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
        pred = model.inference(g, inputs, batch_size, device)
    model.train()
114
    return compute_acc(pred[val_nid], labels[val_nid])
115

116
def load_subtensor(g, labels, seeds, input_nodes, dev_id):
117
118
119
    """
    Copys features and labels of a set of nodes onto GPU.
    """
120
    batch_inputs = g.ndata['features'][input_nodes].to(dev_id)
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    batch_labels = labels[seeds].to(dev_id)
    return batch_inputs, batch_labels

#### Entry point

def run(proc_id, n_gpus, args, devices, data):
    # Start up distributed training, if enabled.
    dev_id = devices[proc_id]
    if n_gpus > 1:
        dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
            master_ip='127.0.0.1', master_port='12345')
        world_size = n_gpus
        th.distributed.init_process_group(backend="nccl",
                                          init_method=dist_init_method,
                                          world_size=world_size,
136
                                          rank=proc_id)
137
138
139
    th.cuda.set_device(dev_id)

    # Unpack data
140
141
142
    in_feats, n_classes, train_g, val_g, test_g = data
    train_mask = train_g.ndata['train_mask']
    val_mask = val_g.ndata['val_mask']
143
    test_mask = ~(test_g.ndata['train_mask'] | test_g.ndata['val_mask'])
144
145
146
    train_nid = train_mask.nonzero().squeeze()
    val_nid = val_mask.nonzero().squeeze()
    test_nid = test_mask.nonzero().squeeze()
147
148

    # Split train_nid
149
    train_nid = th.split(train_nid, math.ceil(len(train_nid) / n_gpus))[proc_id]
150

151
    # Create PyTorch DataLoader for constructing blocks
152
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
153
        [int(fanout) for fanout in args.fan_out.split(',')])
154
    dataloader = dgl.dataloading.NodeDataLoader(
155
        train_g,
156
157
        train_nid,
        sampler,
158
159
160
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
161
        num_workers=args.num_workers)
162
163

    # Define model and optimizer
164
    model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
165
166
167
168
169
170
171
172
173
174
175
176
    model = model.to(dev_id)
    if n_gpus > 1:
        model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
    loss_fcn = nn.CrossEntropyLoss()
    loss_fcn = loss_fcn.to(dev_id)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    avg = 0
    iter_tput = []
    for epoch in range(args.num_epochs):
        tic = time.time()
177
178
179

        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.
180
        for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
181
182
183
184
            if proc_id == 0:
                tic_step = time.time()

            # Load the input features as well as output labels
185
            batch_inputs, batch_labels = load_subtensor(train_g, train_g.ndata['labels'], seeds, input_nodes, dev_id)
186
            blocks = [block.int().to(dev_id) for block in blocks]
187
188
189
190
191
192
193
194
195
196
197
            # Compute loss and prediction
            batch_pred = model(blocks, batch_inputs)
            loss = loss_fcn(batch_pred, batch_labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if proc_id == 0:
                iter_tput.append(len(seeds) * n_gpus / (time.time() - tic_step))
            if step % args.log_every == 0 and proc_id == 0:
                acc = compute_acc(batch_pred, batch_labels)
198
199
                print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MiB'.format(
                    epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), th.cuda.max_memory_allocated() / 1000000))
200
201
202
203
204
205
206
207
208
209

        if n_gpus > 1:
            th.distributed.barrier()

        toc = time.time()
        if proc_id == 0:
            print('Epoch Time(s): {:.4f}'.format(toc - tic))
            if epoch >= 5:
                avg += toc - tic
            if epoch % args.eval_every == 0 and epoch != 0:
210
                if n_gpus == 1:
211
212
213
214
                    eval_acc = evaluate(
                        model, val_g, val_g.ndata['features'], val_g.ndata['labels'], val_nid, args.batch_size, devices[0])
                    test_acc = evaluate(
                        model, test_g, test_g.ndata['features'], test_g.ndata['labels'], test_nid, args.batch_size, devices[0])
215
                else:
216
217
218
219
                    eval_acc = evaluate(
                        model.module, val_g, val_g.ndata['features'], val_g.ndata['labels'], val_nid, args.batch_size, devices[0])
                    test_acc = evaluate(
                        model.module, test_g, test_g.ndata['features'], test_g.ndata['labels'], test_nid, args.batch_size, devices[0])
220
                print('Eval Acc {:.4f}'.format(eval_acc))
221
                print('Test Acc: {:.4f}'.format(test_acc))
222

223

224
225
226
227
228
229
230
    if n_gpus > 1:
        th.distributed.barrier()
    if proc_id == 0:
        print('Avg epoch time: {}'.format(avg / (epoch - 4)))

if __name__ == '__main__':
    argparser = argparse.ArgumentParser("multi-gpu training")
231
232
    argparser.add_argument('--gpu', type=str, default='0',
        help="Comma separated list of GPU device IDs.")
233
234
235
    argparser.add_argument('--num-epochs', type=int, default=20)
    argparser.add_argument('--num-hidden', type=int, default=16)
    argparser.add_argument('--num-layers', type=int, default=2)
236
    argparser.add_argument('--fan-out', type=str, default='10,25')
237
238
239
240
    argparser.add_argument('--batch-size', type=int, default=1000)
    argparser.add_argument('--log-every', type=int, default=20)
    argparser.add_argument('--eval-every', type=int, default=5)
    argparser.add_argument('--lr', type=float, default=0.003)
241
242
243
    argparser.add_argument('--dropout', type=float, default=0.5)
    argparser.add_argument('--num-workers', type=int, default=0,
        help="Number of sampling processes. Use 0 for no extra process.")
244
245
    argparser.add_argument('--inductive', action='store_true',
        help="Inductive learning setting")
246
247
248
249
250
    args = argparser.parse_args()
    
    devices = list(map(int, args.gpu.split(',')))
    n_gpus = len(devices)

251
    g, n_classes = load_reddit()
252
    # Construct graph
253
254
255
256
257
258
259
260
    g = dgl.as_heterograph(g)
    in_feats = g.ndata['features'].shape[1]

    if args.inductive:
        train_g, val_g, test_g = inductive_split(g)
    else:
        train_g = val_g = test_g = g

261
262
263
    train_g.create_formats_()
    val_g.create_formats_()
    test_g.create_formats_()
264
    # Pack data
265
    data = in_feats, n_classes, train_g, val_g, test_g
266
267
268
269
270
271

    if n_gpus == 1:
        run(0, n_gpus, args, devices, data)
    else:
        procs = []
        for proc_id in range(n_gpus):
272
273
            p = mp.Process(target=thread_wrapped_func(run),
                           args=(proc_id, n_gpus, args, devices, data))
274
275
276
277
            p.start()
            procs.append(p)
        for p in procs:
            p.join()