"router/vscode:/vscode.git/clone" did not exist on "20c3c5940c6af1ceb50a8b4c713443690a148190"
test_inference.py 5.78 KB
Newer Older
1
2
3
4
import argparse
import os
import random

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
6
import dgl

7
8
9
import numpy as np
import torch
from gnn import GNN
10
11
12
13
14
15
16
17
18
19
20
from ogb.lsc import PCQM4MDataset, PCQM4MEvaluator
from ogb.utils import smiles2graph
from torch.utils.data import DataLoader
from tqdm import tqdm


def collate_dgl(graphs):
    batched_graph = dgl.batch(graphs)

    return batched_graph

21

22
23
24
25
26
27
def test(model, device, loader):
    model.eval()
    y_pred = []

    for step, bg in enumerate(tqdm(loader, desc="Iteration")):
        bg = bg.to(device)
28
29
        x = bg.ndata.pop("feat")
        edge_attr = bg.edata.pop("feat")
30
31

        with torch.no_grad():
32
33
34
            pred = model(bg, x, edge_attr).view(
                -1,
            )
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

        y_pred.append(pred.detach().cpu())

    y_pred = torch.cat(y_pred, dim=0)

    return y_pred


class OnTheFlyPCQMDataset(object):
    def __init__(self, smiles_list, smiles2graph=smiles2graph):
        super(OnTheFlyPCQMDataset, self).__init__()
        self.smiles_list = smiles_list
        self.smiles2graph = smiles2graph

    def __getitem__(self, idx):
50
        """Get datapoint with index"""
51
52
53
        smiles, _ = self.smiles_list[idx]
        graph = self.smiles2graph(smiles)

54
55
56
57
58
59
60
61
62
63
        dgl_graph = dgl.graph(
            (graph["edge_index"][0], graph["edge_index"][1]),
            num_nodes=graph["num_nodes"],
        )
        dgl_graph.edata["feat"] = torch.from_numpy(graph["edge_feat"]).to(
            torch.int64
        )
        dgl_graph.ndata["feat"] = torch.from_numpy(graph["node_feat"]).to(
            torch.int64
        )
64
65
66
67

        return dgl_graph

    def __len__(self):
68
        """Length of the dataset
69
70
71
72
        Returns
        -------
        int
            Length of Dataset
73
        """
74
75
76
77
78
        return len(self.smiles_list)


def main():
    # Training settings
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    parser = argparse.ArgumentParser(
        description="GNN baselines on pcqm4m with DGL"
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="random seed to use (default: 42)"
    )
    parser.add_argument(
        "--device",
        type=int,
        default=0,
        help="which gpu to use if any (default: 0)",
    )
    parser.add_argument(
        "--gnn",
        type=str,
        default="gin-virtual",
        help="GNN to use, which can be from "
        "[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)",
    )
    parser.add_argument(
        "--graph_pooling",
        type=str,
        default="sum",
        help="graph pooling strategy mean or sum (default: sum)",
    )
    parser.add_argument(
        "--drop_ratio", type=float, default=0, help="dropout ratio (default: 0)"
    )
    parser.add_argument(
        "--num_layers",
        type=int,
        default=5,
        help="number of GNN message passing layers (default: 5)",
    )
    parser.add_argument(
        "--emb_dim",
        type=int,
        default=600,
        help="dimensionality of hidden units in GNNs (default: 600)",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=256,
        help="input batch size for training (default: 256)",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=0,
        help="number of workers (default: 0)",
    )
    parser.add_argument(
        "--checkpoint_dir",
        type=str,
        default="",
        help="directory to save checkpoint",
    )
    parser.add_argument(
        "--save_test_dir",
        type=str,
        default="",
        help="directory to save test submission file",
    )
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    args = parser.parse_args()

    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        device = torch.device("cuda:" + str(args.device))
    else:
        device = torch.device("cpu")

    ### automatic data loading and splitting
    ### Read in the raw SMILES strings
159
    smiles_dataset = PCQM4MDataset(root="dataset/", only_smiles=True)
160
161
    split_idx = smiles_dataset.get_idx_split()

162
    test_smiles_dataset = [smiles_dataset[i] for i in split_idx["test"]]
163
    onthefly_dataset = OnTheFlyPCQMDataset(test_smiles_dataset)
164
165
166
167
168
169
170
    test_loader = DataLoader(
        onthefly_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        collate_fn=collate_dgl,
    )
171
172
173
174
175

    ### automatic evaluator.
    evaluator = PCQM4MEvaluator()

    shared_params = {
176
177
178
179
        "num_layers": args.num_layers,
        "emb_dim": args.emb_dim,
        "drop_ratio": args.drop_ratio,
        "graph_pooling": args.graph_pooling,
180
181
    }

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    if args.gnn == "gin":
        model = GNN(gnn_type="gin", virtual_node=False, **shared_params).to(
            device
        )
    elif args.gnn == "gin-virtual":
        model = GNN(gnn_type="gin", virtual_node=True, **shared_params).to(
            device
        )
    elif args.gnn == "gcn":
        model = GNN(gnn_type="gcn", virtual_node=False, **shared_params).to(
            device
        )
    elif args.gnn == "gcn-virtual":
        model = GNN(gnn_type="gcn", virtual_node=True, **shared_params).to(
            device
        )
198
    else:
199
        raise ValueError("Invalid GNN type")
200
201

    num_params = sum(p.numel() for p in model.parameters())
202
    print(f"#Params: {num_params}")
203

204
    checkpoint_path = os.path.join(args.checkpoint_dir, "checkpoint.pt")
205
    if not os.path.exists(checkpoint_path):
206
        raise RuntimeError(f"Checkpoint file not found at {checkpoint_path}")
207
208
209

    ## reading in checkpoint
    checkpoint = torch.load(checkpoint_path)
210
    model.load_state_dict(checkpoint["model_state_dict"])
211

212
    print("Predicting on test data...")
213
    y_pred = test(model, device, test_loader)
214
215
    print("Saving test submission file...")
    evaluator.save_test_submission({"y_pred": y_pred}, args.save_test_dir)
216
217
218
219


if __name__ == "__main__":
    main()