"vscode:/vscode.git/clone" did not exist on "88c344876f39a07ed3650c89264a8f9b20a9179c"
pubmed.py 4.46 KB
Newer Older
1
2
3
4
5
import argparse

import torch as th
import torch.optim as optim

6
from dgl.data import PubmedGraphDataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
7
8
from model import GeniePath, GeniePathLazy
from sklearn.metrics import accuracy_score
9
10
11
12
13
14
15
16
17
18


def main(args):
    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    # Load dataset
    dataset = PubmedGraphDataset()
    graph = dataset[0]

    # check cuda
    if args.gpu >= 0 and th.cuda.is_available():
19
        device = "cuda:{}".format(args.gpu)
20
    else:
21
        device = "cpu"
22
23
24
25

    num_classes = dataset.num_classes

    # retrieve label of ground truth
26
    label = graph.ndata["label"].to(device)
27
28

    # Extract node features
29
    feat = graph.ndata["feat"].to(device)
30
31

    # retrieve masks for train/validation/test
32
33
34
    train_mask = graph.ndata["train_mask"]
    val_mask = graph.ndata["val_mask"]
    test_mask = graph.ndata["test_mask"]
35
36
37
38
39
40
41
42
43

    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device)
    val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device)
    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device)

    graph = graph.to(device)

    # Step 2: Create model =================================================================== #
    if args.lazy:
44
45
46
47
48
49
50
51
        model = GeniePathLazy(
            in_dim=feat.shape[-1],
            out_dim=num_classes,
            hid_dim=args.hid_dim,
            num_layers=args.num_layers,
            num_heads=args.num_heads,
            residual=args.residual,
        )
52
    else:
53
54
55
56
57
58
59
60
        model = GeniePath(
            in_dim=feat.shape[-1],
            out_dim=num_classes,
            hid_dim=args.hid_dim,
            num_layers=args.num_layers,
            num_heads=args.num_heads,
            residual=args.residual,
        )
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    model = model.to(device)

    # Step 3: Create training components ===================================================== #
    loss_fn = th.nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Step 4: training epochs =============================================================== #
    for epoch in range(args.max_epoch):
        # Training and validation
        model.train()
        logits = model(graph, feat)

        # compute loss
        tr_loss = loss_fn(logits[train_idx], label[train_idx])
76
77
78
        tr_acc = accuracy_score(
            label[train_idx].cpu(), logits[train_idx].argmax(dim=1).cpu()
        )
79
80
81

        # validation
        valid_loss = loss_fn(logits[val_idx], label[val_idx])
82
83
84
        valid_acc = accuracy_score(
            label[val_idx].cpu(), logits[val_idx].argmax(dim=1).cpu()
        )
85
86
87
88
89
90
91

        # backward
        optimizer.zero_grad()
        tr_loss.backward()
        optimizer.step()

        # Print out performance
92
93
94
95
96
        print(
            "In epoch {}, Train ACC: {:.4f} | Train Loss: {:.4f}; Valid ACC: {:.4f} | Valid loss: {:.4f}".format(
                epoch, tr_acc, tr_loss.item(), valid_acc, valid_loss.item()
            )
        )
97
98
99
100
101
102
103
104
105

    # Test after all epoch
    model.eval()

    # forward
    logits = model(graph, feat)

    # compute loss
    test_loss = loss_fn(logits[test_idx], label[test_idx])
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    test_acc = accuracy_score(
        label[test_idx].cpu(), logits[test_idx].argmax(dim=1).cpu()
    )

    print(
        "Test ACC: {:.4f} | Test loss: {:.4f}".format(
            test_acc, test_loss.item()
        )
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GeniePath")
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU."
    )
    parser.add_argument(
        "--hid_dim", type=int, default=16, help="Hidden layer dimension"
    )
    parser.add_argument(
        "--num_layers", type=int, default=2, help="Number of GeniePath layers"
    )
    parser.add_argument(
        "--max_epoch",
        type=int,
        default=300,
        help="The max number of epochs. Default: 300",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.0004,
        help="Learning rate. Default: 0.0004",
    )
    parser.add_argument(
        "--num_heads",
        type=int,
        default=1,
        help="Number of head in breadth function. Default: 1",
    )
    parser.add_argument(
        "--residual", type=bool, default=False, help="Residual in GAT or not"
    )
    parser.add_argument(
        "--lazy", type=bool, default=False, help="Variant GeniePath-Lazy"
    )
152
153
154
155
156

    args = parser.parse_args()
    th.manual_seed(16)
    print(args)
    main(args)