Unverified Commit 523fa4df authored by zhengdao-chen's avatar zhengdao-chen Committed by GitHub
Browse files

Adding MWE-GCN for ogb/ogbn-proteins (#1803)



* 07132020_cleaned

* Update README.md

* Create README.md

* Update README.md

* Update README.md

* further cleaning

* further cleaning

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-170.us-west-2.compute.internal>
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 5c92f6c2
...@@ -6,3 +6,5 @@ Currently it contains: ...@@ -6,3 +6,5 @@ Currently it contains:
* OGB-Products * OGB-Products
* GraphSAGE with Neighbor Sampling * GraphSAGE with Neighbor Sampling
* OGB-Proteins
* MWE-GCN and MWE-DGCN ([GCN models for graphs with multi-dimensionally weighted edges](https://cims.nyu.edu/~chenzh/files/GCN_with_edge_weights.pdf))
# DGL for ogbn-proteins
## Models
[MWE-GCN and MWE-DGCN](https://cims.nyu.edu/~chenzh/files/GCN_with_edge_weights.pdf) are GCN models designed for graphs whose edges contain multi-dimensional edge weights that indicate the strengths of the relations represented by the edges.
## Dependencies
- DGL 0.4.3
- PyTorch 1.4.0
- OGB 1.2.0
- Tensorboard 2.1.1
## Usage
To use MWE-GCN:
```python
python main_proteins_full_dgl.py --model MWE-GCN
```
To use MWE-DGCN:
```python
python main_proteins_full_dgl.py --model MWE-DGCN
```
Additional optional arguments include 'rand_seed' (the random seed), 'cuda' (the cuda device number, if available), 'postfix' (a string appended to the saved-model file)
"""Best hyperparameters found."""
import torch
MWE_GCN_proteins = {
'num_ew_channels': 8,
'num_epochs': 2000,
'in_feats': 1,
'hidden_feats': 10,
'out_feats': 112,
'n_layers': 3,
'lr': 2e-2,
'weight_decay': 0,
'patience': 1000,
'dropout': 0.2,
'aggr_mode': 'sum', ## 'sum' or 'concat' for the aggregation across channels
'ewnorm': 'both'
}
MWE_DGCN_proteins = {
'num_ew_channels': 8,
'num_epochs': 2000,
'in_feats': 1,
'hidden_feats': 10,
'out_feats': 112,
'n_layers': 2,
'lr': 1e-2,
'weight_decay': 0,
'patience': 300,
'dropout': 0.5,
'aggr_mode': 'sum',
'residual': True,
'ewnorm': 'none'
}
def get_exp_configure(args):
if (args['model'] == 'MWE-GCN'):
return MWE_GCN_proteins
elif (args['model'] == 'MWE-DGCN'):
return MWE_DGCN_proteins
import os
import numpy as np
import time
import torch
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F
from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset
from ogb.nodeproppred import Evaluator
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import load_model, set_random_seed
def normalize_edge_weights(graph, device, num_ew_channels):
degs = graph.in_degrees().float()
degs = torch.clamp(degs, min=1)
norm = torch.pow(degs, 0.5)
norm = norm.to(args['device'])
graph.ndata['norm'] = norm.unsqueeze(1)
graph.apply_edges(fn.e_div_u('feat', 'norm', 'feat'))
graph.apply_edges(fn.e_div_v('feat', 'norm', 'feat'))
for channel in range(num_ew_channels):
graph.edata['feat_' + str(channel)] = graph.edata['feat'][:, channel:channel+1]
def run_a_train_epoch(graph, node_idx, model, criterion, optimizer, evaluator):
model.train()
logits = model(graph)[node_idx]
labels = graph.ndata['labels'][node_idx]
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss = loss.data.item()
labels = labels.cpu().numpy()
preds = logits.cpu().detach().numpy()
return loss, evaluator.eval({"y_true": labels, "y_pred": preds})['rocauc']
def run_an_eval_epoch(graph, splitted_idx, model, evaluator):
model.eval()
with torch.no_grad():
logits = model(graph)
labels = graph.ndata['labels'].cpu().numpy()
preds = logits.cpu().detach().numpy()
train_score = evaluator.eval({
"y_true": labels[splitted_idx["train"]],
"y_pred": preds[splitted_idx["train"]]
})
val_score = evaluator.eval({
"y_true": labels[splitted_idx["valid"]],
"y_pred": preds[splitted_idx["valid"]]
})
test_score = evaluator.eval({
"y_true": labels[splitted_idx["test"]],
"y_pred": preds[splitted_idx["test"]]
})
return train_score['rocauc'], val_score['rocauc'], test_score['rocauc']
def main(args):
print (args)
if (args['rand_seed'] > -1):
set_random_seed(args['rand_seed'])
dataset = DglNodePropPredDataset(name=args['dataset'])
print(dataset.meta_info[args['dataset']])
splitted_idx = dataset.get_idx_split()
graph = dataset.graph[0]
graph.ndata['labels'] = dataset.labels.float().to(args['device'])
graph.edata['feat'] = graph.edata['feat'].float().to(args['device'])
if (args['ewnorm'] == 'both'):
print ('Symmetric normalization of edge weights by degree')
normalize_edge_weights(graph, args['device'], args['num_ew_channels'])
elif (args['ewnorm'] == 'none'):
print ('Not normalizing edge weights')
for channel in range(args['num_ew_channels']):
graph.edata['feat_' + str(channel)] = graph.edata['feat'][:, channel:channel+1]
model = load_model(args).to(args['device'])
optimizer = Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
min_lr = 1e-3
scheduler = ReduceLROnPlateau(optimizer, 'max', factor=0.7, patience=100, verbose=True, min_lr=min_lr)
print ('scheduler min_lr', min_lr)
criterion = nn.BCEWithLogitsLoss()
evaluator = Evaluator(args['dataset'])
print ('model', args['model'])
print ('n_layers', args['n_layers'])
print ('hidden dim', args['hidden_feats'])
print ('lr', args['lr'])
dur = []
best_val_score = 0.
num_patient_epochs = 0
model_folder = './saved_models/'
model_path = model_folder + str(args['exp_name']) + '_' + str(args['postfix'])
if not os.path.exists(model_folder):
os.makedirs(model_folder)
for epoch in range(1, args['num_epochs'] + 1):
if epoch >= 3:
t0 = time.time()
loss, train_score = run_a_train_epoch(graph, splitted_idx["train"], model,
criterion, optimizer, evaluator)
if epoch >= 3:
dur.append(time.time() - t0)
avg_time = np.mean(dur)
else:
avg_time = None
train_score, val_score, test_score = run_an_eval_epoch(graph, splitted_idx,
model, evaluator)
scheduler.step(val_score)
# Early stop
if val_score > best_val_score:
torch.save(model.state_dict(), model_path)
best_val_score = val_score
num_patient_epochs = 0
else:
num_patient_epochs += 1
print('Epoch {:d}, loss {:.4f}, train score {:.4f}, '
'val score {:.4f}, avg time {}, num patient epochs {:d}'.format(
epoch, loss, train_score, val_score, avg_time, num_patient_epochs))
if num_patient_epochs == args['patience']:
break
model.load_state_dict(torch.load(model_path))
train_score, val_score, test_score = run_an_eval_epoch(graph, splitted_idx, model, evaluator)
print('Train score {:.4f}'.format(train_score))
print('Valid score {:.4f}'.format(val_score))
print('Test score {:.4f}'.format(test_score))
with open('results.txt', 'w') as f:
f.write('loss {:.4f}\n'.format(loss))
f.write('Best validation rocauc {:.4f}\n'.format(best_val_score))
f.write('Test rocauc {:.4f}\n'.format(test_score))
print (args)
if __name__ == '__main__':
import argparse
from configure import get_exp_configure
parser = argparse.ArgumentParser(
description='OGB node property prediction with DGL using full graph training')
parser.add_argument('-m', '--model', type=str, choices=['MWE-GCN', 'MWE-DGCN'], default='MWE-DGCN',
help='Model to use')
parser.add_argument('-c', '--cuda', type=str, default='none')
parser.add_argument('--postfix', type=str, default='', help='a string appended to the file name of the saved model')
parser.add_argument('--rand_seed', type=int, default=-1, help='random seed for torch and numpy')
parser.add_argument('--residual', action='store_true')
parser.add_argument('--ewnorm', type=str, default='none', choices=['none', 'both'])
args = parser.parse_args().__dict__
# Get experiment configuration
args['dataset'] = 'ogbn-proteins'
args['exp_name'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args))
if not (args['cuda'] == 'none'):
args['device'] = torch.device('cuda: ' + str(args['cuda']))
else:
args['device'] = torch.device('cpu')
main(args)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl import DGLGraph
class MWEConv(nn.Module):
def __init__(self,
in_feats,
out_feats,
activation,
bias=True,
num_channels=8,
aggr_mode='sum'):
super(MWEConv, self).__init__()
self.num_channels = num_channels
self._in_feats = in_feats
self._out_feats = out_feats
self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats, num_channels))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_feats, num_channels))
else:
self.bias = None
self.reset_parameters()
self.activation = activation
if (aggr_mode == 'concat'):
self.aggr_mode = 'concat'
self.final = nn.Linear(out_feats * self.num_channels, out_feats)
elif (aggr_mode == 'sum'):
self.aggr_mode = 'sum'
self.final = nn.Linear(out_feats, out_feats)
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
stdv = 1. / math.sqrt(self.bias.size(0))
self.bias.data.uniform_(-stdv, stdv)
def forward(self, g, node_state_prev):
node_state = node_state_prev
# if self.dropout:
# node_states = self.dropout(node_state)
g = g.local_var()
new_node_states = []
## perform weighted convolution for every channel of edge weight
for c in range(self.num_channels):
node_state_c = node_state
if self._out_feats < self._in_feats:
g.ndata['feat_' + str(c)] = torch.mm(node_state_c, self.weight[:, :, c])
else:
g.ndata['feat_' + str(c)] = node_state_c
g.update_all(fn.src_mul_edge('feat_' + str(c), 'feat_' + str(c), 'm'), fn.sum('m', 'feat_' + str(c) + '_new'))
node_state_c = g.ndata.pop('feat_' + str(c) + '_new')
if self._out_feats >= self._in_feats:
node_state_c = torch.mm(node_state_c, self.weight[:, :, c])
if self.bias is not None:
node_state_c = node_state_c + self.bias[:, c]
node_state_c = self.activation(node_state_c)
new_node_states.append(node_state_c)
if (self.aggr_mode == 'sum'):
node_states = torch.stack(new_node_states, dim=1).sum(1)
elif (self.aggr_mode == 'concat'):
node_states = torch.cat(new_node_states, dim=1)
node_states = self.final(node_states)
return node_states
class MWE_GCN(nn.Module):
def __init__(self,
n_input,
n_hidden,
n_output,
n_layers,
activation,
dropout,
aggr_mode='sum',
device='cpu'):
super(MWE_GCN, self).__init__()
self.dropout = dropout
self.activation = activation
self.layers = nn.ModuleList()
self.layers.append(MWEConv(n_input, n_hidden, activation=activation, \
aggr_mode=aggr_mode))
for i in range(n_layers - 1):
self.layers.append(MWEConv(n_hidden, n_hidden, activation=activation, \
aggr_mode=aggr_mode))
self.pred_out = nn.Linear(n_hidden, n_output)
self.device = device
def forward(self, g, node_state=None):
node_state = torch.ones(g.number_of_nodes(), 1).float().to(self.device)
for layer in self.layers:
node_state = F.dropout(node_state, p=self.dropout, training=self.training)
node_state = layer(g, node_state)
node_state = self.activation(node_state)
out = self.pred_out(node_state)
return out
class MWE_DGCN(nn.Module):
def __init__(self,
n_input,
n_hidden,
n_output,
n_layers,
activation,
dropout,
residual=False,
aggr_mode='sum',
device='cpu'):
super(MWE_DGCN, self).__init__()
self.n_layers = n_layers
self.activation = activation
self.dropout = dropout
self.residual = residual
self.layers = nn.ModuleList()
self.layer_norms = nn.ModuleList()
self.layers.append(MWEConv(n_input, n_hidden, activation=activation, \
aggr_mode=aggr_mode))
for i in range(n_layers - 1):
self.layers.append(MWEConv(n_hidden, n_hidden, activation=activation, \
aggr_mode=aggr_mode))
for i in range(n_layers):
self.layer_norms.append(nn.LayerNorm(n_hidden, elementwise_affine=True))
self.pred_out = nn.Linear(n_hidden, n_output)
self.device = device
def forward(self, g, node_state=None):
node_state = torch.ones(g.number_of_nodes(), 1).float().to(self.device)
node_state = self.layers[0](g, node_state)
for layer in range(1, self.n_layers):
node_state_new = self.layer_norms[layer-1](node_state)
node_state_new = self.activation(node_state_new)
node_state_new = F.dropout(node_state_new, p=self.dropout, training=self.training)
if (self.residual == 'true'):
node_state = node_state + self.layers[layer](g, node_state_new)
else:
node_state = self.layers[layer](g, node_state_new)
node_state = self.layer_norms[self.n_layers-1](node_state)
node_state = self.activation(node_state)
node_state = F.dropout(node_state, p=self.dropout, training=self.training)
out = self.pred_out(node_state)
return out
import numpy as np
import random
import torch
from models import MWE_GCN, MWE_DGCN
import torch.nn.functional as F
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
print ('random seed set to be ' + str(seed))
def load_model(args):
if args['model'] == 'MWE-GCN':
model = MWE_GCN(
n_input=args['in_feats'],
n_hidden=args['hidden_feats'],
n_output=args['out_feats'],
n_layers=args['n_layers'],
activation=torch.nn.Tanh(),
dropout=args['dropout'],
aggr_mode=args['aggr_mode'],
device=args['device'])
elif args['model'] == 'MWE-DGCN':
model = MWE_DGCN(
n_input=args['in_feats'],
n_hidden=args['hidden_feats'],
n_output=args['out_feats'],
n_layers=args['n_layers'],
activation=torch.nn.ReLU(),
dropout=args['dropout'],
aggr_mode=args['aggr_mode'],
residual=args['residual'],
device=args['device'])
else:
raise ValueError('Unexpected model {}'.format(args['model']))
return model
class Logger(object):
def __init__(self, runs, info=None):
self.info = info
self.results = [[] for _ in range(runs)]
def add_result(self, run, result):
assert len(result) == 3
assert run >= 0 and run < len(self.results)
self.results[run].append(result)
def print_statistics(self, run=None):
if run is not None:
result = 100 * torch.tensor(self.results[run])
argmax = result[:, 1].argmax().item()
print(f'Run {run + 1:02d}:')
print(f'Highest Train: {result[:, 0].max():.2f}')
print(f'Highest Valid: {result[:, 1].max():.2f}')
print(f' Final Train: {result[argmax, 0]:.2f}')
print(f' Final Test: {result[argmax, 2]:.2f}')
else:
result = 100 * torch.tensor(self.results)
best_results = []
for r in result:
train1 = r[:, 0].max().item()
valid = r[:, 1].max().item()
train2 = r[r[:, 1].argmax(), 0].item()
test = r[r[:, 1].argmax(), 2].item()
best_results.append((train1, valid, train2, test))
best_result = torch.tensor(best_results)
print(f'All runs:')
r = best_result[:, 0]
print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}')
r = best_result[:, 1]
print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}')
r = best_result[:, 2]
print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}')
r = best_result[:, 3]
print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment