train_full.py 3.27 KB
Newer Older
1
2
import argparse

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3
4
import dgl.nn as dglnn

hbsun2113's avatar
hbsun2113 committed
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F
8
from dgl import AddSelfLoop
9
10
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset

hbsun2113's avatar
hbsun2113 committed
11

12
13
14
class SAGE(nn.Module):
    def __init__(self, in_size, hid_size, out_size):
        super().__init__()
hbsun2113's avatar
hbsun2113 committed
15
        self.layers = nn.ModuleList()
16
        # two-layer GraphSAGE-mean
17
18
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, "gcn"))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, "gcn"))
19
        self.dropout = nn.Dropout(0.5)
20

21
22
    def forward(self, graph, x):
        h = self.dropout(x)
23
24
25
        for l, layer in enumerate(self.layers):
            h = layer(graph, h)
            if l != len(self.layers) - 1:
26
                h = F.relu(h)
27
                h = self.dropout(h)
hbsun2113's avatar
hbsun2113 committed
28
29
        return h

30

31
def evaluate(g, features, labels, mask, model):
hbsun2113's avatar
hbsun2113 committed
32
33
    model.eval()
    with torch.no_grad():
34
35
36
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
hbsun2113's avatar
hbsun2113 committed
37
38
39
40
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

41

42
43
44
45
46
def train(g, features, labels, masks, model):
    # define train/val samples, loss function and optimizer
    train_mask, val_mask = masks
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
hbsun2113's avatar
hbsun2113 committed
47

48
49
    # training loop
    for epoch in range(200):
hbsun2113's avatar
hbsun2113 committed
50
        model.train()
51
        logits = model(g, features)
52
        loss = loss_fcn(logits[train_mask], labels[train_mask])
hbsun2113's avatar
hbsun2113 committed
53
54
55
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
56
        acc = evaluate(g, features, labels, val_mask, model)
57
58
59
60
61
62
        print(
            "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
                epoch, loss.item(), acc
            )
        )

hbsun2113's avatar
hbsun2113 committed
63

64
65
66
67
68
69
70
71
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GraphSAGE")
    parser.add_argument(
        "--dataset",
        type=str,
        default="cora",
        help="Dataset name ('cora', 'citeseer', 'pubmed')",
    )
hbsun2113's avatar
hbsun2113 committed
72
    args = parser.parse_args()
73
    print(f"Training with DGL built-in GraphSage module")
hbsun2113's avatar
hbsun2113 committed
74

75
    # load and preprocess dataset
76
77
78
79
    transform = (
        AddSelfLoop()
    )  # by default, it will first remove self-loops to prevent duplication
    if args.dataset == "cora":
80
        data = CoraGraphDataset(transform=transform)
81
    elif args.dataset == "citeseer":
82
        data = CiteseerGraphDataset(transform=transform)
83
    elif args.dataset == "pubmed":
84
85
        data = PubmedGraphDataset(transform=transform)
    else:
86
        raise ValueError("Unknown dataset: {}".format(args.dataset))
87
    g = data[0]
88
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
    g = g.int().to(device)
90
91
92
    features = g.ndata["feat"]
    labels = g.ndata["label"]
    masks = g.ndata["train_mask"], g.ndata["val_mask"]
93
94
95
96
97
98
99

    # create GraphSAGE model
    in_size = features.shape[1]
    out_size = data.num_classes
    model = SAGE(in_size, 16, out_size).to(device)

    # model training
100
    print("Training...")
101
102
103
    train(g, features, labels, masks, model)

    # test the model
104
105
    print("Testing...")
    acc = evaluate(g, features, labels, g.ndata["test_mask"], model)
106
    print("Test accuracy {:.4f}".format(acc))