main.py 8.02 KB
Newer Older
lt610's avatar
lt610 committed
1
import argparse
2
3
4

import numpy as np
import torch
lt610's avatar
lt610 committed
5
6
7
8
from torch import nn
from torch.nn import Parameter
from torch.nn import functional as F
from tqdm import trange
9
10
11
12
from utils import evaluate, generate_random_seeds, set_random_state

import dgl.function as fn
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
lt610's avatar
lt610 committed
13
14
15


class DAGNNConv(nn.Module):
16
    def __init__(self, in_dim, k):
lt610's avatar
lt610 committed
17
18
19
20
21
22
23
24
        super(DAGNNConv, self).__init__()

        self.s = Parameter(torch.FloatTensor(in_dim, 1))
        self.k = k

        self.reset_parameters()

    def reset_parameters(self):
25
        gain = nn.init.calculate_gain("sigmoid")
lt610's avatar
lt610 committed
26
27
28
29
30
31
32
33
34
35
36
37
38
        nn.init.xavier_uniform_(self.s, gain=gain)

    def forward(self, graph, feats):

        with graph.local_scope():
            results = [feats]

            degs = graph.in_degrees().float()
            norm = torch.pow(degs, -0.5)
            norm = norm.to(feats.device).unsqueeze(1)

            for _ in range(self.k):
                feats = feats * norm
39
40
41
                graph.ndata["h"] = feats
                graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
                feats = graph.ndata["h"]
lt610's avatar
lt610 committed
42
43
44
45
46
47
48
49
50
51
52
53
                feats = feats * norm
                results.append(feats)

            H = torch.stack(results, dim=1)
            S = F.sigmoid(torch.matmul(H, self.s))
            S = S.permute(0, 2, 1)
            H = torch.matmul(S, H).squeeze()

            return H


class MLPLayer(nn.Module):
54
    def __init__(self, in_dim, out_dim, bias=True, activation=None, dropout=0):
lt610's avatar
lt610 committed
55
56
57
58
59
60
61
62
        super(MLPLayer, self).__init__()

        self.linear = nn.Linear(in_dim, out_dim, bias=bias)
        self.activation = activation
        self.dropout = nn.Dropout(dropout)
        self.reset_parameters()

    def reset_parameters(self):
63
        gain = 1.0
lt610's avatar
lt610 committed
64
        if self.activation is F.relu:
65
            gain = nn.init.calculate_gain("relu")
lt610's avatar
lt610 committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        nn.init.xavier_uniform_(self.linear.weight, gain=gain)
        if self.linear.bias is not None:
            nn.init.zeros_(self.linear.bias)

    def forward(self, feats):

        feats = self.dropout(feats)
        feats = self.linear(feats)
        if self.activation:
            feats = self.activation(feats)

        return feats


class DAGNN(nn.Module):
81
82
83
84
85
86
87
88
89
90
    def __init__(
        self,
        k,
        in_dim,
        hid_dim,
        out_dim,
        bias=True,
        activation=F.relu,
        dropout=0,
    ):
lt610's avatar
lt610 committed
91
92
        super(DAGNN, self).__init__()
        self.mlp = nn.ModuleList()
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        self.mlp.append(
            MLPLayer(
                in_dim=in_dim,
                out_dim=hid_dim,
                bias=bias,
                activation=activation,
                dropout=dropout,
            )
        )
        self.mlp.append(
            MLPLayer(
                in_dim=hid_dim,
                out_dim=out_dim,
                bias=bias,
                activation=None,
                dropout=dropout,
            )
        )
lt610's avatar
lt610 committed
111
112
113
114
115
116
117
118
119
120
121
122
        self.dagnn = DAGNNConv(in_dim=out_dim, k=k)

    def forward(self, graph, feats):
        for layer in self.mlp:
            feats = layer(feats)
        feats = self.dagnn(graph, feats)
        return feats


def main(args):
    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    # Load from DGL dataset
123
    if args.dataset == "Cora":
lt610's avatar
lt610 committed
124
        dataset = CoraGraphDataset()
125
    elif args.dataset == "Citeseer":
lt610's avatar
lt610 committed
126
        dataset = CiteseerGraphDataset()
127
    elif args.dataset == "Pubmed":
lt610's avatar
lt610 committed
128
129
        dataset = PubmedGraphDataset()
    else:
130
        raise ValueError("Dataset {} is invalid.".format(args.dataset))
lt610's avatar
lt610 committed
131
132
133
134
135
136

    graph = dataset[0]
    graph = graph.add_self_loop()

    # check cuda
    if args.gpu >= 0 and torch.cuda.is_available():
137
        device = "cuda:{}".format(args.gpu)
lt610's avatar
lt610 committed
138
    else:
139
        device = "cpu"
lt610's avatar
lt610 committed
140
141
142
143
144

    # retrieve the number of classes
    n_classes = dataset.num_classes

    # retrieve labels of ground truth
145
    labels = graph.ndata.pop("label").to(device).long()
lt610's avatar
lt610 committed
146
147

    # Extract node features
148
    feats = graph.ndata.pop("feat").to(device)
lt610's avatar
lt610 committed
149
150
151
    n_features = feats.shape[-1]

    # retrieve masks for train/validation/test
152
153
154
    train_mask = graph.ndata.pop("train_mask")
    val_mask = graph.ndata.pop("val_mask")
    test_mask = graph.ndata.pop("test_mask")
lt610's avatar
lt610 committed
155
156
157
158
159
160
161
162

    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)
    val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device)
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device)

    graph = graph.to(device)

    # Step 2: Create model =================================================================== #
163
164
165
166
167
168
169
    model = DAGNN(
        k=args.k,
        in_dim=n_features,
        hid_dim=args.hid_dim,
        out_dim=n_classes,
        dropout=args.dropout,
    )
lt610's avatar
lt610 committed
170
171
172
173
    model = model.to(device)

    # Step 3: Create training components ===================================================== #
    loss_fn = F.cross_entropy
174
175
176
    opt = torch.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.lamb
    )
lt610's avatar
lt610 committed
177
178

    # Step 4: training epochs =============================================================== #
179
    loss = float("inf")
lt610's avatar
lt610 committed
180
181
    best_acc = 0
    no_improvement = 0
182
    epochs = trange(args.epochs, desc="Accuracy & Loss")
lt610's avatar
lt610 committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196

    for _ in epochs:
        model.train()

        logits = model(graph, feats)

        # compute loss
        train_loss = loss_fn(logits[train_idx], labels[train_idx])

        # backward
        opt.zero_grad()
        train_loss.backward()
        opt.step()

197
198
199
200
201
202
203
204
205
206
        (
            train_loss,
            train_acc,
            valid_loss,
            valid_acc,
            test_loss,
            test_acc,
        ) = evaluate(
            model, graph, feats, labels, (train_idx, val_idx, test_idx)
        )
lt610's avatar
lt610 committed
207
208

        # Print out performance
209
210
211
212
213
        epochs.set_description(
            "Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}".format(
                train_acc, train_loss.item(), valid_acc, valid_loss.item()
            )
        )
lt610's avatar
lt610 committed
214
215
216
217

        if valid_loss > loss:
            no_improvement += 1
            if no_improvement == args.early_stopping:
218
                print("Early stop.")
lt610's avatar
lt610 committed
219
220
221
222
223
224
225
226
227
228
229
230
                break
        else:
            no_improvement = 0
            loss = valid_loss
            best_acc = test_acc

    print("Test Acc {:.4f}".format(best_acc))
    return best_acc


if __name__ == "__main__":
    """
231
    DAGNN Model Hyperparameters
lt610's avatar
lt610 committed
232
    """
233
    parser = argparse.ArgumentParser(description="DAGNN")
lt610's avatar
lt610 committed
234
    # data source params
235
236
237
238
239
240
241
    parser.add_argument(
        "--dataset",
        type=str,
        default="Cora",
        choices=["Cora", "Citeseer", "Pubmed"],
        help="Name of dataset.",
    )
lt610's avatar
lt610 committed
242
    # cuda params
243
244
245
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
    )
lt610's avatar
lt610 committed
246
    # training params
247
248
249
250
251
252
253
254
255
256
257
258
    parser.add_argument("--runs", type=int, default=1, help="Training runs.")
    parser.add_argument(
        "--epochs", type=int, default=1500, help="Training epochs."
    )
    parser.add_argument(
        "--early-stopping",
        type=int,
        default=100,
        help="Patient epochs to wait before early stopping.",
    )
    parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
    parser.add_argument("--lamb", type=float, default=0.005, help="L2 reg.")
lt610's avatar
lt610 committed
259
    # model params
260
261
262
263
264
265
266
    parser.add_argument(
        "--k", type=int, default=12, help="Number of propagation layers."
    )
    parser.add_argument(
        "--hid-dim", type=int, default=64, help="Hidden layer dimensionalities."
    )
    parser.add_argument("--dropout", type=float, default=0.8, help="dropout")
lt610's avatar
lt610 committed
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    args = parser.parse_args()
    print(args)

    acc_lists = []
    random_seeds = generate_random_seeds(seed=1222, nums=args.runs)

    for run in range(args.runs):
        set_random_state(random_seeds[run])
        acc_lists.append(main(args))

    acc_lists = np.array(acc_lists)

    mean = np.around(np.mean(acc_lists, axis=0), decimals=4)
    std = np.around(np.std(acc_lists, axis=0), decimals=4)

282
283
284
    print("Total acc: ", acc_lists)
    print("mean", mean)
    print("std", std)