train_sampling.py 9.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import dgl.nn.pytorch as dglnn
import time
import argparse
import tqdm

12
from load_graph import load_reddit, inductive_split
13

14
15
16
17
18
19
20
21
22
23
24
25
26
class SAGE(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
27
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
28
        for i in range(1, n_layers - 1):
29
30
31
32
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation
33
34
35

    def forward(self, blocks, x):
        h = x
36
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
37
            h = layer(block, h)
38
39
40
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
41
42
        return h

43
    def inference(self, g, x, device):
44
45
46
47
        """
        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
        g : the entire graph.
        x : the input of entire node set.
48

49
50
51
52
53
54
55
56
57
        The inference code is written in a fashion that it could handle any number of nodes and
        layers.
        """
        # During inference with sampling, multi-layer blocks are very inefficient because
        # lots of computations in the first few layers are repeated.
        # Therefore, we compute the representation of all nodes layer by layer.  The nodes
        # on each layer are of course splitted in batches.
        # TODO: can we standardize this?
        for l, layer in enumerate(self.layers):
58
            y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
59

60
61
            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
            dataloader = dgl.dataloading.NodeDataLoader(
62
                g,
63
                th.arange(g.num_nodes()),
64
65
66
67
68
69
70
71
                sampler,
                batch_size=args.batch_size,
                shuffle=True,
                drop_last=False,
                num_workers=args.num_workers)

            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                block = blocks[0]
72

73
                block = block.int().to(device)
74
                h = x[input_nodes].to(device)
75
                h = layer(block, h)
76
77
78
                if l != len(self.layers) - 1:
                    h = self.activation(h)
                    h = self.dropout(h)
79

80
                y[output_nodes] = h.cpu()
81
82
83
84
85
86
87
88

            x = y
        return y

def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
89
    labels = labels.long()
90
91
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

92
def evaluate(model, g, nfeat, labels, val_nid, device):
93
    """
94
    Evaluate the model on the validation set specified by ``val_nid``.
95
96
97
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
98
    val_nid : the node Ids for validation.
99
100
101
102
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
103
        pred = model.inference(g, nfeat, device)
104
    model.train()
105
    return compute_acc(pred[val_nid], labels[val_nid].to(pred.device))
106

107
def load_subtensor(nfeat, labels, seeds, input_nodes, device):
108
    """
109
    Extracts features and labels for a subset of nodes
110
    """
111
112
    batch_inputs = nfeat[input_nodes].to(device)
    batch_labels = labels[seeds].to(device)
113
114
115
    return batch_inputs, batch_labels

#### Entry point
116
def run(args, device, data):
117
    # Unpack data
118
119
120
    n_classes, train_g, val_g, test_g, train_nfeat, train_labels, \
    val_nfeat, val_labels, test_nfeat, test_labels = data
    in_feats = train_nfeat.shape[1]
121
122
123
    train_nid = th.nonzero(train_g.ndata['train_mask'], as_tuple=True)[0]
    val_nid = th.nonzero(val_g.ndata['val_mask'], as_tuple=True)[0]
    test_nid = th.nonzero(~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']), as_tuple=True)[0]
124

125
    # Create PyTorch DataLoader for constructing blocks
126
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
127
        [int(fanout) for fanout in args.fan_out.split(',')])
128
    dataloader = dgl.dataloading.NodeDataLoader(
129
        train_g,
130
131
        train_nid,
        sampler,
132
133
134
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
135
        num_workers=args.num_workers)
136
137

    # Define model and optimizer
138
139
    model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
    model = model.to(device)
140
141
142
143
144
145
146
147
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    avg = 0
    iter_tput = []
    for epoch in range(args.num_epochs):
        tic = time.time()
148
149
150

        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
151
        tic_step = time.time()
152
        for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
153
            # Load the input features as well as output labels
154
155
            batch_inputs, batch_labels = load_subtensor(train_nfeat, train_labels,
                                                        seeds, input_nodes, device)
156
            blocks = [block.int().to(device) for block in blocks]
157
158
159
160
161
162
163
164

            # Compute loss and prediction
            batch_pred = model(blocks, batch_inputs)
            loss = loss_fcn(batch_pred, batch_labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

165
166
            iter_tput.append(len(seeds) / (time.time() - tic_step))
            if step % args.log_every == 0:
167
                acc = compute_acc(batch_pred, batch_labels)
168
                gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
maqy1995's avatar
maqy1995 committed
169
                print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB'.format(
170
                    epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc))
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
171
            tic_step = time.time()
172
173

        toc = time.time()
174
175
176
177
        print('Epoch Time(s): {:.4f}'.format(toc - tic))
        if epoch >= 5:
            avg += toc - tic
        if epoch % args.eval_every == 0 and epoch != 0:
178
            eval_acc = evaluate(model, val_g, val_nfeat, val_labels, val_nid, device)
179
            print('Eval Acc {:.4f}'.format(eval_acc))
180
            test_acc = evaluate(model, test_g, test_nfeat, test_labels, test_nid, device)
181
            print('Test Acc: {:.4f}'.format(test_acc))
182
183

    print('Avg epoch time: {}'.format(avg / (epoch - 4)))
184
185
186

if __name__ == '__main__':
    argparser = argparse.ArgumentParser("multi-gpu training")
187
    argparser.add_argument('--gpu', type=int, default=0,
188
                           help="GPU device ID. Use -1 for CPU training")
189
    argparser.add_argument('--dataset', type=str, default='reddit')
190
191
192
    argparser.add_argument('--num-epochs', type=int, default=20)
    argparser.add_argument('--num-hidden', type=int, default=16)
    argparser.add_argument('--num-layers', type=int, default=2)
193
    argparser.add_argument('--fan-out', type=str, default='10,25')
194
195
196
197
    argparser.add_argument('--batch-size', type=int, default=1000)
    argparser.add_argument('--log-every', type=int, default=20)
    argparser.add_argument('--eval-every', type=int, default=5)
    argparser.add_argument('--lr', type=float, default=0.003)
198
    argparser.add_argument('--dropout', type=float, default=0.5)
199
    argparser.add_argument('--num-workers', type=int, default=4,
200
                           help="Number of sampling processes. Use 0 for no extra process.")
201
    argparser.add_argument('--inductive', action='store_true',
202
203
204
205
206
207
                           help="Inductive learning setting")
    argparser.add_argument('--data-cpu', action='store_true',
                           help="By default the script puts all node features and labels "
                                "on GPU when using it to save time for data copy. This may "
                                "be undesired if they cannot fit in GPU memory at once. "
                                "This flag disables that.")
208
    args = argparser.parse_args()
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
209

210
211
212
213
    if args.gpu >= 0:
        device = th.device('cuda:%d' % args.gpu)
    else:
        device = th.device('cpu')
214

215
216
217
218
    if args.dataset == 'reddit':
        g, n_classes = load_reddit()
    else:
        raise Exception('unknown dataset')
219
220
221

    if args.inductive:
        train_g, val_g, test_g = inductive_split(g)
222
223
224
225
226
227
        train_nfeat = train_g.ndata.pop('features')
        val_nfeat = val_g.ndata.pop('features')
        test_nfeat = test_g.ndata.pop('features')
        train_labels = train_g.ndata.pop('labels')
        val_labels = val_g.ndata.pop('labels')
        test_labels = test_g.ndata.pop('labels')
228
229
    else:
        train_g = val_g = test_g = g
230
231
232
233
234
235
        train_nfeat = val_nfeat = test_nfeat = g.ndata.pop('features')
        train_labels = val_labels = test_labels = g.ndata.pop('labels')

    if not args.data_cpu:
        train_nfeat = train_nfeat.to(device)
        train_labels = train_labels.to(device)
236

237
238
    # Create csr/coo/csc formats before launching training processes with multi-gpu.
    # This avoids creating certain formats in each sub-process, which saves momory and CPU.
239
240
241
    train_g.create_formats_()
    val_g.create_formats_()
    test_g.create_formats_()
242
    # Pack data
243
244
    data = n_classes, train_g, val_g, test_g, train_nfeat, train_labels, \
           val_nfeat, val_labels, test_nfeat, test_labels
245

246
    run(args, device, data)