main.py 7.99 KB
Newer Older
lt610's avatar
lt610 committed
1
import argparse
2

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3
4
import dgl.function as fn

5
6
import numpy as np
import torch
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
7
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
lt610's avatar
lt610 committed
8
from torch import nn
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
from torch.nn import functional as F, Parameter
lt610's avatar
lt610 committed
10
from tqdm import trange
11
12
from utils import evaluate, generate_random_seeds, set_random_state

lt610's avatar
lt610 committed
13
14

class DAGNNConv(nn.Module):
15
    def __init__(self, in_dim, k):
lt610's avatar
lt610 committed
16
17
18
19
20
21
22
23
        super(DAGNNConv, self).__init__()

        self.s = Parameter(torch.FloatTensor(in_dim, 1))
        self.k = k

        self.reset_parameters()

    def reset_parameters(self):
24
        gain = nn.init.calculate_gain("sigmoid")
lt610's avatar
lt610 committed
25
26
27
28
29
30
31
32
33
34
35
36
        nn.init.xavier_uniform_(self.s, gain=gain)

    def forward(self, graph, feats):
        with graph.local_scope():
            results = [feats]

            degs = graph.in_degrees().float()
            norm = torch.pow(degs, -0.5)
            norm = norm.to(feats.device).unsqueeze(1)

            for _ in range(self.k):
                feats = feats * norm
37
38
39
                graph.ndata["h"] = feats
                graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
                feats = graph.ndata["h"]
lt610's avatar
lt610 committed
40
41
42
43
44
45
46
47
48
49
50
51
                feats = feats * norm
                results.append(feats)

            H = torch.stack(results, dim=1)
            S = F.sigmoid(torch.matmul(H, self.s))
            S = S.permute(0, 2, 1)
            H = torch.matmul(S, H).squeeze()

            return H


class MLPLayer(nn.Module):
52
    def __init__(self, in_dim, out_dim, bias=True, activation=None, dropout=0):
lt610's avatar
lt610 committed
53
54
55
56
57
58
59
60
        super(MLPLayer, self).__init__()

        self.linear = nn.Linear(in_dim, out_dim, bias=bias)
        self.activation = activation
        self.dropout = nn.Dropout(dropout)
        self.reset_parameters()

    def reset_parameters(self):
61
        gain = 1.0
lt610's avatar
lt610 committed
62
        if self.activation is F.relu:
63
            gain = nn.init.calculate_gain("relu")
lt610's avatar
lt610 committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        nn.init.xavier_uniform_(self.linear.weight, gain=gain)
        if self.linear.bias is not None:
            nn.init.zeros_(self.linear.bias)

    def forward(self, feats):
        feats = self.dropout(feats)
        feats = self.linear(feats)
        if self.activation:
            feats = self.activation(feats)

        return feats


class DAGNN(nn.Module):
78
79
80
81
82
83
84
85
86
87
    def __init__(
        self,
        k,
        in_dim,
        hid_dim,
        out_dim,
        bias=True,
        activation=F.relu,
        dropout=0,
    ):
lt610's avatar
lt610 committed
88
89
        super(DAGNN, self).__init__()
        self.mlp = nn.ModuleList()
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        self.mlp.append(
            MLPLayer(
                in_dim=in_dim,
                out_dim=hid_dim,
                bias=bias,
                activation=activation,
                dropout=dropout,
            )
        )
        self.mlp.append(
            MLPLayer(
                in_dim=hid_dim,
                out_dim=out_dim,
                bias=bias,
                activation=None,
                dropout=dropout,
            )
        )
lt610's avatar
lt610 committed
108
109
110
111
112
113
114
115
116
117
118
119
        self.dagnn = DAGNNConv(in_dim=out_dim, k=k)

    def forward(self, graph, feats):
        for layer in self.mlp:
            feats = layer(feats)
        feats = self.dagnn(graph, feats)
        return feats


def main(args):
    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    # Load from DGL dataset
120
    if args.dataset == "Cora":
lt610's avatar
lt610 committed
121
        dataset = CoraGraphDataset()
122
    elif args.dataset == "Citeseer":
lt610's avatar
lt610 committed
123
        dataset = CiteseerGraphDataset()
124
    elif args.dataset == "Pubmed":
lt610's avatar
lt610 committed
125
126
        dataset = PubmedGraphDataset()
    else:
127
        raise ValueError("Dataset {} is invalid.".format(args.dataset))
lt610's avatar
lt610 committed
128
129
130
131
132
133

    graph = dataset[0]
    graph = graph.add_self_loop()

    # check cuda
    if args.gpu >= 0 and torch.cuda.is_available():
134
        device = "cuda:{}".format(args.gpu)
lt610's avatar
lt610 committed
135
    else:
136
        device = "cpu"
lt610's avatar
lt610 committed
137
138
139
140
141

    # retrieve the number of classes
    n_classes = dataset.num_classes

    # retrieve labels of ground truth
142
    labels = graph.ndata.pop("label").to(device).long()
lt610's avatar
lt610 committed
143
144

    # Extract node features
145
    feats = graph.ndata.pop("feat").to(device)
lt610's avatar
lt610 committed
146
147
148
    n_features = feats.shape[-1]

    # retrieve masks for train/validation/test
149
150
151
    train_mask = graph.ndata.pop("train_mask")
    val_mask = graph.ndata.pop("val_mask")
    test_mask = graph.ndata.pop("test_mask")
lt610's avatar
lt610 committed
152
153
154
155
156
157
158
159

    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)
    val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device)
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device)

    graph = graph.to(device)

    # Step 2: Create model =================================================================== #
160
161
162
163
164
165
166
    model = DAGNN(
        k=args.k,
        in_dim=n_features,
        hid_dim=args.hid_dim,
        out_dim=n_classes,
        dropout=args.dropout,
    )
lt610's avatar
lt610 committed
167
168
169
170
    model = model.to(device)

    # Step 3: Create training components ===================================================== #
    loss_fn = F.cross_entropy
171
172
173
    opt = torch.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.lamb
    )
lt610's avatar
lt610 committed
174
175

    # Step 4: training epochs =============================================================== #
176
    loss = float("inf")
lt610's avatar
lt610 committed
177
178
    best_acc = 0
    no_improvement = 0
179
    epochs = trange(args.epochs, desc="Accuracy & Loss")
lt610's avatar
lt610 committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193

    for _ in epochs:
        model.train()

        logits = model(graph, feats)

        # compute loss
        train_loss = loss_fn(logits[train_idx], labels[train_idx])

        # backward
        opt.zero_grad()
        train_loss.backward()
        opt.step()

194
195
196
197
198
199
200
201
202
203
        (
            train_loss,
            train_acc,
            valid_loss,
            valid_acc,
            test_loss,
            test_acc,
        ) = evaluate(
            model, graph, feats, labels, (train_idx, val_idx, test_idx)
        )
lt610's avatar
lt610 committed
204
205

        # Print out performance
206
207
208
209
210
        epochs.set_description(
            "Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}".format(
                train_acc, train_loss.item(), valid_acc, valid_loss.item()
            )
        )
lt610's avatar
lt610 committed
211
212
213
214

        if valid_loss > loss:
            no_improvement += 1
            if no_improvement == args.early_stopping:
215
                print("Early stop.")
lt610's avatar
lt610 committed
216
217
218
219
220
221
222
223
224
225
226
227
                break
        else:
            no_improvement = 0
            loss = valid_loss
            best_acc = test_acc

    print("Test Acc {:.4f}".format(best_acc))
    return best_acc


if __name__ == "__main__":
    """
228
    DAGNN Model Hyperparameters
lt610's avatar
lt610 committed
229
    """
230
    parser = argparse.ArgumentParser(description="DAGNN")
lt610's avatar
lt610 committed
231
    # data source params
232
233
234
235
236
237
238
    parser.add_argument(
        "--dataset",
        type=str,
        default="Cora",
        choices=["Cora", "Citeseer", "Pubmed"],
        help="Name of dataset.",
    )
lt610's avatar
lt610 committed
239
    # cuda params
240
241
242
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
    )
lt610's avatar
lt610 committed
243
    # training params
244
245
246
247
248
249
250
251
252
253
254
255
    parser.add_argument("--runs", type=int, default=1, help="Training runs.")
    parser.add_argument(
        "--epochs", type=int, default=1500, help="Training epochs."
    )
    parser.add_argument(
        "--early-stopping",
        type=int,
        default=100,
        help="Patient epochs to wait before early stopping.",
    )
    parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
    parser.add_argument("--lamb", type=float, default=0.005, help="L2 reg.")
lt610's avatar
lt610 committed
256
    # model params
257
258
259
260
261
262
263
    parser.add_argument(
        "--k", type=int, default=12, help="Number of propagation layers."
    )
    parser.add_argument(
        "--hid-dim", type=int, default=64, help="Hidden layer dimensionalities."
    )
    parser.add_argument("--dropout", type=float, default=0.8, help="dropout")
lt610's avatar
lt610 committed
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    args = parser.parse_args()
    print(args)

    acc_lists = []
    random_seeds = generate_random_seeds(seed=1222, nums=args.runs)

    for run in range(args.runs):
        set_random_state(random_seeds[run])
        acc_lists.append(main(args))

    acc_lists = np.array(acc_lists)

    mean = np.around(np.mean(acc_lists, axis=0), decimals=4)
    std = np.around(np.std(acc_lists, axis=0), decimals=4)

279
280
281
    print("Total acc: ", acc_lists)
    print("mean", mean)
    print("std", std)