train_full.py 3.2 KB
Newer Older
hbsun2113's avatar
hbsun2113 committed
1
2
3
import torch
import torch.nn as nn
import torch.nn.functional as F
4
5
6
7
import dgl.nn as dglnn
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from dgl import AddSelfLoop
import argparse
hbsun2113's avatar
hbsun2113 committed
8

9
10
11
class SAGE(nn.Module):
    def __init__(self, in_size, hid_size, out_size):
        super().__init__()
hbsun2113's avatar
hbsun2113 committed
12
        self.layers = nn.ModuleList()
13
14
15
16
        # two-layer GraphSAGE-mean
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'gcn'))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, 'gcn'))
        self.dropout = nn.Dropout(0.5)
17

18
19
    def forward(self, graph, x):
        h = self.dropout(x)
20
21
22
        for l, layer in enumerate(self.layers):
            h = layer(graph, h)
            if l != len(self.layers) - 1:
23
                h = F.relu(h)
24
                h = self.dropout(h)
hbsun2113's avatar
hbsun2113 committed
25
26
        return h

27
def evaluate(g, features, labels, mask, model):
hbsun2113's avatar
hbsun2113 committed
28
29
    model.eval()
    with torch.no_grad():
30
31
32
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
hbsun2113's avatar
hbsun2113 committed
33
34
35
36
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

37
38
39
40
41
def train(g, features, labels, masks, model):
    # define train/val samples, loss function and optimizer
    train_mask, val_mask = masks
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
hbsun2113's avatar
hbsun2113 committed
42

43
44
    # training loop
    for epoch in range(200):
hbsun2113's avatar
hbsun2113 committed
45
        model.train()
46
        logits = model(g, features)
47
        loss = loss_fcn(logits[train_mask], labels[train_mask])
hbsun2113's avatar
hbsun2113 committed
48
49
50
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
51
52
53
        acc = evaluate(g, features, labels, val_mask, model)
        print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} "
              . format(epoch, loss.item(), acc))
hbsun2113's avatar
hbsun2113 committed
54
55

if __name__ == '__main__':
56
    parser = argparse.ArgumentParser(description='GraphSAGE')
57
58
    parser.add_argument("--dataset", type=str, default="cora",
                        help="Dataset name ('cora', 'citeseer', 'pubmed')")
hbsun2113's avatar
hbsun2113 committed
59
    args = parser.parse_args()
60
    print(f'Training with DGL built-in GraphSage module')
hbsun2113's avatar
hbsun2113 committed
61

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    # load and preprocess dataset
    transform = AddSelfLoop()  # by default, it will first remove self-loops to prevent duplication
    if args.dataset == 'cora':
        data = CoraGraphDataset(transform=transform)
    elif args.dataset == 'citeseer':
        data = CiteseerGraphDataset(transform=transform)
    elif args.dataset == 'pubmed':
        data = PubmedGraphDataset(transform=transform)
    else:
        raise ValueError('Unknown dataset: {}'.format(args.dataset))
    g = data[0]
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    g = g.int().to(device)
    features = g.ndata['feat']
    labels = g.ndata['label']
    masks = g.ndata['train_mask'], g.ndata['val_mask']

    # create GraphSAGE model
    in_size = features.shape[1]
    out_size = data.num_classes
    model = SAGE(in_size, 16, out_size).to(device)

    # model training
    print('Training...')
    train(g, features, labels, masks, model)

    # test the model
    print('Testing...')
    acc = evaluate(g, features, labels, g.ndata['test_mask'], model)
    print("Test accuracy {:.4f}".format(acc))