train_sampling.py 7.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import dgl.nn.pytorch as dglnn
import time
import argparse
import tqdm

12
from model import SAGE
13
from load_graph import load_reddit, inductive_split, load_ogb
14

15
16
17
18
def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
19
    labels = labels.long()
20
21
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

22
def evaluate(model, g, nfeat, labels, val_nid, device):
23
    """
24
    Evaluate the model on the validation set specified by ``val_nid``.
25
26
27
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
28
    val_nid : the node Ids for validation.
29
30
31
32
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
33
        pred = model.inference(g, nfeat, device, args.batch_size, args.num_workers)
34
    model.train()
35
    return compute_acc(pred[val_nid], labels[val_nid].to(pred.device))
36

37
def load_subtensor(nfeat, labels, seeds, input_nodes, device):
38
    """
39
    Extracts features and labels for a subset of nodes
40
    """
41
42
    batch_inputs = nfeat[input_nodes].to(device)
    batch_labels = labels[seeds].to(device)
43
44
45
    return batch_inputs, batch_labels

#### Entry point
46
def run(args, device, data):
47
    # Unpack data
48
49
50
    n_classes, train_g, val_g, test_g, train_nfeat, train_labels, \
    val_nfeat, val_labels, test_nfeat, test_labels = data
    in_feats = train_nfeat.shape[1]
51
52
53
    train_nid = th.nonzero(train_g.ndata['train_mask'], as_tuple=True)[0]
    val_nid = th.nonzero(val_g.ndata['val_mask'], as_tuple=True)[0]
    test_nid = th.nonzero(~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']), as_tuple=True)[0]
54

55
56
57
58
59
60
61
62
    dataloader_device = th.device('cpu')
    if args.sample_gpu:
        train_nid = train_nid.to(device)
        # copy only the csc to the GPU
        train_g = train_g.formats(['csc'])
        train_g = train_g.to(device)
        dataloader_device = device

63
    # Create PyTorch DataLoader for constructing blocks
64
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
65
        [int(fanout) for fanout in args.fan_out.split(',')])
66
    dataloader = dgl.dataloading.NodeDataLoader(
67
        train_g,
68
69
        train_nid,
        sampler,
70
        device=dataloader_device,
71
72
73
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
74
        num_workers=args.num_workers)
75
76

    # Define model and optimizer
77
78
    model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
    model = model.to(device)
79
80
81
82
83
84
85
86
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    avg = 0
    iter_tput = []
    for epoch in range(args.num_epochs):
        tic = time.time()
87
88
89

        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
90
        tic_step = time.time()
91
        for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
92
            # Load the input features as well as output labels
93
94
            batch_inputs, batch_labels = load_subtensor(train_nfeat, train_labels,
                                                        seeds, input_nodes, device)
95
            blocks = [block.int().to(device) for block in blocks]
96
97
98
99
100
101
102
103

            # Compute loss and prediction
            batch_pred = model(blocks, batch_inputs)
            loss = loss_fcn(batch_pred, batch_labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

104
105
            iter_tput.append(len(seeds) / (time.time() - tic_step))
            if step % args.log_every == 0:
106
                acc = compute_acc(batch_pred, batch_labels)
107
                gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
maqy1995's avatar
maqy1995 committed
108
                print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB'.format(
109
                    epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc))
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
110
            tic_step = time.time()
111
112

        toc = time.time()
113
114
115
116
        print('Epoch Time(s): {:.4f}'.format(toc - tic))
        if epoch >= 5:
            avg += toc - tic
        if epoch % args.eval_every == 0 and epoch != 0:
117
            eval_acc = evaluate(model, val_g, val_nfeat, val_labels, val_nid, device)
118
            print('Eval Acc {:.4f}'.format(eval_acc))
119
            test_acc = evaluate(model, test_g, test_nfeat, test_labels, test_nid, device)
120
            print('Test Acc: {:.4f}'.format(test_acc))
121
122

    print('Avg epoch time: {}'.format(avg / (epoch - 4)))
123
124

if __name__ == '__main__':
125
    argparser = argparse.ArgumentParser()
126
    argparser.add_argument('--gpu', type=int, default=0,
127
                           help="GPU device ID. Use -1 for CPU training")
128
    argparser.add_argument('--dataset', type=str, default='reddit')
129
130
131
    argparser.add_argument('--num-epochs', type=int, default=20)
    argparser.add_argument('--num-hidden', type=int, default=16)
    argparser.add_argument('--num-layers', type=int, default=2)
132
    argparser.add_argument('--fan-out', type=str, default='10,25')
133
134
135
136
    argparser.add_argument('--batch-size', type=int, default=1000)
    argparser.add_argument('--log-every', type=int, default=20)
    argparser.add_argument('--eval-every', type=int, default=5)
    argparser.add_argument('--lr', type=float, default=0.003)
137
    argparser.add_argument('--dropout', type=float, default=0.5)
138
    argparser.add_argument('--num-workers', type=int, default=4,
139
                           help="Number of sampling processes. Use 0 for no extra process.")
140
141
    argparser.add_argument('--sample-gpu', action='store_true',
                           help="Perform the sampling process on the GPU. Must have 0 workers.")
142
    argparser.add_argument('--inductive', action='store_true',
143
144
145
146
147
148
                           help="Inductive learning setting")
    argparser.add_argument('--data-cpu', action='store_true',
                           help="By default the script puts all node features and labels "
                                "on GPU when using it to save time for data copy. This may "
                                "be undesired if they cannot fit in GPU memory at once. "
                                "This flag disables that.")
149
    args = argparser.parse_args()
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
150

151
152
153
154
    if args.gpu >= 0:
        device = th.device('cuda:%d' % args.gpu)
    else:
        device = th.device('cpu')
155

156
157
    if args.dataset == 'reddit':
        g, n_classes = load_reddit()
158
159
    elif args.dataset == 'ogbn-products':
        g, n_classes = load_ogb('ogbn-products')
160
161
    else:
        raise Exception('unknown dataset')
162
163
164

    if args.inductive:
        train_g, val_g, test_g = inductive_split(g)
165
166
167
168
169
170
        train_nfeat = train_g.ndata.pop('features')
        val_nfeat = val_g.ndata.pop('features')
        test_nfeat = test_g.ndata.pop('features')
        train_labels = train_g.ndata.pop('labels')
        val_labels = val_g.ndata.pop('labels')
        test_labels = test_g.ndata.pop('labels')
171
172
    else:
        train_g = val_g = test_g = g
173
174
175
176
177
178
        train_nfeat = val_nfeat = test_nfeat = g.ndata.pop('features')
        train_labels = val_labels = test_labels = g.ndata.pop('labels')

    if not args.data_cpu:
        train_nfeat = train_nfeat.to(device)
        train_labels = train_labels.to(device)
179

180
    # Pack data
181
182
    data = n_classes, train_g, val_g, test_g, train_nfeat, train_labels, \
           val_nfeat, val_labels, test_nfeat, test_labels
183

184
    run(args, device, data)