Unverified Commit a655123b authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Example] Baselines for OGB-LSC-PCQM4M (#2778)

* Update

* Add files via upload

* Add files via upload

* Add files via upload
parent 0576a33b
# Baseline Code for PCQM4M-LSC
The code is ported from the official examples [here](https://github.com/snap-stanford/ogb/tree/master/examples/lsc/pcqm4m). Please refer to the [OGB-LSC paper](https://arxiv.org/abs/2103.09430) for the detailed setting.
## Installation Requirements
```
ogb>=1.3.0
rdkit>=2019.03.1
torch>=1.7.0
```
We recommend installing RDKit with `conda install -c rdkit rdkit==2019.03.1`.
## Commandline Arguments
- `LOG_DIR`: Tensorboard log directory.
- `CHECKPOINT_DIR`: Directory to save the best validation checkpoint. The checkpoint file will be saved at `${CHECKPOINT_DIR}/checkpoint.pt`.
- `TEST_DIR`: Directory path to save the test submission. The test file will be saved at `${TEST_DIR}/y_pred_pcqm4m.npz`.
## Baseline Models
### GIN [1]
```
python main.py --gnn gin --log_dir $LOG_DIR --checkpoint_dir $CHECKPOINT_DIR --save_test_dir $TEST_DIR
```
### GIN-virtual [1,3]
```
python main.py --gnn gin-virtual --log_dir $LOG_DIR --checkpoint_dir $CHECKPOINT_DIR --save_test_dir $TEST_DIR
```
### GCN [2]
```
python main.py --gnn gcn --log_dir $LOG_DIR --checkpoint_dir $CHECKPOINT_DIR --save_test_dir $TEST_DIR
```
### GCN-virtual [2,3]
```
python main.py --gnn gcn-virtual --log_dir $LOG_DIR --checkpoint_dir $CHECKPOINT_DIR --save_test_dir $TEST_DIR
```
## Measuring the Test Inference Time
The code below takes **the raw SMILES strings as input**, uses the saved checkpoint, and performs inference over for all the 377,423 test molecules.
```
python test_inference.py --gnn $GNN --checkpoint_dir $CHECKPOINT_DIR --save_test_dir $TEST_DIR
```
For your model, **the total inference time needs to be less than 12 hours on a single GPU and a CPU**. Ideally, you
should use the CPU/GPU spec of the organizers, which consists of a single GeForce RTX 2080 GPU and an Intel(R) Xeon(R)
Gold 6148 CPU @ 2.40GHz. However, the organizers also allow the use of other GPU/CPU specs, as long as the specs are
clearly reported in the final submission.
## Performance
| Model | Original Valid MAE | DGL Valid MAE | #Parameters |
| ----------- | ------------------ | ------------- | ----------- |
| GIN | 0.1536 | 0.1536 | 3.8M |
| GIN-virtual | 0.1396 | 0.1407 | 6.7M |
| GCN | 0.1684 | 0.1683 | 2.0M |
| GCN-virtual | 0.1510 | 0.1557 | 4.9M |
## References
[1] Xu, K., Hu, W., Leskovec, J., & Jegelka, S. (2019). How powerful are graph neural networks?. ICLR 2019
[2] Kipf, T. N., & Welling, M. (2017). Semi-supervised classification with graph convolutional networks. ICLR 2017
[3] Gilmer, J., Schoenholz, S. S., Riley, P. F., Vinyals, O., & Dahl, G. E. Neural message passing for quantum chemistry. ICML 2017.
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import SumPooling
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
### GIN convolution along the graph structure
class GINConv(nn.Module):
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''
super(GINConv, self).__init__()
self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim),
nn.BatchNorm1d(emb_dim),
nn.ReLU(),
nn.Linear(emb_dim, emb_dim))
self.eps = nn.Parameter(torch.Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim = emb_dim)
def forward(self, g, x, edge_attr):
with g.local_scope():
edge_embedding = self.bond_encoder(edge_attr)
g.ndata['x'] = x
g.apply_edges(fn.copy_u('x', 'm'))
g.edata['m'] = F.relu(g.edata['m'] + edge_embedding)
g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'new_x'))
out = self.mlp((1 + self.eps) * x + g.ndata['new_x'])
return out
### GCN convolution along the graph structure
class GCNConv(nn.Module):
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''
super(GCNConv, self).__init__()
self.linear = nn.Linear(emb_dim, emb_dim)
self.root_emb = nn.Embedding(1, emb_dim)
self.bond_encoder = BondEncoder(emb_dim = emb_dim)
def forward(self, g, x, edge_attr):
with g.local_scope():
x = self.linear(x)
edge_embedding = self.bond_encoder(edge_attr)
# Molecular graphs are undirected
# g.out_degrees() is the same as g.in_degrees()
degs = (g.out_degrees().float() + 1).to(x.device)
norm = torch.pow(degs, -0.5).unsqueeze(-1) # (N, 1)
g.ndata['norm'] = norm
g.apply_edges(fn.u_mul_v('norm', 'norm', 'norm'))
g.ndata['x'] = x
g.apply_edges(fn.copy_u('x', 'm'))
g.edata['m'] = g.edata['norm'] * F.relu(g.edata['m'] + edge_embedding)
g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'new_x'))
out = g.ndata['new_x'] + F.relu(x + self.root_emb.weight) * 1. / degs.view(-1, 1)
return out
### GNN to generate node embedding
class GNN_node(nn.Module):
"""
Output:
node representations
"""
def __init__(self, num_layers, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
'''
num_layers (int): number of GNN message passing layers
emb_dim (int): node embedding dimensionality
'''
super(GNN_node, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
### add residual connection or not
self.residual = residual
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.atom_encoder = AtomEncoder(emb_dim)
###List of GNNs
self.convs = nn.ModuleList()
self.batch_norms = nn.ModuleList()
for layer in range(num_layers):
if gnn_type == 'gin':
self.convs.append(GINConv(emb_dim))
elif gnn_type == 'gcn':
self.convs.append(GCNConv(emb_dim))
else:
ValueError('Undefined GNN type called {}'.format(gnn_type))
self.batch_norms.append(nn.BatchNorm1d(emb_dim))
def forward(self, g, x, edge_attr):
### computing input node embedding
h_list = [self.atom_encoder(x)]
for layer in range(self.num_layers):
h = self.convs[layer](g, h_list[layer], edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layers - 1:
#remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training = self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
if self.residual:
h += h_list[layer]
h_list.append(h)
### Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layers):
node_representation += h_list[layer]
return node_representation
### Virtual GNN to generate node embedding
class GNN_node_Virtualnode(nn.Module):
"""
Output:
node representations
"""
def __init__(self, num_layers, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
'''
num_layers (int): number of GNN message passing layers
emb_dim (int): node embedding dimensionality
'''
super(GNN_node_Virtualnode, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
### add residual connection or not
self.residual = residual
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.atom_encoder = AtomEncoder(emb_dim)
### set the initial virtual node embedding to 0.
self.virtualnode_embedding = nn.Embedding(1, emb_dim)
nn.init.constant_(self.virtualnode_embedding.weight.data, 0)
### List of GNNs
self.convs = nn.ModuleList()
### batch norms applied to node embeddings
self.batch_norms = nn.ModuleList()
### List of MLPs to transform virtual node at every layer
self.mlp_virtualnode_list = nn.ModuleList()
for layer in range(num_layers):
if gnn_type == 'gin':
self.convs.append(GINConv(emb_dim))
elif gnn_type == 'gcn':
self.convs.append(GCNConv(emb_dim))
else:
ValueError('Undefined GNN type called {}'.format(gnn_type))
self.batch_norms.append(nn.BatchNorm1d(emb_dim))
for layer in range(num_layers - 1):
self.mlp_virtualnode_list.append(nn.Sequential(nn.Linear(emb_dim, emb_dim),
nn.BatchNorm1d(emb_dim),
nn.ReLU(),
nn.Linear(emb_dim, emb_dim),
nn.BatchNorm1d(emb_dim),
nn.ReLU()))
self.pool = SumPooling()
def forward(self, g, x, edge_attr):
### virtual node embeddings for graphs
virtualnode_embedding = self.virtualnode_embedding(
torch.zeros(g.batch_size).to(x.dtype).to(x.device))
h_list = [self.atom_encoder(x)]
batch_id = dgl.broadcast_nodes(g, torch.arange(g.batch_size).to(x.device))
for layer in range(self.num_layers):
### add message from virtual nodes to graph nodes
h_list[layer] = h_list[layer] + virtualnode_embedding[batch_id]
### Message passing among graph nodes
h = self.convs[layer](g, h_list[layer], edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layers - 1:
#remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training = self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
if self.residual:
h = h + h_list[layer]
h_list.append(h)
### update the virtual nodes
if layer < self.num_layers - 1:
### add message from graph nodes to virtual nodes
virtualnode_embedding_temp = self.pool(g, h_list[layer]) + virtualnode_embedding
### transform virtual nodes using MLP
virtualnode_embedding_temp = self.mlp_virtualnode_list[layer](
virtualnode_embedding_temp)
if self.residual:
virtualnode_embedding = virtualnode_embedding + F.dropout(
virtualnode_embedding_temp, self.drop_ratio, training = self.training)
else:
virtualnode_embedding = F.dropout(
virtualnode_embedding_temp, self.drop_ratio, training = self.training)
### Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layers):
node_representation += h_list[layer]
return node_representation
import torch
import torch.nn as nn
from dgl.nn.pytorch import SumPooling, AvgPooling, MaxPooling, GlobalAttentionPooling, Set2Set
from conv import GNN_node, GNN_node_Virtualnode
class GNN(nn.Module):
def __init__(self, num_tasks = 1, num_layers = 5, emb_dim = 300, gnn_type = 'gin',
virtual_node = True, residual = False, drop_ratio = 0, JK = "last",
graph_pooling = "sum"):
'''
num_tasks (int): number of labels to be predicted
virtual_node (bool): whether to add virtual node or not
'''
super(GNN, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
self.num_tasks = num_tasks
self.graph_pooling = graph_pooling
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
### GNN to generate node embeddings
if virtual_node:
self.gnn_node = GNN_node_Virtualnode(num_layers, emb_dim, JK = JK,
drop_ratio = drop_ratio,
residual = residual,
gnn_type = gnn_type)
else:
self.gnn_node = GNN_node(num_layers, emb_dim, JK = JK, drop_ratio = drop_ratio,
residual = residual, gnn_type = gnn_type)
### Pooling function to generate whole-graph embeddings
if self.graph_pooling == "sum":
self.pool = SumPooling()
elif self.graph_pooling == "mean":
self.pool = AvgPooling()
elif self.graph_pooling == "max":
self.pool = MaxPooling
elif self.graph_pooling == "attention":
self.pool = GlobalAttentionPooling(
gate_nn = nn.Sequential(nn.Linear(emb_dim, 2*emb_dim),
nn.BatchNorm1d(2*emb_dim),
nn.ReLU(),
nn.Linear(2*emb_dim, 1)))
elif self.graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, n_iters = 2, n_layers = 2)
else:
raise ValueError("Invalid graph pooling type.")
if graph_pooling == "set2set":
self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
def forward(self, g, x, edge_attr):
h_node = self.gnn_node(g, x, edge_attr)
h_graph = self.pool(g, h_node)
output = self.graph_pred_linear(h_graph)
if self.training:
return output
else:
return torch.clamp(output, min=0, max=50)
import argparse
import dgl
import numpy as np
import os
import random
import torch
import torch.optim as optim
from ogb.lsc import DglPCQM4MDataset, PCQM4MEvaluator
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from gnn import GNN
reg_criterion = torch.nn.L1Loss()
def collate_dgl(samples):
graphs, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
labels = torch.stack(labels)
return batched_graph, labels
def train(model, device, loader, optimizer):
model.train()
loss_accum = 0
for step, (bg, labels) in enumerate(tqdm(loader, desc="Iteration")):
bg = bg.to(device)
x = bg.ndata.pop('feat')
edge_attr = bg.edata.pop('feat')
labels = labels.to(device)
pred = model(bg, x, edge_attr).view(-1,)
optimizer.zero_grad()
loss = reg_criterion(pred, labels)
loss.backward()
optimizer.step()
loss_accum += loss.detach().cpu().item()
return loss_accum / (step + 1)
def eval(model, device, loader, evaluator):
model.eval()
y_true = []
y_pred = []
for step, (bg, labels) in enumerate(tqdm(loader, desc="Iteration")):
bg = bg.to(device)
x = bg.ndata.pop('feat')
edge_attr = bg.edata.pop('feat')
labels = labels.to(device)
with torch.no_grad():
pred = model(bg, x, edge_attr).view(-1, )
y_true.append(labels.view(pred.shape).detach().cpu())
y_pred.append(pred.detach().cpu())
y_true = torch.cat(y_true, dim=0)
y_pred = torch.cat(y_pred, dim=0)
input_dict = {"y_true": y_true, "y_pred": y_pred}
return evaluator.eval(input_dict)["mae"]
def test(model, device, loader):
model.eval()
y_pred = []
for step, (bg, _) in enumerate(tqdm(loader, desc="Iteration")):
bg = bg.to(device)
x = bg.ndata.pop('feat')
edge_attr = bg.edata.pop('feat')
with torch.no_grad():
pred = model(bg, x, edge_attr).view(-1, )
y_pred.append(pred.detach().cpu())
y_pred = torch.cat(y_pred, dim=0)
return y_pred
def main():
# Training settings
parser = argparse.ArgumentParser(description='GNN baselines on pcqm4m with DGL')
parser.add_argument('--seed', type=int, default=42,
help='random seed to use (default: 42)')
parser.add_argument('--device', type=int, default=0,
help='which gpu to use if any (default: 0)')
parser.add_argument('--gnn', type=str, default='gin-virtual',
help='GNN to use, which can be from '
'[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)')
parser.add_argument('--graph_pooling', type=str, default='sum',
help='graph pooling strategy mean or sum (default: sum)')
parser.add_argument('--drop_ratio', type=float, default=0,
help='dropout ratio (default: 0)')
parser.add_argument('--num_layers', type=int, default=5,
help='number of GNN message passing layers (default: 5)')
parser.add_argument('--emb_dim', type=int, default=600,
help='dimensionality of hidden units in GNNs (default: 600)')
parser.add_argument('--train_subset', action='store_true',
help='use 10% of the training set for training')
parser.add_argument('--batch_size', type=int, default=256,
help='input batch size for training (default: 256)')
parser.add_argument('--epochs', type=int, default=100,
help='number of epochs to train (default: 100)')
parser.add_argument('--num_workers', type=int, default=4,
help='number of workers (default: 4)')
parser.add_argument('--log_dir', type=str, default="",
help='tensorboard log directory. If not specified, '
'tensorboard will not be used.')
parser.add_argument('--checkpoint_dir', type=str, default='',
help='directory to save checkpoint')
parser.add_argument('--save_test_dir', type=str, default='',
help='directory to save test submission file')
args = parser.parse_args()
print(args)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(args.seed)
device = torch.device("cuda:" + str(args.device))
else:
device = torch.device("cpu")
### automatic dataloading and splitting
dataset = DglPCQM4MDataset(root='dataset/')
# split_idx['train'], split_idx['valid'], split_idx['test']
# separately gives a 1D int64 tensor
split_idx = dataset.get_idx_split()
### automatic evaluator.
evaluator = PCQM4MEvaluator()
if args.train_subset:
subset_ratio = 0.1
subset_idx = torch.randperm(len(split_idx["train"]))[:int(subset_ratio * len(split_idx["train"]))]
train_loader = DataLoader(dataset[split_idx["train"][subset_idx]], batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, collate_fn=collate_dgl)
else:
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, collate_fn=collate_dgl)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, collate_fn=collate_dgl)
if args.save_test_dir is not '':
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, collate_fn=collate_dgl)
if args.checkpoint_dir is not '':
os.makedirs(args.checkpoint_dir, exist_ok=True)
shared_params = {
'num_layers': args.num_layers,
'emb_dim': args.emb_dim,
'drop_ratio': args.drop_ratio,
'graph_pooling': args.graph_pooling
}
if args.gnn == 'gin':
model = GNN(gnn_type='gin', virtual_node=False, **shared_params).to(device)
elif args.gnn == 'gin-virtual':
model = GNN(gnn_type='gin', virtual_node=True, **shared_params).to(device)
elif args.gnn == 'gcn':
model = GNN(gnn_type='gcn', virtual_node=False, **shared_params).to(device)
elif args.gnn == 'gcn-virtual':
model = GNN(gnn_type='gcn', virtual_node=True, **shared_params).to(device)
else:
raise ValueError('Invalid GNN type')
num_params = sum(p.numel() for p in model.parameters())
print(f'#Params: {num_params}')
optimizer = optim.Adam(model.parameters(), lr=0.001)
if args.log_dir is not '':
writer = SummaryWriter(log_dir=args.log_dir)
best_valid_mae = 1000
if args.train_subset:
scheduler = StepLR(optimizer, step_size=300, gamma=0.25)
args.epochs = 1000
else:
scheduler = StepLR(optimizer, step_size=30, gamma=0.25)
for epoch in range(1, args.epochs + 1):
print("=====Epoch {}".format(epoch))
print('Training...')
train_mae = train(model, device, train_loader, optimizer)
print('Evaluating...')
valid_mae = eval(model, device, valid_loader, evaluator)
print({'Train': train_mae, 'Validation': valid_mae})
if args.log_dir is not '':
writer.add_scalar('valid/mae', valid_mae, epoch)
writer.add_scalar('train/mae', train_mae, epoch)
if valid_mae < best_valid_mae:
best_valid_mae = valid_mae
if args.checkpoint_dir is not '':
print('Saving checkpoint...')
checkpoint = {'epoch': epoch, 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(), 'best_val_mae': best_valid_mae,
'num_params': num_params}
torch.save(checkpoint, os.path.join(args.checkpoint_dir, 'checkpoint.pt'))
if args.save_test_dir is not '':
print('Predicting on test data...')
y_pred = test(model, device, test_loader)
print('Saving test submission file...')
evaluator.save_test_submission({'y_pred': y_pred}, args.save_test_dir)
scheduler.step()
print(f'Best validation MAE so far: {best_valid_mae}')
if args.log_dir is not '':
writer.close()
if __name__ == "__main__":
main()
import argparse
import dgl
import numpy as np
import os
import random
import torch
from ogb.lsc import PCQM4MDataset, PCQM4MEvaluator
from ogb.utils import smiles2graph
from torch.utils.data import DataLoader
from tqdm import tqdm
from gnn import GNN
def collate_dgl(graphs):
batched_graph = dgl.batch(graphs)
return batched_graph
def test(model, device, loader):
model.eval()
y_pred = []
for step, bg in enumerate(tqdm(loader, desc="Iteration")):
bg = bg.to(device)
x = bg.ndata.pop('feat')
edge_attr = bg.edata.pop('feat')
with torch.no_grad():
pred = model(bg, x, edge_attr).view(-1, )
y_pred.append(pred.detach().cpu())
y_pred = torch.cat(y_pred, dim=0)
return y_pred
class OnTheFlyPCQMDataset(object):
def __init__(self, smiles_list, smiles2graph=smiles2graph):
super(OnTheFlyPCQMDataset, self).__init__()
self.smiles_list = smiles_list
self.smiles2graph = smiles2graph
def __getitem__(self, idx):
'''Get datapoint with index'''
smiles, _ = self.smiles_list[idx]
graph = self.smiles2graph(smiles)
dgl_graph = dgl.graph((graph['edge_index'][0], graph['edge_index'][1]),
num_nodes=graph['num_nodes'])
dgl_graph.edata['feat'] = torch.from_numpy(graph['edge_feat']).to(torch.int64)
dgl_graph.ndata['feat'] = torch.from_numpy(graph['node_feat']).to(torch.int64)
return dgl_graph
def __len__(self):
'''Length of the dataset
Returns
-------
int
Length of Dataset
'''
return len(self.smiles_list)
def main():
# Training settings
parser = argparse.ArgumentParser(description='GNN baselines on pcqm4m with DGL')
parser.add_argument('--seed', type=int, default=42,
help='random seed to use (default: 42)')
parser.add_argument('--device', type=int, default=0,
help='which gpu to use if any (default: 0)')
parser.add_argument('--gnn', type=str, default='gin-virtual',
help='GNN to use, which can be from '
'[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)')
parser.add_argument('--graph_pooling', type=str, default='sum',
help='graph pooling strategy mean or sum (default: sum)')
parser.add_argument('--drop_ratio', type=float, default=0,
help='dropout ratio (default: 0)')
parser.add_argument('--num_layers', type=int, default=5,
help='number of GNN message passing layers (default: 5)')
parser.add_argument('--emb_dim', type=int, default=600,
help='dimensionality of hidden units in GNNs (default: 600)')
parser.add_argument('--batch_size', type=int, default=256,
help='input batch size for training (default: 256)')
parser.add_argument('--num_workers', type=int, default=0,
help='number of workers (default: 0)')
parser.add_argument('--checkpoint_dir', type=str, default='',
help='directory to save checkpoint')
parser.add_argument('--save_test_dir', type=str, default='',
help='directory to save test submission file')
args = parser.parse_args()
print(args)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(args.seed)
device = torch.device("cuda:" + str(args.device))
else:
device = torch.device("cpu")
### automatic data loading and splitting
### Read in the raw SMILES strings
smiles_dataset = PCQM4MDataset(root='dataset/', only_smiles=True)
split_idx = smiles_dataset.get_idx_split()
test_smiles_dataset = [smiles_dataset[i] for i in split_idx['test']]
onthefly_dataset = OnTheFlyPCQMDataset(test_smiles_dataset)
test_loader = DataLoader(onthefly_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, collate_fn=collate_dgl)
### automatic evaluator.
evaluator = PCQM4MEvaluator()
shared_params = {
'num_layers': args.num_layers,
'emb_dim': args.emb_dim,
'drop_ratio': args.drop_ratio,
'graph_pooling': args.graph_pooling
}
if args.gnn == 'gin':
model = GNN(gnn_type='gin', virtual_node=False, **shared_params).to(device)
elif args.gnn == 'gin-virtual':
model = GNN(gnn_type='gin', virtual_node=True, **shared_params).to(device)
elif args.gnn == 'gcn':
model = GNN(gnn_type='gcn', virtual_node=False, **shared_params).to(device)
elif args.gnn == 'gcn-virtual':
model = GNN(gnn_type='gcn', virtual_node=True, **shared_params).to(device)
else:
raise ValueError('Invalid GNN type')
num_params = sum(p.numel() for p in model.parameters())
print(f'#Params: {num_params}')
checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint.pt')
if not os.path.exists(checkpoint_path):
raise RuntimeError(f'Checkpoint file not found at {checkpoint_path}')
## reading in checkpoint
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
print('Predicting on test data...')
y_pred = test(model, device, test_loader)
print('Saving test submission file...')
evaluator.save_test_submission({'y_pred': y_pred}, args.save_test_dir)
if __name__ == "__main__":
main()
# Baselines for OGB Large-Scale Challenge (LSC) at KDD Cup 2021
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment