citation.py 5.46 KB
Newer Older
1
2
3
4
""" The main file to train an ARMA model using a full graph """

import argparse
import copy
5
6

import numpy as np
7
8
import torch
import torch.nn as nn
9
import torch.optim as optim
10
from model import ARMA4NC
11
12
13
14
from tqdm import trange

from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset

15
16
17
18

def main(args):
    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    # Load from DGL dataset
19
    if args.dataset == "Cora":
20
        dataset = CoraGraphDataset()
21
    elif args.dataset == "Citeseer":
22
        dataset = CiteseerGraphDataset()
23
    elif args.dataset == "Pubmed":
24
25
        dataset = PubmedGraphDataset()
    else:
26
27
        raise ValueError("Dataset {} is invalid.".format(args.dataset))

28
29
30
    graph = dataset[0]

    # check cuda
31
32
33
34
35
    device = (
        f"cuda:{args.gpu}"
        if args.gpu >= 0 and torch.cuda.is_available()
        else "cpu"
    )
36
37
38
39
40

    # retrieve the number of classes
    n_classes = dataset.num_classes

    # retrieve labels of ground truth
41
    labels = graph.ndata.pop("label").to(device).long()
42
43

    # Extract node features
44
    feats = graph.ndata.pop("feat").to(device)
45
46
47
    n_features = feats.shape[-1]

    # retrieve masks for train/validation/test
48
49
50
    train_mask = graph.ndata.pop("train_mask")
    val_mask = graph.ndata.pop("val_mask")
    test_mask = graph.ndata.pop("test_mask")
51
52
53
54
55
56
57
58

    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)
    val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device)
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device)

    graph = graph.to(device)

    # Step 2: Create model =================================================================== #
59
60
61
62
63
64
65
66
67
68
    model = ARMA4NC(
        in_dim=n_features,
        hid_dim=args.hid_dim,
        out_dim=n_classes,
        num_stacks=args.num_stacks,
        num_layers=args.num_layers,
        activation=nn.ReLU(),
        dropout=args.dropout,
    ).to(device)

69
70
71
72
73
74
75
76
77
    best_model = copy.deepcopy(model)

    # Step 3: Create training components ===================================================== #
    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.lamb)

    # Step 4: training epoches =============================================================== #
    acc = 0
    no_improvement = 0
78
    epochs = trange(args.epochs, desc="Accuracy & Loss")
79
80
81
82
83
84
85
86
87

    for _ in epochs:
        # Training using a full graph
        model.train()

        logits = model(graph, feats)

        # compute loss
        train_loss = loss_fn(logits[train_idx], labels[train_idx])
88
89
90
        train_acc = torch.sum(
            logits[train_idx].argmax(dim=1) == labels[train_idx]
        ).item() / len(train_idx)
91
92
93
94
95
96
97
98
99
100
101

        # backward
        opt.zero_grad()
        train_loss.backward()
        opt.step()

        # Validation using a full graph
        model.eval()

        with torch.no_grad():
            valid_loss = loss_fn(logits[val_idx], labels[val_idx])
102
103
104
            valid_acc = torch.sum(
                logits[val_idx].argmax(dim=1) == labels[val_idx]
            ).item() / len(val_idx)
105
106

        # Print out performance
107
108
109
110
111
112
        epochs.set_description(
            "Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}".format(
                train_acc, train_loss.item(), valid_acc, valid_loss.item()
            )
        )

113
114
115
        if valid_acc < acc:
            no_improvement += 1
            if no_improvement == args.early_stopping:
116
                print("Early stop.")
117
118
119
120
121
122
123
124
                break
        else:
            no_improvement = 0
            acc = valid_acc
            best_model = copy.deepcopy(model)

    best_model.eval()
    logits = best_model(graph, feats)
125
126
127
    test_acc = torch.sum(
        logits[test_idx].argmax(dim=1) == labels[test_idx]
    ).item() / len(test_idx)
128
129
130
131

    print("Test Acc {:.4f}".format(test_acc))
    return test_acc

132

133
134
135
136
if __name__ == "__main__":
    """
    ARMA Model Hyperparameters
    """
137
    parser = argparse.ArgumentParser(description="ARMA GCN")
138
139

    # data source params
140
141
142
    parser.add_argument(
        "--dataset", type=str, default="Cora", help="Name of dataset."
    )
143
    # cuda params
144
145
146
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
    )
147
    # training params
148
149
150
151
152
153
154
155
156
157
158
    parser.add_argument(
        "--epochs", type=int, default=2000, help="Training epochs."
    )
    parser.add_argument(
        "--early-stopping",
        type=int,
        default=100,
        help="Patient epochs to wait before early stopping.",
    )
    parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
    parser.add_argument("--lamb", type=float, default=5e-4, help="L2 reg.")
159
    # model params
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    parser.add_argument(
        "--hid-dim", type=int, default=16, help="Hidden layer dimensionalities."
    )
    parser.add_argument(
        "--num-stacks", type=int, default=2, help="Number of K."
    )
    parser.add_argument(
        "--num-layers", type=int, default=1, help="Number of T."
    )
    parser.add_argument(
        "--dropout",
        type=float,
        default=0.75,
        help="Dropout applied at all layers.",
    )
175
176
177
178
179
180
181
182
183
184
185

    args = parser.parse_args()
    print(args)

    acc_lists = []

    for _ in range(100):
        acc_lists.append(main(args))

    mean = np.around(np.mean(acc_lists, axis=0), decimals=3)
    std = np.around(np.std(acc_lists, axis=0), decimals=3)
186
187
188
    print("Total acc: ", acc_lists)
    print("mean", mean)
    print("std", std)