Unverified Commit 2c489fad authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

SPMV specialization (#32)

* fix edge list order problem in cached graph.

* minor fix

* fix bug in edge iter

* SPMV works

* gcn spmv on CPU

* change gcn style

* fix cached graph performance; fixed gcn dataset bug

* reorg dir

* non-batch spmv; partial update problem with shape change

* fix reorder problem; finish gcn-batch impl

* pop API

* GPU context
parent 11e42d10
"""
Semi-Supervised Classification with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1609.02907
Code: https://github.com/tkipf/gcn
"""
import argparse
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import load_cora, load_citeseer, load_pubmed
def gcn_msg(src, edge):
return src['h']
def gcn_reduce(node, msgs):
return sum(msgs)
class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node, accum):
h = self.linear(accum)
if self.activation:
h = self.activation(h)
return {'h' : h}
class GCN(nn.Module):
def __init__(self,
nx_graph,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(GCN, self).__init__()
self.g = DGLGraph(nx_graph)
self.dropout = dropout
# input layer
self.layers = nn.ModuleList([NodeUpdateModule(in_feats, n_hidden, activation)])
# hidden layers
for i in range(n_layers - 1):
self.layers.append(NodeUpdateModule(n_hidden, n_hidden, activation))
# output layer
self.layers.append(NodeUpdateModule(n_hidden, n_classes))
def forward(self, features, train_nodes):
for n, feat in features.items():
self.g.nodes[n]['h'] = feat
for layer in self.layers:
# apply dropout
if self.dropout:
self.g.nodes[n]['h'] = F.dropout(g.nodes[n]['h'], p=self.dropout)
self.g.update_all(gcn_msg, gcn_reduce, layer)
return torch.cat([self.g.nodes[n]['h'] for n in train_nodes])
def main(args):
# load and preprocess dataset
if args.dataset == 'cora':
data = load_cora()
elif args.dataset == 'citeseer':
data = load_citeseer()
elif args.dataset == 'pubmed':
data = load_pubmed()
else:
raise RuntimeError('Error dataset: {}'.format(args.dataset))
# features of each samples
features = {}
labels = []
train_nodes = []
for n in data.graph.nodes():
features[n] = torch.FloatTensor(data.features[n, :])
if data.train_mask[n] == 1:
train_nodes.append(n)
labels.append(data.labels[n])
labels = torch.LongTensor(labels)
in_feats = data.features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
mask = mask.cuda()
# create GCN model
model = GCN(data.graph,
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
F.relu,
args.dropout)
if cuda:
model.cuda()
# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# initialize graph
dur = []
for epoch in range(args.n_epochs):
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features, train_nodes)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
parser.add_argument("--dataset", type=str, required=True,
help="dataset")
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-3,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=20,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
args = parser.parse_args()
print(args)
main(args)
"""
Semi-Supervised Classification with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1609.02907
Code: https://github.com/tkipf/gcn
GCN with batch processing
"""
import argparse
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
from dgl.data import load_cora, load_citeseer, load_pubmed
def gcn_msg(src, edge):
return src
def gcn_reduce(node, msgs):
return torch.sum(msgs, 1)
class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node, accum):
h = self.linear(accum)
if self.activation:
h = self.activation(h)
return h
class GCN(nn.Module):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(GCN, self).__init__()
self.g = g
self.dropout = dropout
# input layer
self.layers = nn.ModuleList([NodeUpdateModule(in_feats, n_hidden, activation)])
# hidden layers
for i in range(n_layers - 1):
self.layers.append(NodeUpdateModule(n_hidden, n_hidden, activation))
# output layer
self.layers.append(NodeUpdateModule(n_hidden, n_classes))
def forward(self, features):
self.g.set_n_repr(features)
for layer in self.layers:
# apply dropout
if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout)
self.g.set_n_repr(val)
self.g.update_all(gcn_msg, gcn_reduce, layer, batchable=True)
return self.g.pop_n_repr()
def main(args):
# load and preprocess dataset
if args.dataset == 'cora':
data = load_cora()
elif args.dataset == 'citeseer':
data = load_citeseer()
elif args.dataset == 'pubmed':
data = load_pubmed()
else:
raise RuntimeError('Error dataset: {}'.format(args.dataset))
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
mask = mask.cuda()
# create GCN model
g = DGLGraph(data.graph)
if cuda:
g.set_device(dgl.gpu(args.gpu))
model = GCN(g,
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
F.relu,
args.dropout)
if cuda:
model.cuda()
# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# initialize graph
dur = []
for epoch in range(args.n_epochs):
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
parser.add_argument("--dataset", type=str, required=True,
help="dataset")
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-3,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=20,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
args = parser.parse_args()
print(args)
main(args)
"""
Semi-Supervised Classification with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1609.02907
Code: https://github.com/tkipf/gcn
GCN with SPMV specialization.
"""
import argparse
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
from dgl.data import load_cora, load_citeseer, load_pubmed
class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node, accum):
h = self.linear(accum)
if self.activation:
h = self.activation(h)
return h
class GCN(nn.Module):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(GCN, self).__init__()
self.g = g
self.dropout = dropout
# input layer
self.layers = nn.ModuleList([NodeUpdateModule(in_feats, n_hidden, activation)])
# hidden layers
for i in range(n_layers - 1):
self.layers.append(NodeUpdateModule(n_hidden, n_hidden, activation))
# output layer
self.layers.append(NodeUpdateModule(n_hidden, n_classes))
def forward(self, features):
self.g.set_n_repr(features)
for layer in self.layers:
# apply dropout
if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout)
self.g.set_n_repr(val)
self.g.update_all('from_src', 'sum', layer, batchable=True)
return self.g.pop_n_repr()
def main(args):
# load and preprocess dataset
if args.dataset == 'cora':
data = load_cora()
elif args.dataset == 'citeseer':
data = load_citeseer()
elif args.dataset == 'pubmed':
data = load_pubmed()
else:
raise RuntimeError('Error dataset: {}'.format(args.dataset))
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
mask = mask.cuda()
# create GCN model
g = DGLGraph(data.graph)
if cuda:
g.set_device(dgl.gpu(args.gpu))
model = GCN(g,
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
F.relu,
args.dropout)
if cuda:
model.cuda()
# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# initialize graph
dur = []
for epoch in range(args.n_epochs):
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
parser.add_argument("--dataset", type=str, required=True,
help="dataset")
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-3,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=20,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
args = parser.parse_args()
print(args)
main(args)
from .base import ALL
from .graph import DGLGraph
from .graph import __MSG__, __REPR__, ALL
from .graph import __MSG__, __REPR__
from .context import cpu, gpu
......@@ -9,20 +9,9 @@ SparseTensor = sp.sparse.spmatrix
def asnumpy(a):
return a
def concatenate(arrays, axis=0):
return np.concatenate(arrays, axis)
def packable(arrays):
return all(isinstance(a, np.ndarray) for a in arrays) and \
all(a.dtype == arrays[0].dtype for a in arrays) and \
all(a.shape[1:] == arrays[0].shape[1:] for a in arrays)
def pack(arrays):
return np.concatenate(arrays, axis=0)
def unpackable(a):
return isinstance(a, np.ndarray) and a.size > 0
def unpack(a):
return np.split(a, a.shape[0], axis=0)
......
......@@ -5,7 +5,7 @@ import scipy.sparse
# Tensor types
Tensor = th.Tensor
SparseTensor = scipy.sparse.spmatrix
SparseTensor = th.sparse.FloatTensor
# Data types
float16 = th.float16
......@@ -19,17 +19,13 @@ int64 = th.int64
# Operators
tensor = th.tensor
sparse_tensor = th.sparse.FloatTensor
sum = th.sum
max = th.max
def asnumpy(a):
return a.cpu().numpy()
def packable(tensors):
return all(isinstance(x, th.Tensor) and \
x.dtype == tensors[0].dtype and \
x.shape[1:] == tensors[0].shape[1:] for x in tensors)
def pack(tensors):
return th.cat(tensors)
......@@ -54,7 +50,24 @@ def broadcast_to(x, to_array):
return x + th.zeros_like(to_array)
nonzero = th.nonzero
def eq_scalar(x, val):
return th.eq(x, float(val))
squeeze = th.squeeze
unsqueeze = th.unsqueeze
reshape = th.reshape
ones = th.ones
spmm = th.spmm
sort = th.sort
def to_context(x, ctx):
if ctx is None:
return x
elif ctx.device == 'gpu':
th.cuda.set_device(ctx.device_id)
return x.cuda()
elif ctx.device == 'cpu':
return x.cpu()
else:
raise RuntimeError('Invalid context', ctx)
"""Module for base types and utilities."""
# A special argument for selecting all nodes/edges.
ALL = "__ALL__"
def is_all(arg):
return isinstance(arg, str) and arg == ALL
......@@ -14,63 +14,80 @@ import dgl.utils as utils
class CachedGraph:
def __init__(self):
self._graph = igraph.Graph(directed=True)
self._adjmat = None # cached adjacency matrix
def add_nodes(self, num_nodes):
self._graph.add_vertices(num_nodes)
def add_edge(self, u, v):
self._graph.add_edge(u, v)
def add_edges(self, u, v):
# The edge will be assigned ids equal to the order.
# TODO(minjie): tensorize the loop
for uu, vv in utils.edge_iter(u, v):
self._graph.add_edge(uu, vv)
uvs = list(utils.edge_iter(u, v))
self._graph.add_edges(uvs)
def get_edge_id(self, u, v):
# TODO(minjie): tensorize the loop
uvs = list(utils.edge_iter(u, v))
eids = self._graph.get_eids(uvs)
return F.tensor(eids, dtype=F.int64)
return utils.convert_to_id_tensor(eids)
def in_edges(self, v):
# TODO(minjie): tensorize the loop
src = []
dst = []
for vv in utils.node_iter(v):
uu = self._graph.predecessors(vv)
src += uu
dst += [vv] * len(uu)
src = F.tensor(src, dtype=F.int64)
dst = F.tensor(dst, dtype=F.int64)
src = utils.convert_to_id_tensor(src)
dst = utils.convert_to_id_tensor(dst)
return src, dst
def out_edges(self, u):
# TODO(minjie): tensorize the loop
src = []
dst = []
for uu in utils.node_iter(u):
vv = self._graph.successors(uu)
src += [uu] * len(vv)
dst += vv
src = F.tensor(src, dtype=F.int64)
dst = F.tensor(dst, dtype=F.int64)
src = utils.convert_to_id_tensor(src)
dst = utils.convert_to_id_tensor(dst)
return src, dst
def edges(self):
# TODO(minjie): tensorize
elist = self._graph.get_edgelist()
src = [u for u, _ in elist]
dst = [v for _, v in elist]
src = F.tensor(src, dtype=F.int64)
dst = F.tensor(dst, dtype=F.int64)
src = utils.convert_to_id_tensor(src)
dst = utils.convert_to_id_tensor(dst)
return src, dst
def in_degrees(self, v):
degs = self._graph.indegree(list(v))
return F.tensor(degs, dtype=F.int64)
return utils.convert_to_id_tensor(degs)
def adjmat(self, ctx):
"""Return a sparse adjacency matrix.
The row dimension represents the dst nodes; the column dimension
represents the src nodes.
"""
if self._adjmat is None:
elist = self._graph.get_edgelist()
src = [u for u, _ in elist]
dst = [v for _, v in elist]
src = F.unsqueeze(utils.convert_to_id_tensor(src), 0)
dst = F.unsqueeze(utils.convert_to_id_tensor(dst), 0)
idx = F.pack([dst, src])
n = self._graph.vcount()
dat = F.ones((len(elist),))
self._adjmat = F.sparse_tensor(idx, dat, [n, n])
# TODO(minjie): manually convert adjmat to context
self._adjmat = F.to_context(self._adjmat, ctx)
return self._adjmat
def create_cached_graph(dglgraph):
# TODO: tensorize the loop
cg = CachedGraph()
cg.add_nodes(dglgraph.number_of_nodes())
for u, v in dglgraph.edges():
cg.add_edges(u, v)
cg._graph.add_edges(dglgraph.edge_list)
return cg
"""DGL's device context shim."""
class Context(object):
def __init__(self, dev, devid=-1):
self.device = dev
self.device_id = devid
def __str__(self):
return '{}:{}'.format(self.device, self.device_id)
def gpu(gpuid):
return Context('gpu', gpuid)
def cpu():
return Context('cpu')
from . import citation_graph as citegrh
load_cora = citegrh.load_cora
load_citeseer = citegrh.load_citeseer
load_pubmed = citegrh.load_pubmed
from .utils import *
"""Cora, citeseer, pubmed dataset.
(lingfan): following dataset loading and preprocessing code from tkipf/gcn
https://github.com/tkipf/gcn/blob/master/gcn/utils.py
"""
import numpy as np
import pickle as pkl
import networkx as nx
import scipy.sparse as sp
import os, sys
from dgl.data.utils import download, extract_archive, get_download_dir
_urls = {
'cora' : 'https://www.dropbox.com/s/3ggdpkj7ou8svoc/cora.zip?dl=1',
'citeseer' : 'https://www.dropbox.com/s/cr4m05shgp8advz/citeseer.zip?dl=1',
'pubmed' : 'https://www.dropbox.com/s/fj5q6pi66xhymcm/pubmed.zip?dl=1',
}
class GCNDataset:
def __init__(self, name):
self.name = name
self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, name)
download(_urls[name], path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, name))
def load(self):
"""Loads input data from gcn/data directory
ind.name.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
ind.name.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
ind.name.allx => the feature vectors of both labeled and unlabeled training instances
(a superset of ind.name.x) as scipy.sparse.csr.csr_matrix object;
ind.name.y => the one-hot labels of the labeled training instances as numpy.ndarray object;
ind.name.ty => the one-hot labels of the test instances as numpy.ndarray object;
ind.name.ally => the labels for instances in ind.name.allx as numpy.ndarray object;
ind.name.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
object;
ind.name.test.index => the indices of test instances in graph, for the inductive setting as list object.
All objects above must be saved using python pickle module.
:param name: Dataset name
:return: All data input files loaded (as well the training/test data).
"""
root = '{}/{}'.format(self.dir, self.name)
objnames = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
objects = []
for i in range(len(objnames)):
with open("{}/ind.{}.{}".format(root, self.name, objnames[i]), 'rb') as f:
if sys.version_info > (3, 0):
objects.append(pkl.load(f, encoding='latin1'))
else:
objects.append(pkl.load(f))
x, y, tx, ty, allx, ally, graph = tuple(objects)
test_idx_reorder = _parse_index_file("{}/ind.{}.test.index".format(root, self.name))
test_idx_range = np.sort(test_idx_reorder)
if self.name == 'citeseer':
# Fix citeseer dataset (there are some isolated nodes in the graph)
# Find isolated nodes, add them as zero-vecs into the right position
test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
tx_extended[test_idx_range-min(test_idx_range), :] = tx
tx = tx_extended
ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
ty_extended[test_idx_range-min(test_idx_range), :] = ty
ty = ty_extended
features = sp.vstack((allx, tx)).tolil()
features[test_idx_reorder, :] = features[test_idx_range, :]
graph = nx.DiGraph(nx.from_dict_of_lists(graph))
onehot_labels = np.vstack((ally, ty))
onehot_labels[test_idx_reorder, :] = onehot_labels[test_idx_range, :]
labels = np.argmax(onehot_labels, 1)
idx_test = test_idx_range.tolist()
idx_train = range(len(y))
idx_val = range(len(y), len(y)+500)
train_mask = _sample_mask(idx_train, labels.shape[0])
val_mask = _sample_mask(idx_val, labels.shape[0])
test_mask = _sample_mask(idx_test, labels.shape[0])
#y_train = np.zeros(labels.shape)
#y_val = np.zeros(labels.shape)
#y_test = np.zeros(labels.shape)
#y_train[train_mask, :] = labels[train_mask, :]
#y_val[val_mask, :] = labels[val_mask, :]
#y_test[test_mask, :] = labels[test_mask, :]
self.graph = graph
self.features = _preprocess_features(features)
self.labels = labels
self.onehot_labels = onehot_labels
self.num_labels = onehot_labels.shape[1]
self.train_mask = train_mask
self.val_mask = val_mask
self.test_mask = test_mask
print('Finished data loading and preprocessing.')
print(' NumNodes: {}'.format(self.graph.number_of_nodes()))
print(' NumEdges: {}'.format(self.graph.number_of_edges()))
print(' NumFeats: {}'.format(self.features.shape[1]))
print(' NumClasses: {}'.format(self.num_labels))
print(' NumTrainingSamples: {}'.format(len(np.nonzero(self.train_mask)[0])))
print(' NumValidationSamples: {}'.format(len(np.nonzero(self.val_mask)[0])))
print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0])))
def _preprocess_features(features):
"""Row-normalize feature matrix and convert to tuple representation"""
rowsum = np.array(features.sum(1))
r_inv = np.power(rowsum, -1).flatten()
r_inv[np.isinf(r_inv)] = 0.
r_mat_inv = sp.diags(r_inv)
features = r_mat_inv.dot(features)
return features.todense()
def _parse_index_file(filename):
"""Parse index file."""
index = []
for line in open(filename):
index.append(int(line.strip()))
return index
def _sample_mask(idx, l):
"""Create mask."""
mask = np.zeros(l)
mask[idx] = 1
return mask
def load_cora():
data = GCNDataset('cora')
data.load()
return data
def load_citeseer():
data = GCNDataset('citeseer')
data.load()
return data
def load_pubmed():
data = GCNDataset('pubmed')
data.load()
return data
"""Dataset utilities."""
import os
import hashlib
import warnings
import zipfile
import tarfile
try:
import requests
except ImportError:
class requests_failed_to_import(object):
pass
requests = requests_failed_to_import
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
"""Download an given URL
Codes borrowed from mxnet/gluon/utils.py
Parameters
----------
url : str
URL to download
path : str, optional
Destination path to store downloaded file. By default stores to the
current directory with same name as in url.
overwrite : bool, optional
Whether to overwrite destination file if already exists.
sha1_hash : str, optional
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
but doesn't match.
retries : integer, default 5
The number of times to attempt the download in case of failure or non 200 return codes
verify_ssl : bool, default True
Verify SSL certificates.
Returns
-------
str
The file path of the downloaded file.
"""
if path is None:
fname = url.split('/')[-1]
# Empty filenames are invalid
assert fname, 'Can\'t construct file-name from this URL. ' \
'Please set the `path` option manually.'
else:
path = os.path.expanduser(path)
if os.path.isdir(path):
fname = os.path.join(path, url.split('/')[-1])
else:
fname = path
assert retries >= 0, "Number of retries should be at least 0"
if not verify_ssl:
warnings.warn(
'Unverified HTTPS request is being made (verify_ssl=False). '
'Adding certificate verification is strongly advised.')
if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)
while retries+1 > 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
print('Downloading %s from %s...'%(fname, url))
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s"%url)
with open(fname, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
if sha1_hash and not check_sha1(fname, sha1_hash):
raise UserWarning('File {} is downloaded but the content hash does not match.'\
' The repo may be outdated or download may be incomplete. '\
'If the "repo_url" is overridden, consider switching to '\
'the default repo.'.format(fname))
break
except Exception as e:
retries -= 1
if retries <= 0:
raise e
else:
print("download failed, retrying, {} attempt{} left"
.format(retries, 's' if retries > 1 else ''))
return fname
def check_sha1(filename, sha1_hash):
"""Check whether the sha1 hash of the file content matches the expected hash.
Codes borrowed from mxnet/gluon/utils.py
Parameters
----------
filename : str
Path to the file.
sha1_hash : str
Expected sha1 hash in hexadecimal digits.
Returns
-------
bool
Whether the file content matches the expected hash.
"""
sha1 = hashlib.sha1()
with open(filename, 'rb') as f:
while True:
data = f.read(1048576)
if not data:
break
sha1.update(data)
return sha1.hexdigest() == sha1_hash
def extract_archive(file, target_dir):
"""Extract archive file
Parameters
----------
file : str
Absolute path of the archive file.
target_dir : str
Target directory of the archive to be uncompressed
"""
if file.endswith('.gz') or file.endswith('.tar') or file.endswith('.tgz'):
archive = tarfile.open(file, 'r')
elif file.endswith('.zip'):
archive = zipfile.ZipFile(file, 'r')
else:
raise Exception('Unrecognized file type: ' + file)
archive.extractall(path=target_dir)
archive.close()
def get_download_dir():
dirname = '_download'
if not os.path.exists(dirname):
os.makedirs(dirname)
return dirname
......@@ -43,6 +43,15 @@ class Frame:
else:
self.update_rows(key, val)
def __delitem__(self, key):
# delete column
del self._columns[key]
def pop(self, key):
col = self._columns[key]
del self._columns[key]
return col
def add_column(self, name, col):
if self.num_columns == 0:
self._num_rows = F.shape(col)[0]
......
......@@ -6,18 +6,56 @@ from collections import MutableMapping
import networkx as nx
from networkx.classes.digraph import DiGraph
from dgl.base import ALL, is_all
import dgl.backend as F
from dgl.backend import Tensor
import dgl.builtin as builtin
#import dgl.state as state
from dgl.frame import Frame
from dgl.cached_graph import CachedGraph, create_cached_graph
import dgl.context as context
from dgl.frame import Frame
import dgl.scheduler as scheduler
import dgl.utils as utils
__MSG__ = "__MSG__"
__REPR__ = "__REPR__"
ALL = "__ALL__"
class _NodeDict(MutableMapping):
def __init__(self, cb):
self._dict = {}
self._cb = cb
def __setitem__(self, key, val):
if isinstance(val, _AdjInnerDict):
# This node dict is used as adj_outer_list
val.src = key
elif key not in self._dict:
self._cb(key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class _AdjInnerDict(MutableMapping):
def __init__(self, cb):
self._dict = {}
self.src = None
self._cb = cb
def __setitem__(self, key, val):
if key not in self._dict:
self._cb(self.src, key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class DGLGraph(DiGraph):
"""Base graph class specialized for neural networks on graphs.
......@@ -32,14 +70,15 @@ class DGLGraph(DiGraph):
attr : keyword arguments, optional
Attributes to add to graph as key=value pairs.
"""
#node_dict_factory = state.NodeDict
#adjlist_outer_dict_factory = state.AdjOuterDict
#adjlist_inner_dict_factory = state.AdjInnerDict
#edge_attr_dict_factory = state.EdgeAttrDict
def __init__(self, graph_data=None, **attr):
# call base class init
super(DGLGraph, self).__init__(graph_data, **attr)
# setup dict overlay
self.node_dict_factory = lambda : _NodeDict(self._add_node_callback)
# In networkx 2.1, DiGraph is not using this factory. Instead, the outer
# dict uses the same data structure as the node dict.
self.adjlist_outer_dict_factory = None
self.adjlist_inner_dict_factory = lambda : _AdjInnerDict(self._add_edge_callback)
self.edge_attr_dict_factory = dict
# cached graph and storage
self._cached_graph = None
self._node_frame = Frame()
self._edge_frame = Frame()
......@@ -50,6 +89,11 @@ class DGLGraph(DiGraph):
self._reduce_func = None
self._update_func = None
self._edge_func = None
self._edge_cb_state = True
self._edge_list = []
self._context = context.cpu()
# call base class init
super(DGLGraph, self).__init__(graph_data, **attr)
def set_n_repr(self, hu, u=ALL):
"""Set node(s) representation.
......@@ -64,7 +108,7 @@ class DGLGraph(DiGraph):
Parameters
----------
hu : any
hu : tensor or dict of tensor
Node representation.
u : node, container or tensor
The node(s).
......@@ -73,13 +117,13 @@ class DGLGraph(DiGraph):
if isinstance(u, str) and u == ALL:
num_nodes = self.number_of_nodes()
else:
u = utils.convert_to_id_tensor(u)
u = utils.convert_to_id_tensor(u, self.context)
num_nodes = len(u)
if isinstance(hu, dict):
for key, val in hu.items():
assert F.shape(val)[0] == num_nodes
else:
F.shape(hu)[0] == num_nodes
assert F.shape(hu)[0] == num_nodes
# set
if isinstance(u, str) and u == ALL:
if isinstance(hu, dict):
......@@ -108,12 +152,22 @@ class DGLGraph(DiGraph):
else:
return dict(self._node_frame)
else:
u = utils.convert_to_id_tensor(u)
u = utils.convert_to_id_tensor(u, self.context)
if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
return self._node_frame[__REPR__][u]
else:
return self._node_frame.select_rows(u)
def pop_n_repr(self, key=__REPR__):
"""Get and remove the specified node repr.
Parameters
----------
key : str
The attribute name.
"""
return self._node_frame.pop(key)
def set_e_repr(self, h_uv, u=ALL, v=ALL):
"""Set edge(s) representation.
......@@ -127,7 +181,7 @@ class DGLGraph(DiGraph):
Parameters
----------
h_uv : any
h_uv : tensor or dict of tensor
Edge representation.
u : node, container or tensor
The source node(s).
......@@ -141,14 +195,14 @@ class DGLGraph(DiGraph):
if u_is_all:
num_edges = self.number_of_edges()
else:
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
u = utils.convert_to_id_tensor(u, self.context)
v = utils.convert_to_id_tensor(v, self.context)
num_edges = max(len(u), len(v))
if isinstance(h_uv, dict):
for key, val in h_uv.items():
assert F.shape(val)[0] == num_edges
else:
F.shape(h_uv)[0] == num_edges
assert F.shape(h_uv)[0] == num_edges
# set
if u_is_all:
if isinstance(h_uv, dict):
......@@ -164,6 +218,41 @@ class DGLGraph(DiGraph):
else:
self._edge_frame[__REPR__][eid] = h_uv
def set_e_repr_by_id(self, h_uv, eid=ALL):
"""Set edge(s) representation by edge id.
Parameters
----------
h_uv : tensor or dict of tensor
Edge representation.
eid : int, container or tensor
The edge id(s).
"""
# sanity check
if isinstance(eid, str) and eid == ALL:
num_edges = self.number_of_edges()
else:
eid = utils.convert_to_id_tensor(eid, self.context)
num_edges = len(eid)
if isinstance(h_uv, dict):
for key, val in h_uv.items():
assert F.shape(val)[0] == num_edges
else:
assert F.shape(h_uv)[0] == num_edges
# set
if isinstance(eid, str) and eid == ALL:
if isinstance(h_uv, dict):
for key, val in h_uv.items():
self._edge_frame[key] = val
else:
self._edge_frame[__REPR__] = h_uv
else:
if isinstance(h_uv, dict):
for key, val in h_uv.items():
self._edge_frame[key][eid] = val
else:
self._edge_frame[__REPR__][eid] = h_uv
def get_e_repr(self, u=ALL, v=ALL):
"""Get node(s) representation.
......@@ -183,14 +272,59 @@ class DGLGraph(DiGraph):
else:
return dict(self._edge_frame)
else:
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
u = utils.convert_to_id_tensor(u, self.context)
v = utils.convert_to_id_tensor(v, self.context)
eid = self.cached_graph.get_edge_id(u, v)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__][eid]
else:
return self._edge_frame.select_rows(eid)
def pop_e_repr(self, key=__REPR__):
"""Get and remove the specified edge repr.
Parameters
----------
key : str
The attribute name.
"""
return self._edge_frame.pop(key)
def get_e_repr_by_id(self, eid=ALL):
"""Get edge(s) representation by edge id.
Parameters
----------
eid : int, container or tensor
The edge id(s).
"""
if isinstance(eid, str) and eid == ALL:
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__]
else:
return dict(self._edge_frame)
else:
eid = utils.convert_to_id_tensor(eid, self.context)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__][eid]
else:
return self._edge_frame.select_rows(eid)
def set_device(self, ctx):
"""Set device context for this graph.
Parameters
----------
ctx : dgl.context.Context
The device context.
"""
self._context = ctx
@property
def context(self):
"""Get the device context of this graph."""
return self._context
def register_message_func(self,
message_func,
batchable=False):
......@@ -308,14 +442,17 @@ class DGLGraph(DiGraph):
self.edges[uu, vv][__MSG__] = ret
def _batch_sendto(self, u, v, message_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
edge_id = self.cached_graph.get_edge_id(u, v)
eid = self.cached_graph.get_edge_id(u, v)
self.msg_graph.add_edges(u, v)
if len(u) != len(v) and len(u) == 1:
u = F.broadcast_to(u, v)
src_reprs = _get_repr(self._node_frame.select_rows(u))
edge_reprs = _get_repr(self._edge_frame.select_rows(edge_id))
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs)
if isinstance(msgs, dict):
self._msg_frame.append(msgs)
......@@ -362,16 +499,17 @@ class DGLGraph(DiGraph):
def _batch_update_edge(self, u, v, edge_func):
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
edge_id = self.cached_graph.get_edge_id(u, v)
eid = self.cached_graph.get_edge_id(u, v)
if len(u) != len(v) and len(u) == 1:
u = F.broadcast_to(u, v)
elif len(u) != len(v) and len(v) == 1:
v = F.broadcast_to(v, u)
src_reprs = _get_repr(self._node_frame.select_rows(u))
dst_reprs = _get_repr(self._node_frame.select_rows(v))
edge_reprs = _get_repr(self._edge_frame.select_rows(edge_id))
# call the UDF
src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v)
edge_reprs = self.get_e_repr_by_id(eid)
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
_batch_set_repr(self._edge_frame, edge_id, new_edge_reprs)
self.set_e_repr_by_id(new_edge_reprs, eid)
def recv(self,
u,
......@@ -443,6 +581,9 @@ class DGLGraph(DiGraph):
_set_repr(self.nodes[uu], ret)
def _batch_recv(self, v, reduce_func, update_func):
v_is_all = is_all(v)
if v_is_all:
v = list(range(self.number_of_nodes()))
# sanity checks
v = utils.convert_to_id_tensor(v)
f_reduce = _get_reduce_func(reduce_func)
......@@ -454,9 +595,8 @@ class DGLGraph(DiGraph):
bkt_len = len(v_bkt)
uu, vv = self.msg_graph.in_edges(v_bkt)
in_msg_ids = self.msg_graph.get_edge_id(uu, vv)
# The in_msgs represents the rows selected. Since our storage
# is column-based, it will only be materialized when user
# tries to get the column (e.g. when user called `msgs['h']`)
# TODO(minjie): manually convert ids to context.
in_msg_ids = F.to_context(in_msg_ids, self.context)
in_msgs = self._msg_frame.select_rows(in_msg_ids)
# Reshape the column tensor to (B, Deg, ...).
def _reshape_fn(msg):
......@@ -468,7 +608,7 @@ class DGLGraph(DiGraph):
else:
reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)
dst_reprs = _get_repr(self._node_frame.select_rows(v_bkt))
dst_reprs = self.get_n_repr(v_bkt)
reduced_msgs.append(f_reduce(dst_reprs, reshaped_in_msgs))
# TODO: clear partial messages
......@@ -476,14 +616,26 @@ class DGLGraph(DiGraph):
# Read the node states in the degree-bucketing order.
reordered_v = F.pack(v_buckets)
reordered_ns = _get_repr(self._node_frame.select_rows(reordered_v))
reordered_ns = self.get_n_repr(reordered_v)
# Pack all reduced msgs together
if isinstance(reduced_msgs, dict):
if isinstance(reduced_msgs[0], dict):
all_reduced_msgs = {key : F.pack(val) for key, val in reduced_msgs.items()}
else:
all_reduced_msgs = F.pack(reduced_msgs)
new_ns = f_update(reordered_ns, all_reduced_msgs)
_batch_set_repr(self._node_frame, reordered_v, new_ns)
if v_is_all:
# First do reorder and then replace the whole column.
_, indices = F.sort(reordered_v)
# TODO(minjie): manually convert ids to context.
indices = F.to_context(indices, self.context)
if isinstance(new_ns, dict):
for key, val in new_ns.items():
self._node_frame[key] = F.gather_row(val, indices)
else:
self._node_frame[__REPR__] = F.gather_row(new_ns, indices)
else:
# Use setter to do reorder.
self.set_n_repr(new_ns, reordered_v)
def update_by_edge(self,
u, v,
......@@ -527,9 +679,9 @@ class DGLGraph(DiGraph):
def _nonbatch_update_by_edge(
self,
u, v,
message_func=None,
reduce_func=None,
update_func=None):
message_func,
reduce_func,
update_func):
self._nonbatch_sendto(u, v, message_func)
dst = set()
for uu, vv in utils.edge_iter(u, v):
......@@ -539,16 +691,29 @@ class DGLGraph(DiGraph):
def _batch_update_by_edge(
self,
u, v,
message_func=None,
reduce_func=None,
update_func=None):
if message_func == 'from_src' and reduce_func == 'sum':
# Specialized to generic-SPMV
raise NotImplementedError('SPVM specialization')
message_func,
reduce_func,
update_func):
if message_func == 'from_src' and reduce_func == 'sum' \
and is_all(u) and is_all(v):
# TODO(minjie): SPMV is only supported for updating all nodes right now.
adjmat = self.cached_graph.adjmat(self.context)
reduced_msgs = {}
for key in self._node_frame.schemes:
col = self._node_frame[key]
reduced_msgs[key] = F.spmm(adjmat, col)
node_repr = self.get_n_repr()
if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs:
reduced_msgs = reduced_msgs[__REPR__]
self.set_n_repr(update_func(node_repr, reduced_msgs))
else:
self._batch_sendto(u, v, message_func)
unique_v = F.unique(v)
self._batch_recv(unique_v, reduce_func, update_func)
if is_all(u) and is_all(v):
self._batch_sendto(u, v, message_func)
self._batch_recv(v, reduce_func, update_func)
else:
self._batch_sendto(u, v, message_func)
unique_v = F.unique(v)
self._batch_recv(unique_v, reduce_func, update_func)
def update_to(self,
v,
......@@ -660,8 +825,7 @@ class DGLGraph(DiGraph):
assert reduce_func is not None
assert update_func is not None
if batchable:
u, v = self.cached_graph.edges()
self._batch_update_by_edge(u, v,
self._batch_update_by_edge(ALL, ALL,
message_func, reduce_func, update_func)
else:
u = [uu for uu, _ in self.edges]
......@@ -743,6 +907,23 @@ class DGLGraph(DiGraph):
def _edges_or_all(self, edges):
return self.edges() if edges == ALL else edges
def _add_node_callback(self, node):
self._cached_graph = None
def _add_edge_callback(self, u, v):
# In networkx 2.1, two adjlists are maintained. One for succ, one for pred.
# We only record once for the succ addition.
if self._edge_cb_state:
#print('New edge:', u, v)
self._edge_list.append((u, v))
self._edge_cb_state = not self._edge_cb_state
self._cached_graph = None
@property
def edge_list(self):
"""Return edges in the addition order."""
return self._edge_list
def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict:
return attr_dict[__REPR__]
......@@ -755,12 +936,6 @@ def _set_repr(attr_dict, attr):
else:
attr_dict[__REPR__] = attr
def _batch_set_repr(frame, rows, attr):
if isinstance(attr, dict):
frame.update_rows(rows, attr)
else:
frame.update_rows(rows, {__REPR__ : attr})
def _get_reduce_func(reduce_func):
if isinstance(reduce_func, str):
# built-in reduce func
......
......@@ -2,12 +2,13 @@
from __future__ import absolute_import
import dgl.backend as F
import numpy as np
def degree_bucketing(cached_graph, v):
degrees = cached_graph.in_degrees(v)
unique_degrees = list(F.asnumpy(F.unique(degrees)))
degrees = F.asnumpy(cached_graph.in_degrees(v))
unique_degrees = list(np.unique(degrees))
v_bkt = []
for deg in unique_degrees:
idx = F.squeeze(F.nonzero(F.eq_scalar(degrees, deg)), 1)
idx = np.where(degrees == deg)
v_bkt.append(v[idx])
return unique_degrees, v_bkt
......@@ -12,59 +12,53 @@ def is_id_container(u):
return isinstance(u, list)
def node_iter(n):
if is_id_tensor(n):
n = list(F.asnumpy(n))
if is_id_container(n):
for nn in n:
yield nn
else:
yield n
n = convert_to_id_container(n)
for nn in n:
yield nn
def edge_iter(u, v):
u_is_container = is_id_container(u)
v_is_container = is_id_container(v)
u_is_tensor = is_id_tensor(u)
v_is_tensor = is_id_tensor(v)
if u_is_tensor:
u = F.asnumpy(u)
u_is_tensor = False
u_is_container = True
if v_is_tensor:
v = F.asnumpy(v)
v_is_tensor = False
v_is_container = True
if u_is_container and v_is_container:
u = convert_to_id_container(u)
v = convert_to_id_container(v)
if len(u) == len(v):
# many-many
for uu, vv in zip(u, v):
yield uu, vv
elif u_is_container and not v_is_container:
elif len(v) == 1:
# many-one
for uu in u:
yield uu, v
elif not u_is_container and v_is_container:
yield uu, v[0]
elif len(u) == 1:
# one-many
for vv in v:
yield u, vv
yield u[0], vv
else:
yield u, v
def homogeneous(x_list, type_x=None):
type_x = type_x if type_x else type(x_list[0])
return all(type(x) == type_x for x in x_list)
raise ValueError('Error edges:', u, v)
def convert_to_id_tensor(x):
def convert_to_id_container(x):
if is_id_container(x):
assert homogeneous(x, int)
return F.tensor(x)
elif is_id_tensor(x):
return x
elif isinstance(x, int):
x = F.tensor([x])
return x
elif is_id_tensor(x):
return F.asnumpy(x)
else:
raise TypeError('Error node: %s' % str(x))
try:
return [int(x)]
except:
raise TypeError('Error node: %s' % str(x))
return None
def convert_to_id_tensor(x, ctx=None):
if is_id_container(x):
ret = F.tensor(x, dtype=F.int64)
elif is_id_tensor(x):
ret = x
else:
try:
ret = F.tensor([int(x)], dtype=F.int64)
except:
raise TypeError('Error node: %s' % str(x))
ret = F.to_context(ret, ctx)
return ret
class LazyDict(Mapping):
"""A readonly dictionary that does not materialize the storage."""
def __init__(self, fn, keys):
......
......@@ -41,6 +41,10 @@ def test_batch_setter_getter():
# set all nodes
g.set_n_repr({'h' : th.zeros((10, D))})
assert _pfc(g.get_n_repr()['h']) == [0.] * 10
# pop nodes
assert _pfc(g.pop_n_repr('h')) == [0.] * 10
assert len(g.get_n_repr()) == 0
g.set_n_repr({'h' : th.zeros((10, D))})
# set partial nodes
u = th.tensor([1, 3, 5])
g.set_n_repr({'h' : th.ones((3, D))}, u)
......@@ -72,15 +76,41 @@ def test_batch_setter_getter():
# set all edges
g.set_e_repr({'l' : th.zeros((17, D))})
assert _pfc(g.get_e_repr()['l']) == [0.] * 17
# set partial nodes (many-many)
# TODO(minjie): following case will fail at the moment as CachedGraph
# does not maintain edge addition order.
#u = th.tensor([0, 0, 2, 5, 9])
#v = th.tensor([1, 3, 9, 9, 0])
#g.set_e_repr({'l' : th.ones((5, D))}, u, v)
#truth = [0.] * 17
#truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
#assert _pfc(g.get_e_repr()['l']) == truth
# pop edges
assert _pfc(g.pop_e_repr('l')) == [0.] * 17
assert len(g.get_e_repr()) == 0
g.set_e_repr({'l' : th.zeros((17, D))})
# set partial edges (many-many)
u = th.tensor([0, 0, 2, 5, 9])
v = th.tensor([1, 3, 9, 9, 0])
g.set_e_repr({'l' : th.ones((5, D))}, u, v)
truth = [0.] * 17
truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
# set partial edges (many-one)
u = th.tensor([3, 4, 6])
v = th.tensor([9])
g.set_e_repr({'l' : th.ones((3, D))}, u, v)
truth[5] = truth[7] = truth[11] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
# set partial edges (one-many)
u = th.tensor([0])
v = th.tensor([4, 5, 6])
g.set_e_repr({'l' : th.ones((3, D))}, u, v)
truth[6] = truth[8] = truth[10] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
# get partial edges (many-many)
u = th.tensor([0, 6, 0])
v = th.tensor([6, 9, 7])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
# get partial edges (many-one)
u = th.tensor([5, 6, 7])
v = th.tensor([9])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
# get partial edges (one-many)
u = th.tensor([0])
v = th.tensor([3, 4, 5])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.]
def test_batch_send():
g = generate_graph()
......
......@@ -40,6 +40,10 @@ def test_batch_setter_getter():
# set all nodes
g.set_n_repr(th.zeros((10, D)))
assert _pfc(g.get_n_repr()) == [0.] * 10
# pop nodes
assert _pfc(g.pop_n_repr()) == [0.] * 10
assert len(g.get_n_repr()) == 0
g.set_n_repr(th.zeros((10, D)))
# set partial nodes
u = th.tensor([1, 3, 5])
g.set_n_repr(th.ones((3, D)), u)
......@@ -71,15 +75,41 @@ def test_batch_setter_getter():
# set all edges
g.set_e_repr(th.zeros((17, D)))
assert _pfc(g.get_e_repr()) == [0.] * 17
# set partial nodes (many-many)
# TODO(minjie): following case will fail at the moment as CachedGraph
# does not maintain edge addition order.
#u = th.tensor([0, 0, 2, 5, 9])
#v = th.tensor([1, 3, 9, 9, 0])
#g.set_e_repr({'l' : th.ones((5, D))}, u, v)
#truth = [0.] * 17
#truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
#assert _pfc(g.get_e_repr()['l']) == truth
# pop edges
assert _pfc(g.pop_e_repr()) == [0.] * 17
assert len(g.get_e_repr()) == 0
g.set_e_repr(th.zeros((17, D)))
# set partial edges (many-many)
u = th.tensor([0, 0, 2, 5, 9])
v = th.tensor([1, 3, 9, 9, 0])
g.set_e_repr(th.ones((5, D)), u, v)
truth = [0.] * 17
truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
assert _pfc(g.get_e_repr()) == truth
# set partial edges (many-one)
u = th.tensor([3, 4, 6])
v = th.tensor([9])
g.set_e_repr(th.ones((3, D)), u, v)
truth[5] = truth[7] = truth[11] = 1.
assert _pfc(g.get_e_repr()) == truth
# set partial edges (one-many)
u = th.tensor([0])
v = th.tensor([4, 5, 6])
g.set_e_repr(th.ones((3, D)), u, v)
truth[6] = truth[8] = truth[10] = 1.
assert _pfc(g.get_e_repr()) == truth
# get partial edges (many-many)
u = th.tensor([0, 6, 0])
v = th.tensor([6, 9, 7])
assert _pfc(g.get_e_repr(u, v)) == [1., 1., 0.]
# get partial edges (many-one)
u = th.tensor([5, 6, 7])
v = th.tensor([9])
assert _pfc(g.get_e_repr(u, v)) == [1., 1., 0.]
# get partial edges (one-many)
u = th.tensor([0])
v = th.tensor([3, 4, 5])
assert _pfc(g.get_e_repr(u, v)) == [1., 1., 1.]
def test_batch_send():
g = generate_graph()
......
......@@ -29,5 +29,8 @@ def test_basics():
check_eq(s, th.tensor([1, 1, 2, 2]))
check_eq(d, th.tensor([2, 3, 4, 5]))
print(cg._graph.get_adjacency())
print(cg._graph.get_adjacency(eids=True))
if __name__ == '__main__':
test_basics()
import torch as th
import numpy as np
from dgl.graph import DGLGraph
D = 32
D = 5
def check_eq(a, b):
if not np.allclose(a.numpy(), b.numpy()):
print(a, b)
def message_func(hu, edge):
return hu
def reduce_func(hv, msgs):
return th.sum(msgs, 1)
def update_func(hv, accum):
assert hv.shape == accum.shape
......@@ -17,9 +28,8 @@ def generate_graph():
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
# TODO: use internal interface to set data.
col = th.randn(10, D)
g._node_frame['h'] = col
g.set_n_repr(col)
return g
def test_spmv_specialize():
......@@ -27,7 +37,15 @@ def test_spmv_specialize():
g.register_message_func('from_src', batchable=True)
g.register_reduce_func('sum', batchable=True)
g.register_update_func(update_func, batchable=True)
v1 = g.get_n_repr()
g.update_all()
v2 = g.get_n_repr()
g.set_n_repr(v1)
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.update_all()
v3 = g.get_n_repr()
check_eq(v2, v3)
if __name__ == '__main__':
test_spmv_specialize()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment