Unverified Commit 43d49c1c authored by Zhiteng Li's avatar Zhiteng Li Committed by GitHub
Browse files

[Example] Graphormer for ogbg-molhiv dataset (#5915)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 96c89c0b
Graphormer
==============================
## Introduction
* Graphormer is a Transformer model designed for graph-structured data, which encodes the structural information of a graph into the standard Transformer. Specifically, Graphormer utilizes Degree Encoding to measure the importance of nodes, Spatial Encoding and Path Encoding to measure the relation between node pairs. The former plus the node features serve as input to Graphormer, while the latter acts as bias terms in the self-attention module.
* paper link: [https://arxiv.org/abs/2106.05234](https://arxiv.org/abs/2106.05234)
## Requirements
- accelerate
- transformers
- ogb
## Dataset
Task: Graph Property Prediction
| Dataset | #Graphs | #Node Feats | #Edge Feats | Metric |
| :---------: | :-----: | :---------: | :---------: | :-----: |
| ogbg-molhiv | 41,127 | 9 | 3 | ROC-AUC |
How to run
----------
```bash
accelerate launch --multi_gpu --mixed_precision=fp16 train.py
```
> **_NOTE:_** The script will automatically download weights pre-trained on PCQM4Mv2. To reproduce the same result, set the total batch size to 64.
## Summary
* ogbg-molhiv (pretrained on PCQM4Mv2): ~0.791
"""
This file contains the MolHIVDataset class, which handles data preprocessing
(computing required graph features, converting graphs to tensors) of the
ogbg-molhiv dataset.
"""
import torch as th
import torch.nn.functional as F
from dgl import shortest_dist
from ogb.graphproppred import DglGraphPropPredDataset
from torch.nn.utils.rnn import pad_sequence
class MolHIVDataset(th.utils.data.Dataset):
def __init__(self):
dataset = DglGraphPropPredDataset(name="ogbg-molhiv")
split_idx = dataset.get_idx_split()
# Compute the shortest path distances and their corresponding paths
# of all graphs during preprocessing.
for g, label in dataset:
spd, path = shortest_dist(g, root=None, return_paths=True)
g.ndata["spd"] = spd
g.ndata["path"] = path
self.train, self.val, self.test = (
dataset[split_idx["train"]],
dataset[split_idx["valid"]],
dataset[split_idx["test"]],
)
def collate(self, samples):
# To match Graphormer's input style, all graph features should be
# padded to the same size. Keep in mind that different graphs may
# have varying feature sizes since they have different number of
# nodes, so they will be aligned with the graph having the maximum
# number of nodes.
graphs, labels = map(list, zip(*samples))
labels = th.stack(labels)
num_graphs = len(graphs)
num_nodes = [g.num_nodes() for g in graphs]
max_num_nodes = max(num_nodes)
# Graphormer adds a virual node to the graph, which is connected to
# all other nodes and supposed to represent the graph embedding. So
# here +1 is for the virtual node.
attn_mask = th.zeros(num_graphs, max_num_nodes + 1, max_num_nodes + 1)
node_feat = []
in_degree, out_degree = [], []
path_data = []
# Since shortest_dist returns -1 for unreachable node pairs and padded
# nodes are unreachable to others, distance relevant to padded nodes
# use -1 padding as well.
dist = -th.ones(
(num_graphs, max_num_nodes, max_num_nodes), dtype=th.long
)
for i in range(num_graphs):
# A binary mask where invalid positions are indicated by True.
attn_mask[i, :, num_nodes[i] + 1 :] = 1
# +1 to distinguish padded non-existing nodes from real nodes
node_feat.append(graphs[i].ndata["feat"] + 1)
in_degree.append(
th.clamp(graphs[i].in_degrees() + 1, min=0, max=512)
)
out_degree.append(
th.clamp(graphs[i].out_degrees() + 1, min=0, max=512)
)
# Path padding to make all paths to the same length "max_len".
path = graphs[i].ndata["path"]
path_len = path.size(dim=2)
# shape of shortest_path: [n, n, max_len]
max_len = 5
if path_len >= max_len:
shortest_path = path[:, :, :max_len]
else:
p1d = (0, max_len - path_len)
# Use the same -1 padding as shortest_dist for
# invalid edge IDs.
shortest_path = F.pad(path, p1d, "constant", -1)
pad_num_nodes = max_num_nodes - num_nodes[i]
p3d = (0, 0, 0, pad_num_nodes, 0, pad_num_nodes)
shortest_path = F.pad(shortest_path, p3d, "constant", -1)
# +1 to distinguish padded non-existing edges from real edges
edata = graphs[i].edata["feat"] + 1
# shortest_dist pads non-existing edges (at the end of shortest
# paths) with edge IDs -1, and th.zeros(1, edata.shape[1]) stands
# for all padded edge features.
edata = th.cat(
(edata, th.zeros(1, edata.shape[1]).to(edata.device)), dim=0
)
path_data.append(edata[shortest_path])
dist[i, : num_nodes[i], : num_nodes[i]] = graphs[i].ndata["spd"]
# node feat padding
node_feat = pad_sequence(node_feat, batch_first=True)
# degree padding
in_degree = pad_sequence(in_degree, batch_first=True)
out_degree = pad_sequence(out_degree, batch_first=True)
return (
labels.reshape(num_graphs, -1),
attn_mask,
node_feat,
in_degree,
out_degree,
th.stack(path_data),
dist,
)
"""
This script finetunes and tests a Graphormer model (pretrained on PCQM4Mv2)
for graph classification on ogbg-molhiv dataset.
Paper: [Do Transformers Really Perform Bad for Graph Representation?]
(https://arxiv.org/abs/2106.05234)
This flowchart describes the main functional sequence of the provided example.
main
└───> train_val_pipeline
├───> Load and preprocess dataset
├───> Download pretrained model
├───> train_epoch
│ │
│ └───> Graphormer.forward
└───> evaluate_network
└───> Graphormer.inference
"""
import argparse
import random
import torch as th
import torch.nn as nn
from accelerate import Accelerator
from dataset import MolHIVDataset
from dgl.data import download
from dgl.dataloading import GraphDataLoader
from model import Graphormer
from ogb.graphproppred import Evaluator
from transformers.optimization import (
AdamW,
get_polynomial_decay_schedule_with_warmup,
)
# Instantiate an accelerator object to support distributed
# training and inference.
accelerator = Accelerator()
def train_epoch(model, optimizer, data_loader, lr_scheduler):
model.train()
epoch_loss = 0
list_scores = []
list_labels = []
loss_fn = nn.BCEWithLogitsLoss()
for (
batch_labels,
attn_mask,
node_feat,
in_degree,
out_degree,
path_data,
dist,
) in data_loader:
optimizer.zero_grad()
device = accelerator.device
batch_scores = model(
node_feat.to(device),
in_degree.to(device),
out_degree.to(device),
path_data.to(device),
dist.to(device),
attn_mask=attn_mask,
)
loss = loss_fn(batch_scores, batch_labels.float())
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
epoch_loss += loss.item()
list_scores.append(batch_scores)
list_labels.append(batch_labels)
# Release GPU memory.
del (
batch_labels,
batch_scores,
loss,
attn_mask,
node_feat,
in_degree,
out_degree,
path_data,
dist,
)
th.cuda.empty_cache()
epoch_loss /= len(data_loader)
evaluator = Evaluator(name="ogbg-molhiv")
epoch_auc = evaluator.eval(
{"y_pred": th.cat(list_scores), "y_true": th.cat(list_labels)}
)["rocauc"]
return epoch_loss, epoch_auc
def evaluate_network(model, data_loader):
model.eval()
epoch_loss = 0
loss_fn = nn.BCEWithLogitsLoss()
with th.no_grad():
list_scores = []
list_labels = []
for (
batch_labels,
attn_mask,
node_feat,
in_degree,
out_degree,
path_data,
dist,
) in data_loader:
device = accelerator.device
batch_scores = model(
node_feat.to(device),
in_degree.to(device),
out_degree.to(device),
path_data.to(device),
dist.to(device),
attn_mask=attn_mask,
)
# Gather all predictions and targets.
all_predictions, all_targets = accelerator.gather_for_metrics(
(batch_scores, batch_labels)
)
loss = loss_fn(all_predictions, all_targets.float())
epoch_loss += loss.item()
list_scores.append(all_predictions)
list_labels.append(all_targets)
epoch_loss /= len(data_loader)
evaluator = Evaluator(name="ogbg-molhiv")
epoch_auc = evaluator.eval(
{"y_pred": th.cat(list_scores), "y_true": th.cat(list_labels)}
)["rocauc"]
return epoch_loss, epoch_auc
def train_val_pipeline(params):
dataset = MolHIVDataset()
accelerator.print(
f"train, test, val sizes: {len(dataset.train)}, "
f"{len(dataset.test)}, {len(dataset.val)}."
)
accelerator.print("Finished loading.")
train_loader = GraphDataLoader(
dataset.train,
batch_size=params.batch_size,
shuffle=True,
collate_fn=dataset.collate,
pin_memory=True,
num_workers=16,
)
val_loader = GraphDataLoader(
dataset.val,
batch_size=params.batch_size,
shuffle=False,
collate_fn=dataset.collate,
pin_memory=True,
num_workers=16,
)
test_loader = GraphDataLoader(
dataset.test,
batch_size=params.batch_size,
shuffle=False,
collate_fn=dataset.collate,
pin_memory=True,
num_workers=16,
)
# Load pre-trained model.
download(url="https://data.dgl.ai/pre_trained/graphormer_pcqm.pth")
model = Graphormer()
state_dict = th.load("graphormer_pcqm.pth")
model.load_state_dict(state_dict)
model.reset_output_layer_parameters()
num_epochs = 16
total_updates = 33000 * num_epochs / params.batch_size
# Use warmup schedule to avoid overfitting at the very beginning
# of training, the ratio 0.16 is the same as the paper.
warmup_updates = total_updates * 0.16
optimizer = AdamW(model.parameters(), lr=1e-4, eps=1e-8, weight_decay=0)
lr_scheduler = get_polynomial_decay_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_updates,
num_training_steps=total_updates,
lr_end=1e-9,
power=1.0,
)
epoch_train_AUCs, epoch_val_AUCs, epoch_test_AUCs = [], [], []
# Pass all objects relevant to training to the prepare() method as required
# by Accelerate.
(
model,
optimizer,
train_loader,
val_loader,
test_loader,
lr_scheduler,
) = accelerator.prepare(
model, optimizer, train_loader, val_loader, test_loader, lr_scheduler
)
for epoch in range(num_epochs):
epoch_train_loss, epoch_train_auc = train_epoch(
model, optimizer, train_loader, lr_scheduler
)
epoch_val_loss, epoch_val_auc = evaluate_network(model, val_loader)
epoch_test_loss, epoch_test_auc = evaluate_network(model, test_loader)
epoch_train_AUCs.append(epoch_train_auc)
epoch_val_AUCs.append(epoch_val_auc)
epoch_test_AUCs.append(epoch_test_auc)
accelerator.print(
f"Epoch={epoch + 1} | train_AUC={epoch_train_auc:.3f} | "
f"val_AUC={epoch_val_auc:.3f} | test_AUC={epoch_test_auc:.3f}"
)
# Return test and train AUCs with best val AUC.
index = epoch_val_AUCs.index(max(epoch_val_AUCs))
val_auc = epoch_val_AUCs[index]
train_auc = epoch_train_AUCs[index]
test_auc = epoch_test_AUCs[index]
accelerator.print("Test ROCAUC: {:.4f}".format(test_auc))
accelerator.print("Val ROCAUC: {:.4f}".format(val_auc))
accelerator.print("Train ROCAUC: {:.4f}".format(train_auc))
accelerator.print("Best epoch index: {:.4f}".format(index))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--seed",
default=1,
type=int,
help="Please give a value for random seed",
)
parser.add_argument(
"--batch_size",
default=16,
type=int,
help="Please give a value for batch_size",
)
args = parser.parse_args()
# Set manual seed to bind the order of training data to the random seed.
random.seed(args.seed)
th.manual_seed(args.seed)
if th.cuda.is_available():
th.cuda.manual_seed(args.seed)
train_val_pipeline(args)
"""
This file defines the Graphormer model, which utilizes DegreeEncoder,
SpatialEncoder, PathEncoder and GraphormerLayer from DGL build-in modules.
"""
import torch as th
import torch.nn as nn
from dgl.nn import DegreeEncoder, GraphormerLayer, PathEncoder, SpatialEncoder
class Graphormer(nn.Module):
def __init__(
self,
num_classes=1,
edge_dim=3,
num_atoms=4608,
max_degree=512,
num_spatial=511,
multi_hop_max_dist=5,
num_encoder_layers=12,
embedding_dim=768,
ffn_embedding_dim=768,
num_attention_heads=32,
dropout=0.1,
pre_layernorm=True,
activation_fn=nn.GELU(),
):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
self.embedding_dim = embedding_dim
self.num_heads = num_attention_heads
self.atom_encoder = nn.Embedding(
num_atoms + 1, embedding_dim, padding_idx=0
)
self.graph_token = nn.Embedding(1, embedding_dim)
self.degree_encoder = DegreeEncoder(
max_degree=max_degree, embedding_dim=embedding_dim
)
self.path_encoder = PathEncoder(
max_len=multi_hop_max_dist,
feat_dim=edge_dim,
num_heads=num_attention_heads,
)
self.spatial_encoder = SpatialEncoder(
max_dist=num_spatial, num_heads=num_attention_heads
)
self.graph_token_virtual_dist = nn.Embedding(1, num_attention_heads)
self.emb_layer_norm = nn.LayerNorm(self.embedding_dim)
self.layers = nn.ModuleList([])
self.layers.extend(
[
GraphormerLayer(
feat_size=self.embedding_dim,
hidden_size=ffn_embedding_dim,
num_heads=num_attention_heads,
dropout=dropout,
activation=activation_fn,
norm_first=pre_layernorm,
)
for _ in range(num_encoder_layers)
]
)
# map graph_rep to num_classes
self.lm_head_transform_weight = nn.Linear(
self.embedding_dim, self.embedding_dim
)
self.layer_norm = nn.LayerNorm(self.embedding_dim)
self.activation_fn = activation_fn
self.embed_out = nn.Linear(self.embedding_dim, num_classes, bias=False)
self.lm_output_learned_bias = nn.Parameter(th.zeros(num_classes))
def reset_output_layer_parameters(self):
self.lm_output_learned_bias = nn.Parameter(th.zeros(1))
self.embed_out.reset_parameters()
def forward(
self,
node_feat,
in_degree,
out_degree,
path_data,
dist,
attn_mask=None,
):
num_graphs, max_num_nodes, _ = node_feat.shape
deg_emb = self.degree_encoder(th.stack((in_degree, out_degree)))
# node feature + degree encoding as input
node_feat = self.atom_encoder(node_feat.int()).sum(dim=-2)
node_feat = node_feat + deg_emb
graph_token_feat = self.graph_token.weight.unsqueeze(0).repeat(
num_graphs, 1, 1
)
x = th.cat([graph_token_feat, node_feat], dim=1)
# spatial encoding and path encoding serve as attention bias
attn_bias = th.zeros(
num_graphs,
max_num_nodes + 1,
max_num_nodes + 1,
self.num_heads,
device=dist.device,
)
path_encoding = self.path_encoder(dist, path_data)
spatial_encoding = self.spatial_encoder(dist)
attn_bias[:, 1:, 1:, :] = path_encoding + spatial_encoding
# spatial encoding of the virtual node
t = self.graph_token_virtual_dist.weight.reshape(1, 1, self.num_heads)
# Since the virtual node comes first, the spatial encodings between it
# and other nodes will fill the 1st row and 1st column (omit num_graphs
# and num_heads dimensions) of attn_bias matrix by broadcasting.
attn_bias[:, 1:, 0, :] = attn_bias[:, 1:, 0, :] + t
attn_bias[:, 0, :, :] = attn_bias[:, 0, :, :] + t
x = self.emb_layer_norm(x)
for layer in self.layers:
x = layer(
x,
attn_mask=attn_mask,
attn_bias=attn_bias,
)
graph_rep = x[:, 0, :]
graph_rep = self.layer_norm(
self.activation_fn(self.lm_head_transform_weight(graph_rep))
)
graph_rep = self.embed_out(graph_rep) + self.lm_output_learned_bias
return graph_rep
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment