Unverified Commit 44089c8b authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Refactor][Graph] Merge DGLGraph and DGLHeteroGraph (#1862)



* Merge

* [Graph][CUDA] Graph on GPU and many refactoring (#1791)

* change edge_ids behavior and C++ impl

* fix unittests; remove utils.Index in edge_id

* pass mx and th tests

* pass tf test

* add aten::Scatter_

* Add nonzero; impl CSRGetDataAndIndices/CSRSliceMatrix

* CSRGetData and CSRGetDataAndIndices passed tests

* CSRSliceMatrix basic tests

* fix bug in empty slice

* CUDA CSRHasDuplicate

* has_node; has_edge_between

* predecessors, successors

* deprecate send/recv; fix send_and_recv

* deprecate send/recv; fix send_and_recv

* in_edges; out_edges; all_edges; apply_edges

* in deg/out deg

* subgraph/edge_subgraph

* adj

* in_subgraph/out_subgraph

* sample neighbors

* set/get_n/e_repr

* wip: working on refactoring all idtypes

* pass ndata/edata tests on gpu

* fix

* stash

* workaround nonzero issue

* stash

* nx conversion

* test_hetero_basics except update routines

* test_update_routines

* test_hetero_basics for pytorch

* more fixes

* WIP: flatten graph

* wip: flatten

* test_flatten

* test_to_device

* fix bug in to_homo

* fix bug in CSRSliceMatrix

* pass subgraph test

* fix send_and_recv

* fix filter

* test_heterograph

* passed all pytorch tests

* fix mx unittest

* fix pytorch test_nn

* fix all unittests for PyTorch

* passed all mxnet tests

* lint

* fix tf nn test

* pass all tf tests

* lint

* lint

* change deprecation

* try fix compile

* lint

* update METIDS

* fix utest

* fix

* fix utests

* try debug

* revert

* small fix

* fix utests

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* trigger

* +1s

* [kernel] Use heterograph index instead of unitgraph index (#1813)

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* trigger

* +1s

* [Graph] Mutation for Heterograph (#1818)

* mutation add_nodes and add_edges

* Add support for remove_edges, remove_nodes, add_selfloop, remove_selfloop

* Fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>

* upd

* upd

* upd

* fix

* [Transfom] Mutable transform (#1833)

* add nodesy

* All three

* Fix

* lint

* Add some test case

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* fix

* triger

* Fix

* fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>

* [Graph] Migrate Batch & Readout module to heterograph (#1836)

* dgl.batch

* unbatch

* fix to device

* reduce readout; segment reduce

* change batch_num_nodes|edges to function

* reduce readout/ softmax

* broadcast

* topk

* fix

* fix tf and mx

* fix some ci

* fix batch but unbatch differently

* new checkk

* upd

* upd

* upd

* idtype behavior; code reorg

* idtype behavior; code reorg

* wip: test_basics

* pass test_basics

* WIP: from nx/ to nx

* missing files

* upd

* pass test_basics:test_nx_conversion

* Fix test

* Fix inplace update

* WIP: fixing tests

* upd

* pass test_transform cpu

* pass gpu test_transform

* pass test_batched_graph

* GPU graph auto cast to int32

* missing file

* stash

* WIP: rgcn-hetero

* Fix two datasety

* upd

* weird

* Fix capsuley

* fuck you

* fuck matthias

* Fix dgmg

* fix bug in block degrees; pass rgcn-hetero

* rgcn

* gat and diffpool fix
also fix ppi and tu dataset

* Tree LSTM

* pointcloud

* rrn; wip: sgc

* resolve conflicts

* upd

* sgc and reddit dataset

* upd

* Fix deepwalk, gindt and gcn

* fix datasets and sign

* optimization

* optimization

* upd

* upd

* Fix GIN

* fix bug in add_nodes add_edges; tagcn

* adaptive sampling and gcmc

* upd

* upd

* fix geometric

* fix

* metapath2vec

* fix agnn

* fix pickling problem of block

* fix utests

* miss file

* linegraph

* upd

* upd

* upd

* graphsage

* stgcn_wave

* fix hgt

* on unittests

* Fix transformer

* Fix HAN

* passed pytorch unittests

* lint

* fix

* Fix cluster gcn

* cluster-gcn is ready

* on fixing block related codes

* 2nd order derivative

* Revert "2nd order derivative"

This reverts commit 523bf6c249bee61b51b1ad1babf42aad4167f206.

* passed torch utests again

* fix all mxnet unittests

* delete some useless tests

* pass all tf cpu tests

* disable

* disable distributed unittest

* fix

* fix

* lint

* fix

* fix

* fix script

* fix tutorial

* fix apply edges bug

* fix 2 basics

* fix tutorial
Co-authored-by: default avataryzh119 <expye@outlook.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-7-42.us-west-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-1-5.us-west-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-68-185.ec2.internal>
parent 015acfd2
...@@ -105,8 +105,8 @@ class DiffPoolBatchedGraphLayer(nn.Module): ...@@ -105,8 +105,8 @@ class DiffPoolBatchedGraphLayer(nn.Module):
assign_tensor = self.pool_gc(g, h) assign_tensor = self.pool_gc(g, h)
device = feat.device device = feat.device
assign_tensor_masks = [] assign_tensor_masks = []
batch_size = len(g.batch_num_nodes) batch_size = len(g.batch_num_nodes())
for g_n_nodes in g.batch_num_nodes: for g_n_nodes in g.batch_num_nodes():
mask = torch.ones((g_n_nodes, mask = torch.ones((g_n_nodes,
int(assign_tensor.size()[1] / batch_size))) int(assign_tensor.size()[1] / batch_size)))
assign_tensor_masks.append(mask) assign_tensor_masks.append(mask)
......
...@@ -202,7 +202,7 @@ def collate_fn(batch): ...@@ -202,7 +202,7 @@ def collate_fn(batch):
# batch graphs and cast to PyTorch tensor # batch graphs and cast to PyTorch tensor
for graph in graphs: for graph in graphs:
for (key, value) in graph.ndata.items(): for (key, value) in graph.ndata.items():
graph.ndata[key] = torch.FloatTensor(value) graph.ndata[key] = value.float()
batched_graphs = dgl.batch(graphs) batched_graphs = dgl.batch(graphs)
# cast to PyTorch tensor # cast to PyTorch tensor
...@@ -234,8 +234,7 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None): ...@@ -234,8 +234,7 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
computation_time = 0.0 computation_time = 0.0
for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader): for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader):
if torch.cuda.is_available(): if torch.cuda.is_available():
for (key, value) in batch_graph.ndata.items(): batch_graph = batch_graph.to(torch.cuda.current_device())
batch_graph.ndata[key] = value.cuda()
graph_labels = graph_labels.cuda() graph_labels = graph_labels.cuda()
model.zero_grad() model.zero_grad()
...@@ -285,8 +284,7 @@ def evaluate(dataloader, model, prog_args, logger=None): ...@@ -285,8 +284,7 @@ def evaluate(dataloader, model, prog_args, logger=None):
with torch.no_grad(): with torch.no_grad():
for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader): for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
if torch.cuda.is_available(): if torch.cuda.is_available():
for (key, value) in batch_graph.ndata.items(): batch_graph = batch_graph.to(torch.cuda.current_device())
batch_graph.ndata[key] = value.cuda()
graph_labels = graph_labels.cuda() graph_labels = graph_labels.cuda()
ypred = model(batch_graph) ypred = model(batch_graph)
indi = torch.argmax(ypred, dim=1) indi = torch.argmax(ypred, dim=1)
......
...@@ -78,6 +78,8 @@ def main(args): ...@@ -78,6 +78,8 @@ def main(args):
g.remove_edges_from(nx.selfloop_edges(g)) g.remove_edges_from(nx.selfloop_edges(g))
g = DGLGraph(g) g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes()) g.add_edges(g.nodes(), g.nodes())
if cuda:
g = g.to(args.gpu)
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
# create model # create model
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads] heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
......
...@@ -63,6 +63,7 @@ def main(args): ...@@ -63,6 +63,7 @@ def main(args):
n_classes = train_dataset.labels.shape[1] n_classes = train_dataset.labels.shape[1]
num_feats = train_dataset.features.shape[1] num_feats = train_dataset.features.shape[1]
g = train_dataset.graph g = train_dataset.graph
g = g.to(device)
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads] heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
# define the model # define the model
model = GAT(g, model = GAT(g,
...@@ -84,6 +85,7 @@ def main(args): ...@@ -84,6 +85,7 @@ def main(args):
loss_list = [] loss_list = []
for batch, data in enumerate(train_dataloader): for batch, data in enumerate(train_dataloader):
subgraph, feats, labels = data subgraph, feats, labels = data
subgraph = subgraph.to(device)
feats = feats.to(device) feats = feats.to(device)
labels = labels.to(device) labels = labels.to(device)
model.g = subgraph model.g = subgraph
...@@ -102,6 +104,7 @@ def main(args): ...@@ -102,6 +104,7 @@ def main(args):
val_loss_list = [] val_loss_list = []
for batch, valid_data in enumerate(valid_dataloader): for batch, valid_data in enumerate(valid_dataloader):
subgraph, feats, labels = valid_data subgraph, feats, labels = valid_data
subgraph = subgraph.to(device)
feats = feats.to(device) feats = feats.to(device)
labels = labels.to(device) labels = labels.to(device)
score, val_loss = evaluate(feats.float(), model, subgraph, labels.float(), loss_fcn) score, val_loss = evaluate(feats.float(), model, subgraph, labels.float(), loss_fcn)
...@@ -125,6 +128,7 @@ def main(args): ...@@ -125,6 +128,7 @@ def main(args):
test_score_list = [] test_score_list = []
for batch, test_data in enumerate(test_dataloader): for batch, test_data in enumerate(test_dataloader):
subgraph, feats, labels = test_data subgraph, feats, labels = test_data
subgraph = subgraph.to(device)
feats = feats.to(device) feats = feats.to(device)
labels = labels.to(device) labels = labels.to(device)
test_score_list.append(evaluate(feats, model, subgraph, labels.float(), loss_fcn)[0]) test_score_list.append(evaluate(feats, model, subgraph, labels.float(), loss_fcn)[0])
......
...@@ -269,7 +269,7 @@ class MovieLens(object): ...@@ -269,7 +269,7 @@ class MovieLens(object):
x = x.numpy().astype('float32') x = x.numpy().astype('float32')
x[x == 0.] = np.inf x[x == 0.] = np.inf
x = th.FloatTensor(1. / np.sqrt(x)) x = th.FloatTensor(1. / np.sqrt(x))
return x.to(self._device).unsqueeze(1) return x.unsqueeze(1)
user_ci = [] user_ci = []
user_cj = [] user_cj = []
movie_ci = [] movie_ci = []
...@@ -290,8 +290,8 @@ class MovieLens(object): ...@@ -290,8 +290,8 @@ class MovieLens(object):
user_cj = _calc_norm(sum(user_cj)) user_cj = _calc_norm(sum(user_cj))
movie_cj = _calc_norm(sum(movie_cj)) movie_cj = _calc_norm(sum(movie_cj))
else: else:
user_cj = th.ones(self.num_user,).to(self._device) user_cj = th.ones(self.num_user,)
movie_cj = th.ones(self.num_movie,).to(self._device) movie_cj = th.ones(self.num_movie,)
graph.nodes['user'].data.update({'ci' : user_ci, 'cj' : user_cj}) graph.nodes['user'].data.update({'ci' : user_ci, 'cj' : user_cj})
graph.nodes['movie'].data.update({'ci' : movie_ci, 'cj' : movie_cj}) graph.nodes['movie'].data.update({'ci' : movie_ci, 'cj' : movie_cj})
......
...@@ -32,7 +32,7 @@ class GCMCGraphConv(nn.Module): ...@@ -32,7 +32,7 @@ class GCMCGraphConv(nn.Module):
super(GCMCGraphConv, self).__init__() super(GCMCGraphConv, self).__init__()
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self.device = device self.device = device
self.dropout = nn.Dropout(dropout_rate) self.dropout = nn.Dropout(dropout_rate)
if weight: if weight:
...@@ -210,7 +210,7 @@ class GCMCLayer(nn.Module): ...@@ -210,7 +210,7 @@ class GCMCLayer(nn.Module):
def partial_to(self, device): def partial_to(self, device):
"""Put parameters into device except W_r """Put parameters into device except W_r
Parameters Parameters
---------- ----------
device : torch device device : torch device
......
...@@ -105,6 +105,13 @@ def train(args): ...@@ -105,6 +105,13 @@ def train(args):
count_num = 0 count_num = 0
count_loss = 0 count_loss = 0
dataset.train_enc_graph = dataset.train_enc_graph.to(args.device)
dataset.train_dec_graph = dataset.train_dec_graph.to(args.device)
dataset.valid_enc_graph = dataset.train_enc_graph
dataset.valid_dec_graph = dataset.valid_dec_graph.to(args.device)
dataset.test_enc_graph = dataset.test_enc_graph.to(args.device)
dataset.test_dec_graph = dataset.test_dec_graph.to(args.device)
print("Start training ...") print("Start training ...")
dur = [] dur = []
for iter_idx in range(1, args.train_max_iter): for iter_idx in range(1, args.train_max_iter):
......
...@@ -51,7 +51,7 @@ class GCMCSampler: ...@@ -51,7 +51,7 @@ class GCMCSampler:
The input ``seeds`` represents the edges to compute prediction for. The sampling The input ``seeds`` represents the edges to compute prediction for. The sampling
algorithm works as follows: algorithm works as follows:
1. Get the head and tail nodes of the provided seed edges. 1. Get the head and tail nodes of the provided seed edges.
2. For each head and tail node, extract the entire in-coming neighborhood. 2. For each head and tail node, extract the entire in-coming neighborhood.
3. Copy the node features/embeddings from the full graph to the sampled subgraphs. 3. Copy the node features/embeddings from the full graph to the sampled subgraphs.
...@@ -170,6 +170,7 @@ def evaluate(args, dev_id, net, dataset, dataloader, segment='valid'): ...@@ -170,6 +170,7 @@ def evaluate(args, dev_id, net, dataset, dataloader, segment='valid'):
compact_g, frontier, head_feat, tail_feat, \ compact_g, frontier, head_feat, tail_feat, \
_, true_relation_ratings = sample_data _, true_relation_ratings = sample_data
frontier = frontier.to(dev_id)
head_feat = head_feat.to(dev_id) head_feat = head_feat.to(dev_id)
tail_feat = tail_feat.to(dev_id) tail_feat = tail_feat.to(dev_id)
with th.no_grad(): with th.no_grad():
...@@ -352,11 +353,12 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -352,11 +353,12 @@ def run(proc_id, n_gpus, args, devices, dataset):
if epoch > 1: if epoch > 1:
t0 = time.time() t0 = time.time()
net.train() net.train()
for step, sample_data in enumerate(dataloader): for step, sample_data in enumerate(dataloader):
compact_g, frontier, head_feat, tail_feat, \ compact_g, frontier, head_feat, tail_feat, \
true_relation_labels, true_relation_ratings = sample_data true_relation_labels, true_relation_ratings = sample_data
head_feat = head_feat.to(dev_id) head_feat = head_feat.to(dev_id)
tail_feat = tail_feat.to(dev_id) tail_feat = tail_feat.to(dev_id)
frontier = frontier.to(dev_id)
pred_ratings = net(compact_g, frontier, head_feat, tail_feat, dataset.possible_rating_values) pred_ratings = net(compact_g, frontier, head_feat, tail_feat, dataset.possible_rating_values)
loss = rating_loss_net(pred_ratings, true_relation_labels.to(dev_id)).mean() loss = rating_loss_net(pred_ratings, true_relation_labels.to(dev_id)).mean()
......
...@@ -67,6 +67,8 @@ def main(args): ...@@ -67,6 +67,8 @@ def main(args):
g.add_edges_from(zip(g.nodes(), g.nodes())) g.add_edges_from(zip(g.nodes(), g.nodes()))
g = DGLGraph(g) g = DGLGraph(g)
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
if cuda:
g = g.to(args.gpu)
# normalization # normalization
degs = g.in_degrees().float() degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5) norm = torch.pow(degs, -0.5)
......
...@@ -23,7 +23,8 @@ def train(args, net, trainloader, optimizer, criterion, epoch): ...@@ -23,7 +23,8 @@ def train(args, net, trainloader, optimizer, criterion, epoch):
for pos, (graphs, labels) in zip(bar, trainloader): for pos, (graphs, labels) in zip(bar, trainloader):
# batch graphs will be shipped to device in forward part of model # batch graphs will be shipped to device in forward part of model
labels = labels.to(args.device) labels = labels.to(args.device)
feat = graphs.ndata['attr'].to(args.device) feat = graphs.ndata.pop('attr').to(args.device)
graphs = graphs.to(args.device)
outputs = net(graphs, feat) outputs = net(graphs, feat)
loss = criterion(outputs, labels) loss = criterion(outputs, labels)
...@@ -52,7 +53,8 @@ def eval_net(args, net, dataloader, criterion): ...@@ -52,7 +53,8 @@ def eval_net(args, net, dataloader, criterion):
for data in dataloader: for data in dataloader:
graphs, labels = data graphs, labels = data
feat = graphs.ndata['attr'].to(args.device) feat = graphs.ndata.pop('attr').to(args.device)
graphs = graphs.to(args.device)
labels = labels.to(args.device) labels = labels.to(args.device)
total += len(labels) total += len(labels)
outputs = net(graphs, feat) outputs = net(graphs, feat)
......
...@@ -109,6 +109,7 @@ class SAGE(nn.Module): ...@@ -109,6 +109,7 @@ class SAGE(nn.Module):
end = start + batch_size end = start + batch_size
batch_nodes = nodes[start:end] batch_nodes = nodes[start:end]
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes) block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
block = block.to(device)
induced_nodes = block.srcdata[dgl.NID] induced_nodes = block.srcdata[dgl.NID]
h = x[induced_nodes].to(device) h = x[induced_nodes].to(device)
...@@ -185,20 +186,27 @@ def load_subtensor(g, labels, blocks, hist_blocks, dev_id, aggregation_on_device ...@@ -185,20 +186,27 @@ def load_subtensor(g, labels, blocks, hist_blocks, dev_id, aggregation_on_device
""" """
Copys features and labels of a set of nodes onto GPU. Copys features and labels of a set of nodes onto GPU.
""" """
blocks[0].srcdata['features'] = g.ndata['features'][blocks[0].srcdata[dgl.NID]].to(dev_id) blocks[0].srcdata['features'] = g.ndata['features'][blocks[0].srcdata[dgl.NID]]
blocks[-1].dstdata['label'] = labels[blocks[-1].dstdata[dgl.NID]].to(dev_id) blocks[-1].dstdata['label'] = labels[blocks[-1].dstdata[dgl.NID]]
ret_blocks = []
ret_hist_blocks = []
for i, (block, hist_block) in enumerate(zip(blocks, hist_blocks)): for i, (block, hist_block) in enumerate(zip(blocks, hist_blocks)):
hist_col = 'features' if i == 0 else 'hist_%d' % i hist_col = 'features' if i == 0 else 'hist_%d' % i
block.srcdata['hist'] = g.ndata[hist_col][block.srcdata[dgl.NID]].to(dev_id) block.srcdata['hist'] = g.ndata[hist_col][block.srcdata[dgl.NID]]
# Aggregate history # Aggregate history
hist_block.srcdata['hist'] = g.ndata[hist_col][hist_block.srcdata[dgl.NID]] hist_block.srcdata['hist'] = g.ndata[hist_col][hist_block.srcdata[dgl.NID]]
if aggregation_on_device: if aggregation_on_device:
hist_block.srcdata['hist'] = hist_block.srcdata['hist'].to(dev_id) hist_block = hist_block.to(dev_id)
hist_block.update_all(fn.copy_u('hist', 'm'), fn.mean('m', 'agg_hist')) hist_block.update_all(fn.copy_u('hist', 'm'), fn.mean('m', 'agg_hist'))
block.dstdata['agg_hist'] = hist_block.dstdata['agg_hist']
block = block.to(dev_id)
if not aggregation_on_device: if not aggregation_on_device:
block.dstdata['agg_hist'] = block.dstdata['agg_hist'].to(dev_id) hist_block = hist_block.to(dev_id)
block.dstdata['agg_hist'] = hist_block.dstdata['agg_hist']
ret_blocks.append(block)
ret_hist_blocks.append(hist_block)
return ret_blocks, ret_hist_blocks
def init_history(g, model, dev_id): def init_history(g, model, dev_id):
with th.no_grad(): with th.no_grad():
...@@ -211,7 +219,7 @@ def init_history(g, model, dev_id): ...@@ -211,7 +219,7 @@ def init_history(g, model, dev_id):
def update_history(g, blocks): def update_history(g, blocks):
with th.no_grad(): with th.no_grad():
for i, block in enumerate(blocks): for i, block in enumerate(blocks):
ids = block.dstdata[dgl.NID] ids = block.dstdata[dgl.NID].cpu()
hist_col = 'hist_%d' % (i + 1) hist_col = 'hist_%d' % (i + 1)
h_new = block.dstdata['h_new'].cpu() h_new = block.dstdata['h_new'].cpu()
...@@ -269,7 +277,7 @@ def run(args, dev_id, data): ...@@ -269,7 +277,7 @@ def run(args, dev_id, data):
input_nodes = blocks[0].srcdata[dgl.NID] input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID] seeds = blocks[-1].dstdata[dgl.NID]
load_subtensor(g, labels, blocks, hist_blocks, dev_id, True) blocks, hist_blocks = load_subtensor(g, labels, blocks, hist_blocks, dev_id, True)
# forward # forward
batch_pred = model(blocks) batch_pred = model(blocks)
......
...@@ -112,6 +112,7 @@ class SAGE(nn.Module): ...@@ -112,6 +112,7 @@ class SAGE(nn.Module):
induced_nodes = block.srcdata[dgl.NID] induced_nodes = block.srcdata[dgl.NID]
h = x[induced_nodes].to(device) h = x[induced_nodes].to(device)
block = block.to(device)
h_dst = h[:block.number_of_dst_nodes()] h_dst = h[:block.number_of_dst_nodes()]
h = layer(block, (h, h_dst)) h = layer(block, (h, h_dst))
...@@ -213,20 +214,28 @@ def load_subtensor(g, labels, blocks, hist_blocks, dev_id, aggregation_on_device ...@@ -213,20 +214,28 @@ def load_subtensor(g, labels, blocks, hist_blocks, dev_id, aggregation_on_device
""" """
Copys features and labels of a set of nodes onto GPU. Copys features and labels of a set of nodes onto GPU.
""" """
blocks[0].srcdata['features'] = g.ndata['features'][blocks[0].srcdata[dgl.NID]].to(dev_id) blocks[0].srcdata['features'] = g.ndata['features'][blocks[0].srcdata[dgl.NID]]
blocks[-1].dstdata['label'] = labels[blocks[-1].dstdata[dgl.NID]].to(dev_id) blocks[-1].dstdata['label'] = labels[blocks[-1].dstdata[dgl.NID]]
ret_blocks = []
ret_hist_blocks = []
for i, (block, hist_block) in enumerate(zip(blocks, hist_blocks)): for i, (block, hist_block) in enumerate(zip(blocks, hist_blocks)):
hist_col = 'features' if i == 0 else 'hist_%d' % i hist_col = 'features' if i == 0 else 'hist_%d' % i
block.srcdata['hist'] = g.ndata[hist_col][block.srcdata[dgl.NID]].to(dev_id) block.srcdata['hist'] = g.ndata[hist_col][block.srcdata[dgl.NID]]
# Aggregate history # Aggregate history
hist_block.srcdata['hist'] = g.ndata[hist_col][hist_block.srcdata[dgl.NID]] hist_block.srcdata['hist'] = g.ndata[hist_col][hist_block.srcdata[dgl.NID]]
if aggregation_on_device: if aggregation_on_device:
hist_block.srcdata['hist'] = hist_block.srcdata['hist'].to(dev_id) hist_block = hist_block.to(dev_id)
hist_block.srcdata['hist'] = hist_block.srcdata['hist']
hist_block.update_all(fn.copy_u('hist', 'm'), fn.mean('m', 'agg_hist')) hist_block.update_all(fn.copy_u('hist', 'm'), fn.mean('m', 'agg_hist'))
block.dstdata['agg_hist'] = hist_block.dstdata['agg_hist']
block = block.to(dev_id)
if not aggregation_on_device: if not aggregation_on_device:
block.dstdata['agg_hist'] = block.dstdata['agg_hist'].to(dev_id) hist_block = hist_block.to(dev_id)
block.dstdata['agg_hist'] = hist_block.dstdata['agg_hist']
ret_blocks.append(block)
ret_hist_blocks.append(hist_block)
return ret_blocks, ret_hist_blocks
def create_history_storage(g, args, n_classes): def create_history_storage(g, args, n_classes):
# Initialize history storage # Initialize history storage
...@@ -241,7 +250,7 @@ def init_history(g, model, dev_id, batch_size): ...@@ -241,7 +250,7 @@ def init_history(g, model, dev_id, batch_size):
def update_history(g, blocks): def update_history(g, blocks):
with th.no_grad(): with th.no_grad():
for i, block in enumerate(blocks): for i, block in enumerate(blocks):
ids = block.dstdata[dgl.NID] ids = block.dstdata[dgl.NID].cpu()
hist_col = 'hist_%d' % (i + 1) hist_col = 'hist_%d' % (i + 1)
h_new = block.dstdata['h_new'].cpu() h_new = block.dstdata['h_new'].cpu()
...@@ -317,10 +326,9 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -317,10 +326,9 @@ def run(proc_id, n_gpus, args, devices, data):
# The nodes for input lies at the LHS side of the first block. # The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block. # The nodes for output lies at the RHS side of the last block.
input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID] seeds = blocks[-1].dstdata[dgl.NID]
load_subtensor(g, labels, blocks, hist_blocks, dev_id, True) blocks, hist_blocks = load_subtensor(g, labels, blocks, hist_blocks, dev_id, True)
# forward # forward
batch_pred = model(blocks) batch_pred = model(blocks)
......
...@@ -77,6 +77,7 @@ class SAGE(nn.Module): ...@@ -77,6 +77,7 @@ class SAGE(nn.Module):
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0] block = blocks[0]
block = block.to(device)
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h = layer(block, h) h = layer(block, h)
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
...@@ -171,6 +172,7 @@ def run(args, device, data): ...@@ -171,6 +172,7 @@ def run(args, device, data):
# Load the input features as well as output labels # Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(train_g, seeds, input_nodes, device) batch_inputs, batch_labels = load_subtensor(train_g, seeds, input_nodes, device)
blocks = [block.to(device) for block in blocks]
# Compute loss and prediction # Compute loss and prediction
batch_pred = model(blocks, batch_inputs) batch_pred = model(blocks, batch_inputs)
......
...@@ -78,6 +78,7 @@ class SAGE(nn.Module): ...@@ -78,6 +78,7 @@ class SAGE(nn.Module):
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0] block = blocks[0]
block = block.to(device)
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h = layer(block, h) h = layer(block, h)
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
...@@ -150,7 +151,7 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -150,7 +151,7 @@ def run(proc_id, n_gpus, args, devices, data):
in_feats, n_classes, train_g, val_g, test_g = data in_feats, n_classes, train_g, val_g, test_g = data
train_mask = train_g.ndata['train_mask'] train_mask = train_g.ndata['train_mask']
val_mask = val_g.ndata['val_mask'] val_mask = val_g.ndata['val_mask']
test_mask = ~(train_g.ndata['train_mask'] | val_g.ndata['val_mask']) test_mask = ~(test_g.ndata['train_mask'] | test_g.ndata['val_mask'])
train_nid = train_mask.nonzero()[:, 0] train_nid = train_mask.nonzero()[:, 0]
val_nid = val_mask.nonzero()[:, 0] val_nid = val_mask.nonzero()[:, 0]
test_nid = test_mask.nonzero()[:, 0] test_nid = test_mask.nonzero()[:, 0]
...@@ -193,7 +194,7 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -193,7 +194,7 @@ def run(proc_id, n_gpus, args, devices, data):
# Load the input features as well as output labels # Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(train_g, train_g.ndata['labels'], seeds, input_nodes, dev_id) batch_inputs, batch_labels = load_subtensor(train_g, train_g.ndata['labels'], seeds, input_nodes, dev_id)
blocks = [block.to(dev_id) for block in blocks]
# Compute loss and prediction # Compute loss and prediction
batch_pred = model(blocks, batch_inputs) batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, batch_labels) loss = loss_fcn(batch_pred, batch_labels)
......
...@@ -149,6 +149,7 @@ class SAGE(nn.Module): ...@@ -149,6 +149,7 @@ class SAGE(nn.Module):
end = start + batch_size end = start + batch_size
batch_nodes = nodes[start:end] batch_nodes = nodes[start:end]
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes) block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
block = block.to(device)
input_nodes = block.srcdata[dgl.NID] input_nodes = block.srcdata[dgl.NID]
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
...@@ -293,6 +294,9 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -293,6 +294,9 @@ def run(proc_id, n_gpus, args, devices, data):
batch_inputs = load_subtensor(g, input_nodes, device) batch_inputs = load_subtensor(g, input_nodes, device)
d_step = time.time() d_step = time.time()
pos_graph = pos_graph.to(device)
neg_graph = neg_graph.to(device)
blocks = [block.to(device) for block in blocks]
# Compute loss and prediction # Compute loss and prediction
batch_pred = model(blocks, batch_inputs) batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, pos_graph, neg_graph) loss = loss_fcn(batch_pred, pos_graph, neg_graph)
......
...@@ -48,6 +48,7 @@ def main(args): ...@@ -48,6 +48,7 @@ def main(args):
out_size=num_classes, out_size=num_classes,
num_heads=args['num_heads'], num_heads=args['num_heads'],
dropout=args['dropout']).to(args['device']) dropout=args['dropout']).to(args['device'])
g = g.to(args['device'])
else: else:
from model import HAN from model import HAN
model = HAN(num_meta_paths=len(g), model = HAN(num_meta_paths=len(g),
...@@ -56,6 +57,7 @@ def main(args): ...@@ -56,6 +57,7 @@ def main(args):
out_size=num_classes, out_size=num_classes,
num_heads=args['num_heads'], num_heads=args['num_heads'],
dropout=args['dropout']).to(args['device']) dropout=args['dropout']).to(args['device'])
g = [graph.to(args['device']) for graph in g]
stopper = EarlyStopping(patience=args['patience']) stopper = EarlyStopping(patience=args['patience'])
loss_fcn = torch.nn.CrossEntropyLoss() loss_fcn = torch.nn.CrossEntropyLoss()
......
...@@ -6,14 +6,23 @@ import torch.nn.functional as F ...@@ -6,14 +6,23 @@ import torch.nn.functional as F
import dgl.function as fn import dgl.function as fn
class HGTLayer(nn.Module): class HGTLayer(nn.Module):
def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads, dropout = 0.2, use_norm = False): def __init__(self,
in_dim,
out_dim,
node_dict,
edge_dict,
n_heads,
dropout = 0.2,
use_norm = False):
super(HGTLayer, self).__init__() super(HGTLayer, self).__init__()
self.in_dim = in_dim self.in_dim = in_dim
self.out_dim = out_dim self.out_dim = out_dim
self.num_types = num_types self.node_dict = node_dict
self.num_relations = num_relations self.edge_dict = edge_dict
self.total_rel = num_types * num_relations * num_types self.num_types = len(node_dict)
self.num_relations = len(edge_dict)
self.total_rel = self.num_types * self.num_relations * self.num_types
self.n_heads = n_heads self.n_heads = n_heads
self.d_k = out_dim // n_heads self.d_k = out_dim // n_heads
self.sqrt_dk = math.sqrt(self.d_k) self.sqrt_dk = math.sqrt(self.d_k)
...@@ -26,7 +35,7 @@ class HGTLayer(nn.Module): ...@@ -26,7 +35,7 @@ class HGTLayer(nn.Module):
self.norms = nn.ModuleList() self.norms = nn.ModuleList()
self.use_norm = use_norm self.use_norm = use_norm
for t in range(num_types): for t in range(self.num_types):
self.k_linears.append(nn.Linear(in_dim, out_dim)) self.k_linears.append(nn.Linear(in_dim, out_dim))
self.q_linears.append(nn.Linear(in_dim, out_dim)) self.q_linears.append(nn.Linear(in_dim, out_dim))
self.v_linears.append(nn.Linear(in_dim, out_dim)) self.v_linears.append(nn.Linear(in_dim, out_dim))
...@@ -34,10 +43,10 @@ class HGTLayer(nn.Module): ...@@ -34,10 +43,10 @@ class HGTLayer(nn.Module):
if use_norm: if use_norm:
self.norms.append(nn.LayerNorm(out_dim)) self.norms.append(nn.LayerNorm(out_dim))
self.relation_pri = nn.Parameter(torch.ones(num_relations, self.n_heads)) self.relation_pri = nn.Parameter(torch.ones(self.num_relations, self.n_heads))
self.relation_att = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k)) self.relation_att = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k))
self.relation_msg = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k)) self.relation_msg = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k))
self.skip = nn.Parameter(torch.ones(num_types)) self.skip = nn.Parameter(torch.ones(self.num_types))
self.drop = nn.Dropout(dropout) self.drop = nn.Dropout(dropout)
nn.init.xavier_uniform_(self.relation_att) nn.init.xavier_uniform_(self.relation_att)
...@@ -74,54 +83,62 @@ class HGTLayer(nn.Module): ...@@ -74,54 +83,62 @@ class HGTLayer(nn.Module):
h = torch.sum(att.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1) h = torch.sum(att.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1)
return {'t': h.view(-1, self.out_dim)} return {'t': h.view(-1, self.out_dim)}
def forward(self, G, inp_key, out_key): def forward(self, G, h):
node_dict, edge_dict = G.node_dict, G.edge_dict with G.local_scope():
for srctype, etype, dsttype in G.canonical_etypes: node_dict, edge_dict = self.node_dict, self.edge_dict
k_linear = self.k_linears[node_dict[srctype]] for srctype, etype, dsttype in G.canonical_etypes:
v_linear = self.v_linears[node_dict[srctype]] k_linear = self.k_linears[node_dict[srctype]]
q_linear = self.q_linears[node_dict[dsttype]] v_linear = self.v_linears[node_dict[srctype]]
q_linear = self.q_linears[node_dict[dsttype]]
G.nodes[srctype].data['k'] = k_linear(G.nodes[srctype].data[inp_key]).view(-1, self.n_heads, self.d_k)
G.nodes[srctype].data['v'] = v_linear(G.nodes[srctype].data[inp_key]).view(-1, self.n_heads, self.d_k) G.nodes[srctype].data['k'] = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
G.nodes[dsttype].data['q'] = q_linear(G.nodes[dsttype].data[inp_key]).view(-1, self.n_heads, self.d_k) G.nodes[srctype].data['v'] = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
G.nodes[dsttype].data['q'] = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k)
G.apply_edges(func=self.edge_attention, etype=etype)
G.multi_update_all({etype : (self.message_func, self.reduce_func) \ G.apply_edges(func=self.edge_attention, etype=etype)
for etype in edge_dict}, cross_reducer = 'mean') G.multi_update_all({etype : (self.message_func, self.reduce_func) \
for ntype in G.ntypes: for etype in edge_dict}, cross_reducer = 'mean')
''' new_h = {}
Step 3: Target-specific Aggregation for ntype in G.ntypes:
x = norm( W[node_type] * gelu( Agg(x) ) + x ) '''
''' Step 3: Target-specific Aggregation
n_id = node_dict[ntype] x = norm( W[node_type] * gelu( Agg(x) ) + x )
alpha = torch.sigmoid(self.skip[n_id]) '''
trans_out = self.drop(self.a_linears[n_id](G.nodes[ntype].data['t'])) n_id = node_dict[ntype]
trans_out = trans_out * alpha + G.nodes[ntype].data[inp_key] * (1-alpha) alpha = torch.sigmoid(self.skip[n_id])
if self.use_norm: trans_out = self.drop(self.a_linears[n_id](G.nodes[ntype].data['t']))
G.nodes[ntype].data[out_key] = self.norms[n_id](trans_out) trans_out = trans_out * alpha + h[ntype] * (1-alpha)
if self.use_norm:
new_h[ntype] = self.norms[n_id](trans_out)
else:
new_h[ntype] = trans_out
return new_h
class HGT(nn.Module): class HGT(nn.Module):
def __init__(self, G, n_inp, n_hid, n_out, n_layers, n_heads, use_norm = True): def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads, use_norm = True):
super(HGT, self).__init__() super(HGT, self).__init__()
self.node_dict = node_dict
self.edge_dict = edge_dict
self.gcs = nn.ModuleList() self.gcs = nn.ModuleList()
self.n_inp = n_inp self.n_inp = n_inp
self.n_hid = n_hid self.n_hid = n_hid
self.n_out = n_out self.n_out = n_out
self.n_layers = n_layers self.n_layers = n_layers
self.adapt_ws = nn.ModuleList() self.adapt_ws = nn.ModuleList()
for t in range(len(G.node_dict)): for t in range(len(node_dict)):
self.adapt_ws.append(nn.Linear(n_inp, n_hid)) self.adapt_ws.append(nn.Linear(n_inp, n_hid))
for _ in range(n_layers): for _ in range(n_layers):
self.gcs.append(HGTLayer(n_hid, n_hid, len(G.node_dict), len(G.edge_dict), n_heads, use_norm = use_norm)) self.gcs.append(HGTLayer(n_hid, n_hid, node_dict, edge_dict, n_heads, use_norm = use_norm))
self.out = nn.Linear(n_hid, n_out) self.out = nn.Linear(n_hid, n_out)
def forward(self, G, out_key): def forward(self, G, out_key):
h = {}
for ntype in G.ntypes: for ntype in G.ntypes:
n_id = G.node_dict[ntype] n_id = self.node_dict[ntype]
G.nodes[ntype].data['h'] = F.gelu(self.adapt_ws[n_id](G.nodes[ntype].data['inp'])) h[ntype] = F.gelu(self.adapt_ws[n_id](G.nodes[ntype].data['inp']))
for i in range(self.n_layers): for i in range(self.n_layers):
self.gcs[i](G, 'h', 'h') h = self.gcs[i](G, h)
return self.out(G.nodes[out_key].data['h']) return self.out(h[out_key])
class HeteroRGCNLayer(nn.Module): class HeteroRGCNLayer(nn.Module):
def __init__(self, in_size, out_size, etypes): def __init__(self, in_size, out_size, etypes):
......
...@@ -77,6 +77,7 @@ def train(model, G): ...@@ -77,6 +77,7 @@ def train(model, G):
best_test_acc.item(), best_test_acc.item(),
)) ))
device = torch.device("cuda:0")
G = dgl.heterograph({ G = dgl.heterograph({
('paper', 'written-by', 'author') : data['PvsA'], ('paper', 'written-by', 'author') : data['PvsA'],
...@@ -88,8 +89,6 @@ G = dgl.heterograph({ ...@@ -88,8 +89,6 @@ G = dgl.heterograph({
}) })
print(G) print(G)
pvc = data['PvsC'].tocsr() pvc = data['PvsC'].tocsr()
p_selected = pvc.tocoo() p_selected = pvc.tocoo()
# generate labels # generate labels
...@@ -103,26 +102,30 @@ train_idx = torch.tensor(shuffle[0:800]).long() ...@@ -103,26 +102,30 @@ train_idx = torch.tensor(shuffle[0:800]).long()
val_idx = torch.tensor(shuffle[800:900]).long() val_idx = torch.tensor(shuffle[800:900]).long()
test_idx = torch.tensor(shuffle[900:]).long() test_idx = torch.tensor(shuffle[900:]).long()
node_dict = {}
edge_dict = {}
device = torch.device("cuda:0")
G.node_dict = {}
G.edge_dict = {}
for ntype in G.ntypes: for ntype in G.ntypes:
G.node_dict[ntype] = len(G.node_dict) node_dict[ntype] = len(node_dict)
for etype in G.etypes: for etype in G.etypes:
G.edge_dict[etype] = len(G.edge_dict) edge_dict[etype] = len(edge_dict)
G.edges[etype].data['id'] = torch.ones(G.number_of_edges(etype), dtype=torch.long) * G.edge_dict[etype] G.edges[etype].data['id'] = torch.ones(G.number_of_edges(etype), dtype=torch.long) * edge_dict[etype]
# Random initialize input feature # Random initialize input feature
for ntype in G.ntypes: for ntype in G.ntypes:
emb = nn.Parameter(torch.Tensor(G.number_of_nodes(ntype), 256), requires_grad = False).to(device) emb = nn.Parameter(torch.Tensor(G.number_of_nodes(ntype), 256), requires_grad = False)
nn.init.xavier_uniform_(emb) nn.init.xavier_uniform_(emb)
G.nodes[ntype].data['inp'] = emb G.nodes[ntype].data['inp'] = emb
G = G.to(device)
model = HGT(G, n_inp=args.n_inp, n_hid=args.n_hid, n_out=labels.max().item()+1, n_layers=2, n_heads=4, use_norm = True).to(device)
model = HGT(G,
node_dict, edge_dict,
n_inp=args.n_inp,
n_hid=args.n_hid,
n_out=labels.max().item()+1,
n_layers=2,
n_heads=4,
use_norm = True).to(device)
optimizer = torch.optim.AdamW(model.parameters()) optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr) scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr)
print('Training HGT with #param: %d' % (get_n_params(model))) print('Training HGT with #param: %d' % (get_n_params(model)))
...@@ -131,7 +134,10 @@ train(model, G) ...@@ -131,7 +134,10 @@ train(model, G)
model = HeteroRGCN(G, in_size=args.n_inp, hidden_size=args.n_hid, out_size=labels.max().item()+1).to(device) model = HeteroRGCN(G,
in_size=args.n_inp,
hidden_size=args.n_hid,
out_size=labels.max().item()+1).to(device)
optimizer = torch.optim.AdamW(model.parameters()) optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr) scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr)
print('Training RGCN with #param: %d' % (get_n_params(model))) print('Training RGCN with #param: %d' % (get_n_params(model)))
...@@ -139,7 +145,13 @@ train(model, G) ...@@ -139,7 +145,13 @@ train(model, G)
model = HGT(G, n_inp=args.n_inp, n_hid=args.n_hid, n_out=labels.max().item()+1, n_layers=0, n_heads=4).to(device) model = HGT(G,
node_dict, edge_dict,
n_inp=args.n_inp,
n_hid=args.n_hid,
n_out=labels.max().item()+1,
n_layers=0,
n_heads=4).to(device)
optimizer = torch.optim.AdamW(model.parameters()) optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr) scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr)
print('Training MLP with #param: %d' % (get_n_params(model))) print('Training MLP with #param: %d' % (get_n_params(model)))
......
...@@ -30,13 +30,13 @@ class GNNModule(nn.Module): ...@@ -30,13 +30,13 @@ class GNNModule(nn.Module):
def aggregate(self, g, z): def aggregate(self, g, z):
z_list = [] z_list = []
g.set_n_repr({'z' : z}) g.ndata['z'] = z
g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z')) g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.get_n_repr()['z']) z_list.append(g.ndata['z'])
for i in range(self.radius - 1): for i in range(self.radius - 1):
for j in range(2 ** i): for j in range(2 ** i):
g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z')) g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.get_n_repr()['z']) z_list.append(g.ndata['z'])
return z_list return z_list
def forward(self, g, lg, x, y, deg_g, deg_lg, pm_pd): def forward(self, g, lg, x, y, deg_g, deg_lg, pm_pd):
...@@ -44,9 +44,9 @@ class GNNModule(nn.Module): ...@@ -44,9 +44,9 @@ class GNNModule(nn.Module):
sum_x = sum(theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))) sum_x = sum(theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x)))
g.set_e_repr({'y' : y}) g.edata['y'] = y
g.update_all(fn.copy_edge(edge='y', out='m'), fn.sum('m', 'pmpd_y')) g.update_all(fn.copy_edge(edge='y', out='m'), fn.sum('m', 'pmpd_y'))
pmpd_y = g.pop_n_repr('pmpd_y') pmpd_y = g.ndata.pop('pmpd_y')
x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum_x + self.theta_y(pmpd_y) x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum_x + self.theta_y(pmpd_y)
n = self.out_feats // 2 n = self.out_feats // 2
...@@ -63,11 +63,6 @@ class GNNModule(nn.Module): ...@@ -63,11 +63,6 @@ class GNNModule(nn.Module):
class GNN(nn.Module): class GNN(nn.Module):
def __init__(self, feats, radius, n_classes): def __init__(self, feats, radius, n_classes):
"""
Parameters
----------
g : networkx.DiGraph
"""
super(GNN, self).__init__() super(GNN, self).__init__()
self.linear = nn.Linear(feats[-1], n_classes) self.linear = nn.Linear(feats[-1], n_classes)
self.module_list = nn.ModuleList([GNNModule(m, n, radius) self.module_list = nn.ModuleList([GNNModule(m, n, radius)
......
...@@ -67,8 +67,10 @@ def from_np(f, *args): ...@@ -67,8 +67,10 @@ def from_np(f, *args):
@from_np @from_np
def step(i, j, g, lg, deg_g, deg_lg, pm_pd): def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
""" One step of training. """ """ One step of training. """
deg_g = deg_g.to(dev) g = g.to(dev)
deg_lg = deg_lg.to(dev) lg = lg.to(dev)
deg_g = deg_g.to(dev).unsqueeze(1)
deg_lg = deg_lg.to(dev).unsqueeze(1)
pm_pd = pm_pd.to(dev) pm_pd = pm_pd.to(dev)
t0 = time.time() t0 = time.time()
z = model(g, lg, deg_g, deg_lg, pm_pd) z = model(g, lg, deg_g, deg_lg, pm_pd)
...@@ -88,8 +90,10 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd): ...@@ -88,8 +90,10 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
@from_np @from_np
def inference(g, lg, deg_g, deg_lg, pm_pd): def inference(g, lg, deg_g, deg_lg, pm_pd):
deg_g = deg_g.to(dev) g = g.to(dev)
deg_lg = deg_lg.to(dev) lg = lg.to(dev)
deg_g = deg_g.to(dev).unsqueeze(1)
deg_lg = deg_lg.to(dev).unsqueeze(1)
pm_pd = pm_pd.to(dev) pm_pd = pm_pd.to(dev)
z = model(g, lg, deg_g, deg_lg, pm_pd) z = model(g, lg, deg_g, deg_lg, pm_pd)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment