main.py 6.38 KB
Newer Older
Smile's avatar
Smile committed
1
import time
2

Smile's avatar
Smile committed
3
4
import numpy as np
import torch
5
6
7
8
import torch.multiprocessing
from logger import LightLogging
from model import DGCNN, GCN
from sampler import SEALData
Smile's avatar
Smile committed
9
from torch.nn import BCEWithLogitsLoss
10
11
12
13
from tqdm import tqdm
from utils import evaluate_hits, load_ogb_dataset, parse_arguments

from dgl import EID, NID
Smile's avatar
Smile committed
14
15
from dgl.dataloading import GraphDataLoader

16
17
18
torch.multiprocessing.set_sharing_strategy("file_system")

"""
Smile's avatar
Smile committed
19
20
Part of the code are adapted from
https://github.com/facebookresearch/SEAL_OGB
21
22
23
24
25
26
27
28
29
30
31
32
"""


def train(
    model,
    dataloader,
    loss_fn,
    optimizer,
    device,
    num_graphs=32,
    total_graphs=None,
):
Smile's avatar
Smile committed
33
34
35
36
37
38
39
    model.train()

    total_loss = 0
    for g, labels in tqdm(dataloader, ncols=100):
        g = g.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
40
        logits = model(g, g.ndata["z"], g.ndata[NID], g.edata[EID])
Smile's avatar
Smile committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * num_graphs

    return total_loss / total_graphs


@torch.no_grad()
def evaluate(model, dataloader, device):
    model.eval()

    y_pred, y_true = [], []
    for g, labels in tqdm(dataloader, ncols=100):
        g = g.to(device)
56
        logits = model(g, g.ndata["z"], g.ndata[NID], g.edata[EID])
Smile's avatar
Smile committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        y_pred.append(logits.view(-1).cpu())
        y_true.append(labels.view(-1).cpu().to(torch.float))

    y_pred, y_true = torch.cat(y_pred), torch.cat(y_true)
    pos_pred = y_pred[y_true == 1]
    neg_pred = y_pred[y_true == 0]

    return pos_pred, neg_pred


def main(args, print_fn=print):
    print_fn("Experiment arguments: {}".format(args))

    if args.random_seed:
        torch.manual_seed(args.random_seed)
    else:
        torch.manual_seed(123)
    # Load dataset
75
    if args.dataset.startswith("ogbl"):
Smile's avatar
Smile committed
76
77
78
79
80
81
82
83
        graph, split_edge = load_ogb_dataset(args.dataset)
    else:
        raise NotImplementedError

    num_nodes = graph.num_nodes()

    # set gpu
    if args.gpu_id >= 0 and torch.cuda.is_available():
84
        device = "cuda:{}".format(args.gpu_id)
Smile's avatar
Smile committed
85
    else:
86
        device = "cpu"
Smile's avatar
Smile committed
87

88
    if args.dataset == "ogbl-collab":
Smile's avatar
Smile committed
89
90
91
92
93
94
95
        # ogbl-collab dataset is multi-edge graph
        use_coalesce = True
    else:
        use_coalesce = False

    # Generate positive and negative edges and corresponding labels
    # Sampling subgraphs and generate node labeling features
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    seal_data = SEALData(
        g=graph,
        split_edge=split_edge,
        hop=args.hop,
        neg_samples=args.neg_samples,
        subsample_ratio=args.subsample_ratio,
        use_coalesce=use_coalesce,
        prefix=args.dataset,
        save_dir=args.save_dir,
        num_workers=args.num_workers,
        print_fn=print_fn,
    )
    node_attribute = seal_data.ndata["feat"]
    edge_weight = seal_data.edata["weight"].float()

    train_data = seal_data("train")
    val_data = seal_data("valid")
    test_data = seal_data("test")
Smile's avatar
Smile committed
114
115
116
117
118

    train_graphs = len(train_data.graph_list)

    # Set data loader

119
120
121
122
123
124
125
126
127
    train_loader = GraphDataLoader(
        train_data, batch_size=args.batch_size, num_workers=args.num_workers
    )
    val_loader = GraphDataLoader(
        val_data, batch_size=args.batch_size, num_workers=args.num_workers
    )
    test_loader = GraphDataLoader(
        test_data, batch_size=args.batch_size, num_workers=args.num_workers
    )
Smile's avatar
Smile committed
128
129

    # set model
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    if args.model == "gcn":
        model = GCN(
            num_layers=args.num_layers,
            hidden_units=args.hidden_units,
            gcn_type=args.gcn_type,
            pooling_type=args.pooling,
            node_attributes=node_attribute,
            edge_weights=edge_weight,
            node_embedding=None,
            use_embedding=True,
            num_nodes=num_nodes,
            dropout=args.dropout,
        )
    elif args.model == "dgcnn":
        model = DGCNN(
            num_layers=args.num_layers,
            hidden_units=args.hidden_units,
            k=args.sort_k,
            gcn_type=args.gcn_type,
            node_attributes=node_attribute,
            edge_weights=edge_weight,
            node_embedding=None,
            use_embedding=True,
            num_nodes=num_nodes,
            dropout=args.dropout,
        )
Smile's avatar
Smile committed
156
    else:
157
        raise ValueError("Model error")
Smile's avatar
Smile committed
158
159
160
161
162

    model = model.to(device)
    parameters = model.parameters()
    optimizer = torch.optim.Adam(parameters, lr=args.lr)
    loss_fn = BCEWithLogitsLoss()
163
164
165
166
167
    print_fn(
        "Total parameters: {}".format(
            sum([p.numel() for p in model.parameters()])
        )
    )
Smile's avatar
Smile committed
168
169
170
171
172
173

    # train and evaluate loop
    summary_val = []
    summary_test = []
    for epoch in range(args.epochs):
        start_time = time.time()
174
175
176
177
178
179
180
181
182
        loss = train(
            model=model,
            dataloader=train_loader,
            loss_fn=loss_fn,
            optimizer=optimizer,
            device=device,
            num_graphs=args.batch_size,
            total_graphs=train_graphs,
        )
Smile's avatar
Smile committed
183
184
        train_time = time.time()
        if epoch % args.eval_steps == 0:
185
186
187
188
189
190
191
192
193
194
195
196
197
            val_pos_pred, val_neg_pred = evaluate(
                model=model, dataloader=val_loader, device=device
            )
            test_pos_pred, test_neg_pred = evaluate(
                model=model, dataloader=test_loader, device=device
            )

            val_metric = evaluate_hits(
                args.dataset, val_pos_pred, val_neg_pred, args.hits_k
            )
            test_metric = evaluate_hits(
                args.dataset, test_pos_pred, test_neg_pred, args.hits_k
            )
Smile's avatar
Smile committed
198
            evaluate_time = time.time()
199
200
201
202
203
204
205
206
207
208
209
210
            print_fn(
                "Epoch-{}, train loss: {:.4f}, hits@{}: val-{:.4f}, test-{:.4f}, "
                "cost time: train-{:.1f}s, total-{:.1f}s".format(
                    epoch,
                    loss,
                    args.hits_k,
                    val_metric,
                    test_metric,
                    train_time - start_time,
                    evaluate_time - start_time,
                )
            )
Smile's avatar
Smile committed
211
212
213
214
215
216
            summary_val.append(val_metric)
            summary_test.append(test_metric)

    summary_test = np.array(summary_test)

    print_fn("Experiment Results:")
217
218
219
220
221
    print_fn(
        "Best hits@{}: {:.4f}, epoch: {}".format(
            args.hits_k, np.max(summary_test), np.argmax(summary_test)
        )
    )
Smile's avatar
Smile committed
222
223


224
if __name__ == "__main__":
Smile's avatar
Smile committed
225
    args = parse_arguments()
226
    logger = LightLogging(log_name="SEAL", log_path="./logs")
Smile's avatar
Smile committed
227
    main(args, logger.info)