node_classification.py 5.72 KB
Newer Older
1
2
3
4
5
6
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
7
8
from dgl.data import AsNodePredDataset
from dgl.dataloading import DataLoader, NeighborSampler, MultiLayerFullNeighborSampler
9
from ogb.nodeproppred import DglNodePropPredDataset
10
11
import tqdm
import argparse
12
13

class SAGE(nn.Module):
14
    def __init__(self, in_size, hid_size, out_size):
15
16
        super().__init__()
        self.layers = nn.ModuleList()
17
        # three-layer GraphSAGE-mean
18
19
20
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'mean'))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean'))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, 'mean'))
21
        self.dropout = nn.Dropout(0.5)
22
23
        self.hid_size = hid_size
        self.out_size = out_size
24
25
26
27
28
29
30
31
32
33

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

34
35
    def inference(self, g, device, batch_size):
        """Conduct layer-wise inference to get all the node embeddings."""
36
        feat = g.ndata['feat']
37
38
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
        dataloader = DataLoader(
39
                g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
40
                batch_size=batch_size, shuffle=False, drop_last=False,
41
                num_workers=0)
42
        buffer_device = torch.device('cpu')
43
        pin_memory = (buffer_device != device)
44
45

        for l, layer in enumerate(self.layers):
46
            y = torch.empty(
47
48
                g.num_nodes(), self.hid_size if l != len(self.layers) - 1 else self.out_size,
                device=buffer_device, pin_memory=pin_memory)
49
            feat = feat.to(device)
50
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
51
                x = feat[input_nodes]
52
                h = layer(blocks[0], x) # len(blocks) = 1
53
54
55
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
56
                # by design, our output nodes are contiguous
57
58
                y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device)
            feat = y
59
        return y
60

61
def evaluate(model, graph, dataloader):
62
63
64
    model.eval()
    ys = []
    y_hats = []
65
    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
66
67
68
69
        with torch.no_grad():
            x = blocks[0].srcdata['feat']
            ys.append(blocks[-1].dstdata['label'])
            y_hats.append(model(blocks, x))
70
71
    return MF.accuracy(torch.cat(y_hats), torch.cat(ys))

72
def layerwise_infer(device, graph, nid, model, batch_size):
73
74
    model.eval()
    with torch.no_grad():
75
        pred = model.inference(graph, device, batch_size) # pred in buffer_device
76
        pred = pred[nid]
77
        label = graph.ndata['label'][nid].to(pred.device)
78
79
80
81
82
83
        return MF.accuracy(pred, label)

def train(args, device, g, dataset, model):
    # create sampler & dataloader
    train_idx = dataset.train_idx.to(device)
    val_idx = dataset.val_idx.to(device)
84
    sampler = NeighborSampler([10, 10, 10],  # fanout for [layer-0, layer-1, layer-2]
85
86
87
88
89
90
91
92
93
94
95
96
97
98
                              prefetch_node_feats=['feat'],
                              prefetch_labels=['label'])
    use_uva = (args.mode == 'mixed')
    train_dataloader = DataLoader(g, train_idx, sampler, device=device,
                                  batch_size=1024, shuffle=True,
                                  drop_last=False, num_workers=0,
                                  use_uva=use_uva)

    val_dataloader = DataLoader(g, val_idx, sampler, device=device,
                                batch_size=1024, shuffle=True,
                                drop_last=False, num_workers=0,
                                use_uva=use_uva)

    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
99

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    for epoch in range(10):
        model.train()
        total_loss = 0
        for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
            x = blocks[0].srcdata['feat']
            y = blocks[-1].dstdata['label']
            y_hat = model(blocks, x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
        acc = evaluate(model, g, val_dataloader)
        print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} "
              .format(epoch, total_loss / (it+1), acc.item()))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", default='mixed', choices=['cpu', 'mixed', 'puregpu'],
                        help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
                             "'puregpu' for pure-GPU training.")
    args = parser.parse_args()
    if not torch.cuda.is_available():
        args.mode = 'cpu'
    print(f'Training in {args.mode} mode.')
125

126
127
128
129
130
131
132
133
134
135
136
    # load and preprocess dataset
    print('Loading data')
    dataset = AsNodePredDataset(DglNodePropPredDataset('ogbn-products'))
    g = dataset[0]
    g = g.to('cuda' if args.mode == 'puregpu' else 'cpu')
    device = torch.device('cpu' if args.mode == 'cpu' else 'cuda')

    # create GraphSAGE model
    in_size = g.ndata['feat'].shape[1]
    out_size = dataset.num_classes
    model = SAGE(in_size, 256, out_size).to(device)
137

138
    # model training
139
    print('Training...')
140
    train(args, device, g, dataset, model)
141

142
    # test the model
143
    print('Testing...')
144
    acc = layerwise_infer(device, g, dataset.test_idx, model, batch_size=4096)
145
    print("Test Accuracy {:.4f}".format(acc.item()))