"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5aa31bd6741a3469415bd95c91b0b47b98ba87a8"
Unverified Commit f5eb80d2 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Edge DataLoader for edge classification & link prediction (#1828)

* clean commit

* oops forgot the most important files

* use einsum

* copy feature from frontier to block

* Revert "copy feature from frontier to block"

This reverts commit 5224ec963eb6a3ef1b6ab74d8ecbd44e4e42f285.

* temp fix

* unit test

* fix

* revert jtnn

* lint

* fix win64

* docstring fixes and doc indexing

* revert einsum in sparse bidecoder

* fix some examples

* lint

* fix due to some tediousness in remove_edges

* addresses comments

* fix

* more jtnn fixes

* fix
parent d340ea3a
.. _api-sampling:
dgl.dataloading
=================================
.. automodule:: dgl.dataloading
PyTorch node/edge DataLoaders
-----------------------------
.. autoclass:: pytorch.NodeDataLoader
.. autoclass:: pytorch.EdgeDataLoader
General collating functions
---------------------------
.. autoclass:: Collator
.. autoclass:: NodeCollator
.. autoclass:: EdgeCollator
Base Multi-layer Neighborhood Sampling Class
--------------------------------------------
.. autoclass:: BlockSampler
Uniform Node-wise Neighbor Sampling (GraphSAGE style)
-----------------------------------------------------
.. autoclass:: MultiLayerNeighborSampler
Negative Samplers for Link Prediction
-------------------------------------
.. autoclass:: negative_sampler.Uniform
......@@ -11,3 +11,4 @@ API Reference
dgl.ops
dgl.function
sampling
dataloading
......@@ -25,23 +25,7 @@ Neighbor sampling functions
sample_neighbors
select_topk
PyTorch DataLoaders with neighborhood sampling
----------------------------------------------
.. autoclass:: pytorch.NeighborSamplerNodeDataLoader
Builtin sampler classes for more complicated sampling algorithms
----------------------------------------------------------------
.. autoclass:: RandomWalkNeighborSampler
.. autoclass:: PinSAGESampler
Neighborhood samplers for multilayer GNNs
-----------------------------------------
.. autoclass:: MultiLayerNeighborSampler
Data loaders for minibatch iteration
------------------------------------
.. autoclass:: NodeCollator
Abstract class for neighborhood sampler
---------------------------------------
.. autoclass:: BlockSampler
......@@ -8,6 +8,7 @@ import torch as th
import dgl
from dgl.data.utils import download, extract_archive, get_download_dir
from utils import to_etype_name
_urls = {
'ml-100k' : 'http://files.grouplens.org/datasets/movielens/ml-100k.zip',
......@@ -210,7 +211,7 @@ class MovieLens(object):
def _npairs(graph):
rst = 0
for r in self.possible_rating_values:
r = str(r).replace('.', '_')
r = to_etype_name(r)
rst += graph.number_of_edges(str(r))
return rst
......@@ -252,7 +253,7 @@ class MovieLens(object):
ridx = np.where(rating_values == rating)
rrow = rating_row[ridx]
rcol = rating_col[ridx]
rating = str(rating).replace('.', '_')
rating = to_etype_name(rating)
bg = dgl.bipartite((rrow, rcol), 'user', rating, 'movie',
num_nodes=(self._num_user, self._num_movie))
rev_bg = dgl.bipartite((rcol, rrow), 'movie', 'rev-%s' % rating, 'user',
......@@ -275,7 +276,7 @@ class MovieLens(object):
movie_ci = []
movie_cj = []
for r in self.possible_rating_values:
r = str(r).replace('.', '_')
r = to_etype_name(r)
user_ci.append(graph['rev-%s' % r].in_degrees())
movie_ci.append(graph[r].in_degrees())
if self._symm:
......
......@@ -5,7 +5,7 @@ from torch.nn import init
import dgl.function as fn
import dgl.nn.pytorch as dglnn
from utils import get_activation
from utils import get_activation, to_etype_name
class GCMCGraphConv(nn.Module):
"""Graph convolution module used in the GCMC model.
......@@ -175,7 +175,7 @@ class GCMCLayer(nn.Module):
subConv = {}
for rating in rating_vals:
# PyTorch parameter name can't contain "."
rating = str(rating).replace('.', '_')
rating = to_etype_name(rating)
rev_rating = 'rev-%s' % rating
if share_user_item_param and user_in_units == movie_in_units:
self.W_r[rating] = nn.Parameter(th.randn(user_in_units, msg_units))
......@@ -251,7 +251,7 @@ class GCMCLayer(nn.Module):
in_feats = {'user' : ufeat, 'movie' : ifeat}
mod_args = {}
for i, rating in enumerate(self.rating_vals):
rating = str(rating).replace('.', '_')
rating = to_etype_name(rating)
rev_rating = 'rev-%s' % rating
mod_args[rating] = (self.W_r[rating] if self.W_r is not None else None,)
mod_args[rev_rating] = (self.W_r[rev_rating] if self.W_r is not None else None,)
......@@ -304,9 +304,7 @@ class BiDecoder(nn.Module):
super(BiDecoder, self).__init__()
self._num_basis = num_basis
self.dropout = nn.Dropout(dropout_rate)
self.Ps = nn.ParameterList()
for i in range(num_basis):
self.Ps.append(nn.Parameter(th.randn(in_units, in_units)))
self.P = nn.Parameter(th.randn(num_basis, in_units, in_units))
self.combine_basis = nn.Linear(self._num_basis, num_classes, bias=False)
self.reset_parameters()
......@@ -392,12 +390,7 @@ class DenseBiDecoder(BiDecoder):
"""
ufeat = self.dropout(ufeat)
ifeat = self.dropout(ifeat)
basis_out = []
for i in range(self._num_basis):
ufeat_i = ufeat @ self.Ps[i]
out = th.einsum('ab,ab->a', ufeat_i, ifeat)
basis_out.append(out.unsqueeze(1))
out = th.cat(basis_out, dim=1)
out = th.einsum('ai,bij,aj->ab', ufeat, self.P, ifeat)
out = self.combine_basis(out)
return out
......
......@@ -21,103 +21,9 @@ from _thread import start_new_thread
from functools import wraps
from data import MovieLens
from model import GCMCLayer, DenseBiDecoder
from utils import get_activation, get_optimizer, torch_total_param_num, torch_net_info, MetricLogger
from utils import get_activation, get_optimizer, torch_total_param_num, torch_net_info, MetricLogger, to_etype_name
import dgl
class GCMCSampler:
"""Neighbor sampler in GCMC mini-batch training."""
def __init__(self, dataset, segment='train'):
self.dataset = dataset
if segment == 'train':
self.truths = dataset.train_truths
self.labels = dataset.train_labels
self.enc_graph = dataset.train_enc_graph
self.dec_graph = dataset.train_dec_graph
elif segment == 'valid':
self.truths = dataset.valid_truths
self.labels = None
self.enc_graph = dataset.valid_enc_graph
self.dec_graph = dataset.valid_dec_graph
elif segment == 'test':
self.truths = dataset.test_truths
self.labels = None
self.enc_graph = dataset.test_enc_graph
self.dec_graph = dataset.test_dec_graph
else:
assert False, "Unknow dataset {}".format(segment)
def sample_blocks(self, seeds):
"""Sample subgraphs from the entire graph.
The input ``seeds`` represents the edges to compute prediction for. The sampling
algorithm works as follows:
1. Get the head and tail nodes of the provided seed edges.
2. For each head and tail node, extract the entire in-coming neighborhood.
3. Copy the node features/embeddings from the full graph to the sampled subgraphs.
"""
dataset = self.dataset
enc_graph = self.enc_graph
dec_graph = self.dec_graph
edge_ids = th.stack(seeds)
# generate frontiers for user and item
possible_rating_values = dataset.possible_rating_values
true_relation_ratings = self.truths[edge_ids]
true_relation_labels = None if self.labels is None else self.labels[edge_ids]
# 1. Get the head and tail nodes from both the decoder and encoder graphs.
head_id, tail_id = dec_graph.find_edges(edge_ids)
utype, _, vtype = enc_graph.canonical_etypes[0]
subg = []
true_rel_ratings = []
true_rel_labels = []
for possible_rating_value in possible_rating_values:
idx_loc = (true_relation_ratings == possible_rating_value)
head = head_id[idx_loc]
tail = tail_id[idx_loc]
true_rel_ratings.append(true_relation_ratings[idx_loc])
if self.labels is not None:
true_rel_labels.append(true_relation_labels[idx_loc])
subg.append(dgl.bipartite((head, tail),
utype=utype,
etype=str(possible_rating_value),
vtype=vtype,
num_nodes=(enc_graph.number_of_nodes(utype),
enc_graph.number_of_nodes(vtype))))
# Convert the encoder subgraph to a more compact one by removing nodes that covered
# by the seed edges.
g = dgl.hetero_from_relations(subg)
g = dgl.compact_graphs(g)
# 2. For each head and tail node, extract the entire in-coming neighborhood.
seed_nodes = {}
for ntype in g.ntypes:
seed_nodes[ntype] = g.nodes[ntype].data[dgl.NID]
frontier = dgl.in_subgraph(enc_graph, seed_nodes)
frontier = dgl.to_block(frontier, seed_nodes)
# 3. Copy the node features/embeddings from the full graph to the sampled subgraphs.
frontier.dstnodes['user'].data['ci'] = \
enc_graph.nodes['user'].data['ci'][frontier.dstnodes['user'].data[dgl.NID]]
frontier.srcnodes['movie'].data['cj'] = \
enc_graph.nodes['movie'].data['cj'][frontier.srcnodes['movie'].data[dgl.NID]]
frontier.srcnodes['user'].data['cj'] = \
enc_graph.nodes['user'].data['cj'][frontier.srcnodes['user'].data[dgl.NID]]
frontier.dstnodes['movie'].data['ci'] = \
enc_graph.nodes['movie'].data['ci'][frontier.dstnodes['movie'].data[dgl.NID]]
# handle features
head_feat = frontier.srcnodes['user'].data[dgl.NID].long() \
if dataset.user_feature is None else \
dataset.user_feature[frontier.srcnodes['user'].data[dgl.NID]]
tail_feat = frontier.srcnodes['movie'].data[dgl.NID].long()\
if dataset.movie_feature is None else \
dataset.movie_feature[frontier.srcnodes['movie'].data[dgl.NID]]
true_rel_labels = None if self.labels is None else th.cat(true_rel_labels, dim=0)
true_rel_ratings = th.cat(true_rel_ratings, dim=0)
return (g, frontier, head_feat, tail_feat, true_rel_labels, true_rel_ratings)
class Net(nn.Module):
def __init__(self, args, dev_id):
super(Net, self).__init__()
......@@ -147,34 +53,80 @@ class Net(nn.Module):
def forward(self, compact_g, frontier, ufeat, ifeat, possible_rating_values):
user_out, movie_out = self.encoder(frontier, ufeat, ifeat)
head_emb = []
tail_emb = []
for possible_rating_value in possible_rating_values:
head, tail = compact_g.all_edges(etype=str(possible_rating_value))
head_emb.append(user_out[head])
tail_emb.append(movie_out[tail])
head_emb = th.cat(head_emb, dim=0)
tail_emb = th.cat(tail_emb, dim=0)
head, tail = compact_g.edges(order='eid')
head_emb = user_out[head]
tail_emb = movie_out[tail]
pred_ratings = self.decoder(head_emb, tail_emb)
return pred_ratings
def load_subtensor(input_nodes, pair_graph, blocks, dataset, parent_graph):
output_nodes = pair_graph.ndata[dgl.NID]
head_feat = input_nodes['user'] if dataset.user_feature is None else \
dataset.user_feature[input_nodes['user']]
tail_feat = input_nodes['movie'] if dataset.movie_feature is None else \
dataset.movie_feature[input_nodes['movie']]
for block in blocks:
block.dstnodes['user'].data['ci'] = \
parent_graph.nodes['user'].data['ci'][block.dstnodes['user'].data[dgl.NID]]
block.srcnodes['user'].data['cj'] = \
parent_graph.nodes['user'].data['cj'][block.srcnodes['user'].data[dgl.NID]]
block.dstnodes['movie'].data['ci'] = \
parent_graph.nodes['movie'].data['ci'][block.dstnodes['movie'].data[dgl.NID]]
block.srcnodes['movie'].data['cj'] = \
parent_graph.nodes['movie'].data['cj'][block.srcnodes['movie'].data[dgl.NID]]
return head_feat, tail_feat, blocks
def flatten_etypes(pair_graph, dataset, segment):
n_users = pair_graph.number_of_nodes('user')
n_movies = pair_graph.number_of_nodes('movie')
src = []
dst = []
labels = []
ratings = []
for rating in dataset.possible_rating_values:
src_etype, dst_etype = pair_graph.edges(order='eid', etype=to_etype_name(rating))
src.append(src_etype)
dst.append(dst_etype)
label = np.searchsorted(dataset.possible_rating_values, rating)
ratings.append(th.LongTensor(np.full_like(src_etype, rating)))
labels.append(th.LongTensor(np.full_like(src_etype, label)))
src = th.cat(src)
dst = th.cat(dst)
ratings = th.cat(ratings)
labels = th.cat(labels)
flattened_pair_graph = dgl.heterograph({
('user', 'rate', 'movie'): (src, dst)},
num_nodes_dict={'user': n_users, 'movie': n_movies})
flattened_pair_graph.edata['rating'] = ratings
flattened_pair_graph.edata['label'] = labels
return flattened_pair_graph
def evaluate(args, dev_id, net, dataset, dataloader, segment='valid'):
possible_rating_values = dataset.possible_rating_values
nd_possible_rating_values = th.FloatTensor(possible_rating_values).to(dev_id)
real_pred_ratings = []
true_rel_ratings = []
for sample_data in dataloader:
compact_g, frontier, head_feat, tail_feat, \
_, true_relation_ratings = sample_data
for input_nodes, pair_graph, blocks in dataloader:
head_feat, tail_feat, blocks = load_subtensor(
input_nodes, pair_graph, blocks, dataset,
dataset.valid_enc_graph if segment == 'valid' else dataset.test_enc_graph)
frontier = blocks[0]
true_relation_ratings = \
dataset.valid_truths[pair_graph.edata[dgl.EID]] if segment == 'valid' else \
dataset.test_truths[pair_graph.edata[dgl.EID]]
frontier = frontier.to(dev_id)
head_feat = head_feat.to(dev_id)
tail_feat = tail_feat.to(dev_id)
with th.no_grad():
pred_ratings = net(compact_g, frontier,
pred_ratings = net(pair_graph, frontier,
head_feat, tail_feat, possible_rating_values)
batch_pred_ratings = (th.softmax(pred_ratings, dim=1) *
nd_possible_rating_values.view(1, -1)).sum(dim=1)
......@@ -260,47 +212,43 @@ def config():
return args
@thread_wrapped_func
def run(proc_id, n_gpus, args, devices, dataset):
dev_id = devices[proc_id]
train_labels = dataset.train_labels
train_truths = dataset.train_truths
num_edges = train_truths.shape[0]
sampler = GCMCSampler(dataset,
'train')
seeds = th.arange(num_edges)
dataloader = DataLoader(
dataset=seeds,
reverse_types = {to_etype_name(k): 'rev-' + to_etype_name(k)
for k in dataset.possible_rating_values}
reverse_types.update({v: k for k, v in reverse_types.items()})
sampler = dgl.dataloading.MultiLayerNeighborSampler([None], return_eids=True)
dataloader = dgl.dataloading.EdgeDataLoader(
dataset.train_enc_graph,
{to_etype_name(k): th.arange(
dataset.train_enc_graph.number_of_edges(etype=to_etype_name(k)))
for k in dataset.possible_rating_values},
sampler,
batch_size=args.minibatch_size,
collate_fn=sampler.sample_blocks,
shuffle=True,
pin_memory=True,
drop_last=False,
num_workers=args.num_workers_per_gpu)
drop_last=False)
if proc_id == 0:
valid_sampler = GCMCSampler(dataset,
'valid')
valid_seeds = th.arange(dataset.valid_truths.shape[0])
valid_dataloader = DataLoader(dataset=valid_seeds,
batch_size=args.minibatch_size,
collate_fn=valid_sampler.sample_blocks,
shuffle=False,
pin_memory=True,
drop_last=False,
num_workers=args.num_workers_per_gpu)
test_sampler = GCMCSampler(dataset,
'test')
test_seeds = th.arange(dataset.test_truths.shape[0])
test_dataloader = DataLoader(dataset=test_seeds,
batch_size=args.minibatch_size,
collate_fn=test_sampler.sample_blocks,
shuffle=False,
pin_memory=True,
drop_last=False,
num_workers=args.num_workers_per_gpu)
valid_dataloader = dgl.dataloading.EdgeDataLoader(
dataset.valid_dec_graph,
th.arange(dataset.valid_dec_graph.number_of_edges()),
sampler,
g_sampling=dataset.valid_enc_graph,
batch_size=args.minibatch_size,
shuffle=False,
drop_last=False)
test_dataloader = dgl.dataloading.EdgeDataLoader(
dataset.test_dec_graph,
th.arange(dataset.test_dec_graph.number_of_edges()),
sampler,
g_sampling=dataset.test_enc_graph,
batch_size=args.minibatch_size,
shuffle=False,
drop_last=False)
if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
......@@ -341,9 +289,14 @@ def run(proc_id, n_gpus, args, devices, dataset):
if epoch > 1:
t0 = time.time()
net.train()
for step, sample_data in enumerate(dataloader):
compact_g, frontier, head_feat, tail_feat, \
true_relation_labels, true_relation_ratings = sample_data
for step, (input_nodes, pair_graph, blocks) in enumerate(dataloader):
head_feat, tail_feat, blocks = load_subtensor(
input_nodes, pair_graph, blocks, dataset, dataset.train_enc_graph)
frontier = blocks[0]
compact_g = flatten_etypes(pair_graph, dataset, 'train').to(dev_id)
true_relation_labels = compact_g.edata['label']
true_relation_ratings = compact_g.edata['rating']
head_feat = head_feat.to(dev_id)
tail_feat = tail_feat.to(dev_id)
frontier = frontier.to(dev_id)
......@@ -375,6 +328,9 @@ def run(proc_id, n_gpus, args, devices, dataset):
print("[{}] {}".format(proc_id, logging_str))
iter_idx += 1
if step == 20:
return
if epoch > 1:
epoch_time = time.time() - t0
print("Epoch {} time {}".format(epoch, epoch_time))
......@@ -460,7 +416,7 @@ if __name__ == '__main__':
dataset.train_dec_graph.create_format_()
procs = []
for proc_id in range(n_gpus):
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, dataset))
p = mp.Process(target=thread_wrapped_func(run), args=(proc_id, n_gpus, args, devices, dataset))
p.start()
procs.append(p)
for p in procs:
......
......@@ -76,3 +76,7 @@ def get_optimizer(opt):
return optim.Adam
else:
raise NotImplementedError
def to_etype_name(rating):
return str(rating).replace('.', '_')
......@@ -64,8 +64,8 @@ class SAGE(nn.Module):
for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
sampler = dgl.sampling.MultiLayerNeighborSampler([None])
dataloader = dgl.sampling.NodeDataLoader(
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
g,
th.arange(g.number_of_nodes()),
sampler,
......@@ -129,9 +129,9 @@ def run(args, device, data):
test_nid = th.nonzero(~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']), as_tuple=True)[0]
# Create PyTorch DataLoader for constructing blocks
sampler = dgl.sampling.MultiLayerNeighborSampler(
sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.sampling.NodeDataLoader(
dataloader = dgl.dataloading.NodeDataLoader(
train_g,
train_nid,
sampler,
......
......@@ -66,8 +66,8 @@ class SAGE(nn.Module):
for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
sampler = dgl.sampling.MultiLayerNeighborSampler([None])
dataloader = dgl.sampling.NodeDataLoader(
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
g,
th.arange(g.number_of_nodes()),
sampler,
......@@ -149,9 +149,9 @@ def run(proc_id, n_gpus, args, devices, data):
train_nid = th.split(train_nid, math.ceil(len(train_nid) / n_gpus))[proc_id]
# Create PyTorch DataLoader for constructing blocks
sampler = dgl.sampling.MultiLayerNeighborSampler(
sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.sampling.NodeDataLoader(
dataloader = dgl.dataloading.NodeDataLoader(
train_g,
train_nid,
sampler,
......
......@@ -21,75 +21,22 @@ import sklearn.metrics as skm
from utils import thread_wrapped_func
#### Negative sampler
class NegativeSampler(object):
def __init__(self, g):
def __init__(self, g, k, neg_share=False):
self.weights = g.in_degrees().float() ** 0.75
def __call__(self, num_samples):
return self.weights.multinomial(num_samples, replacement=True)
#### Neighbor sampler
class NeighborSampler(object):
def __init__(self, g, fanouts, num_negs, neg_share=False):
self.g = g
self.fanouts = fanouts
self.neg_sampler = NegativeSampler(g)
self.num_negs = num_negs
self.k = k
self.neg_share = neg_share
def sample_blocks(self, seed_edges):
n_edges = len(seed_edges)
seed_edges = th.LongTensor(np.asarray(seed_edges))
heads, tails = self.g.find_edges(seed_edges)
if self.neg_share and n_edges % self.num_negs == 0:
neg_tails = self.neg_sampler(n_edges)
neg_tails = neg_tails.view(-1, 1, self.num_negs).expand(n_edges//self.num_negs,
self.num_negs,
self.num_negs).flatten()
neg_heads = heads.view(-1, 1).expand(n_edges, self.num_negs).flatten()
def __call__(self, g, eids):
src, _ = g.find_edges(eids)
n = len(src)
if self.neg_share and n % self.k == 0:
dst = self.weights.multinomial(n, replacement=True)
dst = dst.view(-1, 1, self.k).expand(-1, self.k, -1).flatten()
else:
neg_tails = self.neg_sampler(self.num_negs * n_edges)
neg_heads = heads.view(-1, 1).expand(n_edges, self.num_negs).flatten()
# Maintain the correspondence between heads, tails and negative tails as two
# graphs.
# pos_graph contains the correspondence between each head and its positive tail.
# neg_graph contains the correspondence between each head and its negative tails.
# Both pos_graph and neg_graph are first constructed with the same node space as
# the original graph. Then they are compacted together with dgl.compact_graphs.
pos_graph = dgl.graph((heads, tails), num_nodes=self.g.number_of_nodes())
neg_graph = dgl.graph((neg_heads, neg_tails), num_nodes=self.g.number_of_nodes())
pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph])
# Obtain the node IDs being used in either pos_graph or neg_graph. Since they
# are compacted together, pos_graph and neg_graph share the same compacted node
# space.
seeds = pos_graph.ndata[dgl.NID]
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
frontier = dgl.sampling.sample_neighbors(self.g, seeds, fanout, replace=True)
# Remove all edges between heads and tails, as well as heads and neg_tails.
_, _, edge_ids = frontier.edge_ids(
th.cat([heads, tails, neg_heads, neg_tails]),
th.cat([tails, heads, neg_tails, neg_heads]),
return_uv=True)
frontier = dgl.remove_edges(frontier, edge_ids)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)
# Pre-generate CSR format that it can be used in training directly
block.create_format_()
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
# Pre-generate CSR format that it can be used in training directly
return pos_graph, neg_graph, blocks
dst = self.weights.multinomial(n, replacement=True)
src = src.repeat_interleave(self.k)
return src, dst
def load_subtensor(g, input_nodes, device):
"""
......@@ -145,12 +92,18 @@ class SAGE(nn.Module):
for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
for start in tqdm.trange(0, len(nodes), batch_size):
end = start + batch_size
batch_nodes = nodes[start:end]
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
block = block.int().to(device)
input_nodes = block.srcdata[dgl.NID]
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
g,
th.arange(g.number_of_nodes()),
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device)
h = x[input_nodes].to(device)
h = layer(block, h)
......@@ -158,7 +111,7 @@ class SAGE(nn.Module):
h = self.activation(h)
h = self.dropout(h)
y[start:end] = h.cpu()
y[output_nodes] = h.cpu()
x = y
return y
......@@ -245,11 +198,9 @@ def run(proc_id, n_gpus, args, devices, data):
#val_nid = th.LongTensor(np.nonzero(val_mask)[0])
#test_nid = th.LongTensor(np.nonzero(test_mask)[0])
# Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], args.num_negs, args.neg_share)
# Create PyTorch DataLoader for constructing blocks
train_seeds = np.arange(g.number_of_edges())
n_edges = g.number_of_edges()
train_seeds = np.arange(n_edges)
if n_gpus > 0:
num_per_gpu = (train_seeds.shape[0] + n_gpus -1) // n_gpus
train_seeds = train_seeds[proc_id * num_per_gpu :
......@@ -257,10 +208,17 @@ def run(proc_id, n_gpus, args, devices, data):
if (proc_id + 1) * num_per_gpu < train_seeds.shape[0]
else train_seeds.shape[0]]
dataloader = DataLoader(
dataset=train_seeds,
# Create sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.dataloading.EdgeDataLoader(
g, train_seeds, sampler, exclude='reverse_id',
# For each edge with ID e in Reddit dataset, the reverse edge is e ± |E|/2.
reverse_eids=th.cat([
th.arange(n_edges // 2, n_edges),
th.arange(0, n_edges // 2)]),
negative_sampler=NegativeSampler(g, args.num_negs),
batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True,
drop_last=False,
pin_memory=True,
......@@ -290,10 +248,7 @@ def run(proc_id, n_gpus, args, devices, data):
# blocks.
tic_step = time.time()
for step, (pos_graph, neg_graph, blocks) in enumerate(dataloader):
# The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block.
input_nodes = blocks[0].srcdata[dgl.NID]
for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(dataloader):
batch_inputs = load_subtensor(g, input_nodes, device)
d_step = time.time()
......
......@@ -84,9 +84,9 @@ class JTNNDataset(Dataset):
cand_graphs = []
atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC)
bond_x_dec = torch.zeros(0, BOND_FDIM_DEC)
tree_mess_src_e = torch.zeros(0, 2).int()
tree_mess_tgt_e = torch.zeros(0, 2).int()
tree_mess_tgt_n = torch.zeros(0).int()
tree_mess_src_e = torch.zeros(0, 2).long()
tree_mess_tgt_e = torch.zeros(0, 2).long()
tree_mess_tgt_n = torch.zeros(0).long()
# prebuild the stereoisomers
cands = mol_tree.stereo_cands
......
import torch
import torch.nn as nn
from .nnutils import cuda, line_graph
from .nnutils import cuda
import rdkit.Chem as Chem
import dgl
from dgl import mean_nodes
from dgl import mean_nodes, line_graph
import dgl.function as DGLF
import os
......@@ -93,13 +93,9 @@ def mol2dgl_single(cand_batch):
return cand_graphs, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \
torch.IntTensor(tree_mess_source_edges), \
torch.IntTensor(tree_mess_target_edges), \
torch.IntTensor(tree_mess_target_nodes)
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msg='msg', out='accum_msg')
torch.LongTensor(tree_mess_source_edges), \
torch.LongTensor(tree_mess_target_edges), \
torch.LongTensor(tree_mess_target_nodes)
class LoopyBPUpdate(nn.Module):
......@@ -224,19 +220,19 @@ class DGLJTMPN(nn.Module):
tgt_u, tgt_v = tree_mess_tgt_edges.unbind(1)
src_u = src_u.to(mol_tree_batch.device)
src_v = src_v.to(mol_tree_batch.device)
eid = mol_tree_batch.edge_ids(src_u.int(), src_v.int()).long()
eid = mol_tree_batch.edge_ids(src_u, src_v)
alpha = mol_tree_batch.edata['m'][eid]
cand_graphs.edges[tgt_u, tgt_v].data['alpha'] = alpha
else:
src_u, src_v = tree_mess_src_edges.unbind(1)
src_u = src_u.to(mol_tree_batch.device)
src_v = src_v.to(mol_tree_batch.device)
eid = mol_tree_batch.edge_ids(src_u.int(), src_v.int()).long()
eid = mol_tree_batch.edge_ids(src_u, src_v)
alpha = mol_tree_batch.edata['m'][eid]
node_idx = (tree_mess_tgt_nodes
.to(device=zero_node_state.device)[:, None]
.expand_as(alpha))
node_alpha = zero_node_state.clone().scatter_add(0, node_idx.long(), alpha)
node_alpha = zero_node_state.clone().scatter_add(0, node_idx, alpha)
cand_graphs.ndata['alpha'] = node_alpha
cand_graphs.apply_edges(
func=lambda edges: {'alpha': edges.src['alpha']},
......@@ -244,17 +240,14 @@ class DGLJTMPN(nn.Module):
cand_line_graph.ndata.update(cand_graphs.edata)
for i in range(self.depth - 1):
cand_line_graph.update_all(
mpn_loopy_bp_msg,
mpn_loopy_bp_reduce,
self.loopy_bp_updater,
)
cand_line_graph.update_all(DGLF.copy_u('msg', 'msg'), DGLF.sum('msg', 'accum_msg'))
cand_line_graph.apply_nodes(self.loopy_bp_updater)
cand_graphs.edata.update(cand_line_graph.ndata)
cand_graphs.update_all(
mpn_gather_msg,
mpn_gather_reduce,
self.gather_updater,
)
cand_graphs.update_all(DGLF.copy_e('msg', 'msg'), DGLF.sum('msg', 'm'))
if PAPER:
cand_graphs.update_all(DGLF.copy_e('alpha', 'alpha'), DGLF.sum('alpha', 'accum_alpha'))
cand_graphs.apply_nodes(self.gather_updater)
return cand_graphs
......@@ -3,8 +3,8 @@ import torch.nn as nn
import torch.nn.functional as F
from .mol_tree_nx import DGLMolTree
from .chemutils import enum_assemble_nx, get_mol
from .nnutils import GRUUpdate, cuda, line_graph, tocpu
from dgl import batch, dfs_labeled_edges_generator
from .nnutils import GRUUpdate, cuda, tocpu
from dgl import batch, dfs_labeled_edges_generator, line_graph
import dgl.function as DGLF
import numpy as np
......@@ -29,10 +29,6 @@ def dec_tree_node_update(nodes):
return {'new': nodes.data['new'].clone().zero_()}
dec_tree_edge_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
dec_tree_edge_reduce = [DGLF.sum(msg='m', out='s'), DGLF.sum(msg='rm', out='accum_rm')]
def have_slots(fa_slots, ch_slots):
if len(fa_slots) > 2 and len(ch_slots) > 2:
return True
......@@ -146,12 +142,8 @@ class DGLJTNNDecoder(nn.Module):
q_targets = []
# Predict root
mol_tree_batch.pull(
root_ids,
dec_tree_node_msg,
dec_tree_node_reduce,
dec_tree_node_update,
)
mol_tree_batch.pull(root_ids, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
mol_tree_batch.apply_nodes(dec_tree_node_update)
# Extract hidden states and store them for stop/label prediction
h = mol_tree_batch.nodes[root_ids].data['h']
x = mol_tree_batch.nodes[root_ids].data['x']
......@@ -168,28 +160,22 @@ class DGLJTNNDecoder(nn.Module):
u, v = mol_tree_batch.find_edges(eid)
p_target_list = torch.zeros_like(root_out_degrees)
p_target_list[root_out_degrees > 0] = (1 - p).int()
p_target_list[root_out_degrees > 0] = (1 - p)
p_target_list = p_target_list[root_out_degrees >= 0]
p_targets.append(torch.tensor(p_target_list))
root_out_degrees -= (root_out_degrees == 0).int()
root_out_degrees -= (root_out_degrees == 0).long()
root_out_degrees -= torch.tensor(np.isin(root_ids, v.cpu().numpy())).to(root_out_degrees)
mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)
mol_tree_batch_lg.pull(
eid,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update,
)
mol_tree_batch_lg.pull(eid, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
mol_tree_batch_lg.pull(eid, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm'))
mol_tree_batch_lg.apply_nodes(self.dec_tree_edge_update)
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
is_new = mol_tree_batch.nodes[v].data['new']
mol_tree_batch.pull(
v,
dec_tree_node_msg,
dec_tree_node_reduce,
dec_tree_node_update,
)
mol_tree_batch.pull(v, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
mol_tree_batch.apply_nodes(dec_tree_node_update)
# Extract
n_repr = mol_tree_batch.nodes[v].data
......@@ -209,7 +195,7 @@ class DGLJTNNDecoder(nn.Module):
p_targets.append(torch.zeros(
(root_out_degrees == 0).sum(),
device=root_out_degrees.device,
dtype=torch.int32))
dtype=torch.int64))
# Batch compute the stop/label prediction losses
p_inputs = torch.cat(p_inputs, 0)
......@@ -310,16 +296,15 @@ class DGLJTNNDecoder(nn.Module):
mol_tree_graph_lg.pull(
uv,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update.update_zm,
)
DGLF.copy_u('m', 'm'),
DGLF.sum('m', 's'))
mol_tree_graph_lg.pull(
uv,
DGLF.copy_u('rm', 'rm'),
DGLF.sum('rm', 'accum_rm'))
mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update.update_zm)
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull(
v,
dec_tree_node_msg,
dec_tree_node_reduce,
)
mol_tree_graph.pull(v, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
h_v = mol_tree_graph.ndata['h'][v:v+1]
q_input = torch.cat([h_v, mol_vec], 1)
......@@ -371,18 +356,11 @@ class DGLJTNNDecoder(nn.Module):
pu, _ = stack[-2]
u_pu = mol_tree_graph.edge_id(u, pu)
mol_tree_graph_lg.pull(
u_pu,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update,
)
mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm'))
mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update)
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull(
pu,
dec_tree_node_msg,
dec_tree_node_reduce,
)
mol_tree_graph.pull(pu, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
stack.pop()
effective_nodes = mol_tree_graph.filter_nodes(lambda nodes: nodes.data['fail'] != 1)
......
import torch
import torch.nn as nn
from .nnutils import GRUUpdate, cuda, line_graph, tocpu
from dgl import batch, bfs_edges_generator
from .nnutils import GRUUpdate, cuda, tocpu
from dgl import batch, bfs_edges_generator, line_graph
import dgl.function as DGLF
import numpy as np
......@@ -18,11 +18,6 @@ def level_order(forest, roots):
yield from reversed(edges_back)
yield from edges
enc_tree_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
enc_tree_reduce = [DGLF.sum(msg='m', out='s'), DGLF.sum(msg='rm', out='accum_rm')]
enc_tree_gather_msg = DGLF.copy_edge(edge='m', out='m')
enc_tree_gather_reduce = DGLF.sum(msg='m', out='m')
class EncoderGatherUpdate(nn.Module):
def __init__(self, hidden_size):
nn.Module.__init__(self)
......@@ -102,21 +97,15 @@ class DGLJTNNEncoder(nn.Module):
# if m_ij is actually computed or not.
mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)
for eid in level_order(mol_tree_batch, root_ids):
#eid = mol_tree_batch.edge_ids(u, v)
mol_tree_batch_lg.pull(
eid.to(mol_tree_batch_lg.device),
enc_tree_msg,
enc_tree_reduce,
self.enc_tree_update,
)
eid = eid.to(mol_tree_batch_lg.device)
mol_tree_batch_lg.pull(eid, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
mol_tree_batch_lg.pull(eid, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'rm'))
mol_tree_batch_lg.apply_nodes(self.enc_tree_update)
# Readout
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
mol_tree_batch.update_all(
enc_tree_gather_msg,
enc_tree_gather_reduce,
self.enc_tree_gather_update,
)
mol_tree_batch.update_all(DGLF.copy_e('m', 'm'), DGLF.sum('m', 'm'))
mol_tree_batch.apply_nodes(self.enc_tree_gather_update)
root_vecs = mol_tree_batch.nodes[root_ids].data['h']
......
......@@ -183,7 +183,7 @@ class DGLJTNNVAE(nn.Module):
node['is_leaf'] = False
set_atommap(node['mol'], node['nid'])
mol_tree_sg = mol_tree.graph.subgraph(effective_nodes.int().to(tree_vec.device))
mol_tree_sg = mol_tree.graph.subgraph(effective_nodes.to(tree_vec.device))
mol_tree_msg, _ = self.jtnn([mol_tree_sg])
mol_tree_msg = unbatch(mol_tree_msg)[0]
mol_tree_msg.nodes_dict = nodes_dict
......
......@@ -4,9 +4,8 @@ import rdkit.Chem as Chem
import torch.nn.functional as F
from .chemutils import get_mol
import dgl
from dgl import mean_nodes
from dgl import mean_nodes, line_graph
import dgl.function as DGLF
from .nnutils import line_graph
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
......@@ -66,10 +65,6 @@ def mol2dgl_single(smiles):
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0)
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msg='msg', out='accum_msg')
class LoopyBPUpdate(nn.Module):
def __init__(self, hidden_size):
super(LoopyBPUpdate, self).__init__()
......@@ -84,10 +79,6 @@ class LoopyBPUpdate(nn.Module):
return {'msg': msg}
mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg')
mpn_gather_reduce = DGLF.sum(msg='msg', out='m')
class GatherUpdate(nn.Module):
def __init__(self, hidden_size):
super(GatherUpdate, self).__init__()
......@@ -163,17 +154,11 @@ class DGLMPN(nn.Module):
})
for i in range(self.depth - 1):
mol_line_graph.update_all(
mpn_loopy_bp_msg,
mpn_loopy_bp_reduce,
self.loopy_bp_updater,
)
mol_line_graph.update_all(DGLF.copy_u('msg', 'msg'), DGLF.sum('msg', 'accum_msg'))
mol_line_graph.apply_nodes(self.loopy_bp_updater)
mol_graph.edata.update(mol_line_graph.ndata)
mol_graph.update_all(
mpn_gather_msg,
mpn_gather_reduce,
self.gather_updater,
)
mol_graph.update_all(DGLF.copy_e('msg', 'msg'), DGLF.sum('msg', 'm'))
mol_graph.apply_nodes(self.gather_updater)
return mol_graph
......@@ -48,10 +48,3 @@ def tocpu(g):
src = src.cpu()
dst = dst.cpu()
return dgl.graph((src, dst), num_nodes=g.number_of_nodes())
def line_graph(g, backtracking=True, shared=False):
#g2 = tocpu(g)
g2 = dgl.line_graph(g, backtracking, shared)
#g2 = g2.to(g.device)
g2.ndata.update(g.edata)
return g2
......@@ -78,7 +78,7 @@ def train():
for epoch in range(MAX_EPOCH):
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
for it, batch in tqdm.tqdm(enumerate(dataloader), total=2000):
for it, batch in enumerate(tqdm.tqdm(dataloader)):
model.zero_grad()
try:
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
......
......@@ -85,8 +85,8 @@ class GAT(nn.Module):
y = th.zeros(g.number_of_nodes(), self.n_hidden * num_heads if l != len(self.layers) - 1 else self.n_classes)
else:
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
sampler = dgl.sampling.MultiLayerNeighborSampler([None])
dataloader = dgl.sampling.NodeDataLoader(
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
g,
th.arange(g.number_of_nodes()),
sampler,
......
......@@ -17,29 +17,6 @@ import tqdm
import traceback
from ogb.nodeproppred import DglNodePropPredDataset
#### Neighbor sampler
class NeighborSampler(object):
def __init__(self, g, fanouts):
self.g = g
self.fanouts = fanouts
def sample_blocks(self, seeds):
seeds = th.LongTensor(np.asarray(seeds))
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
if fanout == 0:
frontier = dgl.in_subgraph(self.g, seeds)
else:
frontier = dgl.sampling.sample_neighbors(self.g, seeds, fanout, replace=True)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
return blocks
class GAT(nn.Module):
def __init__(self,
......@@ -97,8 +74,8 @@ class GAT(nn.Module):
else:
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
sampler = dgl.sampling.MultiLayerNeighborSampler([None])
dataloader = dgl.sampling.NodeDataLoader(
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
g,
th.arange(g.number_of_nodes()),
sampler,
......@@ -161,9 +138,9 @@ def run(args, device, data):
train_nid, val_nid, test_nid, in_feats, labels, n_classes, g, num_heads = data
# Create PyTorch DataLoader for constructing blocks
sampler = dgl.sampling.MultiLayerNeighborSampler(
sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.sampling.NodeDataLoader(
dataloader = dgl.dataloading.NodeDataLoader(
g,
train_nid,
sampler,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment