main.py 7.1 KB
Newer Older
1
import copy
2
3
4
import os
import warnings

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
6
import dgl

7
import numpy as np
8
import torch
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
10
11
12
13
14
15
16
17
18
19
20
from eval_function import (
    fit_logistic_regression,
    fit_logistic_regression_preset_splits,
    fit_ppi_linear,
)
from model import (
    BGRL,
    compute_representations,
    GCN,
    GraphSAGE_GCN,
    MLP_Predictor,
)
21
from torch.nn.functional import cosine_similarity
22
23
24
from torch.optim import AdamW
from tqdm import tqdm
from utils import CosineDecayScheduler, get_dataset, get_graph_drop_transform
25
26
27
28

warnings.filterwarnings("ignore")


29
30
31
32
33
34
35
36
37
38
39
def train(
    step,
    model,
    optimizer,
    lr_scheduler,
    mm_scheduler,
    transform_1,
    transform_2,
    data,
    args,
):
40
41
42
43
44
    model.train()

    # update learning rate
    lr = lr_scheduler.get(step)
    for param_group in optimizer.param_groups:
45
        param_group["lr"] = lr
46
47
48
49
50
51
52
53
54

    # update momentum
    mm = 1 - mm_scheduler.get(step)

    # forward
    optimizer.zero_grad()

    x1, x2 = transform_1(data), transform_2(data)

55
    if args.dataset != "ppi":
56
57
58
59
60
        x1, x2 = dgl.add_self_loop(x1), dgl.add_self_loop(x2)

    q1, y2 = model(x1, x2)
    q2, y1 = model(x2, x1)

61
62
63
64
65
    loss = (
        2
        - cosine_similarity(q1, y2.detach(), dim=-1).mean()
        - cosine_similarity(q2, y1.detach(), dim=-1).mean()
    )
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    loss.backward()

    # update online network
    optimizer.step()
    # update target network
    model.update_target_network(mm)

    return loss.item()


def eval(model, dataset, device, args, train_data, val_data, test_data):
    # make temporary copy of encoder
    tmp_encoder = copy.deepcopy(model.online_encoder).eval()
    val_scores = None

81
    if args.dataset == "ppi":
82
83
84
85
        train_data = compute_representations(tmp_encoder, train_data, device)
        val_data = compute_representations(tmp_encoder, val_data, device)
        test_data = compute_representations(tmp_encoder, test_data, device)
        num_classes = train_data[1].shape[1]
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        val_scores, test_scores = fit_ppi_linear(
            num_classes,
            train_data,
            val_data,
            test_data,
            device,
            args.num_eval_splits,
        )
    elif args.dataset != "wiki_cs":
        representations, labels = compute_representations(
            tmp_encoder, dataset, device
        )
        test_scores = fit_logistic_regression(
            representations.cpu().numpy(),
            labels.cpu().numpy(),
            data_random_seed=args.data_seed,
            repeat=args.num_eval_splits,
        )
104
105
    else:
        g = dataset[0]
106
107
108
109
110
111
112
113
114
115
116
117
118
        train_mask = g.ndata["train_mask"]
        val_mask = g.ndata["val_mask"]
        test_mask = g.ndata["test_mask"]
        representations, labels = compute_representations(
            tmp_encoder, dataset, device
        )
        test_scores = fit_logistic_regression_preset_splits(
            representations.cpu().numpy(),
            labels.cpu().numpy(),
            train_mask,
            val_mask,
            test_mask,
        )
119
120
121
122
123
124

    return val_scores, test_scores


def main(args):
    # use CUDA_VISIBLE_DEVICES to select gpu
125
126
127
128
129
130
    device = (
        torch.device("cuda")
        if torch.cuda.is_available()
        else torch.device("cpu")
    )
    print("Using device:", device)
131
132
133
134
135
136

    dataset, train_data, val_data, test_data = get_dataset(args.dataset)

    g = dataset[0]
    g = g.to(device)

137
138
139
140
    input_size, representation_size = (
        g.ndata["feat"].size(1),
        args.graph_encoder_layer[-1],
    )
141
142

    # prepare transforms
143
144
145
146
147
148
    transform_1 = get_graph_drop_transform(
        drop_edge_p=args.drop_edge_p[0], feat_mask_p=args.feat_mask_p[0]
    )
    transform_2 = get_graph_drop_transform(
        drop_edge_p=args.drop_edge_p[1], feat_mask_p=args.feat_mask_p[1]
    )
149
150

    # scheduler
151
152
153
    lr_scheduler = CosineDecayScheduler(
        args.lr, args.lr_warmup_epochs, args.epochs
    )
154
155
156
    mm_scheduler = CosineDecayScheduler(1 - args.mm, 0, args.epochs)

    # build networks
157
    if args.dataset == "ppi":
158
159
160
        encoder = GraphSAGE_GCN([input_size] + args.graph_encoder_layer)
    else:
        encoder = GCN([input_size] + args.graph_encoder_layer)
161
162
163
164
165
    predictor = MLP_Predictor(
        representation_size,
        representation_size,
        hidden_size=args.predictor_hidden_size,
    )
166
167
168
    model = BGRL(encoder, predictor).to(device)

    # optimizer
169
170
171
    optimizer = AdamW(
        model.trainable_parameters(), lr=args.lr, weight_decay=args.weight_decay
    )
172
173

    # train
174
175
176
177
178
179
180
181
182
183
184
185
    for epoch in tqdm(range(1, args.epochs + 1), desc="  - (Training)  "):
        train(
            epoch - 1,
            model,
            optimizer,
            lr_scheduler,
            mm_scheduler,
            transform_1,
            transform_2,
            g,
            args,
        )
186
        if epoch % args.eval_epochs == 0:
187
188
189
190
191
192
193
194
195
            val_scores, test_scores = eval(
                model, dataset, device, args, train_data, val_data, test_data
            )
            if args.dataset == "ppi":
                print(
                    "Epoch: {:04d} | Best Val F1: {:.4f} | Test F1: {:.4f}".format(
                        epoch, np.mean(val_scores), np.mean(test_scores)
                    )
                )
196
            else:
197
198
199
200
201
                print(
                    "Epoch: {:04d} | Test Accuracy: {:.4f}".format(
                        epoch, np.mean(test_scores)
                    )
                )
202
203
204
205

    # save encoder weights
    if not os.path.isdir(args.weights_dir):
        os.mkdir(args.weights_dir)
206
207
208
209
    torch.save(
        {"model": model.online_encoder.state_dict()},
        os.path.join(args.weights_dir, "bgrl-{}.pt".format(args.dataset)),
    )
210
211


212
if __name__ == "__main__":
213
214
215
216
217
    from argparse import ArgumentParser

    parser = ArgumentParser()

    # Dataset options.
218
219
220
221
222
223
224
225
226
227
228
229
230
    parser.add_argument(
        "--dataset",
        type=str,
        default="amazon_photos",
        choices=[
            "coauthor_cs",
            "coauthor_physics",
            "amazon_photos",
            "amazon_computers",
            "wiki_cs",
            "ppi",
        ],
    )
231
232

    # Model options.
233
234
235
236
    parser.add_argument(
        "--graph_encoder_layer", type=int, nargs="+", default=[256, 128]
    )
    parser.add_argument("--predictor_hidden_size", type=int, default=512)
237
238

    # Training options.
239
240
241
242
243
244
    parser.add_argument("--epochs", type=int, default=10000)
    parser.add_argument("--lr", type=float, default=1e-5)
    parser.add_argument("--weight_decay", type=float, default=1e-5)
    parser.add_argument("--mm", type=float, default=0.99)
    parser.add_argument("--lr_warmup_epochs", type=int, default=1000)
    parser.add_argument("--weights_dir", type=str, default="../weights")
245
246

    # Augmentations options.
247
248
249
250
251
252
    parser.add_argument(
        "--drop_edge_p", type=float, nargs="+", default=[0.0, 0.0]
    )
    parser.add_argument(
        "--feat_mask_p", type=float, nargs="+", default=[0.0, 0.0]
    )
253
254

    # Evaluation options.
255
256
257
    parser.add_argument("--eval_epochs", type=int, default=250)
    parser.add_argument("--num_eval_splits", type=int, default=20)
    parser.add_argument("--data_seed", type=int, default=1)
258
259

    # Experiment options.
260
    parser.add_argument("--num_experiments", type=int, default=20)
261
262
263
264

    args = parser.parse_args()

    main(args)