Unverified Commit 44089c8b authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Refactor][Graph] Merge DGLGraph and DGLHeteroGraph (#1862)



* Merge

* [Graph][CUDA] Graph on GPU and many refactoring (#1791)

* change edge_ids behavior and C++ impl

* fix unittests; remove utils.Index in edge_id

* pass mx and th tests

* pass tf test

* add aten::Scatter_

* Add nonzero; impl CSRGetDataAndIndices/CSRSliceMatrix

* CSRGetData and CSRGetDataAndIndices passed tests

* CSRSliceMatrix basic tests

* fix bug in empty slice

* CUDA CSRHasDuplicate

* has_node; has_edge_between

* predecessors, successors

* deprecate send/recv; fix send_and_recv

* deprecate send/recv; fix send_and_recv

* in_edges; out_edges; all_edges; apply_edges

* in deg/out deg

* subgraph/edge_subgraph

* adj

* in_subgraph/out_subgraph

* sample neighbors

* set/get_n/e_repr

* wip: working on refactoring all idtypes

* pass ndata/edata tests on gpu

* fix

* stash

* workaround nonzero issue

* stash

* nx conversion

* test_hetero_basics except update routines

* test_update_routines

* test_hetero_basics for pytorch

* more fixes

* WIP: flatten graph

* wip: flatten

* test_flatten

* test_to_device

* fix bug in to_homo

* fix bug in CSRSliceMatrix

* pass subgraph test

* fix send_and_recv

* fix filter

* test_heterograph

* passed all pytorch tests

* fix mx unittest

* fix pytorch test_nn

* fix all unittests for PyTorch

* passed all mxnet tests

* lint

* fix tf nn test

* pass all tf tests

* lint

* lint

* change deprecation

* try fix compile

* lint

* update METIDS

* fix utest

* fix

* fix utests

* try debug

* revert

* small fix

* fix utests

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* trigger

* +1s

* [kernel] Use heterograph index instead of unitgraph index (#1813)

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* trigger

* +1s

* [Graph] Mutation for Heterograph (#1818)

* mutation add_nodes and add_edges

* Add support for remove_edges, remove_nodes, add_selfloop, remove_selfloop

* Fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>

* upd

* upd

* upd

* fix

* [Transfom] Mutable transform (#1833)

* add nodesy

* All three

* Fix

* lint

* Add some test case

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* fix

* triger

* Fix

* fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>

* [Graph] Migrate Batch & Readout module to heterograph (#1836)

* dgl.batch

* unbatch

* fix to device

* reduce readout; segment reduce

* change batch_num_nodes|edges to function

* reduce readout/ softmax

* broadcast

* topk

* fix

* fix tf and mx

* fix some ci

* fix batch but unbatch differently

* new checkk

* upd

* upd

* upd

* idtype behavior; code reorg

* idtype behavior; code reorg

* wip: test_basics

* pass test_basics

* WIP: from nx/ to nx

* missing files

* upd

* pass test_basics:test_nx_conversion

* Fix test

* Fix inplace update

* WIP: fixing tests

* upd

* pass test_transform cpu

* pass gpu test_transform

* pass test_batched_graph

* GPU graph auto cast to int32

* missing file

* stash

* WIP: rgcn-hetero

* Fix two datasety

* upd

* weird

* Fix capsuley

* fuck you

* fuck matthias

* Fix dgmg

* fix bug in block degrees; pass rgcn-hetero

* rgcn

* gat and diffpool fix
also fix ppi and tu dataset

* Tree LSTM

* pointcloud

* rrn; wip: sgc

* resolve conflicts

* upd

* sgc and reddit dataset

* upd

* Fix deepwalk, gindt and gcn

* fix datasets and sign

* optimization

* optimization

* upd

* upd

* Fix GIN

* fix bug in add_nodes add_edges; tagcn

* adaptive sampling and gcmc

* upd

* upd

* fix geometric

* fix

* metapath2vec

* fix agnn

* fix pickling problem of block

* fix utests

* miss file

* linegraph

* upd

* upd

* upd

* graphsage

* stgcn_wave

* fix hgt

* on unittests

* Fix transformer

* Fix HAN

* passed pytorch unittests

* lint

* fix

* Fix cluster gcn

* cluster-gcn is ready

* on fixing block related codes

* 2nd order derivative

* Revert "2nd order derivative"

This reverts commit 523bf6c249bee61b51b1ad1babf42aad4167f206.

* passed torch utests again

* fix all mxnet unittests

* delete some useless tests

* pass all tf cpu tests

* disable

* disable distributed unittest

* fix

* fix

* lint

* fix

* fix

* fix script

* fix tutorial

* fix apply edges bug

* fix 2 basics

* fix tutorial
Co-authored-by: default avataryzh119 <expye@outlook.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-7-42.us-west-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-1-5.us-west-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-68-185.ec2.internal>
parent 015acfd2
......@@ -105,8 +105,8 @@ class DiffPoolBatchedGraphLayer(nn.Module):
assign_tensor = self.pool_gc(g, h)
device = feat.device
assign_tensor_masks = []
batch_size = len(g.batch_num_nodes)
for g_n_nodes in g.batch_num_nodes:
batch_size = len(g.batch_num_nodes())
for g_n_nodes in g.batch_num_nodes():
mask = torch.ones((g_n_nodes,
int(assign_tensor.size()[1] / batch_size)))
assign_tensor_masks.append(mask)
......
......@@ -202,7 +202,7 @@ def collate_fn(batch):
# batch graphs and cast to PyTorch tensor
for graph in graphs:
for (key, value) in graph.ndata.items():
graph.ndata[key] = torch.FloatTensor(value)
graph.ndata[key] = value.float()
batched_graphs = dgl.batch(graphs)
# cast to PyTorch tensor
......@@ -234,8 +234,7 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
computation_time = 0.0
for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader):
if torch.cuda.is_available():
for (key, value) in batch_graph.ndata.items():
batch_graph.ndata[key] = value.cuda()
batch_graph = batch_graph.to(torch.cuda.current_device())
graph_labels = graph_labels.cuda()
model.zero_grad()
......@@ -285,8 +284,7 @@ def evaluate(dataloader, model, prog_args, logger=None):
with torch.no_grad():
for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
if torch.cuda.is_available():
for (key, value) in batch_graph.ndata.items():
batch_graph.ndata[key] = value.cuda()
batch_graph = batch_graph.to(torch.cuda.current_device())
graph_labels = graph_labels.cuda()
ypred = model(batch_graph)
indi = torch.argmax(ypred, dim=1)
......
......@@ -78,6 +78,8 @@ def main(args):
g.remove_edges_from(nx.selfloop_edges(g))
g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes())
if cuda:
g = g.to(args.gpu)
n_edges = g.number_of_edges()
# create model
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
......
......@@ -63,6 +63,7 @@ def main(args):
n_classes = train_dataset.labels.shape[1]
num_feats = train_dataset.features.shape[1]
g = train_dataset.graph
g = g.to(device)
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
# define the model
model = GAT(g,
......@@ -84,6 +85,7 @@ def main(args):
loss_list = []
for batch, data in enumerate(train_dataloader):
subgraph, feats, labels = data
subgraph = subgraph.to(device)
feats = feats.to(device)
labels = labels.to(device)
model.g = subgraph
......@@ -102,6 +104,7 @@ def main(args):
val_loss_list = []
for batch, valid_data in enumerate(valid_dataloader):
subgraph, feats, labels = valid_data
subgraph = subgraph.to(device)
feats = feats.to(device)
labels = labels.to(device)
score, val_loss = evaluate(feats.float(), model, subgraph, labels.float(), loss_fcn)
......@@ -125,6 +128,7 @@ def main(args):
test_score_list = []
for batch, test_data in enumerate(test_dataloader):
subgraph, feats, labels = test_data
subgraph = subgraph.to(device)
feats = feats.to(device)
labels = labels.to(device)
test_score_list.append(evaluate(feats, model, subgraph, labels.float(), loss_fcn)[0])
......
......@@ -269,7 +269,7 @@ class MovieLens(object):
x = x.numpy().astype('float32')
x[x == 0.] = np.inf
x = th.FloatTensor(1. / np.sqrt(x))
return x.to(self._device).unsqueeze(1)
return x.unsqueeze(1)
user_ci = []
user_cj = []
movie_ci = []
......@@ -290,8 +290,8 @@ class MovieLens(object):
user_cj = _calc_norm(sum(user_cj))
movie_cj = _calc_norm(sum(movie_cj))
else:
user_cj = th.ones(self.num_user,).to(self._device)
movie_cj = th.ones(self.num_movie,).to(self._device)
user_cj = th.ones(self.num_user,)
movie_cj = th.ones(self.num_movie,)
graph.nodes['user'].data.update({'ci' : user_ci, 'cj' : user_cj})
graph.nodes['movie'].data.update({'ci' : movie_ci, 'cj' : movie_cj})
......
......@@ -32,7 +32,7 @@ class GCMCGraphConv(nn.Module):
super(GCMCGraphConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self.device = device
self.device = device
self.dropout = nn.Dropout(dropout_rate)
if weight:
......@@ -210,7 +210,7 @@ class GCMCLayer(nn.Module):
def partial_to(self, device):
"""Put parameters into device except W_r
Parameters
----------
device : torch device
......
......@@ -105,6 +105,13 @@ def train(args):
count_num = 0
count_loss = 0
dataset.train_enc_graph = dataset.train_enc_graph.to(args.device)
dataset.train_dec_graph = dataset.train_dec_graph.to(args.device)
dataset.valid_enc_graph = dataset.train_enc_graph
dataset.valid_dec_graph = dataset.valid_dec_graph.to(args.device)
dataset.test_enc_graph = dataset.test_enc_graph.to(args.device)
dataset.test_dec_graph = dataset.test_dec_graph.to(args.device)
print("Start training ...")
dur = []
for iter_idx in range(1, args.train_max_iter):
......
......@@ -51,7 +51,7 @@ class GCMCSampler:
The input ``seeds`` represents the edges to compute prediction for. The sampling
algorithm works as follows:
1. Get the head and tail nodes of the provided seed edges.
2. For each head and tail node, extract the entire in-coming neighborhood.
3. Copy the node features/embeddings from the full graph to the sampled subgraphs.
......@@ -170,6 +170,7 @@ def evaluate(args, dev_id, net, dataset, dataloader, segment='valid'):
compact_g, frontier, head_feat, tail_feat, \
_, true_relation_ratings = sample_data
frontier = frontier.to(dev_id)
head_feat = head_feat.to(dev_id)
tail_feat = tail_feat.to(dev_id)
with th.no_grad():
......@@ -352,11 +353,12 @@ def run(proc_id, n_gpus, args, devices, dataset):
if epoch > 1:
t0 = time.time()
net.train()
for step, sample_data in enumerate(dataloader):
for step, sample_data in enumerate(dataloader):
compact_g, frontier, head_feat, tail_feat, \
true_relation_labels, true_relation_ratings = sample_data
head_feat = head_feat.to(dev_id)
tail_feat = tail_feat.to(dev_id)
frontier = frontier.to(dev_id)
pred_ratings = net(compact_g, frontier, head_feat, tail_feat, dataset.possible_rating_values)
loss = rating_loss_net(pred_ratings, true_relation_labels.to(dev_id)).mean()
......
......@@ -67,6 +67,8 @@ def main(args):
g.add_edges_from(zip(g.nodes(), g.nodes()))
g = DGLGraph(g)
n_edges = g.number_of_edges()
if cuda:
g = g.to(args.gpu)
# normalization
degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5)
......
......@@ -23,7 +23,8 @@ def train(args, net, trainloader, optimizer, criterion, epoch):
for pos, (graphs, labels) in zip(bar, trainloader):
# batch graphs will be shipped to device in forward part of model
labels = labels.to(args.device)
feat = graphs.ndata['attr'].to(args.device)
feat = graphs.ndata.pop('attr').to(args.device)
graphs = graphs.to(args.device)
outputs = net(graphs, feat)
loss = criterion(outputs, labels)
......@@ -52,7 +53,8 @@ def eval_net(args, net, dataloader, criterion):
for data in dataloader:
graphs, labels = data
feat = graphs.ndata['attr'].to(args.device)
feat = graphs.ndata.pop('attr').to(args.device)
graphs = graphs.to(args.device)
labels = labels.to(args.device)
total += len(labels)
outputs = net(graphs, feat)
......
......@@ -109,6 +109,7 @@ class SAGE(nn.Module):
end = start + batch_size
batch_nodes = nodes[start:end]
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
block = block.to(device)
induced_nodes = block.srcdata[dgl.NID]
h = x[induced_nodes].to(device)
......@@ -185,20 +186,27 @@ def load_subtensor(g, labels, blocks, hist_blocks, dev_id, aggregation_on_device
"""
Copys features and labels of a set of nodes onto GPU.
"""
blocks[0].srcdata['features'] = g.ndata['features'][blocks[0].srcdata[dgl.NID]].to(dev_id)
blocks[-1].dstdata['label'] = labels[blocks[-1].dstdata[dgl.NID]].to(dev_id)
blocks[0].srcdata['features'] = g.ndata['features'][blocks[0].srcdata[dgl.NID]]
blocks[-1].dstdata['label'] = labels[blocks[-1].dstdata[dgl.NID]]
ret_blocks = []
ret_hist_blocks = []
for i, (block, hist_block) in enumerate(zip(blocks, hist_blocks)):
hist_col = 'features' if i == 0 else 'hist_%d' % i
block.srcdata['hist'] = g.ndata[hist_col][block.srcdata[dgl.NID]].to(dev_id)
block.srcdata['hist'] = g.ndata[hist_col][block.srcdata[dgl.NID]]
# Aggregate history
hist_block.srcdata['hist'] = g.ndata[hist_col][hist_block.srcdata[dgl.NID]]
if aggregation_on_device:
hist_block.srcdata['hist'] = hist_block.srcdata['hist'].to(dev_id)
hist_block = hist_block.to(dev_id)
hist_block.update_all(fn.copy_u('hist', 'm'), fn.mean('m', 'agg_hist'))
block.dstdata['agg_hist'] = hist_block.dstdata['agg_hist']
block = block.to(dev_id)
if not aggregation_on_device:
block.dstdata['agg_hist'] = block.dstdata['agg_hist'].to(dev_id)
hist_block = hist_block.to(dev_id)
block.dstdata['agg_hist'] = hist_block.dstdata['agg_hist']
ret_blocks.append(block)
ret_hist_blocks.append(hist_block)
return ret_blocks, ret_hist_blocks
def init_history(g, model, dev_id):
with th.no_grad():
......@@ -211,7 +219,7 @@ def init_history(g, model, dev_id):
def update_history(g, blocks):
with th.no_grad():
for i, block in enumerate(blocks):
ids = block.dstdata[dgl.NID]
ids = block.dstdata[dgl.NID].cpu()
hist_col = 'hist_%d' % (i + 1)
h_new = block.dstdata['h_new'].cpu()
......@@ -269,7 +277,7 @@ def run(args, dev_id, data):
input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID]
load_subtensor(g, labels, blocks, hist_blocks, dev_id, True)
blocks, hist_blocks = load_subtensor(g, labels, blocks, hist_blocks, dev_id, True)
# forward
batch_pred = model(blocks)
......
......@@ -112,6 +112,7 @@ class SAGE(nn.Module):
induced_nodes = block.srcdata[dgl.NID]
h = x[induced_nodes].to(device)
block = block.to(device)
h_dst = h[:block.number_of_dst_nodes()]
h = layer(block, (h, h_dst))
......@@ -213,20 +214,28 @@ def load_subtensor(g, labels, blocks, hist_blocks, dev_id, aggregation_on_device
"""
Copys features and labels of a set of nodes onto GPU.
"""
blocks[0].srcdata['features'] = g.ndata['features'][blocks[0].srcdata[dgl.NID]].to(dev_id)
blocks[-1].dstdata['label'] = labels[blocks[-1].dstdata[dgl.NID]].to(dev_id)
blocks[0].srcdata['features'] = g.ndata['features'][blocks[0].srcdata[dgl.NID]]
blocks[-1].dstdata['label'] = labels[blocks[-1].dstdata[dgl.NID]]
ret_blocks = []
ret_hist_blocks = []
for i, (block, hist_block) in enumerate(zip(blocks, hist_blocks)):
hist_col = 'features' if i == 0 else 'hist_%d' % i
block.srcdata['hist'] = g.ndata[hist_col][block.srcdata[dgl.NID]].to(dev_id)
block.srcdata['hist'] = g.ndata[hist_col][block.srcdata[dgl.NID]]
# Aggregate history
hist_block.srcdata['hist'] = g.ndata[hist_col][hist_block.srcdata[dgl.NID]]
if aggregation_on_device:
hist_block.srcdata['hist'] = hist_block.srcdata['hist'].to(dev_id)
hist_block = hist_block.to(dev_id)
hist_block.srcdata['hist'] = hist_block.srcdata['hist']
hist_block.update_all(fn.copy_u('hist', 'm'), fn.mean('m', 'agg_hist'))
block.dstdata['agg_hist'] = hist_block.dstdata['agg_hist']
block = block.to(dev_id)
if not aggregation_on_device:
block.dstdata['agg_hist'] = block.dstdata['agg_hist'].to(dev_id)
hist_block = hist_block.to(dev_id)
block.dstdata['agg_hist'] = hist_block.dstdata['agg_hist']
ret_blocks.append(block)
ret_hist_blocks.append(hist_block)
return ret_blocks, ret_hist_blocks
def create_history_storage(g, args, n_classes):
# Initialize history storage
......@@ -241,7 +250,7 @@ def init_history(g, model, dev_id, batch_size):
def update_history(g, blocks):
with th.no_grad():
for i, block in enumerate(blocks):
ids = block.dstdata[dgl.NID]
ids = block.dstdata[dgl.NID].cpu()
hist_col = 'hist_%d' % (i + 1)
h_new = block.dstdata['h_new'].cpu()
......@@ -317,10 +326,9 @@ def run(proc_id, n_gpus, args, devices, data):
# The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block.
input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID]
load_subtensor(g, labels, blocks, hist_blocks, dev_id, True)
blocks, hist_blocks = load_subtensor(g, labels, blocks, hist_blocks, dev_id, True)
# forward
batch_pred = model(blocks)
......
......@@ -77,6 +77,7 @@ class SAGE(nn.Module):
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0]
block = block.to(device)
h = x[input_nodes].to(device)
h = layer(block, h)
if l != len(self.layers) - 1:
......@@ -171,6 +172,7 @@ def run(args, device, data):
# Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(train_g, seeds, input_nodes, device)
blocks = [block.to(device) for block in blocks]
# Compute loss and prediction
batch_pred = model(blocks, batch_inputs)
......
......@@ -78,6 +78,7 @@ class SAGE(nn.Module):
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0]
block = block.to(device)
h = x[input_nodes].to(device)
h = layer(block, h)
if l != len(self.layers) - 1:
......@@ -150,7 +151,7 @@ def run(proc_id, n_gpus, args, devices, data):
in_feats, n_classes, train_g, val_g, test_g = data
train_mask = train_g.ndata['train_mask']
val_mask = val_g.ndata['val_mask']
test_mask = ~(train_g.ndata['train_mask'] | val_g.ndata['val_mask'])
test_mask = ~(test_g.ndata['train_mask'] | test_g.ndata['val_mask'])
train_nid = train_mask.nonzero()[:, 0]
val_nid = val_mask.nonzero()[:, 0]
test_nid = test_mask.nonzero()[:, 0]
......@@ -193,7 +194,7 @@ def run(proc_id, n_gpus, args, devices, data):
# Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(train_g, train_g.ndata['labels'], seeds, input_nodes, dev_id)
blocks = [block.to(dev_id) for block in blocks]
# Compute loss and prediction
batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, batch_labels)
......
......@@ -149,6 +149,7 @@ class SAGE(nn.Module):
end = start + batch_size
batch_nodes = nodes[start:end]
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
block = block.to(device)
input_nodes = block.srcdata[dgl.NID]
h = x[input_nodes].to(device)
......@@ -293,6 +294,9 @@ def run(proc_id, n_gpus, args, devices, data):
batch_inputs = load_subtensor(g, input_nodes, device)
d_step = time.time()
pos_graph = pos_graph.to(device)
neg_graph = neg_graph.to(device)
blocks = [block.to(device) for block in blocks]
# Compute loss and prediction
batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, pos_graph, neg_graph)
......
......@@ -48,6 +48,7 @@ def main(args):
out_size=num_classes,
num_heads=args['num_heads'],
dropout=args['dropout']).to(args['device'])
g = g.to(args['device'])
else:
from model import HAN
model = HAN(num_meta_paths=len(g),
......@@ -56,6 +57,7 @@ def main(args):
out_size=num_classes,
num_heads=args['num_heads'],
dropout=args['dropout']).to(args['device'])
g = [graph.to(args['device']) for graph in g]
stopper = EarlyStopping(patience=args['patience'])
loss_fcn = torch.nn.CrossEntropyLoss()
......
......@@ -6,14 +6,23 @@ import torch.nn.functional as F
import dgl.function as fn
class HGTLayer(nn.Module):
def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads, dropout = 0.2, use_norm = False):
def __init__(self,
in_dim,
out_dim,
node_dict,
edge_dict,
n_heads,
dropout = 0.2,
use_norm = False):
super(HGTLayer, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.num_types = num_types
self.num_relations = num_relations
self.total_rel = num_types * num_relations * num_types
self.node_dict = node_dict
self.edge_dict = edge_dict
self.num_types = len(node_dict)
self.num_relations = len(edge_dict)
self.total_rel = self.num_types * self.num_relations * self.num_types
self.n_heads = n_heads
self.d_k = out_dim // n_heads
self.sqrt_dk = math.sqrt(self.d_k)
......@@ -26,7 +35,7 @@ class HGTLayer(nn.Module):
self.norms = nn.ModuleList()
self.use_norm = use_norm
for t in range(num_types):
for t in range(self.num_types):
self.k_linears.append(nn.Linear(in_dim, out_dim))
self.q_linears.append(nn.Linear(in_dim, out_dim))
self.v_linears.append(nn.Linear(in_dim, out_dim))
......@@ -34,10 +43,10 @@ class HGTLayer(nn.Module):
if use_norm:
self.norms.append(nn.LayerNorm(out_dim))
self.relation_pri = nn.Parameter(torch.ones(num_relations, self.n_heads))
self.relation_att = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
self.relation_msg = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
self.skip = nn.Parameter(torch.ones(num_types))
self.relation_pri = nn.Parameter(torch.ones(self.num_relations, self.n_heads))
self.relation_att = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k))
self.relation_msg = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k))
self.skip = nn.Parameter(torch.ones(self.num_types))
self.drop = nn.Dropout(dropout)
nn.init.xavier_uniform_(self.relation_att)
......@@ -74,54 +83,62 @@ class HGTLayer(nn.Module):
h = torch.sum(att.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1)
return {'t': h.view(-1, self.out_dim)}
def forward(self, G, inp_key, out_key):
node_dict, edge_dict = G.node_dict, G.edge_dict
for srctype, etype, dsttype in G.canonical_etypes:
k_linear = self.k_linears[node_dict[srctype]]
v_linear = self.v_linears[node_dict[srctype]]
q_linear = self.q_linears[node_dict[dsttype]]
G.nodes[srctype].data['k'] = k_linear(G.nodes[srctype].data[inp_key]).view(-1, self.n_heads, self.d_k)
G.nodes[srctype].data['v'] = v_linear(G.nodes[srctype].data[inp_key]).view(-1, self.n_heads, self.d_k)
G.nodes[dsttype].data['q'] = q_linear(G.nodes[dsttype].data[inp_key]).view(-1, self.n_heads, self.d_k)
G.apply_edges(func=self.edge_attention, etype=etype)
G.multi_update_all({etype : (self.message_func, self.reduce_func) \
for etype in edge_dict}, cross_reducer = 'mean')
for ntype in G.ntypes:
'''
Step 3: Target-specific Aggregation
x = norm( W[node_type] * gelu( Agg(x) ) + x )
'''
n_id = node_dict[ntype]
alpha = torch.sigmoid(self.skip[n_id])
trans_out = self.drop(self.a_linears[n_id](G.nodes[ntype].data['t']))
trans_out = trans_out * alpha + G.nodes[ntype].data[inp_key] * (1-alpha)
if self.use_norm:
G.nodes[ntype].data[out_key] = self.norms[n_id](trans_out)
def forward(self, G, h):
with G.local_scope():
node_dict, edge_dict = self.node_dict, self.edge_dict
for srctype, etype, dsttype in G.canonical_etypes:
k_linear = self.k_linears[node_dict[srctype]]
v_linear = self.v_linears[node_dict[srctype]]
q_linear = self.q_linears[node_dict[dsttype]]
G.nodes[srctype].data['k'] = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
G.nodes[srctype].data['v'] = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
G.nodes[dsttype].data['q'] = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k)
G.apply_edges(func=self.edge_attention, etype=etype)
G.multi_update_all({etype : (self.message_func, self.reduce_func) \
for etype in edge_dict}, cross_reducer = 'mean')
new_h = {}
for ntype in G.ntypes:
'''
Step 3: Target-specific Aggregation
x = norm( W[node_type] * gelu( Agg(x) ) + x )
'''
n_id = node_dict[ntype]
alpha = torch.sigmoid(self.skip[n_id])
trans_out = self.drop(self.a_linears[n_id](G.nodes[ntype].data['t']))
trans_out = trans_out * alpha + h[ntype] * (1-alpha)
if self.use_norm:
new_h[ntype] = self.norms[n_id](trans_out)
else:
new_h[ntype] = trans_out
return new_h
class HGT(nn.Module):
def __init__(self, G, n_inp, n_hid, n_out, n_layers, n_heads, use_norm = True):
def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads, use_norm = True):
super(HGT, self).__init__()
self.node_dict = node_dict
self.edge_dict = edge_dict
self.gcs = nn.ModuleList()
self.n_inp = n_inp
self.n_hid = n_hid
self.n_out = n_out
self.n_layers = n_layers
self.adapt_ws = nn.ModuleList()
for t in range(len(G.node_dict)):
for t in range(len(node_dict)):
self.adapt_ws.append(nn.Linear(n_inp, n_hid))
for _ in range(n_layers):
self.gcs.append(HGTLayer(n_hid, n_hid, len(G.node_dict), len(G.edge_dict), n_heads, use_norm = use_norm))
self.gcs.append(HGTLayer(n_hid, n_hid, node_dict, edge_dict, n_heads, use_norm = use_norm))
self.out = nn.Linear(n_hid, n_out)
def forward(self, G, out_key):
h = {}
for ntype in G.ntypes:
n_id = G.node_dict[ntype]
G.nodes[ntype].data['h'] = F.gelu(self.adapt_ws[n_id](G.nodes[ntype].data['inp']))
n_id = self.node_dict[ntype]
h[ntype] = F.gelu(self.adapt_ws[n_id](G.nodes[ntype].data['inp']))
for i in range(self.n_layers):
self.gcs[i](G, 'h', 'h')
return self.out(G.nodes[out_key].data['h'])
h = self.gcs[i](G, h)
return self.out(h[out_key])
class HeteroRGCNLayer(nn.Module):
def __init__(self, in_size, out_size, etypes):
......
......@@ -77,6 +77,7 @@ def train(model, G):
best_test_acc.item(),
))
device = torch.device("cuda:0")
G = dgl.heterograph({
('paper', 'written-by', 'author') : data['PvsA'],
......@@ -88,8 +89,6 @@ G = dgl.heterograph({
})
print(G)
pvc = data['PvsC'].tocsr()
p_selected = pvc.tocoo()
# generate labels
......@@ -103,26 +102,30 @@ train_idx = torch.tensor(shuffle[0:800]).long()
val_idx = torch.tensor(shuffle[800:900]).long()
test_idx = torch.tensor(shuffle[900:]).long()
device = torch.device("cuda:0")
G.node_dict = {}
G.edge_dict = {}
node_dict = {}
edge_dict = {}
for ntype in G.ntypes:
G.node_dict[ntype] = len(G.node_dict)
node_dict[ntype] = len(node_dict)
for etype in G.etypes:
G.edge_dict[etype] = len(G.edge_dict)
G.edges[etype].data['id'] = torch.ones(G.number_of_edges(etype), dtype=torch.long) * G.edge_dict[etype]
edge_dict[etype] = len(edge_dict)
G.edges[etype].data['id'] = torch.ones(G.number_of_edges(etype), dtype=torch.long) * edge_dict[etype]
# Random initialize input feature
for ntype in G.ntypes:
emb = nn.Parameter(torch.Tensor(G.number_of_nodes(ntype), 256), requires_grad = False).to(device)
emb = nn.Parameter(torch.Tensor(G.number_of_nodes(ntype), 256), requires_grad = False)
nn.init.xavier_uniform_(emb)
G.nodes[ntype].data['inp'] = emb
model = HGT(G, n_inp=args.n_inp, n_hid=args.n_hid, n_out=labels.max().item()+1, n_layers=2, n_heads=4, use_norm = True).to(device)
G = G.to(device)
model = HGT(G,
node_dict, edge_dict,
n_inp=args.n_inp,
n_hid=args.n_hid,
n_out=labels.max().item()+1,
n_layers=2,
n_heads=4,
use_norm = True).to(device)
optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr)
print('Training HGT with #param: %d' % (get_n_params(model)))
......@@ -131,7 +134,10 @@ train(model, G)
model = HeteroRGCN(G, in_size=args.n_inp, hidden_size=args.n_hid, out_size=labels.max().item()+1).to(device)
model = HeteroRGCN(G,
in_size=args.n_inp,
hidden_size=args.n_hid,
out_size=labels.max().item()+1).to(device)
optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr)
print('Training RGCN with #param: %d' % (get_n_params(model)))
......@@ -139,7 +145,13 @@ train(model, G)
model = HGT(G, n_inp=args.n_inp, n_hid=args.n_hid, n_out=labels.max().item()+1, n_layers=0, n_heads=4).to(device)
model = HGT(G,
node_dict, edge_dict,
n_inp=args.n_inp,
n_hid=args.n_hid,
n_out=labels.max().item()+1,
n_layers=0,
n_heads=4).to(device)
optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr)
print('Training MLP with #param: %d' % (get_n_params(model)))
......
......@@ -30,13 +30,13 @@ class GNNModule(nn.Module):
def aggregate(self, g, z):
z_list = []
g.set_n_repr({'z' : z})
g.ndata['z'] = z
g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.get_n_repr()['z'])
z_list.append(g.ndata['z'])
for i in range(self.radius - 1):
for j in range(2 ** i):
g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.get_n_repr()['z'])
z_list.append(g.ndata['z'])
return z_list
def forward(self, g, lg, x, y, deg_g, deg_lg, pm_pd):
......@@ -44,9 +44,9 @@ class GNNModule(nn.Module):
sum_x = sum(theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x)))
g.set_e_repr({'y' : y})
g.edata['y'] = y
g.update_all(fn.copy_edge(edge='y', out='m'), fn.sum('m', 'pmpd_y'))
pmpd_y = g.pop_n_repr('pmpd_y')
pmpd_y = g.ndata.pop('pmpd_y')
x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum_x + self.theta_y(pmpd_y)
n = self.out_feats // 2
......@@ -63,11 +63,6 @@ class GNNModule(nn.Module):
class GNN(nn.Module):
def __init__(self, feats, radius, n_classes):
"""
Parameters
----------
g : networkx.DiGraph
"""
super(GNN, self).__init__()
self.linear = nn.Linear(feats[-1], n_classes)
self.module_list = nn.ModuleList([GNNModule(m, n, radius)
......
......@@ -67,8 +67,10 @@ def from_np(f, *args):
@from_np
def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
""" One step of training. """
deg_g = deg_g.to(dev)
deg_lg = deg_lg.to(dev)
g = g.to(dev)
lg = lg.to(dev)
deg_g = deg_g.to(dev).unsqueeze(1)
deg_lg = deg_lg.to(dev).unsqueeze(1)
pm_pd = pm_pd.to(dev)
t0 = time.time()
z = model(g, lg, deg_g, deg_lg, pm_pd)
......@@ -88,8 +90,10 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
@from_np
def inference(g, lg, deg_g, deg_lg, pm_pd):
deg_g = deg_g.to(dev)
deg_lg = deg_lg.to(dev)
g = g.to(dev)
lg = lg.to(dev)
deg_g = deg_g.to(dev).unsqueeze(1)
deg_lg = deg_lg.to(dev).unsqueeze(1)
pm_pd = pm_pd.to(dev)
z = model(g, lg, deg_g, deg_lg, pm_pd)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment