Unverified Commit ac282a5e authored by Ziniu Hu's avatar Ziniu Hu Committed by GitHub
Browse files

[Model][Hetero] HGT (#1778)



* add HGT example

* Update README.md

* Update model.py

* Update train_acm.py

* Add comments
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 33330fcb
# Heterogeneous Graph Transformer (HGT)
[Alternative PyTorch-Geometric implementation](https://github.com/acbull/pyHGT)
[“**Heterogeneous Graph Transformer**”](https://arxiv.org/abs/2003.01332) is a graph neural network architecture that can deal with large-scale heterogeneous and dynamic graphs.
This toy experiment is based on DGL's official [tutorial](https://docs.dgl.ai/en/0.4.x/generated/dgl.heterograph.html). As the ACM datasets doesn't have input feature, we simply randomly assign features for each node. Such process can be simply replaced by any prepared features.
The reference performance against R-GCN and MLP running 5 times:
| Model | Test Accuracy | # Parameter |
| --------- | --------------- | -------------|
| 2-layer HGT | 0.465 ± 0.007 | 2,176,324 |
| 2-layer RGCN | 0.392 ± 0.013 | 416,340 |
| MLP | 0.132 ± 0.003 | 200,974 |
import dgl
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
class HGTLayer(nn.Module):
def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads, dropout = 0.2, use_norm = False):
super(HGTLayer, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.num_types = num_types
self.num_relations = num_relations
self.total_rel = num_types * num_relations * num_types
self.n_heads = n_heads
self.d_k = out_dim // n_heads
self.sqrt_dk = math.sqrt(self.d_k)
self.att = None
self.k_linears = nn.ModuleList()
self.q_linears = nn.ModuleList()
self.v_linears = nn.ModuleList()
self.a_linears = nn.ModuleList()
self.norms = nn.ModuleList()
self.use_norm = use_norm
for t in range(num_types):
self.k_linears.append(nn.Linear(in_dim, out_dim))
self.q_linears.append(nn.Linear(in_dim, out_dim))
self.v_linears.append(nn.Linear(in_dim, out_dim))
self.a_linears.append(nn.Linear(out_dim, out_dim))
if use_norm:
self.norms.append(nn.LayerNorm(out_dim))
self.relation_pri = nn.Parameter(torch.ones(num_relations, self.n_heads))
self.relation_att = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
self.relation_msg = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
self.skip = nn.Parameter(torch.ones(num_types))
self.drop = nn.Dropout(dropout)
nn.init.xavier_uniform_(self.relation_att)
nn.init.xavier_uniform_(self.relation_msg)
def edge_attention(self, edges):
etype = edges.data['id'][0]
'''
Step 1: Heterogeneous Mutual Attention
'''
relation_att = self.relation_att[etype]
relation_pri = self.relation_pri[etype]
key = torch.bmm(edges.src['k'].transpose(1,0), relation_att).transpose(1,0)
att = (edges.dst['q'] * key).sum(dim=-1) * relation_pri / self.sqrt_dk
'''
Step 2: Heterogeneous Message Passing
'''
relation_msg = self.relation_msg[etype]
val = torch.bmm(edges.src['v'].transpose(1,0), relation_msg).transpose(1,0)
return {'a': att, 'v': val}
def message_func(self, edges):
return {'v': edges.data['v'], 'a': edges.data['a']}
def reduce_func(self, nodes):
'''
Softmax based on target node's id (edge_index_i).
NOTE: Using DGL's API, there is a minor difference with this softmax with the original one.
This implementation will do softmax only on edges belong to the same relation type, instead of for all of the edges.
'''
att = F.softmax(nodes.mailbox['a'], dim=1)
h = torch.sum(att.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1)
return {'t': h.view(-1, self.out_dim)}
def forward(self, G, inp_key, out_key):
node_dict, edge_dict = G.node_dict, G.edge_dict
for srctype, etype, dsttype in G.canonical_etypes:
k_linear = self.k_linears[node_dict[srctype]]
v_linear = self.v_linears[node_dict[srctype]]
q_linear = self.q_linears[node_dict[dsttype]]
G.nodes[srctype].data['k'] = k_linear(G.nodes[srctype].data[inp_key]).view(-1, self.n_heads, self.d_k)
G.nodes[srctype].data['v'] = v_linear(G.nodes[srctype].data[inp_key]).view(-1, self.n_heads, self.d_k)
G.nodes[dsttype].data['q'] = q_linear(G.nodes[dsttype].data[inp_key]).view(-1, self.n_heads, self.d_k)
G.apply_edges(func=self.edge_attention, etype=etype)
G.multi_update_all({etype : (self.message_func, self.reduce_func) \
for etype in edge_dict}, cross_reducer = 'mean')
for ntype in G.ntypes:
'''
Step 3: Target-specific Aggregation
x = norm( W[node_type] * gelu( Agg(x) ) + x )
'''
n_id = node_dict[ntype]
alpha = torch.sigmoid(self.skip[n_id])
trans_out = self.drop(self.a_linears[n_id](G.nodes[ntype].data['t']))
trans_out = trans_out * alpha + G.nodes[ntype].data[inp_key] * (1-alpha)
if self.use_norm:
G.nodes[ntype].data[out_key] = self.norms[n_id](trans_out)
class HGT(nn.Module):
def __init__(self, G, n_inp, n_hid, n_out, n_layers, n_heads, use_norm = True):
super(HGT, self).__init__()
self.gcs = nn.ModuleList()
self.n_inp = n_inp
self.n_hid = n_hid
self.n_out = n_out
self.n_layers = n_layers
self.adapt_ws = nn.ModuleList()
for t in range(len(G.node_dict)):
self.adapt_ws.append(nn.Linear(n_inp, n_hid))
for _ in range(n_layers):
self.gcs.append(HGTLayer(n_hid, n_hid, len(G.node_dict), len(G.edge_dict), n_heads, use_norm = use_norm))
self.out = nn.Linear(n_hid, n_out)
def forward(self, G, out_key):
for ntype in G.ntypes:
n_id = G.node_dict[ntype]
G.nodes[ntype].data['h'] = F.gelu(self.adapt_ws[n_id](G.nodes[ntype].data['inp']))
for i in range(self.n_layers):
self.gcs[i](G, 'h', 'h')
return self.out(G.nodes[out_key].data['h'])
class HeteroRGCNLayer(nn.Module):
def __init__(self, in_size, out_size, etypes):
super(HeteroRGCNLayer, self).__init__()
# W_r for each relation
self.weight = nn.ModuleDict({
name : nn.Linear(in_size, out_size) for name in etypes
})
def forward(self, G, feat_dict):
# The input is a dictionary of node features for each type
funcs = {}
for srctype, etype, dsttype in G.canonical_etypes:
# Compute W_r * h
Wh = self.weight[etype](feat_dict[srctype])
# Save it in graph for message passing
G.nodes[srctype].data['Wh_%s' % etype] = Wh
# Specify per-relation message passing functions: (message_func, reduce_func).
# Note that the results are saved to the same destination feature 'h', which
# hints the type wise reducer for aggregation.
funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
# Trigger message passing of multiple types.
# The first argument is the message passing functions for each relation.
# The second one is the type wise reducer, could be "sum", "max",
# "min", "mean", "stack"
G.multi_update_all(funcs, 'sum')
# return the updated node feature dictionary
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
class HeteroRGCN(nn.Module):
def __init__(self, G, in_size, hidden_size, out_size):
super(HeteroRGCN, self).__init__()
# create layers
self.layer1 = HeteroRGCNLayer(in_size, hidden_size, G.etypes)
self.layer2 = HeteroRGCNLayer(hidden_size, out_size, G.etypes)
def forward(self, G, out_key):
input_dict = {ntype : G.nodes[ntype].data['inp'] for ntype in G.ntypes}
h_dict = self.layer1(G, input_dict)
h_dict = {k : F.leaky_relu(h) for k, h in h_dict.items()}
h_dict = self.layer2(G, h_dict)
# get paper logits
return h_dict[out_key]
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import scipy.io
import urllib.request
import dgl
import math
import numpy as np
from model import *
import argparse
torch.manual_seed(0)
data_url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/ACM.mat'
data_file_path = '/tmp/ACM.mat'
urllib.request.urlretrieve(data_url, data_file_path)
data = scipy.io.loadmat(data_file_path)
parser = argparse.ArgumentParser(description='Training GNN on ogbn-products benchmark')
parser.add_argument('--n_epoch', type=int, default=200)
parser.add_argument('--n_hid', type=int, default=256)
parser.add_argument('--n_inp', type=int, default=256)
parser.add_argument('--clip', type=int, default=1.0)
parser.add_argument('--max_lr', type=float, default=1e-3)
args = parser.parse_args()
def get_n_params(model):
pp=0
for p in list(model.parameters()):
nn=1
for s in list(p.size()):
nn = nn*s
pp += nn
return pp
def train(model, G):
best_val_acc = 0
best_test_acc = 0
train_step = 0
for epoch in np.arange(args.n_epoch) + 1:
model.train()
logits = model(G, 'paper')
# The loss is computed only for labeled nodes.
loss = F.cross_entropy(logits[train_idx], labels[train_idx].to(device))
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
train_step += 1
scheduler.step(train_step)
if epoch % 5 == 0:
model.eval()
logits = model(G, 'paper')
pred = logits.argmax(1).cpu()
train_acc = (pred[train_idx] == labels[train_idx]).float().mean()
val_acc = (pred[val_idx] == labels[val_idx]).float().mean()
test_acc = (pred[test_idx] == labels[test_idx]).float().mean()
if best_val_acc < val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
print('Epoch: %d LR: %.5f Loss %.4f, Train Acc %.4f, Val Acc %.4f (Best %.4f), Test Acc %.4f (Best %.4f)' % (
epoch,
optimizer.param_groups[0]['lr'],
loss.item(),
train_acc.item(),
val_acc.item(),
best_val_acc.item(),
test_acc.item(),
best_test_acc.item(),
))
G = dgl.heterograph({
('paper', 'written-by', 'author') : data['PvsA'],
('author', 'writing', 'paper') : data['PvsA'].transpose(),
('paper', 'citing', 'paper') : data['PvsP'],
('paper', 'cited', 'paper') : data['PvsP'].transpose(),
('paper', 'is-about', 'subject') : data['PvsL'],
('subject', 'has', 'paper') : data['PvsL'].transpose(),
})
print(G)
pvc = data['PvsC'].tocsr()
p_selected = pvc.tocoo()
# generate labels
labels = pvc.indices
labels = torch.tensor(labels).long()
# generate train/val/test split
pid = p_selected.row
shuffle = np.random.permutation(pid)
train_idx = torch.tensor(shuffle[0:800]).long()
val_idx = torch.tensor(shuffle[800:900]).long()
test_idx = torch.tensor(shuffle[900:]).long()
device = torch.device("cuda:0")
G.node_dict = {}
G.edge_dict = {}
for ntype in G.ntypes:
G.node_dict[ntype] = len(G.node_dict)
for etype in G.etypes:
G.edge_dict[etype] = len(G.edge_dict)
G.edges[etype].data['id'] = torch.ones(G.number_of_edges(etype), dtype=torch.long) * G.edge_dict[etype]
# Random initialize input feature
for ntype in G.ntypes:
emb = nn.Parameter(torch.Tensor(G.number_of_nodes(ntype), 256), requires_grad = False).to(device)
nn.init.xavier_uniform_(emb)
G.nodes[ntype].data['inp'] = emb
model = HGT(G, n_inp=args.n_inp, n_hid=args.n_hid, n_out=labels.max().item()+1, n_layers=2, n_heads=4, use_norm = True).to(device)
optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr)
print('Training HGT with #param: %d' % (get_n_params(model)))
train(model, G)
model = HeteroRGCN(G, in_size=args.n_inp, hidden_size=args.n_hid, out_size=labels.max().item()+1).to(device)
optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr)
print('Training RGCN with #param: %d' % (get_n_params(model)))
train(model, G)
model = HGT(G, n_inp=args.n_inp, n_hid=args.n_hid, n_out=labels.max().item()+1, n_layers=0, n_heads=4).to(device)
optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr)
print('Training MLP with #param: %d' % (get_n_params(model)))
train(model, G)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment