train.py 3.34 KB
Newer Older
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
import dgl
5
import dgl.nn as dglnn
6
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
7
8
from dgl import AddSelfLoop
import argparse
9

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class GCN(nn.Module):
    def __init__(self, in_size, hid_size, out_size):
        super().__init__()
        self.layers = nn.ModuleList()
        # two-layer GCN
        self.layers.append(dglnn.GraphConv(in_size, hid_size, activation=F.relu))
        self.layers.append(dglnn.GraphConv(hid_size, out_size))
        self.dropout = nn.Dropout(0.5)

    def forward(self, g, features):
        h = features
        for i, layer in enumerate(self.layers):
            if i != 0:
                h = self.dropout(h)
            h = layer(g, h)
        return h
    
def evaluate(g, features, labels, mask, model):
28
29
    model.eval()
    with torch.no_grad():
30
        logits = model(g, features)
31
32
33
34
35
36
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def train(g, features, labels, masks, model):
    # define train/val samples, loss function and optimizer
    train_mask = masks[0]
    val_mask = masks[1]
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)

    # training loop
    for epoch in range(200):
        model.train()
        logits = model(g, features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc = evaluate(g, features, labels, val_mask, model)
        print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} "
              . format(epoch, loss.item(), acc))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="cora",
                        help="Dataset name ('cora', 'citeseer', 'pubmed').")
    args = parser.parse_args()
    print(f'Training with DGL built-in GraphConv module.')
 
64
    # load and preprocess dataset
65
    transform = AddSelfLoop()  # by default, it will first remove self-loops to prevent duplication
66
    if args.dataset == 'cora':
67
        data = CoraGraphDataset(transform=transform)
68
    elif args.dataset == 'citeseer':
69
        data = CiteseerGraphDataset(transform=transform)
70
    elif args.dataset == 'pubmed':
71
        data = PubmedGraphDataset(transform=transform)
72
    else:
73
74
        raise ValueError('Unknown dataset: {}'.format(args.dataset))
    g = data[0]
75
76
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    g = g.int().to(device)
77
78
    features = g.ndata['feat']
    labels = g.ndata['label']
79
80
    masks = g.ndata['train_mask'], g.ndata['val_mask'], g.ndata['test_mask']
        
81
82
    # normalization
    degs = g.in_degrees().float()
83
    norm = torch.pow(degs, -0.5).to(device)
84
85
86
    norm[torch.isinf(norm)] = 0
    g.ndata['norm'] = norm.unsqueeze(1)

87
88
89
90
91
92
93
94
95
96
97
98
99
    # create GCN model    
    in_size = features.shape[1]
    out_size = data.num_classes
    model = GCN(in_size, 16, out_size).to(device)

    # model training
    print('Training...')
    train(g, features, labels, masks, model)
    
    # test the model
    print('Testing...')
    acc = evaluate(g, features, labels, masks[2], model)
    print("Test accuracy {:.4f}".format(acc))