citation.py 5.46 KB
Newer Older
1
2
3
4
""" The main file to train an ARMA model using a full graph """

import argparse
import copy
5
6

import numpy as np
7
8
import torch
import torch.nn as nn
9
10
11
import torch.optim as optim

from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
12
13
from model import ARMA4NC
from tqdm import trange
14

15
16
17
18

def main(args):
    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    # Load from DGL dataset
19
    if args.dataset == "Cora":
20
        dataset = CoraGraphDataset()
21
    elif args.dataset == "Citeseer":
22
        dataset = CiteseerGraphDataset()
23
    elif args.dataset == "Pubmed":
24
25
        dataset = PubmedGraphDataset()
    else:
26
27
        raise ValueError("Dataset {} is invalid.".format(args.dataset))

28
29
30
    graph = dataset[0]

    # check cuda
31
32
33
34
35
    device = (
        f"cuda:{args.gpu}"
        if args.gpu >= 0 and torch.cuda.is_available()
        else "cpu"
    )
36
37
38
39
40

    # retrieve the number of classes
    n_classes = dataset.num_classes

    # retrieve labels of ground truth
41
    labels = graph.ndata.pop("label").to(device).long()
42
43

    # Extract node features
44
    feats = graph.ndata.pop("feat").to(device)
45
46
47
    n_features = feats.shape[-1]

    # retrieve masks for train/validation/test
48
49
50
    train_mask = graph.ndata.pop("train_mask")
    val_mask = graph.ndata.pop("val_mask")
    test_mask = graph.ndata.pop("test_mask")
51
52
53
54
55
56
57
58

    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)
    val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device)
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device)

    graph = graph.to(device)

    # Step 2: Create model =================================================================== #
59
60
61
62
63
64
65
66
67
68
    model = ARMA4NC(
        in_dim=n_features,
        hid_dim=args.hid_dim,
        out_dim=n_classes,
        num_stacks=args.num_stacks,
        num_layers=args.num_layers,
        activation=nn.ReLU(),
        dropout=args.dropout,
    ).to(device)

69
70
71
72
73
74
75
76
77
    best_model = copy.deepcopy(model)

    # Step 3: Create training components ===================================================== #
    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.lamb)

    # Step 4: training epoches =============================================================== #
    acc = 0
    no_improvement = 0
78
    epochs = trange(args.epochs, desc="Accuracy & Loss")
79
80
81
82
83
84
85
86
87

    for _ in epochs:
        # Training using a full graph
        model.train()

        logits = model(graph, feats)

        # compute loss
        train_loss = loss_fn(logits[train_idx], labels[train_idx])
88
89
90
        train_acc = torch.sum(
            logits[train_idx].argmax(dim=1) == labels[train_idx]
        ).item() / len(train_idx)
91
92
93
94
95
96
97
98
99
100
101

        # backward
        opt.zero_grad()
        train_loss.backward()
        opt.step()

        # Validation using a full graph
        model.eval()

        with torch.no_grad():
            valid_loss = loss_fn(logits[val_idx], labels[val_idx])
102
103
104
            valid_acc = torch.sum(
                logits[val_idx].argmax(dim=1) == labels[val_idx]
            ).item() / len(val_idx)
105
106

        # Print out performance
107
108
109
110
111
112
        epochs.set_description(
            "Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}".format(
                train_acc, train_loss.item(), valid_acc, valid_loss.item()
            )
        )

113
114
115
        if valid_acc < acc:
            no_improvement += 1
            if no_improvement == args.early_stopping:
116
                print("Early stop.")
117
118
119
120
121
122
123
124
                break
        else:
            no_improvement = 0
            acc = valid_acc
            best_model = copy.deepcopy(model)

    best_model.eval()
    logits = best_model(graph, feats)
125
126
127
    test_acc = torch.sum(
        logits[test_idx].argmax(dim=1) == labels[test_idx]
    ).item() / len(test_idx)
128
129
130
131

    print("Test Acc {:.4f}".format(test_acc))
    return test_acc

132

133
134
135
136
if __name__ == "__main__":
    """
    ARMA Model Hyperparameters
    """
137
    parser = argparse.ArgumentParser(description="ARMA GCN")
138
139

    # data source params
140
141
142
    parser.add_argument(
        "--dataset", type=str, default="Cora", help="Name of dataset."
    )
143
    # cuda params
144
145
146
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
    )
147
    # training params
148
149
150
151
152
153
154
155
156
157
158
    parser.add_argument(
        "--epochs", type=int, default=2000, help="Training epochs."
    )
    parser.add_argument(
        "--early-stopping",
        type=int,
        default=100,
        help="Patient epochs to wait before early stopping.",
    )
    parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
    parser.add_argument("--lamb", type=float, default=5e-4, help="L2 reg.")
159
    # model params
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    parser.add_argument(
        "--hid-dim", type=int, default=16, help="Hidden layer dimensionalities."
    )
    parser.add_argument(
        "--num-stacks", type=int, default=2, help="Number of K."
    )
    parser.add_argument(
        "--num-layers", type=int, default=1, help="Number of T."
    )
    parser.add_argument(
        "--dropout",
        type=float,
        default=0.75,
        help="Dropout applied at all layers.",
    )
175
176
177
178
179
180
181
182
183
184
185

    args = parser.parse_args()
    print(args)

    acc_lists = []

    for _ in range(100):
        acc_lists.append(main(args))

    mean = np.around(np.mean(acc_lists, axis=0), decimals=3)
    std = np.around(np.std(acc_lists, axis=0), decimals=3)
186
187
188
    print("Total acc: ", acc_lists)
    print("mean", mean)
    print("std", std)