train.py 3.99 KB
Newer Older
1
2
import argparse

3
import dgl
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
5
import dgl.nn as dglnn

6
import torch
7
import torch.nn as nn
8
import torch.nn.functional as F
9
from dgl import AddSelfLoop
10
11
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset

12

13
class GAT(nn.Module):
14
    def __init__(self, in_size, hid_size, out_size, heads):
15
16
17
        super().__init__()
        self.gat_layers = nn.ModuleList()
        # two-layer GAT
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
        self.gat_layers.append(
            dglnn.GATConv(
                in_size,
                hid_size,
                heads[0],
                feat_drop=0.6,
                attn_drop=0.6,
                activation=F.elu,
            )
        )
        self.gat_layers.append(
            dglnn.GATConv(
                hid_size * heads[0],
                out_size,
                heads[1],
                feat_drop=0.6,
                attn_drop=0.6,
                activation=None,
            )
        )

39
40
41
42
    def forward(self, g, inputs):
        h = inputs
        for i, layer in enumerate(self.gat_layers):
            h = layer(g, h)
43
            if i == 1:  # last layer
44
                h = h.mean(1)
45
            else:  # other layer(s)
46
47
                h = h.flatten(1)
        return h
48
49


50
def evaluate(g, features, labels, mask, model):
51
52
    model.eval()
    with torch.no_grad():
53
        logits = model(g, features)
54
55
        logits = logits[mask]
        labels = labels[mask]
56
57
58
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)
59
60


61
62
63
64
65
66
67
def train(g, features, labels, masks, model):
    # define train/val samples, loss function and optimizer
    train_mask = masks[0]
    val_mask = masks[1]
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)

68
    # training loop
69
70
71
72
73
74
75
76
    for epoch in range(200):
        model.train()
        logits = model(g, features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc = evaluate(g, features, labels, val_mask, model)
77
78
79
80
81
82
83
84
        print(
            "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
                epoch, loss.item(), acc
            )
        )


if __name__ == "__main__":
85
    parser = argparse.ArgumentParser()
86
87
88
89
90
91
    parser.add_argument(
        "--dataset",
        type=str,
        default="cora",
        help="Dataset name ('cora', 'citeseer', 'pubmed').",
    )
92
93
94
95
96
97
    parser.add_argument(
        "--dt",
        type=str,
        default="float",
        help="data type(float, bfloat16)",
    )
98
    args = parser.parse_args()
99
100
    print(f"Training with DGL built-in GATConv module.")

101
    # load and preprocess dataset
102
103
104
105
    transform = (
        AddSelfLoop()
    )  # by default, it will first remove self-loops to prevent duplication
    if args.dataset == "cora":
106
        data = CoraGraphDataset(transform=transform)
107
    elif args.dataset == "citeseer":
108
        data = CiteseerGraphDataset(transform=transform)
109
    elif args.dataset == "pubmed":
110
        data = PubmedGraphDataset(transform=transform)
111
    else:
112
        raise ValueError("Unknown dataset: {}".format(args.dataset))
113
    g = data[0]
114
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
115
    g = g.int().to(device)
116
117
118
119
120
    features = g.ndata["feat"]
    labels = g.ndata["label"]
    masks = g.ndata["train_mask"], g.ndata["val_mask"], g.ndata["test_mask"]

    # create GAT model
121
122
    in_size = features.shape[1]
    out_size = data.num_classes
123
124
    model = GAT(in_size, 8, out_size, heads=[8, 1]).to(device)

125
126
127
128
129
130
    # convert model and graph to bfloat16 if needed
    if args.dt == "bfloat16":
        g = dgl.to_bfloat16(g)
        features = features.to(dtype=torch.bfloat16)
        model = model.to(dtype=torch.bfloat16)

131
    # model training
132
    print("Training...")
133
134
135
    train(g, features, labels, masks, model)

    # test the model
136
    print("Testing...")
137
138
    acc = evaluate(g, features, labels, masks[2], model)
    print("Test accuracy {:.4f}".format(acc))