"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "ba61a566141c24afb069ca0f7fcb3bbb0bd9c759"
Unverified Commit f5eb80d2 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Edge DataLoader for edge classification & link prediction (#1828)

* clean commit

* oops forgot the most important files

* use einsum

* copy feature from frontier to block

* Revert "copy feature from frontier to block"

This reverts commit 5224ec963eb6a3ef1b6ab74d8ecbd44e4e42f285.

* temp fix

* unit test

* fix

* revert jtnn

* lint

* fix win64

* docstring fixes and doc indexing

* revert einsum in sparse bidecoder

* fix some examples

* lint

* fix due to some tediousness in remove_edges

* addresses comments

* fix

* more jtnn fixes

* fix
parent d340ea3a
.. _api-sampling:
dgl.dataloading
=================================
.. automodule:: dgl.dataloading
PyTorch node/edge DataLoaders
-----------------------------
.. autoclass:: pytorch.NodeDataLoader
.. autoclass:: pytorch.EdgeDataLoader
General collating functions
---------------------------
.. autoclass:: Collator
.. autoclass:: NodeCollator
.. autoclass:: EdgeCollator
Base Multi-layer Neighborhood Sampling Class
--------------------------------------------
.. autoclass:: BlockSampler
Uniform Node-wise Neighbor Sampling (GraphSAGE style)
-----------------------------------------------------
.. autoclass:: MultiLayerNeighborSampler
Negative Samplers for Link Prediction
-------------------------------------
.. autoclass:: negative_sampler.Uniform
...@@ -11,3 +11,4 @@ API Reference ...@@ -11,3 +11,4 @@ API Reference
dgl.ops dgl.ops
dgl.function dgl.function
sampling sampling
dataloading
...@@ -25,23 +25,7 @@ Neighbor sampling functions ...@@ -25,23 +25,7 @@ Neighbor sampling functions
sample_neighbors sample_neighbors
select_topk select_topk
PyTorch DataLoaders with neighborhood sampling
----------------------------------------------
.. autoclass:: pytorch.NeighborSamplerNodeDataLoader
Builtin sampler classes for more complicated sampling algorithms Builtin sampler classes for more complicated sampling algorithms
---------------------------------------------------------------- ----------------------------------------------------------------
.. autoclass:: RandomWalkNeighborSampler .. autoclass:: RandomWalkNeighborSampler
.. autoclass:: PinSAGESampler .. autoclass:: PinSAGESampler
Neighborhood samplers for multilayer GNNs
-----------------------------------------
.. autoclass:: MultiLayerNeighborSampler
Data loaders for minibatch iteration
------------------------------------
.. autoclass:: NodeCollator
Abstract class for neighborhood sampler
---------------------------------------
.. autoclass:: BlockSampler
...@@ -8,6 +8,7 @@ import torch as th ...@@ -8,6 +8,7 @@ import torch as th
import dgl import dgl
from dgl.data.utils import download, extract_archive, get_download_dir from dgl.data.utils import download, extract_archive, get_download_dir
from utils import to_etype_name
_urls = { _urls = {
'ml-100k' : 'http://files.grouplens.org/datasets/movielens/ml-100k.zip', 'ml-100k' : 'http://files.grouplens.org/datasets/movielens/ml-100k.zip',
...@@ -210,7 +211,7 @@ class MovieLens(object): ...@@ -210,7 +211,7 @@ class MovieLens(object):
def _npairs(graph): def _npairs(graph):
rst = 0 rst = 0
for r in self.possible_rating_values: for r in self.possible_rating_values:
r = str(r).replace('.', '_') r = to_etype_name(r)
rst += graph.number_of_edges(str(r)) rst += graph.number_of_edges(str(r))
return rst return rst
...@@ -252,7 +253,7 @@ class MovieLens(object): ...@@ -252,7 +253,7 @@ class MovieLens(object):
ridx = np.where(rating_values == rating) ridx = np.where(rating_values == rating)
rrow = rating_row[ridx] rrow = rating_row[ridx]
rcol = rating_col[ridx] rcol = rating_col[ridx]
rating = str(rating).replace('.', '_') rating = to_etype_name(rating)
bg = dgl.bipartite((rrow, rcol), 'user', rating, 'movie', bg = dgl.bipartite((rrow, rcol), 'user', rating, 'movie',
num_nodes=(self._num_user, self._num_movie)) num_nodes=(self._num_user, self._num_movie))
rev_bg = dgl.bipartite((rcol, rrow), 'movie', 'rev-%s' % rating, 'user', rev_bg = dgl.bipartite((rcol, rrow), 'movie', 'rev-%s' % rating, 'user',
...@@ -275,7 +276,7 @@ class MovieLens(object): ...@@ -275,7 +276,7 @@ class MovieLens(object):
movie_ci = [] movie_ci = []
movie_cj = [] movie_cj = []
for r in self.possible_rating_values: for r in self.possible_rating_values:
r = str(r).replace('.', '_') r = to_etype_name(r)
user_ci.append(graph['rev-%s' % r].in_degrees()) user_ci.append(graph['rev-%s' % r].in_degrees())
movie_ci.append(graph[r].in_degrees()) movie_ci.append(graph[r].in_degrees())
if self._symm: if self._symm:
......
...@@ -5,7 +5,7 @@ from torch.nn import init ...@@ -5,7 +5,7 @@ from torch.nn import init
import dgl.function as fn import dgl.function as fn
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
from utils import get_activation from utils import get_activation, to_etype_name
class GCMCGraphConv(nn.Module): class GCMCGraphConv(nn.Module):
"""Graph convolution module used in the GCMC model. """Graph convolution module used in the GCMC model.
...@@ -175,7 +175,7 @@ class GCMCLayer(nn.Module): ...@@ -175,7 +175,7 @@ class GCMCLayer(nn.Module):
subConv = {} subConv = {}
for rating in rating_vals: for rating in rating_vals:
# PyTorch parameter name can't contain "." # PyTorch parameter name can't contain "."
rating = str(rating).replace('.', '_') rating = to_etype_name(rating)
rev_rating = 'rev-%s' % rating rev_rating = 'rev-%s' % rating
if share_user_item_param and user_in_units == movie_in_units: if share_user_item_param and user_in_units == movie_in_units:
self.W_r[rating] = nn.Parameter(th.randn(user_in_units, msg_units)) self.W_r[rating] = nn.Parameter(th.randn(user_in_units, msg_units))
...@@ -251,7 +251,7 @@ class GCMCLayer(nn.Module): ...@@ -251,7 +251,7 @@ class GCMCLayer(nn.Module):
in_feats = {'user' : ufeat, 'movie' : ifeat} in_feats = {'user' : ufeat, 'movie' : ifeat}
mod_args = {} mod_args = {}
for i, rating in enumerate(self.rating_vals): for i, rating in enumerate(self.rating_vals):
rating = str(rating).replace('.', '_') rating = to_etype_name(rating)
rev_rating = 'rev-%s' % rating rev_rating = 'rev-%s' % rating
mod_args[rating] = (self.W_r[rating] if self.W_r is not None else None,) mod_args[rating] = (self.W_r[rating] if self.W_r is not None else None,)
mod_args[rev_rating] = (self.W_r[rev_rating] if self.W_r is not None else None,) mod_args[rev_rating] = (self.W_r[rev_rating] if self.W_r is not None else None,)
...@@ -304,9 +304,7 @@ class BiDecoder(nn.Module): ...@@ -304,9 +304,7 @@ class BiDecoder(nn.Module):
super(BiDecoder, self).__init__() super(BiDecoder, self).__init__()
self._num_basis = num_basis self._num_basis = num_basis
self.dropout = nn.Dropout(dropout_rate) self.dropout = nn.Dropout(dropout_rate)
self.Ps = nn.ParameterList() self.P = nn.Parameter(th.randn(num_basis, in_units, in_units))
for i in range(num_basis):
self.Ps.append(nn.Parameter(th.randn(in_units, in_units)))
self.combine_basis = nn.Linear(self._num_basis, num_classes, bias=False) self.combine_basis = nn.Linear(self._num_basis, num_classes, bias=False)
self.reset_parameters() self.reset_parameters()
...@@ -392,12 +390,7 @@ class DenseBiDecoder(BiDecoder): ...@@ -392,12 +390,7 @@ class DenseBiDecoder(BiDecoder):
""" """
ufeat = self.dropout(ufeat) ufeat = self.dropout(ufeat)
ifeat = self.dropout(ifeat) ifeat = self.dropout(ifeat)
basis_out = [] out = th.einsum('ai,bij,aj->ab', ufeat, self.P, ifeat)
for i in range(self._num_basis):
ufeat_i = ufeat @ self.Ps[i]
out = th.einsum('ab,ab->a', ufeat_i, ifeat)
basis_out.append(out.unsqueeze(1))
out = th.cat(basis_out, dim=1)
out = self.combine_basis(out) out = self.combine_basis(out)
return out return out
......
...@@ -21,103 +21,9 @@ from _thread import start_new_thread ...@@ -21,103 +21,9 @@ from _thread import start_new_thread
from functools import wraps from functools import wraps
from data import MovieLens from data import MovieLens
from model import GCMCLayer, DenseBiDecoder from model import GCMCLayer, DenseBiDecoder
from utils import get_activation, get_optimizer, torch_total_param_num, torch_net_info, MetricLogger from utils import get_activation, get_optimizer, torch_total_param_num, torch_net_info, MetricLogger, to_etype_name
import dgl import dgl
class GCMCSampler:
"""Neighbor sampler in GCMC mini-batch training."""
def __init__(self, dataset, segment='train'):
self.dataset = dataset
if segment == 'train':
self.truths = dataset.train_truths
self.labels = dataset.train_labels
self.enc_graph = dataset.train_enc_graph
self.dec_graph = dataset.train_dec_graph
elif segment == 'valid':
self.truths = dataset.valid_truths
self.labels = None
self.enc_graph = dataset.valid_enc_graph
self.dec_graph = dataset.valid_dec_graph
elif segment == 'test':
self.truths = dataset.test_truths
self.labels = None
self.enc_graph = dataset.test_enc_graph
self.dec_graph = dataset.test_dec_graph
else:
assert False, "Unknow dataset {}".format(segment)
def sample_blocks(self, seeds):
"""Sample subgraphs from the entire graph.
The input ``seeds`` represents the edges to compute prediction for. The sampling
algorithm works as follows:
1. Get the head and tail nodes of the provided seed edges.
2. For each head and tail node, extract the entire in-coming neighborhood.
3. Copy the node features/embeddings from the full graph to the sampled subgraphs.
"""
dataset = self.dataset
enc_graph = self.enc_graph
dec_graph = self.dec_graph
edge_ids = th.stack(seeds)
# generate frontiers for user and item
possible_rating_values = dataset.possible_rating_values
true_relation_ratings = self.truths[edge_ids]
true_relation_labels = None if self.labels is None else self.labels[edge_ids]
# 1. Get the head and tail nodes from both the decoder and encoder graphs.
head_id, tail_id = dec_graph.find_edges(edge_ids)
utype, _, vtype = enc_graph.canonical_etypes[0]
subg = []
true_rel_ratings = []
true_rel_labels = []
for possible_rating_value in possible_rating_values:
idx_loc = (true_relation_ratings == possible_rating_value)
head = head_id[idx_loc]
tail = tail_id[idx_loc]
true_rel_ratings.append(true_relation_ratings[idx_loc])
if self.labels is not None:
true_rel_labels.append(true_relation_labels[idx_loc])
subg.append(dgl.bipartite((head, tail),
utype=utype,
etype=str(possible_rating_value),
vtype=vtype,
num_nodes=(enc_graph.number_of_nodes(utype),
enc_graph.number_of_nodes(vtype))))
# Convert the encoder subgraph to a more compact one by removing nodes that covered
# by the seed edges.
g = dgl.hetero_from_relations(subg)
g = dgl.compact_graphs(g)
# 2. For each head and tail node, extract the entire in-coming neighborhood.
seed_nodes = {}
for ntype in g.ntypes:
seed_nodes[ntype] = g.nodes[ntype].data[dgl.NID]
frontier = dgl.in_subgraph(enc_graph, seed_nodes)
frontier = dgl.to_block(frontier, seed_nodes)
# 3. Copy the node features/embeddings from the full graph to the sampled subgraphs.
frontier.dstnodes['user'].data['ci'] = \
enc_graph.nodes['user'].data['ci'][frontier.dstnodes['user'].data[dgl.NID]]
frontier.srcnodes['movie'].data['cj'] = \
enc_graph.nodes['movie'].data['cj'][frontier.srcnodes['movie'].data[dgl.NID]]
frontier.srcnodes['user'].data['cj'] = \
enc_graph.nodes['user'].data['cj'][frontier.srcnodes['user'].data[dgl.NID]]
frontier.dstnodes['movie'].data['ci'] = \
enc_graph.nodes['movie'].data['ci'][frontier.dstnodes['movie'].data[dgl.NID]]
# handle features
head_feat = frontier.srcnodes['user'].data[dgl.NID].long() \
if dataset.user_feature is None else \
dataset.user_feature[frontier.srcnodes['user'].data[dgl.NID]]
tail_feat = frontier.srcnodes['movie'].data[dgl.NID].long()\
if dataset.movie_feature is None else \
dataset.movie_feature[frontier.srcnodes['movie'].data[dgl.NID]]
true_rel_labels = None if self.labels is None else th.cat(true_rel_labels, dim=0)
true_rel_ratings = th.cat(true_rel_ratings, dim=0)
return (g, frontier, head_feat, tail_feat, true_rel_labels, true_rel_ratings)
class Net(nn.Module): class Net(nn.Module):
def __init__(self, args, dev_id): def __init__(self, args, dev_id):
super(Net, self).__init__() super(Net, self).__init__()
...@@ -147,34 +53,80 @@ class Net(nn.Module): ...@@ -147,34 +53,80 @@ class Net(nn.Module):
def forward(self, compact_g, frontier, ufeat, ifeat, possible_rating_values): def forward(self, compact_g, frontier, ufeat, ifeat, possible_rating_values):
user_out, movie_out = self.encoder(frontier, ufeat, ifeat) user_out, movie_out = self.encoder(frontier, ufeat, ifeat)
head_emb = [] head, tail = compact_g.edges(order='eid')
tail_emb = [] head_emb = user_out[head]
for possible_rating_value in possible_rating_values: tail_emb = movie_out[tail]
head, tail = compact_g.all_edges(etype=str(possible_rating_value))
head_emb.append(user_out[head])
tail_emb.append(movie_out[tail])
head_emb = th.cat(head_emb, dim=0)
tail_emb = th.cat(tail_emb, dim=0)
pred_ratings = self.decoder(head_emb, tail_emb) pred_ratings = self.decoder(head_emb, tail_emb)
return pred_ratings return pred_ratings
def load_subtensor(input_nodes, pair_graph, blocks, dataset, parent_graph):
output_nodes = pair_graph.ndata[dgl.NID]
head_feat = input_nodes['user'] if dataset.user_feature is None else \
dataset.user_feature[input_nodes['user']]
tail_feat = input_nodes['movie'] if dataset.movie_feature is None else \
dataset.movie_feature[input_nodes['movie']]
for block in blocks:
block.dstnodes['user'].data['ci'] = \
parent_graph.nodes['user'].data['ci'][block.dstnodes['user'].data[dgl.NID]]
block.srcnodes['user'].data['cj'] = \
parent_graph.nodes['user'].data['cj'][block.srcnodes['user'].data[dgl.NID]]
block.dstnodes['movie'].data['ci'] = \
parent_graph.nodes['movie'].data['ci'][block.dstnodes['movie'].data[dgl.NID]]
block.srcnodes['movie'].data['cj'] = \
parent_graph.nodes['movie'].data['cj'][block.srcnodes['movie'].data[dgl.NID]]
return head_feat, tail_feat, blocks
def flatten_etypes(pair_graph, dataset, segment):
n_users = pair_graph.number_of_nodes('user')
n_movies = pair_graph.number_of_nodes('movie')
src = []
dst = []
labels = []
ratings = []
for rating in dataset.possible_rating_values:
src_etype, dst_etype = pair_graph.edges(order='eid', etype=to_etype_name(rating))
src.append(src_etype)
dst.append(dst_etype)
label = np.searchsorted(dataset.possible_rating_values, rating)
ratings.append(th.LongTensor(np.full_like(src_etype, rating)))
labels.append(th.LongTensor(np.full_like(src_etype, label)))
src = th.cat(src)
dst = th.cat(dst)
ratings = th.cat(ratings)
labels = th.cat(labels)
flattened_pair_graph = dgl.heterograph({
('user', 'rate', 'movie'): (src, dst)},
num_nodes_dict={'user': n_users, 'movie': n_movies})
flattened_pair_graph.edata['rating'] = ratings
flattened_pair_graph.edata['label'] = labels
return flattened_pair_graph
def evaluate(args, dev_id, net, dataset, dataloader, segment='valid'): def evaluate(args, dev_id, net, dataset, dataloader, segment='valid'):
possible_rating_values = dataset.possible_rating_values possible_rating_values = dataset.possible_rating_values
nd_possible_rating_values = th.FloatTensor(possible_rating_values).to(dev_id) nd_possible_rating_values = th.FloatTensor(possible_rating_values).to(dev_id)
real_pred_ratings = [] real_pred_ratings = []
true_rel_ratings = [] true_rel_ratings = []
for sample_data in dataloader: for input_nodes, pair_graph, blocks in dataloader:
compact_g, frontier, head_feat, tail_feat, \ head_feat, tail_feat, blocks = load_subtensor(
_, true_relation_ratings = sample_data input_nodes, pair_graph, blocks, dataset,
dataset.valid_enc_graph if segment == 'valid' else dataset.test_enc_graph)
frontier = blocks[0]
true_relation_ratings = \
dataset.valid_truths[pair_graph.edata[dgl.EID]] if segment == 'valid' else \
dataset.test_truths[pair_graph.edata[dgl.EID]]
frontier = frontier.to(dev_id) frontier = frontier.to(dev_id)
head_feat = head_feat.to(dev_id) head_feat = head_feat.to(dev_id)
tail_feat = tail_feat.to(dev_id) tail_feat = tail_feat.to(dev_id)
with th.no_grad(): with th.no_grad():
pred_ratings = net(compact_g, frontier, pred_ratings = net(pair_graph, frontier,
head_feat, tail_feat, possible_rating_values) head_feat, tail_feat, possible_rating_values)
batch_pred_ratings = (th.softmax(pred_ratings, dim=1) * batch_pred_ratings = (th.softmax(pred_ratings, dim=1) *
nd_possible_rating_values.view(1, -1)).sum(dim=1) nd_possible_rating_values.view(1, -1)).sum(dim=1)
...@@ -260,47 +212,43 @@ def config(): ...@@ -260,47 +212,43 @@ def config():
return args return args
@thread_wrapped_func
def run(proc_id, n_gpus, args, devices, dataset): def run(proc_id, n_gpus, args, devices, dataset):
dev_id = devices[proc_id] dev_id = devices[proc_id]
train_labels = dataset.train_labels train_labels = dataset.train_labels
train_truths = dataset.train_truths train_truths = dataset.train_truths
num_edges = train_truths.shape[0] num_edges = train_truths.shape[0]
sampler = GCMCSampler(dataset,
'train')
seeds = th.arange(num_edges) reverse_types = {to_etype_name(k): 'rev-' + to_etype_name(k)
dataloader = DataLoader( for k in dataset.possible_rating_values}
dataset=seeds, reverse_types.update({v: k for k, v in reverse_types.items()})
sampler = dgl.dataloading.MultiLayerNeighborSampler([None], return_eids=True)
dataloader = dgl.dataloading.EdgeDataLoader(
dataset.train_enc_graph,
{to_etype_name(k): th.arange(
dataset.train_enc_graph.number_of_edges(etype=to_etype_name(k)))
for k in dataset.possible_rating_values},
sampler,
batch_size=args.minibatch_size, batch_size=args.minibatch_size,
collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=True,
pin_memory=True, drop_last=False)
drop_last=False,
num_workers=args.num_workers_per_gpu)
if proc_id == 0: if proc_id == 0:
valid_sampler = GCMCSampler(dataset, valid_dataloader = dgl.dataloading.EdgeDataLoader(
'valid') dataset.valid_dec_graph,
valid_seeds = th.arange(dataset.valid_truths.shape[0]) th.arange(dataset.valid_dec_graph.number_of_edges()),
valid_dataloader = DataLoader(dataset=valid_seeds, sampler,
batch_size=args.minibatch_size, g_sampling=dataset.valid_enc_graph,
collate_fn=valid_sampler.sample_blocks, batch_size=args.minibatch_size,
shuffle=False, shuffle=False,
pin_memory=True, drop_last=False)
drop_last=False, test_dataloader = dgl.dataloading.EdgeDataLoader(
num_workers=args.num_workers_per_gpu) dataset.test_dec_graph,
th.arange(dataset.test_dec_graph.number_of_edges()),
test_sampler = GCMCSampler(dataset, sampler,
'test') g_sampling=dataset.test_enc_graph,
test_seeds = th.arange(dataset.test_truths.shape[0]) batch_size=args.minibatch_size,
test_dataloader = DataLoader(dataset=test_seeds, shuffle=False,
batch_size=args.minibatch_size, drop_last=False)
collate_fn=test_sampler.sample_blocks,
shuffle=False,
pin_memory=True,
drop_last=False,
num_workers=args.num_workers_per_gpu)
if n_gpus > 1: if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format( dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
...@@ -341,9 +289,14 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -341,9 +289,14 @@ def run(proc_id, n_gpus, args, devices, dataset):
if epoch > 1: if epoch > 1:
t0 = time.time() t0 = time.time()
net.train() net.train()
for step, sample_data in enumerate(dataloader): for step, (input_nodes, pair_graph, blocks) in enumerate(dataloader):
compact_g, frontier, head_feat, tail_feat, \ head_feat, tail_feat, blocks = load_subtensor(
true_relation_labels, true_relation_ratings = sample_data input_nodes, pair_graph, blocks, dataset, dataset.train_enc_graph)
frontier = blocks[0]
compact_g = flatten_etypes(pair_graph, dataset, 'train').to(dev_id)
true_relation_labels = compact_g.edata['label']
true_relation_ratings = compact_g.edata['rating']
head_feat = head_feat.to(dev_id) head_feat = head_feat.to(dev_id)
tail_feat = tail_feat.to(dev_id) tail_feat = tail_feat.to(dev_id)
frontier = frontier.to(dev_id) frontier = frontier.to(dev_id)
...@@ -375,6 +328,9 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -375,6 +328,9 @@ def run(proc_id, n_gpus, args, devices, dataset):
print("[{}] {}".format(proc_id, logging_str)) print("[{}] {}".format(proc_id, logging_str))
iter_idx += 1 iter_idx += 1
if step == 20:
return
if epoch > 1: if epoch > 1:
epoch_time = time.time() - t0 epoch_time = time.time() - t0
print("Epoch {} time {}".format(epoch, epoch_time)) print("Epoch {} time {}".format(epoch, epoch_time))
...@@ -460,7 +416,7 @@ if __name__ == '__main__': ...@@ -460,7 +416,7 @@ if __name__ == '__main__':
dataset.train_dec_graph.create_format_() dataset.train_dec_graph.create_format_()
procs = [] procs = []
for proc_id in range(n_gpus): for proc_id in range(n_gpus):
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, dataset)) p = mp.Process(target=thread_wrapped_func(run), args=(proc_id, n_gpus, args, devices, dataset))
p.start() p.start()
procs.append(p) procs.append(p)
for p in procs: for p in procs:
......
...@@ -76,3 +76,7 @@ def get_optimizer(opt): ...@@ -76,3 +76,7 @@ def get_optimizer(opt):
return optim.Adam return optim.Adam
else: else:
raise NotImplementedError raise NotImplementedError
def to_etype_name(rating):
return str(rating).replace('.', '_')
...@@ -64,8 +64,8 @@ class SAGE(nn.Module): ...@@ -64,8 +64,8 @@ class SAGE(nn.Module):
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
sampler = dgl.sampling.MultiLayerNeighborSampler([None]) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.sampling.NodeDataLoader( dataloader = dgl.dataloading.NodeDataLoader(
g, g,
th.arange(g.number_of_nodes()), th.arange(g.number_of_nodes()),
sampler, sampler,
...@@ -129,9 +129,9 @@ def run(args, device, data): ...@@ -129,9 +129,9 @@ def run(args, device, data):
test_nid = th.nonzero(~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']), as_tuple=True)[0] test_nid = th.nonzero(~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']), as_tuple=True)[0]
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
sampler = dgl.sampling.MultiLayerNeighborSampler( sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')]) [int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.sampling.NodeDataLoader( dataloader = dgl.dataloading.NodeDataLoader(
train_g, train_g,
train_nid, train_nid,
sampler, sampler,
......
...@@ -66,8 +66,8 @@ class SAGE(nn.Module): ...@@ -66,8 +66,8 @@ class SAGE(nn.Module):
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
sampler = dgl.sampling.MultiLayerNeighborSampler([None]) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.sampling.NodeDataLoader( dataloader = dgl.dataloading.NodeDataLoader(
g, g,
th.arange(g.number_of_nodes()), th.arange(g.number_of_nodes()),
sampler, sampler,
...@@ -149,9 +149,9 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -149,9 +149,9 @@ def run(proc_id, n_gpus, args, devices, data):
train_nid = th.split(train_nid, math.ceil(len(train_nid) / n_gpus))[proc_id] train_nid = th.split(train_nid, math.ceil(len(train_nid) / n_gpus))[proc_id]
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
sampler = dgl.sampling.MultiLayerNeighborSampler( sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')]) [int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.sampling.NodeDataLoader( dataloader = dgl.dataloading.NodeDataLoader(
train_g, train_g,
train_nid, train_nid,
sampler, sampler,
......
...@@ -21,75 +21,22 @@ import sklearn.metrics as skm ...@@ -21,75 +21,22 @@ import sklearn.metrics as skm
from utils import thread_wrapped_func from utils import thread_wrapped_func
#### Negative sampler
class NegativeSampler(object): class NegativeSampler(object):
def __init__(self, g): def __init__(self, g, k, neg_share=False):
self.weights = g.in_degrees().float() ** 0.75 self.weights = g.in_degrees().float() ** 0.75
self.k = k
def __call__(self, num_samples):
return self.weights.multinomial(num_samples, replacement=True)
#### Neighbor sampler
class NeighborSampler(object):
def __init__(self, g, fanouts, num_negs, neg_share=False):
self.g = g
self.fanouts = fanouts
self.neg_sampler = NegativeSampler(g)
self.num_negs = num_negs
self.neg_share = neg_share self.neg_share = neg_share
def sample_blocks(self, seed_edges): def __call__(self, g, eids):
n_edges = len(seed_edges) src, _ = g.find_edges(eids)
seed_edges = th.LongTensor(np.asarray(seed_edges)) n = len(src)
heads, tails = self.g.find_edges(seed_edges) if self.neg_share and n % self.k == 0:
if self.neg_share and n_edges % self.num_negs == 0: dst = self.weights.multinomial(n, replacement=True)
neg_tails = self.neg_sampler(n_edges) dst = dst.view(-1, 1, self.k).expand(-1, self.k, -1).flatten()
neg_tails = neg_tails.view(-1, 1, self.num_negs).expand(n_edges//self.num_negs,
self.num_negs,
self.num_negs).flatten()
neg_heads = heads.view(-1, 1).expand(n_edges, self.num_negs).flatten()
else: else:
neg_tails = self.neg_sampler(self.num_negs * n_edges) dst = self.weights.multinomial(n, replacement=True)
neg_heads = heads.view(-1, 1).expand(n_edges, self.num_negs).flatten() src = src.repeat_interleave(self.k)
return src, dst
# Maintain the correspondence between heads, tails and negative tails as two
# graphs.
# pos_graph contains the correspondence between each head and its positive tail.
# neg_graph contains the correspondence between each head and its negative tails.
# Both pos_graph and neg_graph are first constructed with the same node space as
# the original graph. Then they are compacted together with dgl.compact_graphs.
pos_graph = dgl.graph((heads, tails), num_nodes=self.g.number_of_nodes())
neg_graph = dgl.graph((neg_heads, neg_tails), num_nodes=self.g.number_of_nodes())
pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph])
# Obtain the node IDs being used in either pos_graph or neg_graph. Since they
# are compacted together, pos_graph and neg_graph share the same compacted node
# space.
seeds = pos_graph.ndata[dgl.NID]
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
frontier = dgl.sampling.sample_neighbors(self.g, seeds, fanout, replace=True)
# Remove all edges between heads and tails, as well as heads and neg_tails.
_, _, edge_ids = frontier.edge_ids(
th.cat([heads, tails, neg_heads, neg_tails]),
th.cat([tails, heads, neg_tails, neg_heads]),
return_uv=True)
frontier = dgl.remove_edges(frontier, edge_ids)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)
# Pre-generate CSR format that it can be used in training directly
block.create_format_()
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
# Pre-generate CSR format that it can be used in training directly
return pos_graph, neg_graph, blocks
def load_subtensor(g, input_nodes, device): def load_subtensor(g, input_nodes, device):
""" """
...@@ -145,12 +92,18 @@ class SAGE(nn.Module): ...@@ -145,12 +92,18 @@ class SAGE(nn.Module):
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
for start in tqdm.trange(0, len(nodes), batch_size): sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
end = start + batch_size dataloader = dgl.dataloading.NodeDataLoader(
batch_nodes = nodes[start:end] g,
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes) th.arange(g.number_of_nodes()),
block = block.int().to(device) sampler,
input_nodes = block.srcdata[dgl.NID] batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device)
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h = layer(block, h) h = layer(block, h)
...@@ -158,7 +111,7 @@ class SAGE(nn.Module): ...@@ -158,7 +111,7 @@ class SAGE(nn.Module):
h = self.activation(h) h = self.activation(h)
h = self.dropout(h) h = self.dropout(h)
y[start:end] = h.cpu() y[output_nodes] = h.cpu()
x = y x = y
return y return y
...@@ -245,11 +198,9 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -245,11 +198,9 @@ def run(proc_id, n_gpus, args, devices, data):
#val_nid = th.LongTensor(np.nonzero(val_mask)[0]) #val_nid = th.LongTensor(np.nonzero(val_mask)[0])
#test_nid = th.LongTensor(np.nonzero(test_mask)[0]) #test_nid = th.LongTensor(np.nonzero(test_mask)[0])
# Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], args.num_negs, args.neg_share)
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
train_seeds = np.arange(g.number_of_edges()) n_edges = g.number_of_edges()
train_seeds = np.arange(n_edges)
if n_gpus > 0: if n_gpus > 0:
num_per_gpu = (train_seeds.shape[0] + n_gpus -1) // n_gpus num_per_gpu = (train_seeds.shape[0] + n_gpus -1) // n_gpus
train_seeds = train_seeds[proc_id * num_per_gpu : train_seeds = train_seeds[proc_id * num_per_gpu :
...@@ -257,10 +208,17 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -257,10 +208,17 @@ def run(proc_id, n_gpus, args, devices, data):
if (proc_id + 1) * num_per_gpu < train_seeds.shape[0] if (proc_id + 1) * num_per_gpu < train_seeds.shape[0]
else train_seeds.shape[0]] else train_seeds.shape[0]]
dataloader = DataLoader( # Create sampler
dataset=train_seeds, sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.dataloading.EdgeDataLoader(
g, train_seeds, sampler, exclude='reverse_id',
# For each edge with ID e in Reddit dataset, the reverse edge is e ± |E|/2.
reverse_eids=th.cat([
th.arange(n_edges // 2, n_edges),
th.arange(0, n_edges // 2)]),
negative_sampler=NegativeSampler(g, args.num_negs),
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
pin_memory=True, pin_memory=True,
...@@ -290,10 +248,7 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -290,10 +248,7 @@ def run(proc_id, n_gpus, args, devices, data):
# blocks. # blocks.
tic_step = time.time() tic_step = time.time()
for step, (pos_graph, neg_graph, blocks) in enumerate(dataloader): for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(dataloader):
# The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block.
input_nodes = blocks[0].srcdata[dgl.NID]
batch_inputs = load_subtensor(g, input_nodes, device) batch_inputs = load_subtensor(g, input_nodes, device)
d_step = time.time() d_step = time.time()
......
...@@ -84,9 +84,9 @@ class JTNNDataset(Dataset): ...@@ -84,9 +84,9 @@ class JTNNDataset(Dataset):
cand_graphs = [] cand_graphs = []
atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC) atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC)
bond_x_dec = torch.zeros(0, BOND_FDIM_DEC) bond_x_dec = torch.zeros(0, BOND_FDIM_DEC)
tree_mess_src_e = torch.zeros(0, 2).int() tree_mess_src_e = torch.zeros(0, 2).long()
tree_mess_tgt_e = torch.zeros(0, 2).int() tree_mess_tgt_e = torch.zeros(0, 2).long()
tree_mess_tgt_n = torch.zeros(0).int() tree_mess_tgt_n = torch.zeros(0).long()
# prebuild the stereoisomers # prebuild the stereoisomers
cands = mol_tree.stereo_cands cands = mol_tree.stereo_cands
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from .nnutils import cuda, line_graph from .nnutils import cuda
import rdkit.Chem as Chem import rdkit.Chem as Chem
import dgl import dgl
from dgl import mean_nodes from dgl import mean_nodes, line_graph
import dgl.function as DGLF import dgl.function as DGLF
import os import os
...@@ -93,13 +93,9 @@ def mol2dgl_single(cand_batch): ...@@ -93,13 +93,9 @@ def mol2dgl_single(cand_batch):
return cand_graphs, torch.stack(atom_x), \ return cand_graphs, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \ torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \
torch.IntTensor(tree_mess_source_edges), \ torch.LongTensor(tree_mess_source_edges), \
torch.IntTensor(tree_mess_target_edges), \ torch.LongTensor(tree_mess_target_edges), \
torch.IntTensor(tree_mess_target_nodes) torch.LongTensor(tree_mess_target_nodes)
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msg='msg', out='accum_msg')
class LoopyBPUpdate(nn.Module): class LoopyBPUpdate(nn.Module):
...@@ -224,19 +220,19 @@ class DGLJTMPN(nn.Module): ...@@ -224,19 +220,19 @@ class DGLJTMPN(nn.Module):
tgt_u, tgt_v = tree_mess_tgt_edges.unbind(1) tgt_u, tgt_v = tree_mess_tgt_edges.unbind(1)
src_u = src_u.to(mol_tree_batch.device) src_u = src_u.to(mol_tree_batch.device)
src_v = src_v.to(mol_tree_batch.device) src_v = src_v.to(mol_tree_batch.device)
eid = mol_tree_batch.edge_ids(src_u.int(), src_v.int()).long() eid = mol_tree_batch.edge_ids(src_u, src_v)
alpha = mol_tree_batch.edata['m'][eid] alpha = mol_tree_batch.edata['m'][eid]
cand_graphs.edges[tgt_u, tgt_v].data['alpha'] = alpha cand_graphs.edges[tgt_u, tgt_v].data['alpha'] = alpha
else: else:
src_u, src_v = tree_mess_src_edges.unbind(1) src_u, src_v = tree_mess_src_edges.unbind(1)
src_u = src_u.to(mol_tree_batch.device) src_u = src_u.to(mol_tree_batch.device)
src_v = src_v.to(mol_tree_batch.device) src_v = src_v.to(mol_tree_batch.device)
eid = mol_tree_batch.edge_ids(src_u.int(), src_v.int()).long() eid = mol_tree_batch.edge_ids(src_u, src_v)
alpha = mol_tree_batch.edata['m'][eid] alpha = mol_tree_batch.edata['m'][eid]
node_idx = (tree_mess_tgt_nodes node_idx = (tree_mess_tgt_nodes
.to(device=zero_node_state.device)[:, None] .to(device=zero_node_state.device)[:, None]
.expand_as(alpha)) .expand_as(alpha))
node_alpha = zero_node_state.clone().scatter_add(0, node_idx.long(), alpha) node_alpha = zero_node_state.clone().scatter_add(0, node_idx, alpha)
cand_graphs.ndata['alpha'] = node_alpha cand_graphs.ndata['alpha'] = node_alpha
cand_graphs.apply_edges( cand_graphs.apply_edges(
func=lambda edges: {'alpha': edges.src['alpha']}, func=lambda edges: {'alpha': edges.src['alpha']},
...@@ -244,17 +240,14 @@ class DGLJTMPN(nn.Module): ...@@ -244,17 +240,14 @@ class DGLJTMPN(nn.Module):
cand_line_graph.ndata.update(cand_graphs.edata) cand_line_graph.ndata.update(cand_graphs.edata)
for i in range(self.depth - 1): for i in range(self.depth - 1):
cand_line_graph.update_all( cand_line_graph.update_all(DGLF.copy_u('msg', 'msg'), DGLF.sum('msg', 'accum_msg'))
mpn_loopy_bp_msg, cand_line_graph.apply_nodes(self.loopy_bp_updater)
mpn_loopy_bp_reduce,
self.loopy_bp_updater,
)
cand_graphs.edata.update(cand_line_graph.ndata) cand_graphs.edata.update(cand_line_graph.ndata)
cand_graphs.update_all(
mpn_gather_msg, cand_graphs.update_all(DGLF.copy_e('msg', 'msg'), DGLF.sum('msg', 'm'))
mpn_gather_reduce, if PAPER:
self.gather_updater, cand_graphs.update_all(DGLF.copy_e('alpha', 'alpha'), DGLF.sum('alpha', 'accum_alpha'))
) cand_graphs.apply_nodes(self.gather_updater)
return cand_graphs return cand_graphs
...@@ -3,8 +3,8 @@ import torch.nn as nn ...@@ -3,8 +3,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .mol_tree_nx import DGLMolTree from .mol_tree_nx import DGLMolTree
from .chemutils import enum_assemble_nx, get_mol from .chemutils import enum_assemble_nx, get_mol
from .nnutils import GRUUpdate, cuda, line_graph, tocpu from .nnutils import GRUUpdate, cuda, tocpu
from dgl import batch, dfs_labeled_edges_generator from dgl import batch, dfs_labeled_edges_generator, line_graph
import dgl.function as DGLF import dgl.function as DGLF
import numpy as np import numpy as np
...@@ -29,10 +29,6 @@ def dec_tree_node_update(nodes): ...@@ -29,10 +29,6 @@ def dec_tree_node_update(nodes):
return {'new': nodes.data['new'].clone().zero_()} return {'new': nodes.data['new'].clone().zero_()}
dec_tree_edge_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
dec_tree_edge_reduce = [DGLF.sum(msg='m', out='s'), DGLF.sum(msg='rm', out='accum_rm')]
def have_slots(fa_slots, ch_slots): def have_slots(fa_slots, ch_slots):
if len(fa_slots) > 2 and len(ch_slots) > 2: if len(fa_slots) > 2 and len(ch_slots) > 2:
return True return True
...@@ -146,12 +142,8 @@ class DGLJTNNDecoder(nn.Module): ...@@ -146,12 +142,8 @@ class DGLJTNNDecoder(nn.Module):
q_targets = [] q_targets = []
# Predict root # Predict root
mol_tree_batch.pull( mol_tree_batch.pull(root_ids, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
root_ids, mol_tree_batch.apply_nodes(dec_tree_node_update)
dec_tree_node_msg,
dec_tree_node_reduce,
dec_tree_node_update,
)
# Extract hidden states and store them for stop/label prediction # Extract hidden states and store them for stop/label prediction
h = mol_tree_batch.nodes[root_ids].data['h'] h = mol_tree_batch.nodes[root_ids].data['h']
x = mol_tree_batch.nodes[root_ids].data['x'] x = mol_tree_batch.nodes[root_ids].data['x']
...@@ -168,28 +160,22 @@ class DGLJTNNDecoder(nn.Module): ...@@ -168,28 +160,22 @@ class DGLJTNNDecoder(nn.Module):
u, v = mol_tree_batch.find_edges(eid) u, v = mol_tree_batch.find_edges(eid)
p_target_list = torch.zeros_like(root_out_degrees) p_target_list = torch.zeros_like(root_out_degrees)
p_target_list[root_out_degrees > 0] = (1 - p).int() p_target_list[root_out_degrees > 0] = (1 - p)
p_target_list = p_target_list[root_out_degrees >= 0] p_target_list = p_target_list[root_out_degrees >= 0]
p_targets.append(torch.tensor(p_target_list)) p_targets.append(torch.tensor(p_target_list))
root_out_degrees -= (root_out_degrees == 0).int() root_out_degrees -= (root_out_degrees == 0).long()
root_out_degrees -= torch.tensor(np.isin(root_ids, v.cpu().numpy())).to(root_out_degrees) root_out_degrees -= torch.tensor(np.isin(root_ids, v.cpu().numpy())).to(root_out_degrees)
mol_tree_batch_lg.ndata.update(mol_tree_batch.edata) mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)
mol_tree_batch_lg.pull( mol_tree_batch_lg.pull(eid, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
eid, mol_tree_batch_lg.pull(eid, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm'))
dec_tree_edge_msg, mol_tree_batch_lg.apply_nodes(self.dec_tree_edge_update)
dec_tree_edge_reduce,
self.dec_tree_edge_update,
)
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata) mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
is_new = mol_tree_batch.nodes[v].data['new'] is_new = mol_tree_batch.nodes[v].data['new']
mol_tree_batch.pull( mol_tree_batch.pull(v, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
v, mol_tree_batch.apply_nodes(dec_tree_node_update)
dec_tree_node_msg,
dec_tree_node_reduce,
dec_tree_node_update,
)
# Extract # Extract
n_repr = mol_tree_batch.nodes[v].data n_repr = mol_tree_batch.nodes[v].data
...@@ -209,7 +195,7 @@ class DGLJTNNDecoder(nn.Module): ...@@ -209,7 +195,7 @@ class DGLJTNNDecoder(nn.Module):
p_targets.append(torch.zeros( p_targets.append(torch.zeros(
(root_out_degrees == 0).sum(), (root_out_degrees == 0).sum(),
device=root_out_degrees.device, device=root_out_degrees.device,
dtype=torch.int32)) dtype=torch.int64))
# Batch compute the stop/label prediction losses # Batch compute the stop/label prediction losses
p_inputs = torch.cat(p_inputs, 0) p_inputs = torch.cat(p_inputs, 0)
...@@ -310,16 +296,15 @@ class DGLJTNNDecoder(nn.Module): ...@@ -310,16 +296,15 @@ class DGLJTNNDecoder(nn.Module):
mol_tree_graph_lg.pull( mol_tree_graph_lg.pull(
uv, uv,
dec_tree_edge_msg, DGLF.copy_u('m', 'm'),
dec_tree_edge_reduce, DGLF.sum('m', 's'))
self.dec_tree_edge_update.update_zm, mol_tree_graph_lg.pull(
) uv,
DGLF.copy_u('rm', 'rm'),
DGLF.sum('rm', 'accum_rm'))
mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update.update_zm)
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata) mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull( mol_tree_graph.pull(v, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
v,
dec_tree_node_msg,
dec_tree_node_reduce,
)
h_v = mol_tree_graph.ndata['h'][v:v+1] h_v = mol_tree_graph.ndata['h'][v:v+1]
q_input = torch.cat([h_v, mol_vec], 1) q_input = torch.cat([h_v, mol_vec], 1)
...@@ -371,18 +356,11 @@ class DGLJTNNDecoder(nn.Module): ...@@ -371,18 +356,11 @@ class DGLJTNNDecoder(nn.Module):
pu, _ = stack[-2] pu, _ = stack[-2]
u_pu = mol_tree_graph.edge_id(u, pu) u_pu = mol_tree_graph.edge_id(u, pu)
mol_tree_graph_lg.pull( mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
u_pu, mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm'))
dec_tree_edge_msg, mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update)
dec_tree_edge_reduce,
self.dec_tree_edge_update,
)
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata) mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull( mol_tree_graph.pull(pu, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
pu,
dec_tree_node_msg,
dec_tree_node_reduce,
)
stack.pop() stack.pop()
effective_nodes = mol_tree_graph.filter_nodes(lambda nodes: nodes.data['fail'] != 1) effective_nodes = mol_tree_graph.filter_nodes(lambda nodes: nodes.data['fail'] != 1)
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from .nnutils import GRUUpdate, cuda, line_graph, tocpu from .nnutils import GRUUpdate, cuda, tocpu
from dgl import batch, bfs_edges_generator from dgl import batch, bfs_edges_generator, line_graph
import dgl.function as DGLF import dgl.function as DGLF
import numpy as np import numpy as np
...@@ -18,11 +18,6 @@ def level_order(forest, roots): ...@@ -18,11 +18,6 @@ def level_order(forest, roots):
yield from reversed(edges_back) yield from reversed(edges_back)
yield from edges yield from edges
enc_tree_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
enc_tree_reduce = [DGLF.sum(msg='m', out='s'), DGLF.sum(msg='rm', out='accum_rm')]
enc_tree_gather_msg = DGLF.copy_edge(edge='m', out='m')
enc_tree_gather_reduce = DGLF.sum(msg='m', out='m')
class EncoderGatherUpdate(nn.Module): class EncoderGatherUpdate(nn.Module):
def __init__(self, hidden_size): def __init__(self, hidden_size):
nn.Module.__init__(self) nn.Module.__init__(self)
...@@ -102,21 +97,15 @@ class DGLJTNNEncoder(nn.Module): ...@@ -102,21 +97,15 @@ class DGLJTNNEncoder(nn.Module):
# if m_ij is actually computed or not. # if m_ij is actually computed or not.
mol_tree_batch_lg.ndata.update(mol_tree_batch.edata) mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)
for eid in level_order(mol_tree_batch, root_ids): for eid in level_order(mol_tree_batch, root_ids):
#eid = mol_tree_batch.edge_ids(u, v) eid = eid.to(mol_tree_batch_lg.device)
mol_tree_batch_lg.pull( mol_tree_batch_lg.pull(eid, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
eid.to(mol_tree_batch_lg.device), mol_tree_batch_lg.pull(eid, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'rm'))
enc_tree_msg, mol_tree_batch_lg.apply_nodes(self.enc_tree_update)
enc_tree_reduce,
self.enc_tree_update,
)
# Readout # Readout
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata) mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
mol_tree_batch.update_all( mol_tree_batch.update_all(DGLF.copy_e('m', 'm'), DGLF.sum('m', 'm'))
enc_tree_gather_msg, mol_tree_batch.apply_nodes(self.enc_tree_gather_update)
enc_tree_gather_reduce,
self.enc_tree_gather_update,
)
root_vecs = mol_tree_batch.nodes[root_ids].data['h'] root_vecs = mol_tree_batch.nodes[root_ids].data['h']
......
...@@ -183,7 +183,7 @@ class DGLJTNNVAE(nn.Module): ...@@ -183,7 +183,7 @@ class DGLJTNNVAE(nn.Module):
node['is_leaf'] = False node['is_leaf'] = False
set_atommap(node['mol'], node['nid']) set_atommap(node['mol'], node['nid'])
mol_tree_sg = mol_tree.graph.subgraph(effective_nodes.int().to(tree_vec.device)) mol_tree_sg = mol_tree.graph.subgraph(effective_nodes.to(tree_vec.device))
mol_tree_msg, _ = self.jtnn([mol_tree_sg]) mol_tree_msg, _ = self.jtnn([mol_tree_sg])
mol_tree_msg = unbatch(mol_tree_msg)[0] mol_tree_msg = unbatch(mol_tree_msg)[0]
mol_tree_msg.nodes_dict = nodes_dict mol_tree_msg.nodes_dict = nodes_dict
......
...@@ -4,9 +4,8 @@ import rdkit.Chem as Chem ...@@ -4,9 +4,8 @@ import rdkit.Chem as Chem
import torch.nn.functional as F import torch.nn.functional as F
from .chemutils import get_mol from .chemutils import get_mol
import dgl import dgl
from dgl import mean_nodes from dgl import mean_nodes, line_graph
import dgl.function as DGLF import dgl.function as DGLF
from .nnutils import line_graph
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown'] 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
...@@ -66,10 +65,6 @@ def mol2dgl_single(smiles): ...@@ -66,10 +65,6 @@ def mol2dgl_single(smiles):
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0) torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0)
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msg='msg', out='accum_msg')
class LoopyBPUpdate(nn.Module): class LoopyBPUpdate(nn.Module):
def __init__(self, hidden_size): def __init__(self, hidden_size):
super(LoopyBPUpdate, self).__init__() super(LoopyBPUpdate, self).__init__()
...@@ -84,10 +79,6 @@ class LoopyBPUpdate(nn.Module): ...@@ -84,10 +79,6 @@ class LoopyBPUpdate(nn.Module):
return {'msg': msg} return {'msg': msg}
mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg')
mpn_gather_reduce = DGLF.sum(msg='msg', out='m')
class GatherUpdate(nn.Module): class GatherUpdate(nn.Module):
def __init__(self, hidden_size): def __init__(self, hidden_size):
super(GatherUpdate, self).__init__() super(GatherUpdate, self).__init__()
...@@ -163,17 +154,11 @@ class DGLMPN(nn.Module): ...@@ -163,17 +154,11 @@ class DGLMPN(nn.Module):
}) })
for i in range(self.depth - 1): for i in range(self.depth - 1):
mol_line_graph.update_all( mol_line_graph.update_all(DGLF.copy_u('msg', 'msg'), DGLF.sum('msg', 'accum_msg'))
mpn_loopy_bp_msg, mol_line_graph.apply_nodes(self.loopy_bp_updater)
mpn_loopy_bp_reduce,
self.loopy_bp_updater,
)
mol_graph.edata.update(mol_line_graph.ndata) mol_graph.edata.update(mol_line_graph.ndata)
mol_graph.update_all( mol_graph.update_all(DGLF.copy_e('msg', 'msg'), DGLF.sum('msg', 'm'))
mpn_gather_msg, mol_graph.apply_nodes(self.gather_updater)
mpn_gather_reduce,
self.gather_updater,
)
return mol_graph return mol_graph
...@@ -48,10 +48,3 @@ def tocpu(g): ...@@ -48,10 +48,3 @@ def tocpu(g):
src = src.cpu() src = src.cpu()
dst = dst.cpu() dst = dst.cpu()
return dgl.graph((src, dst), num_nodes=g.number_of_nodes()) return dgl.graph((src, dst), num_nodes=g.number_of_nodes())
def line_graph(g, backtracking=True, shared=False):
#g2 = tocpu(g)
g2 = dgl.line_graph(g, backtracking, shared)
#g2 = g2.to(g.device)
g2.ndata.update(g.edata)
return g2
...@@ -78,7 +78,7 @@ def train(): ...@@ -78,7 +78,7 @@ def train():
for epoch in range(MAX_EPOCH): for epoch in range(MAX_EPOCH):
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0 word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
for it, batch in tqdm.tqdm(enumerate(dataloader), total=2000): for it, batch in enumerate(tqdm.tqdm(dataloader)):
model.zero_grad() model.zero_grad()
try: try:
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta) loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
......
...@@ -85,8 +85,8 @@ class GAT(nn.Module): ...@@ -85,8 +85,8 @@ class GAT(nn.Module):
y = th.zeros(g.number_of_nodes(), self.n_hidden * num_heads if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(g.number_of_nodes(), self.n_hidden * num_heads if l != len(self.layers) - 1 else self.n_classes)
else: else:
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
sampler = dgl.sampling.MultiLayerNeighborSampler([None]) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.sampling.NodeDataLoader( dataloader = dgl.dataloading.NodeDataLoader(
g, g,
th.arange(g.number_of_nodes()), th.arange(g.number_of_nodes()),
sampler, sampler,
......
...@@ -17,29 +17,6 @@ import tqdm ...@@ -17,29 +17,6 @@ import tqdm
import traceback import traceback
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
#### Neighbor sampler
class NeighborSampler(object):
def __init__(self, g, fanouts):
self.g = g
self.fanouts = fanouts
def sample_blocks(self, seeds):
seeds = th.LongTensor(np.asarray(seeds))
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
if fanout == 0:
frontier = dgl.in_subgraph(self.g, seeds)
else:
frontier = dgl.sampling.sample_neighbors(self.g, seeds, fanout, replace=True)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
return blocks
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self, def __init__(self,
...@@ -97,8 +74,8 @@ class GAT(nn.Module): ...@@ -97,8 +74,8 @@ class GAT(nn.Module):
else: else:
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
sampler = dgl.sampling.MultiLayerNeighborSampler([None]) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.sampling.NodeDataLoader( dataloader = dgl.dataloading.NodeDataLoader(
g, g,
th.arange(g.number_of_nodes()), th.arange(g.number_of_nodes()),
sampler, sampler,
...@@ -161,9 +138,9 @@ def run(args, device, data): ...@@ -161,9 +138,9 @@ def run(args, device, data):
train_nid, val_nid, test_nid, in_feats, labels, n_classes, g, num_heads = data train_nid, val_nid, test_nid, in_feats, labels, n_classes, g, num_heads = data
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
sampler = dgl.sampling.MultiLayerNeighborSampler( sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')]) [int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.sampling.NodeDataLoader( dataloader = dgl.dataloading.NodeDataLoader(
g, g,
train_nid, train_nid,
sampler, sampler,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment