train_full.py 3.62 KB
Newer Older
1
2
import argparse

3
import dgl
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
5
import dgl.nn as dglnn

hbsun2113's avatar
hbsun2113 committed
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
9
from dgl import AddSelfLoop
10
11
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset

hbsun2113's avatar
hbsun2113 committed
12

13
14
15
class SAGE(nn.Module):
    def __init__(self, in_size, hid_size, out_size):
        super().__init__()
hbsun2113's avatar
hbsun2113 committed
16
        self.layers = nn.ModuleList()
17
        # two-layer GraphSAGE-mean
18
19
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, "gcn"))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, "gcn"))
20
        self.dropout = nn.Dropout(0.5)
21

22
23
    def forward(self, graph, x):
        h = self.dropout(x)
24
25
26
        for l, layer in enumerate(self.layers):
            h = layer(graph, h)
            if l != len(self.layers) - 1:
27
                h = F.relu(h)
28
                h = self.dropout(h)
hbsun2113's avatar
hbsun2113 committed
29
30
        return h

31

32
def evaluate(g, features, labels, mask, model):
hbsun2113's avatar
hbsun2113 committed
33
34
    model.eval()
    with torch.no_grad():
35
36
37
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
hbsun2113's avatar
hbsun2113 committed
38
39
40
41
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

42

43
44
45
46
47
def train(g, features, labels, masks, model):
    # define train/val samples, loss function and optimizer
    train_mask, val_mask = masks
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
hbsun2113's avatar
hbsun2113 committed
48

49
50
    # training loop
    for epoch in range(200):
hbsun2113's avatar
hbsun2113 committed
51
        model.train()
52
        logits = model(g, features)
53
        loss = loss_fcn(logits[train_mask], labels[train_mask])
hbsun2113's avatar
hbsun2113 committed
54
55
56
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
57
        acc = evaluate(g, features, labels, val_mask, model)
58
59
60
61
62
63
        print(
            "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
                epoch, loss.item(), acc
            )
        )

hbsun2113's avatar
hbsun2113 committed
64

65
66
67
68
69
70
71
72
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GraphSAGE")
    parser.add_argument(
        "--dataset",
        type=str,
        default="cora",
        help="Dataset name ('cora', 'citeseer', 'pubmed')",
    )
73
74
75
76
77
78
    parser.add_argument(
        "--dt",
        type=str,
        default="float",
        help="data type(float, bfloat16)",
    )
hbsun2113's avatar
hbsun2113 committed
79
    args = parser.parse_args()
80
    print(f"Training with DGL built-in GraphSage module")
hbsun2113's avatar
hbsun2113 committed
81

82
    # load and preprocess dataset
83
84
85
86
    transform = (
        AddSelfLoop()
    )  # by default, it will first remove self-loops to prevent duplication
    if args.dataset == "cora":
87
        data = CoraGraphDataset(transform=transform)
88
    elif args.dataset == "citeseer":
89
        data = CiteseerGraphDataset(transform=transform)
90
    elif args.dataset == "pubmed":
91
92
        data = PubmedGraphDataset(transform=transform)
    else:
93
        raise ValueError("Unknown dataset: {}".format(args.dataset))
94
    g = data[0]
95
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
96
    g = g.int().to(device)
97
98
99
    features = g.ndata["feat"]
    labels = g.ndata["label"]
    masks = g.ndata["train_mask"], g.ndata["val_mask"]
100
101
102
103
104
105

    # create GraphSAGE model
    in_size = features.shape[1]
    out_size = data.num_classes
    model = SAGE(in_size, 16, out_size).to(device)

106
107
108
109
110
111
    # convert model and graph to bfloat16 if needed
    if args.dt == "bfloat16":
        g = dgl.to_bfloat16(g)
        features = features.to(dtype=torch.bfloat16)
        model = model.to(dtype=torch.bfloat16)

112
    # model training
113
    print("Training...")
114
115
116
    train(g, features, labels, masks, model)

    # test the model
117
118
    print("Testing...")
    acc = evaluate(g, features, labels, g.ndata["test_mask"], model)
119
    print("Test accuracy {:.4f}".format(acc))