Unverified Commit 11e42d10 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Batching semantics and naive frame storage (#31)

* batch message_func, reduce_func and update_func

Conflicts:
	python/dgl/backend/pytorch.py

* test cases for batching

Conflicts:
	python/dgl/graph.py

* resolve conflicts

* setter/getter

Conflicts:
	python/dgl/graph.py

* test setter/getter

Conflicts:
	python/dgl/graph.py

* merge DGLGraph and DGLBGraph

Conflicts:
	python/dgl/graph.py

Conflicts:
	python/dgl/graph.py

* batchability test

Conflicts:
	python/dgl/graph.py

Conflicts:
	python/dgl/graph.py

* New interface (draft)

Conflicts:
	_reference/gat_mx.py
	_reference/molecular-gcn.py
	_reference/molecular-gcn_mx.py
	_reference/multi-gcn.py
	_reference/multi-gcn_mx.py
	_reference/mx.py
	python/dgl/graph.py

* Batch operations on graph

Conflicts:
	.gitignore
	python/dgl/backend/__init__.py
	python/dgl/backend/numpy.py
	python/dgl/graph.py

* sendto

* storage

* NodeDict

* DGLFrame/DGLArray

* scaffold code for graph.py

* clean up files; initial frame code

* basic frame tests using pytorch

* frame autograd test passed

* fix non-batched tests

* initial code for cached graph; tested

* batch sendto

* batch recv

* update routines

* update all

* anonymous repr batching

* specialize test

* igraph dep

* fix

* fix

* fix

* fix

* fix

* clean some files

* batch setter and getter

* fix utests
parent 8361bbbe
......@@ -8,6 +8,7 @@ pipeline {
stage('SETUP') {
steps {
sh 'easy_install nose'
sh 'apt-get update && apt-get install -y libxml2-dev'
}
}
stage('BUILD') {
......@@ -20,6 +21,7 @@ pipeline {
stage('TEST') {
steps {
sh 'nosetests tests -v --with-xunit'
sh 'nosetests tests/pytorch -v --with-xunit'
}
}
}
......
from collections import MutableMapping
import dgl.backend as F
class DGLArray(MutableMapping):
def __init__(self):
pass
def __delitem__(self, key, value):
raise NotImplementedError()
def __getitem__(self, key):
"""
If the key is an DGLArray of identical length, this function performs a
logical filter: i.e. it subselects all the elements in this array
where the corresponding value in the other array evaluates to true.
If the key is an integer this returns a single row of
the DGLArray. If the key is a slice, this returns an DGLArray with the
sliced rows. See the Turi Create User Guide for usage examples.
"""
raise NotImplementedError()
def __iter__(self):
raise NotImplementedError()
def __len__(self):
raise NotImplementedError()
def __setitem__(self, key, value):
raise NotImplementedError()
class DGLDenseArray(DGLArray):
def __init__(self, data, applicable=None):
"""
Parameters
----------
data : list or tensor
"""
if type(data) is list:
raise NotImplementedError()
elif isinstance(data, F.Tensor):
self._data = data
if applicable is None:
self._applicable = F.ones(F.shape(data)[0], dtype=F.bool) # TODO: device
else:
assert isinstance(applicable, F.Tensor)
assert F.device(applicable) == F.device(data)
assert F.isboolean(applicable)
a_shape = F.shape(applicable)
assert len(a_shape) == 1
assert a_shape[0] == F.shape(data)[0]
self._applicable = applicable
def __getitem__(self, key):
"""
If the key is an DGLDenseArray of identical length, this function performs a
logical filter: i.e. it subselects all the elements in this array
where the corresponding value in the other array evaluates to true.
If the key is an integer this returns a single row of
the DGLArray. If the key is a slice, this returns an DGLArray with the
sliced rows. See the Turi Create User Guide for usage examples.
"""
if type(key) is DGLDenseArray:
if type(key._data) is list:
raise NotImplementedError()
elif type(key._data) is F.Tensor:
if type(self._data) is F.Tensor:
shape = F.shape(key._data)
assert len(shape) == 1
assert shape[0] == F.shape(self._data)[0]
assert F.dtype(key._data) is F.bool
data = self._data[key._data]
return DGLDenseArray(data)
else:
raise NotImplementedError()
else:
raise RuntimeError()
elif type(key) is int:
return self._data[key]
elif type(key) is slice:
raise NotImplementedError()
else:
raise RuntimeError()
def __iter__(self):
return iter(range(len(self)))
def __len__(self):
if type(self._data) is F.Tensor:
return F.shape(self._data)[0]
elif type(self._data) is list:
return len(self._data)
else:
raise RuntimeError()
def __setitem__(self, key, value):
if type(key) is int:
if type(self._data) is list:
raise NotImplementedError()
elif type(self._data) is F.Tensor:
assert isinstance(value, F.Tensor)
assert F.device(value) == F.device(self._data)
assert F.dtype(value) == F.dtype(self._data)
# TODO(gaiyu): shape
x = []
if key > 0:
x.append(self._data[:key])
x.append(F.expand_dims(value, 0))
if key < F.shape(self._data)[0] - 1:
x.append(self._data[key + 1:])
self._data = F.concatenate(x)
else:
raise RuntimeError()
elif type(key) is DGLDenseArray:
shape = F.shape(key._data)
assert len(shape) == 1
assert shape[0] == F.shape(self._data)[0]
assert F.isboolean(key._data)
data = self._data[key._data]
elif type(key) is DGLSparseArray:
raise NotImplementedError()
else:
raise RuntimeError()
def _listize(self):
raise NotImplementedError()
def _tensorize(self):
raise NotImplementedError()
def append(self, other):
assert type(other, DGLDenseArray)
if self.shape is None:
return other
elif other.shape is None:
return self
else:
assert self.shape[1:] == other.shape[1:]
data = F.concatenate([self.data, other.data])
return DGLDenseArray(data)
@property
def applicable(self):
return self._applicable
@property
def data(self):
return self._data
def dropna(self):
if type(self._data) is list:
raise NotImplementedError()
elif isinstance(self._data, F.Tensor):
data = F.index_by_bool(self._data, self._applicable)
return DGLDenseArray(data)
else:
raise RuntimeError()
class DGLSparseArray(DGLArray):
def __init__(self):
raise NotImplementedError()
from dgl.array import DGLArray, DGLDenseArray, DGLSparseArray
import dgl.backend as F
def _gridize(frame, key_column_names, src_column_name):
if type(key_column_names) is str:
key_column = frame[key_column_names]
assert F.prod(key_column.applicable)
if type(key_column) is DGLDenseArray:
row = key_column.data
if type(row) is F.Tensor:
assert F.isinteger(row) and len(F.shape(row)) == 1
col = F.unique(row)
xy = (F.expand_dims(row, 1) == F.expand_dims(col, 0))
if src_column_name:
src_column = frame[src_column_name]
if type(src_column) is DGLDenseArray:
z = src_column.data
if type(z) is F.Tensor:
z = F.expand_dims(z, 1)
for i in range(2, len(F.shape(z))):
xy = F.expand_dims(xy, i)
xy = F.astype(xy, F.dtype(z))
return col, xy * z
elif type(z) is list:
raise NotImplementedError()
else:
raise RuntimeError()
else:
return col, xy
elif type(row) is list:
raise NotImplementedError()
else:
raise RuntimeError()
else:
raise NotImplementedError()
elif type(key_column_names) is list:
raise NotImplementedError()
else:
raise RuntimeError()
def aggregator(src_column_name=''):
def decorator(a):
def decorated(frame, key_column_names):
col, xy = _gridize(frame, key_column_names, src_column_name)
trg_column_name = src_column_name + a.__name__
key = DGLDenseArray(col)
trg = DGLDenseArray(a(xy))
return {key_column_names : key, trg_column_name : trg}
return decorated
return decorator
def COUNT():
@aggregator()
def count(x):
return F.sum(x, 0)
return count
def SUM(src_column_name):
@aggregator(src_column_name)
def sum(x):
return F.sum(x, 0)
return sum
import dgl.backend as F
class DGLArray:
def __init__(self):
pass
def __getitem__(self, x):
raise NotImplementedError()
class DGLDenseArray(DGLArray):
def __init__(self):
pass
class DGLSparseArray(DGLArray):
def __init__(self, data, ):
raise NotImplementedError()
from dgl.array import DGLArray, DGLDenseArray, DGLSparseArray
import dgl.backend as F
from collections import MutableMapping
from functools import reduce
from itertools import dropwhile
import operator
class DGLFrame(MutableMapping):
def __init__(self, data=None):
self._columns = {}
if data is None:
pass
elif isinstance(data, dict):
for key, value in data.items():
device = self.device()
if device:
assert value.device() == device
if type(value) is DGLDenseArray:
num_rows = self.num_rows()
if num_rows:
assert value.shape[0] == num_rows
self._columns[key] = value
else:
raise NotImplementedError()
def __copy__(self):
return self._columns.copy()
def __delitem__(self, key):
"""
"""
del self._columns[key]
def __getitem__(self, key):
"""
This method does things based on the type of `key`.
If `key` is:
* str
selects column with name 'key'
* type
selects all columns with types matching the type
* list of str or type
selects all columns with names or type in the list
* DGLArray
Performs a logical filter. Expects given DGLArray to be the same
length as all columns in current DGLFrame. Every row
corresponding with an entry in the given DGLArray that is
equivalent to False is filtered from the result.
* int
Returns a single row of the DGLFrame (the `key`th one) as a dictionary.
* slice
Returns an DGLFrame including only the sliced rows.
"""
if type(key) is str:
return self._columns[key]
elif type(key) is type:
raise NotImplementedError()
elif type(key) is list:
raise NotImplementedError()
elif type(key) is DGLDenseArray:
return DGLFrame({k : v[key] for k, v in self._columns.items()})
elif type(key) is int:
return {k : v[key] for k, v in self._columns.items()}
elif type(key) is slice:
raise NotImplementedError()
else:
raise RuntimeError()
def __iter__(self):
return iter(self._columns.keys())
def __len__(self):
return len(self._columns)
def __setitem__(self, key, value):
"""
A wrapper around add_column(s). Key can be either a list or a str. If
value is an DGLArray, it is added to the DGLFrame as a column. If it is a
constant value (int, str, or float), then a column is created where
every entry is equal to the constant value. Existing columns can also
be replaced using this wrapper.
"""
if type(key) is str:
if type(value) is DGLDenseArray:
assert value.shape[0] == self.num_rows()
self._columns[key] = value
elif type(value) is DGLSparseArray:
raise NotImplementedError()
else:
raise RuntimeError()
elif type(key) is list:
raise NotImplementedError()
else:
raise RuntimeError()
def _next_dense_column(self):
if self._columns:
predicate = lambda x: type(x) is DGLDenseArray
try:
return next(dropwhile(predicate, self._columns.values()))
except StopIteration:
return None
else:
return None
def append(self, other):
"""
Add the rows of an DGLFrame to the end of this DGLFrame.
Both DGLFrames must have the same set of columns with the same column
names and column types.
Parameters
----------
other : DGLFrame
Another DGLFrame whose rows are appended to the current DGLFrame.
Returns
-------
out : DGLFrame
The result DGLFrame from the append operation.
"""
assert isisntance(other, DGLFrame)
assert set(self._columns) == set(other._columns)
if self.num_rows() == 0:
return other.__copy__()
elif self.num_rows() == 0:
return self.__copy__()
else:
return {k : v.append(other[k]) for k, v in self._columns.items()}
def device(self):
dense_column = self._next_dense_column()
return None if dense_column is None else dense_column.device()
def dropna(self, columns=None, how='any'):
columns = list(self._columns) if columns is None else columns
assert type(columns) is list
assert len(columns) > 0
column_list = [self._columns[x] for x in columns]
if all(type(x) is DGLDenseArray for x in column_list):
a_list = [x.applicable for x in column_list]
if how == 'any':
a = reduce(operator.mul, a_list)
elif how == 'all':
a = (reduce(operator.add, a_list) > 0)
else:
raise RuntimeError()
a_array = DGLDenseArray(a)
return DGLFrame({k : v[a_array] for k, v in self._columns.items()})
else:
raise NotImplementedError()
def filter_by(self, values, column_name, exclude=False):
"""
Filter an DGLFrame by values inside an iterable object. Result is an
DGLFrame that only includes (or excludes) the rows that have a column
with the given ``column_name`` which holds one of the values in the
given ``values`` :class:`~turicreate.DGLArray`. If ``values`` is not an
DGLArray, we attempt to convert it to one before filtering.
Parameters
----------
values : DGLArray | list | numpy.ndarray | pandas.Series | str
The values to use to filter the DGLFrame. The resulting DGLFrame will
only include rows that have one of these values in the given
column.
column_name : str
The column of the DGLFrame to match with the given `values`.
exclude : bool
If True, the result DGLFrame will contain all rows EXCEPT those that
have one of ``values`` in ``column_name``.
Returns
-------
out : DGLFrame
The filtered DGLFrame.
"""
if type(values) is DGLDenseArray:
mask = F.isin(self._columns[column_name], values.data)
if exclude:
mask = 1 - mask
return self[mask]
else:
raise NotImplementedError()
def groupby(self, key_column_names, operations, *args):
"""
Perform a group on the key_column_names followed by aggregations on the
columns listed in operations.
The operations parameter is a dictionary that indicates which
aggregation operators to use and which columns to use them on. The
available operators are SUM, MAX, MIN, COUNT, AVG, VAR, STDV, CONCAT,
SELECT_ONE, ARGMIN, ARGMAX, and QUANTILE. For convenience, aggregators
MEAN, STD, and VARIANCE are available as synonyms for AVG, STDV, and
VAR. See :mod:`~turicreate.aggregate` for more detail on the aggregators.
Parameters
----------
key_column_names : string | list[string]
Column(s) to group by. Key columns can be of any type other than
dictionary.
operations : dict, list
Dictionary of columns and aggregation operations. Each key is a
output column name and each value is an aggregator. This can also
be a list of aggregators, in which case column names will be
automatically assigned.
*args
All other remaining arguments will be interpreted in the same
way as the operations argument.
Returns
-------
out_sf : DGLFrame
A new DGLFrame, with a column for each groupby column and each
aggregation operation.
See Also
--------
aggregate
Notes
-----
* Numeric aggregators (such as sum, mean, stdev etc.) follow the skip
None policy i.e they will omit all missing values from the aggregation.
As an example, `sum([None, 5, 10]) = 15` because the `None` value is
skipped.
* Aggregators have a default value when no values (after skipping all
`None` values) are present. Default values are `None` for ['ARGMAX',
'ARGMIN', 'AVG', 'STD', 'MEAN', 'MIN', 'MAX'], `0` for ['COUNT'
'COUNT_DISTINCT', 'DISTINCT'] `[]` for 'CONCAT', 'QUANTILE',
'DISTINCT', and `{}` for 'FREQ_COUNT'.
"""
if type(key_column_names) is str:
if type(operations) is list:
raise NotImplementedError()
elif type(operations) is dict:
if len(operations) == 1:
dst_solumn_name, = operations.keys()
aggregator, = operations.values()
return DGLFrame(aggregator(self, key_column_names))
else:
raise NotImplementedError()
else:
raise RuntimeError()
else:
raise NotImplementedError()
def join(self, right, on=None, how='inner'):
"""
Merge two DGLFrames. Merges the current (left) DGLFrame with the given
(right) DGLFrame using a SQL-style equi-join operation by columns.
Parameters
----------
right : DGLFrame
The DGLFrame to join.
on : None | str | list | dict, optional
The column name(s) representing the set of join keys. Each row that
has the same value in this set of columns will be merged together.
* If 'None' is given, join will use all columns that have the same
name as the set of join keys.
* If a str is given, this is interpreted as a join using one column,
where both DGLFrames have the same column name.
* If a list is given, this is interpreted as a join using one or
more column names, where each column name given exists in both
DGLFrames.
* If a dict is given, each dict key is taken as a column name in the
left DGLFrame, and each dict value is taken as the column name in
right DGLFrame that will be joined together. e.g.
{'left_col_name':'right_col_name'}.
how : {'left', 'right', 'outer', 'inner'}, optional
The type of join to perform. 'inner' is default.
* inner: Equivalent to a SQL inner join. Result consists of the
rows from the two frames whose join key values match exactly,
merged together into one DGLFrame.
* left: Equivalent to a SQL left outer join. Result is the union
between the result of an inner join and the rest of the rows from
the left DGLFrame, merged with missing values.
* right: Equivalent to a SQL right outer join. Result is the union
between the result of an inner join and the rest of the rows from
the right DGLFrame, merged with missing values.
* outer: Equivalent to a SQL full outer join. Result is
the union between the result of a left outer join and a right
outer join.
Returns
-------
out : DGLFrame
"""
assert type(right) == DGLFrame
if on is None:
raise NotImplementedError()
elif type(on) is str:
assert set(self._columns).intersection(set(right._columns)) == {on}
elif type(on) is list:
raise NotImplementedError()
elif type(on) is dict:
raise NotImplementedError()
else:
raise RuntimeError()
if how == 'left':
raise NotImplementedError()
elif how == 'right':
raise NotImplementedError()
elif how == 'outer':
raise NotImplementedError()
elif how == 'inner':
lhs = self._columns[on]
rhs = right._columns[on]
if type(lhs) is DGLDenseArray and type(rhs) is DGLDenseArray:
if isinstance(lhs.data, F.Tensor) and isinstance(rhs.data, F.Tensor) and \
len(F.shape(lhs.data)) == 1 and len(F.shape(rhs.data)) == 1:
assert F.prod(lhs.applicable) and F.prod(rhs.applicable)
isin = F.isin(lhs.data, rhs.data)
columns = {k : v[isin] for k, v in self._columns.items()}
columns.update({k : v for k, v in self._columns.items() if k != on})
else:
raise NotImplementedError()
else:
raise NotImplementedError()
else:
raise RuntimeError()
def num_rows(self):
dense_column = self._next_dense_column()
return None if dense_column is None else dense_column.shape[0]
class NodeDictOverlay(MutableMapping):
def __init__(self, frame):
self._frame = frame
@property
def num_nodes(self):
return self._frame.num_rows()
def add_nodes(self, nodes, attrs):
# NOTE: currently `nodes` are not used. Users need to make sure
# the node ids are continuous ids from 0.
# NOTE: this is a good place to hook any graph mutation logic.
self._frame.append(attrs)
def delete_nodes(self, nodes):
# NOTE: this is a good place to hook any graph mutation logic.
raise NotImplementedError('Delete nodes in the graph is currently not supported.')
def get_node_attrs(self, nodes, key):
if nodes == ALL:
# get the whole column
return self._frame[key]
else:
# TODO(minjie): should not rely on tensor's __getitem__ syntax.
return utils.id_type_dispatch(
nodes,
lambda nid : self._frame[key][nid],
lambda id_array : self._frame[key][id_array])
def set_node_attrs(self, nodes, key, val):
if nodes == ALL:
# replace the whole column
self._frame[key] = val
else:
# TODO(minjie): should not rely on tensor's __setitem__ syntax.
utils.id_type_dispatch(
nodes,
lambda nid : self._frame[key][nid] = val,
lambda id_array : self._frame[key][id_array] = val)
def __getitem__(self, nodes):
def _check_one(nid):
if nid >= self.num_nodes:
raise KeyError
def _check_many(id_array):
if F.max(id_array) >= self.num_nodes:
raise KeyError
utils.id_type_dispatch(nodes, _check_one, _check_many)
return utils.MutableLazyDict(
lambda key: self.get_node_attrs(nodes, key),
lambda key, val: self.set_node_attrs(nodes, key, val)
self._frame.schemes)
def __setitem__(self, nodes, attrs):
# Happens when adding new nodes in the graph.
self.add_nodes(nodes, attrs)
def __delitem__(self, nodes):
# Happens when deleting nodes in the graph.
self.delete_nodes(nodes)
def __len__(self):
return self.num_nodes
def __iter__(self):
raise NotImplementedError()
class AdjOuterOverlay(MutableMapping):
"""
TODO: Replace this with a more efficient dict structure.
TODO: Batch graph mutation is not supported.
"""
def __init__(self):
self._adj = {}
def __setitem__(self, u, inner_dict):
self._adj[u] = inner_dict
def __getitem__(self, u):
def _check_one(nid):
if nid not in self._adj:
raise KeyError
def _check_many(id_array):
pass
utils.id_type_dispatch(u, _check_one, _check_many)
return utils.id_type_dispatch(u)
def __delitem__(self, u):
# The delitem is ignored.
raise NotImplementedError('Delete edges in the graph is currently not supported.')
class AdjInnerOverlay(dict):
"""TODO: replace this with a more efficient dict structure."""
def __setitem__(self, v, attrs):
pass
from collections import defaultdict, MutableMapping
import dgl.backend as F
import dgl.utils as utils
class NodeDict(MutableMapping):
def __init__(self):
self._node = set()
self._attrs = defaultdict(dict)
@staticmethod
def _deltensor(attr_value, u):
"""
Parameters
----------
u : Tensor
"""
isin = F.isin(attr_value.idx, u)
if F.sum(isin):
if F.prod(isin):
return DGLNodeTensor
else:
return attr_value[1 - isin]
@staticmethod
def _delitem(attrs, attr_name, u, uu):
"""
Parameters
----------
attrs :
"""
attr_value = attrs[attr_name]
deltensor = NodeDict._deltensor
if isinstance(attr_value, dict):
if isinstance(u, list):
for x in u:
attr_value.pop(x, None)
elif isinstance(u, F.Tensor):
uu = uu if uu else map(F.item, u)
for x in uu:
attr_value.pop(x, None)
elif u == slice(None, None, None):
assert not uu
attrs[attr_name] = {}
else:
raise RuntimeError()
elif isinstance(attr_value, DGLNodeTensor):
if isinstance(u, list):
uu = uu if uu else F.tensor(u) # TODO(gaiyu): device, dtype, shape
attrs[attr_name] = deltensor(attr_value, uu)
elif isinstance(u, Tensor):
attrs[attr_name] = deltensor(attr_value, u)
elif u == slice(None, None, None):
assert not uu
attrs[attr_name] = DGLNodeTensor
else:
raise RuntimeError()
elif attr_value != DGLNodeTensor:
raise RuntimeError()
def __delitem__(self, u):
"""
Parameters
----------
"""
if isinstance(u, list):
assert utils.homogeneous(u, int)
if all(x not in self._adj for x in u):
raise KeyError()
self._node = self._node.difference(set(u))
uu = None
elif isinstance(u, F.Tensor):
assert len(F.shape(u)) == 1 \
and F.isinteger(u) \
and F.prod(u >= 0) \
and F.unpackable(u)
uu = F.unpackable(u)
self._node = self._node.difference(set(uu))
elif u == slice(None, None, None):
uu = None
else:
assert isinstance(u, int) and u >= 0
self._node.remove(u)
u, uu = [u], None
for attr_name in self._attrs:
self._delitem(self._attrs, attr_name, u, uu)
def __getitem__(self, u):
"""
Parameters
----------
u :
"""
if isinstance(u, list):
assert utils.homogeneous(u, int) and all(x >= 0 for x in u)
if all(x not in self._node for x in u):
raise KeyError()
uu = None
elif isinstance(u, F.Tensor):
assert len(F.shape(u)) == 1 and F.unpackable(u)
uu = list(map(F.item, F.unpack(u)))
assert utils.homogeneous(uu, int) and all(x >= 0 for x in uu)
if all(x not in self._node for x in uu):
raise KeyError()
elif u == slice(None, None, None):
uu = None
elif isinstance(u, int):
assert u >= 0
if u not in self._node:
raise KeyError()
uu = None
else:
raise KeyError()
return LazyNodeAttrDict(u, uu, self._node, self._attrs)
def __iter__(self):
return iter(self._node)
def __len__(self):
return len(self._node)
@staticmethod
def _settensor(attrs, attr_name, u, uu, attr_value):
"""
Parameters
----------
attrs :
attr_name :
u : Tensor or slice(None, None, None) or None
uu : list or None
attr_value : Tensor
"""
x = attrs[attr_name]
if isinstance(x, dict):
if isinstance(u, list):
for y, z in zip(u, F.unpack(attr_value)):
x[y] = z
elif isinstance(u, F.Tensor):
uu = uu if uu else map(F.item, F.unpack(u))
assert F.unpackable(attr_value)
for y, z in zip(uu, F.unpack(attr_value)):
x[y] = z
elif u == slice(None, None, None):
assert not uu
attrs[attr_name] = self._dictize(attr_value)
else:
raise RuntimeError()
elif isinstance(x, DGLNodeTensor):
u = u if u else F.tensor(uu)
isin = F.isin(x.idx, u)
if F.sum(isin):
if F.prod(isin):
attrs[attr_name] = DGLEdgeTensor(u, attr_value)
else:
y = attr_value[1 - isin]
z = DGLNodeTensor(u, attr_value)
attrs[attr_name] = concatenate([y, z])
elif x == DGLNodeTensor:
attrs[attr_name] = DGLEdgeTensor(F.tensor(u), attr_value)
@staticmethod
def _setitem(node, attrs, attr_name, u, uu, attr_value):
def valid(x):
return isinstance(attr_value, F.Tensor) \
and F.shape(attr_value)[0] == x \
and F.unpackable(attr_value)
settensor = NodeDict._settensor
if isinstance(u, list):
assert valid(len(u))
settensor(attrs, attr_name, u, None, attr_value)
elif isinstance(u, F.Tensor):
assert valid(F.shape(u)[0])
settensor(attrs, attr_name, u, uu, attr_value)
elif u == slice(None, None, None):
assert valid(len(node))
settensor(attrs, attr_name, u, None, attr_value)
elif isinstance(u, int):
assert u >= 0
if isinstance(attr_value, F.Tensor):
assert valid(1)
settensor(attrs, attr_name, [u], None, attr_value)
else:
attrs[attr_name][u] = attr_value
else:
raise RuntimeError()
def __setitem__(self, u, attrs):
"""
Parameters
----------
u :
attrs : dict
"""
if isinstance(u, list):
assert utils.homogeneous(u, int) and all(x >= 0 for x in u)
self._node.update(u)
uu = None
elif isinstance(u, F.Tensor):
assert len(F.shape(u)) == 1 and F.isinteger(u) and F.prod(u >= 0)
uu = list(map(F.item, F.unpack(u)))
self._node.update(uu)
elif u == slice(None, None, None):
uu = None
elif isinstance(u, int):
assert u >= 0
self._node.add(u)
uu = None
else:
raise RuntimeError()
for attr_name, attr_value in attrs.items():
self._setitem(self._node, self._attrs, attr_name, u, uu, attr_value)
@staticmethod
def _tensorize(attr_value):
assert isinstance(attr_value, dict)
if attr_value:
assert F.packable([x for x in attr_value.values()])
keys, values = map(list, zip(*attr_value.items()))
assert utils.homoegeneous(keys, int) and all(x >= 0 for x in keys)
assert F.packable(values)
idx = F.tensor(keys) # TODO(gaiyu): device, dtype, shape
dat = F.pack(values) # TODO(gaiyu): device, dtype, shape
return DGLNodeTensor(idx, dat)
else:
return DGLNodeTensor
def tensorize(self, attr_name):
self._attrs[attr_name] = self._tensorize(self.attrs[attr_name])
def istensorized(self, attr_name):
attr_value = self._attrs[attr_name]
return isinstance(attr_value, DGLNodeTensor) or attr_value == DGLNodeTensor
@staticmethod
def _dictize(attr_value):
assert isinstance(attr_value, DGLNodeTensor)
keys = map(F.item, F.unpack(attr_value.idx))
values = F.unpack(attr_value.dat)
return dict(zip(keys, values))
def dictize(self, attr_name):
self._attrs[attr_name] = self._dictize(attr_name)
def isdictized(self, attr_name):
return isinstance(self._attrs[attr_name], dict)
def purge(self):
predicate = lambda x: (isinstance(x, dict) and x) or isinstance(x, DGLNodeTensor)
self._attrs = {k : v for k, v in self._attrs.items() if predicate(v)}
class LazyNodeAttrDict(MutableMapping):
"""
`__iter__` and `__len__` are undefined for list.
"""
def __init__(self, u, uu, node, attrs):
self._u = u
self._uu = uu
self._node = node
self._attrs = attrs
def __delitem__(self, attr_name):
NodeDict._delitem(self._attrs, self._u, attr_name)
def __getitem__(self, attr_name):
attr_value = self._attrs[attr_name]
if isinstance(self._u, list):
if all(x not in self._node for x in self._u):
raise KeyError()
if isinstance(attr_value, dict):
y = [attr_value[x] for x in self._u]
assert F.packable(y)
return F.pack(y)
elif isinstance(attr_value, DGLNodeTensor):
uu = self._uu if self._uu else F.tensor(u)
isin = F.isin(attr_value.idx, uu)
return attr_value[isin].dat
else:
raise KeyError()
elif isinstance(self._u, F.Tensor):
uu = self._uu if self._uu else list(map(F.item, F.unpack(self._u)))
if all(x not in self._node for x in uu):
raise KeyError()
if isinstance(attr_value, dict):
y_list = [attr_value[x] for x in uu]
assert F.packable(y_list)
return F.pack(y_list)
elif isinstance(attr_value, DGLNodeTensor):
isin = F.isin(attr_value.idx, self._u)
return attr_value[isin].dat
else:
raise KeyError()
elif self._u == slice(None, None, None):
assert not self._uu
if isinstance(attr_value, dict) and attr_value:
return NodeDict._tensorize(attr_value).dat
elif isinstance(attr_value, DGLNodeTensor):
return attr_value.dat
else:
raise KeyError()
elif isinstance(self._u, int):
assert not self._uu
if isinstance(attr_value, dict):
return attr_value[self._u]
elif isinstance(attr_value, DGLNodeTensor):
try: # TODO(gaiyu)
return attr_value.dat[self._u]
except:
raise KeyError()
else:
raise KeyError()
else:
raise KeyError()
def __iter__(self):
if isinstance(self._u, int):
for key, value in self._attrs.items():
if (isinstance(value, dict) and self._u in value) or \
(isinstance(value, DGLNodeTensor) and F.sum(value.idx == self._u)):
yield key
else:
raise RuntimeError()
def __len__(self):
return sum(1 for x in self)
def __setitem__(self, attr_name, attr_value):
"""
Parameters
----------
"""
setitem = NodeDict._setitem
if isinstance(self._u, int):
assert self._u in self._node
if isinstance(attr_value, F.Tensor):
setitem(self._node, self._attrs, attr_name, self._u, None, attr_value)
else:
self._attrs[self._u][attr_name] = attr_value
else:
if all(x not in self._node for x in self._u):
raise KeyError()
setitem(self._node, self._attrs, self._u, self._uu, attr_name)
def materialized(self):
attrs = {}
for key in self._attrs:
try:
attrs[key] = self[key]
except:
KeyError()
return attrs
class AdjOuterDict(MutableMapping):
def __init__(self):
self._adj = defaultdict(lambda: defaultdict(dict))
self._attrs = defaultdict(dict)
@staticmethod
def _delitem(attrs, attr_name, u, uu):
attr_value = attrs[attr_name]
if isinstance(attr_value, dict):
if u == slice(None, None, None):
assert not uu
attrs[attr_name] = {}
else:
uu = uu if uu else map(F.item, u)
for x in uu:
attr_value.pop(x, None)
elif isinstance(attr_value, DGLNodeTensor):
if u == slice(None, None, None):
assert not uu
attrs[attr_name] = DGLEdgeTensor
else:
u = u if u else F.tensor(uu) # TODO(gaiyu): device, dtype, shape
isin = F.isin(attr_value.idx, u)
if F.sum(isin):
if F.prod(isin):
attrs[attr_name] = DGLEdgeTensor
else:
attrs[attr_name] = attr_value[1 - isin]
elif attr_value != DGLEdgeTensor:
raise RuntimeError()
def __delitem__(self, u):
if isinstance(u, list):
assert utils.homogeneous(u, int) and all(x >= 0 for x in u)
if all(x not in self._attrs for x in u):
raise KeyError()
for x in u:
self._attrs.pop(x, None)
elif isinstance(u, F.Tensor):
pass
for attr_name in self._attrs:
self._delitem(self._attrs, attr_name, u, uu)
def __iter__(self):
return iter(self._adj)
def __len__(self):
return len(self._adj)
def __getitem__(self, u):
if isinstance(u, list):
assert utils.homogeneous(u, int)
if all(x not in self._adj for x in u):
raise KeyError()
elif isinstance(u, slice):
assert u == slice(None, None, None)
elif u not in self._adj:
raise KeyError()
return LazyAdjInnerDict(u, self._adj, self._attrs)
def __setitem__(self, u, attrs):
pass
def uv(self, attr_name, u=None, v=None):
if u:
assert not v
assert (isinstance(u, list) and utils.homogeneous(u, int)) or \
(isinstance(u, F.Tensor) and F.isinteger(u) and len(F.shape(u)) == 1)
elif v:
assert not u
assert (isinstance(v, list) and utils.homogeneous(v, int)) or \
(isinstance(v, F.Tensor) and F.isinteger(v) and len(F.shape(v)) == 1)
else:
raise RuntimeError()
attr_value = self._attrs[attr_name]
if isinstance(attr_value, dict):
if u:
v = [[src, dst] for dst in attr_value.get(src, {}) for src in u]
elif v:
pass
elif isinstance(attr_value, DGLEdgeTensor):
u, v = attr_value._complete(u, v)
return u, v
class LazyAdjInnerDict(MutableMapping):
def __init__(self, u, uu, adj, attrs):
self._u = u
self._uu = uu
self._adj = adj
self._attrs = attrs
def __getitem__(self, v):
pass
def __iter__(self):
if isinstance(self._u, int):
pass
else:
raise RuntimeError()
def __len__(self):
if not isinstance(self._u, [list, slice]):
return len(self._adj[u])
else:
raise RuntimeError()
def __setitem__(self, v, attr_dict):
pass
class LazyEdgeAttrDict(MutableMapping):
"""dict: attr_name -> attr"""
def __init__(self, u, v, uu, vv, adj, attrs):
self._u = u
self._v = v
self._uu = uu
self._vv = vv
self._adj = adj
self._attrs = attrs
def __getitem__(self, attr_name):
edge_iter = utils.edge_iter(self._u, self._v)
attr_list = [self._outer_dict[uu, vv][attr_name] for uu, vv in edge_iter]
return F.pack(attr_list) if F.packable(attr_list) else attr_list
def __iter__(self):
raise NotImplementedError()
def __len__(self):
raise NotImplementedError()
def __setitem__(self, attr_name, attr):
if F.unpackable(attr):
for [uu, vv], a in zip(utils.edge_iter(self._u, self._v), F.unpack(attr)):
self._outer_dict[uu][vv][attr_name] = a
else:
for uu, vv in utils.edge_iter(self._u, self._v):
self._outer_dict[uu][vv][attr_name] = attr
AdjInnerDict = dict
EdgeAttrDict = dict
# import numpy as F
import torch as F
from dgl.state import NodeDict
# TODO(gaiyu): more test cases
def test_node_dict():
# Make sure the semantics should be the same as a normal dict.
nodes = NodeDict()
nodes[0] = {'k1' : 'n01'}
nodes[0]['k2'] = 'n02'
nodes[1] = {}
nodes[1]['k1'] = 'n11'
print(nodes)
for key, value in nodes.items():
print(key, value)
print(nodes.items())
nodes.clear()
print(nodes)
def test_node_dict_batched():
nodes = NodeDict()
n0 = 0
n1 = 1
n2 = 2
# Set node 0, 1, 2 attrs in a batch
nodes[[n0, n1, n2]] = {'k1' : F.tensor([0, 1, 2]), 'k2' : F.tensor([0, 1, 2])}
# Query in a batch
assert F.prod(nodes[[n0, n1]]['k1'] == F.tensor([0, 1]))
assert F.prod(nodes[[n2, n1]]['k2'] == F.tensor([2, 1]))
# Set all nodes with the same attribute (not supported, having to be a Python loop)
# nodes[[n0, n1, n2]]['k1'] = 10
# assert F.prod(nodes[[n0, n1, n2]]['k1'] == F.tensor([10, 10, 10]))
print(nodes)
def test_node_dict_batched_tensor():
nodes = NodeDict()
n0 = 0
n1 = 1
n2 = 2
# Set node 0, 1, 2 attrs in a batch
# Each node has a feature vector of shape (10,)
all_node_features = F.ones([3, 10])
nodes[[n0, n1, n2]] = {'k' : all_node_features}
assert nodes[[n0, n1]]['k'].shape == (2, 10)
assert nodes[[n2, n1, n2, n0]]['k'].shape == (4, 10)
def test_node_dict_tensor_arg():
nodes = NodeDict()
# Set node 0, 1, 2 attrs in a batch
# Each node has a feature vector of shape (10,)
all_nodes = F.arange(3).int()
all_node_features = F.ones([3, 10])
nodes[all_nodes] = {'k' : all_node_features}
assert nodes[[0, 1]]['k'].shape == (2, 10)
assert nodes[[2, 1, 2, 0]]['k'].shape == (4, 10)
query = F.tensor([2, 1, 2, 0, 1])
assert nodes[query]['k'].shape == (5, 10)
test_node_dict()
test_node_dict_batched()
test_node_dict_batched_tensor()
test_node_dict_tensor_arg()
import networkx as nx
# import numpy as np
import torch as F
from dgl.graph import DGLGraph
def test_node1():
graph = DGLGraph()
n0 = 0
n1 = 1
graph.add_node(n0, x=F.tensor([10]))
graph.add_node(n1, x=F.tensor([11]))
assert len(graph.nodes()) == 2
assert F.prod(graph.nodes[[n0, n1]]['x'] == F.tensor([10, 11]))
# tensor state
graph.add_node(n0, y=F.zeros([1, 10]))
graph.add_node(n1, y=F.zeros([1, 10]))
assert graph.nodes[[n0, n1]]['y'].shape == (2, 10)
# tensor args
nodes = F.tensor([n0, n1, n1, n0])
assert graph.node[nodes]['y'].shape == (4, 10)
def test_node2():
g = DGLGraph()
n0 = 0
n1 = 1
g.add_node([n0, n1])
assert len(g.nodes()) == 2
def test_edge1():
g = DGLGraph()
g.add_node(list(range(10))) # add 10 nodes.
g.add_edge(0, 1, x=10)
assert g.number_of_edges() == 1
assert g[0][1]['x'] == 10
# add many-many edges
u = [1, 2, 3]
v = [2, 3, 4]
g.add_edge(u, v, y=11) # add 3 edges.
assert g.number_of_edges() == 4
assert g[u][v]['y'] == [11, 11, 11]
# add one-many edges
u = 5
v = [6, 7]
g.add_edge(u, v, y=22) # add 2 edges.
assert g.number_of_edges() == 6
assert g[u][v]['y'] == [22, 22]
# add many-one edges
u = [8, 9]
v = 7
g.add_edge(u, v, y=33) # add 2 edges.
assert g.number_of_edges() == 8
assert g[u][v]['y'] == [33, 33]
# tensor type edge attr
z = np.zeros((5, 10)) # 5 edges, each of is (10,) vector
u = [1, 2, 3, 5, 8]
v = [2, 3, 4, 6, 7]
g[u][v]['z'] = z
u = np.array(u)
v = np.array(v)
assert g[u][v]['z'].shape == (5, 10)
def test_graph1():
g = DGLGraph(nx.path_graph(3))
def test_view():
g = DGLGraph(nx.path_graph(3))
g.nodes[0]
g.edges[0, 1]
u = [0, 1]
v = [1, 2]
g.nodes[u]
g.edges[u, v]['x'] = 1
assert g.edges[u, v]['x'] == [1, 1]
test_node1()
test_node2()
test_edge1()
test_graph1()
test_view()
from .graph import DGLGraph
from .graph import __MSG__, __REPR__, ALL
......@@ -8,3 +8,23 @@ SparseTensor = sp.sparse.spmatrix
def asnumpy(a):
return a
def concatenate(arrays, axis=0):
return np.concatenate(arrays, axis)
def packable(arrays):
return all(isinstance(a, np.ndarray) for a in arrays) and \
all(a.dtype == arrays[0].dtype for a in arrays) and \
all(a.shape[1:] == arrays[0].shape[1:] for a in arrays)
def pack(arrays):
return np.concatenate(arrays, axis=0)
def unpackable(a):
return isinstance(a, np.ndarray) and a.size > 0
def unpack(a):
return np.split(a, a.shape[0], axis=0)
def shape(a):
return a.shape
from __future__ import absolute_import
import torch
import torch as th
import scipy.sparse
Tensor = torch.Tensor
# Tensor types
Tensor = th.Tensor
SparseTensor = scipy.sparse.spmatrix
# Data types
float16 = th.float16
float32 = th.float32
float64 = th.float64
uint8 = th.uint8
int8 = th.int8
int16 = th.int16
int32 = th.int32
int64 = th.int64
# Operators
tensor = th.tensor
sum = th.sum
max = th.max
def asnumpy(a):
return a.cpu().numpy()
def reduce_sum(a):
return sum(a)
def packable(tensors):
return all(isinstance(x, th.Tensor) and \
x.dtype == tensors[0].dtype and \
x.shape[1:] == tensors[0].shape[1:] for x in tensors)
def pack(tensors):
return th.cat(tensors)
def unpack(x):
return th.split(x, 1)
def shape(x):
return x.shape
def isinteger(x):
return x.dtype in [th.int, th.int8, th.int16, th.int32, th.int64]
unique = th.unique
def gather_row(data, row_index):
return th.index_select(data, 0, row_index)
def scatter_row(data, row_index, value):
return data.index_copy(0, row_index, value)
def broadcast_to(x, to_array):
return x + th.zeros_like(to_array)
def reduce_max(a):
a = torch.cat(a, 0)
a, _ = torch.max(a, 0, keepdim=True)
return a
nonzero = th.nonzero
def eq_scalar(x, val):
return th.eq(x, float(val))
squeeze = th.squeeze
reshape = th.reshape
"""Built-in functors."""
from __future__ import absolute_import
import dgl.backend as F
def message_from_src(src, edge):
return src
def reduce_sum(node, msgs):
if isinstance(msgs, list):
return sum(msgs)
else:
return F.sum(msgs, 1)
def reduce_max(node, msgs):
if isinstance(msgs, list):
return max(msgs)
else:
return F.max(msgs, 1)
"""High-performance graph structure query component.
TODO: Currently implemented by igraph. Should replace with more efficient
solution later.
"""
from __future__ import absolute_import
import igraph
import dgl.backend as F
from dgl.backend import Tensor
import dgl.utils as utils
class CachedGraph:
def __init__(self):
self._graph = igraph.Graph(directed=True)
def add_nodes(self, num_nodes):
self._graph.add_vertices(num_nodes)
def add_edges(self, u, v):
# The edge will be assigned ids equal to the order.
# TODO(minjie): tensorize the loop
for uu, vv in utils.edge_iter(u, v):
self._graph.add_edge(uu, vv)
def get_edge_id(self, u, v):
# TODO(minjie): tensorize the loop
uvs = list(utils.edge_iter(u, v))
eids = self._graph.get_eids(uvs)
return F.tensor(eids, dtype=F.int64)
def in_edges(self, v):
# TODO(minjie): tensorize the loop
src = []
dst = []
for vv in utils.node_iter(v):
uu = self._graph.predecessors(vv)
src += uu
dst += [vv] * len(uu)
src = F.tensor(src, dtype=F.int64)
dst = F.tensor(dst, dtype=F.int64)
return src, dst
def out_edges(self, u):
# TODO(minjie): tensorize the loop
src = []
dst = []
for uu in utils.node_iter(u):
vv = self._graph.successors(uu)
src += [uu] * len(vv)
dst += vv
src = F.tensor(src, dtype=F.int64)
dst = F.tensor(dst, dtype=F.int64)
return src, dst
def edges(self):
# TODO(minjie): tensorize
elist = self._graph.get_edgelist()
src = [u for u, _ in elist]
dst = [v for _, v in elist]
src = F.tensor(src, dtype=F.int64)
dst = F.tensor(dst, dtype=F.int64)
return src, dst
def in_degrees(self, v):
degs = self._graph.indegree(list(v))
return F.tensor(degs, dtype=F.int64)
def create_cached_graph(dglgraph):
# TODO: tensorize the loop
cg = CachedGraph()
cg.add_nodes(dglgraph.number_of_nodes())
for u, v in dglgraph.edges():
cg.add_edges(u, v)
return cg
"""Columnar storage for graph attributes."""
from __future__ import absolute_import
import dgl.backend as F
from dgl.backend import Tensor
from dgl.utils import LazyDict
class Frame:
def __init__(self, data=None):
if data is None:
self._columns = dict()
self._num_rows = 0
else:
self._columns = data
self._num_rows = F.shape(list(data.values())[0])[0]
for k, v in data.items():
assert F.shape(v)[0] == self._num_rows
@property
def schemes(self):
return set(self._columns.keys())
@property
def num_columns(self):
return len(self._columns)
@property
def num_rows(self):
return self._num_rows
def __contains__(self, key):
return key in self._columns
def __getitem__(self, key):
if isinstance(key, str):
return self._columns[key]
else:
return self.select_rows(key)
def __setitem__(self, key, val):
if isinstance(key, str):
self._columns[key] = val
else:
self.update_rows(key, val)
def add_column(self, name, col):
if self.num_columns == 0:
self._num_rows = F.shape(col)[0]
else:
assert F.shape(col)[0] == self._num_rows
self._columns[name] = col
def append(self, other):
if not isinstance(other, Frame):
other = Frame(data=other)
if len(self._columns) == 0:
self._columns = other._columns
self._num_rows = other._num_rows
else:
assert self.schemes == other.schemes
self._columns = {key : F.pack([self[key], other[key]]) for key in self._columns}
self._num_rows += other._num_rows
def clear(self):
self._columns = {}
self._num_rows = 0
def select_rows(self, rowids):
def _lazy_select(key):
return F.gather_row(self._columns[key], rowids)
return LazyDict(_lazy_select, keys=self._columns.keys())
def update_rows(self, rowids, other):
if not isinstance(other, Frame):
other = Frame(data=other)
for key in other.schemes:
assert key in self._columns
self._columns[key] = F.scatter_row(self[key], rowids, other[key])
def __iter__(self):
for key, col in self._columns.items():
yield key, col
def __len__(self):
return self.num_columns
"""Base graph class specialized for neural networks on graphs.
"""
from __future__ import absolute_import
from collections import defaultdict
from collections import MutableMapping
import networkx as nx
from networkx.classes.digraph import DiGraph
import dgl.backend as F
from dgl.backend import Tensor
import dgl.builtin as builtin
#import dgl.state as state
from dgl.frame import Frame
from dgl.cached_graph import CachedGraph, create_cached_graph
import dgl.scheduler as scheduler
import dgl.utils as utils
__MSG__ = "__msg__"
__REPR__ = "__repr__"
__MFUNC__ = "__mfunc__"
__EFUNC__ = "__efunc__"
__UFUNC__ = "__ufunc__"
__RFUNC__ = "__rfunc__"
__READOUT__ = "__readout__"
__MSG__ = "__MSG__"
__REPR__ = "__REPR__"
ALL = "__ALL__"
class DGLGraph(DiGraph):
"""Base graph class specialized for neural networks on graphs.
TODO(minjie): document of multi-node and multi-edge syntax.
TODO(minjie): document of batching semantics
TODO(minjie): document of __REPR__ semantics
Parameters
----------
......@@ -29,285 +32,352 @@ class DGLGraph(DiGraph):
attr : keyword arguments, optional
Attributes to add to graph as key=value pairs.
"""
#node_dict_factory = state.NodeDict
#adjlist_outer_dict_factory = state.AdjOuterDict
#adjlist_inner_dict_factory = state.AdjInnerDict
#edge_attr_dict_factory = state.EdgeAttrDict
def __init__(self, graph_data=None, **attr):
# call base class init
super(DGLGraph, self).__init__(graph_data, **attr)
self._glb_func = {}
self._cached_graph = None
self._node_frame = Frame()
self._edge_frame = Frame()
# other class members
self._msg_graph = None
self._msg_frame = Frame()
self._message_func = None
self._reduce_func = None
self._update_func = None
self._edge_func = None
def set_n_repr(self, hu, u=ALL):
"""Set node(s) representation.
To set multiple node representations at once, pass `u` with a tensor or
a supported container of node ids. In this case, `hu` must be a tensor
of shape (B, D1, D2, ...), where B is the number of the nodes and
(D1, D2, ...) is the shape of the node representation tensor.
Dictionary type is also supported for `hu`. In this case, each item
will be treated as separate attribute of the nodes.
def init_reprs(self, h_init=None):
print("[DEPRECATED]: please directly set node attrs "
"(e.g. g.nodes[node]['x'] = val).")
for n in self.nodes:
self.set_repr(n, h_init)
Parameters
----------
hu : any
Node representation.
u : node, container or tensor
The node(s).
"""
# sanity check
if isinstance(u, str) and u == ALL:
num_nodes = self.number_of_nodes()
else:
u = utils.convert_to_id_tensor(u)
num_nodes = len(u)
if isinstance(hu, dict):
for key, val in hu.items():
assert F.shape(val)[0] == num_nodes
else:
F.shape(hu)[0] == num_nodes
# set
if isinstance(u, str) and u == ALL:
if isinstance(hu, dict):
for key, val in hu.items():
self._node_frame[key] = val
else:
self._node_frame[__REPR__] = hu
else:
if isinstance(hu, dict):
for key, val in hu.items():
self._node_frame[key][u] = val
else:
self._node_frame[__REPR__][u] = hu
def set_n_repr(self, u, h_u):
assert u in self.nodes
kwarg = {__REPR__: h_u}
self.add_node(u, **kwarg)
def get_n_repr(self, u=ALL):
"""Get node(s) representation.
def get_n_repr(self, u):
assert u in self.nodes
return self.nodes[u][__REPR__]
Parameters
----------
u : node, container or tensor
The node(s).
"""
if isinstance(u, str) and u == ALL:
if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
return self._node_frame[__REPR__]
else:
return dict(self._node_frame)
else:
u = utils.convert_to_id_tensor(u)
if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
return self._node_frame[__REPR__][u]
else:
return self._node_frame.select_rows(u)
def set_e_repr(self, u, v, h_uv):
assert (u, v) in self.edges
self.edges[u, v][__REPR__] = h_uv
def set_e_repr(self, h_uv, u=ALL, v=ALL):
"""Set edge(s) representation.
def get_e_repr(self, u, v):
assert (u, v) in self.edges
return self.edges[u, v][__REPR__]
To set multiple edge representations at once, pass `u` and `v` with tensors or
supported containers of node ids. In this case, `h_uv` must be a tensor
of shape (B, D1, D2, ...), where B is the number of the edges and
(D1, D2, ...) is the shape of the edge representation tensor.
def register_message_func(self,
message_func,
edges='all',
batchable=False):
"""Register computation on edges.
Dictionary type is also supported for `h_uv`. In this case, each item
will be treated as separate attribute of the edges.
The message function should be compatible with following signature:
Parameters
----------
h_uv : any
Edge representation.
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
"""
# sanity check
u_is_all = isinstance(u, str) and u == ALL
v_is_all = isinstance(v, str) and v == ALL
assert u_is_all == v_is_all
if u_is_all:
num_edges = self.number_of_edges()
else:
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
num_edges = max(len(u), len(v))
if isinstance(h_uv, dict):
for key, val in h_uv.items():
assert F.shape(val)[0] == num_edges
else:
F.shape(h_uv)[0] == num_edges
# set
if u_is_all:
if isinstance(h_uv, dict):
for key, val in h_uv.items():
self._edge_frame[key] = val
else:
self._edge_frame[__REPR__] = h_uv
else:
eid = self.cached_graph.get_edge_id(u, v)
if isinstance(h_uv, dict):
for key, val in h_uv.items():
self._edge_frame[key][eid] = val
else:
self._edge_frame[__REPR__][eid] = h_uv
(node_reprs, node_reprs, edge_reprs) -> msg
def get_e_repr(self, u=ALL, v=ALL):
"""Get node(s) representation.
It computes the representation of a message
using the representations of the source node, target node and the edge
itself. All node_reprs and edge_reprs are dictionaries.
Parameters
----------
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
"""
u_is_all = isinstance(u, str) and u == ALL
v_is_all = isinstance(v, str) and v == ALL
assert u_is_all == v_is_all
if u_is_all:
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__]
else:
return dict(self._edge_frame)
else:
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
eid = self.cached_graph.get_edge_id(u, v)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__][eid]
else:
return self._edge_frame.select_rows(eid)
def register_message_func(self,
message_func,
batchable=False):
"""Register global message function.
Parameters
----------
message_func : callable
Message function on the edge.
edges : str, pair of nodes, pair of containers, pair of tensors
The edges for which the message function is registered. Default is
registering for all the edges. Registering for multiple edges is
supported.
batchable : bool
Whether the provided message function allows batch computing.
Examples
--------
Register for all edges.
>>> g.register_message_func(mfunc)
Register for a specific edge.
>>> g.register_message_func(mfunc, (u, v))
Register for multiple edges.
>>> u = [u1, u2, u3, ...]
>>> v = [v1, v2, v3, ...]
>>> g.register_message_func(mfunc, (u, v))
"""
def _msg_edge_func(u, v, e_uv):
return {__MSG__ : message_func(u, v, e_uv)}
self._internal_register_edge(__MFUNC__, _msg_edge_func, edges, batchable)
self._message_func = (message_func, batchable)
def register_edge_func(self,
edge_func,
edges='all',
batchable=False):
"""Register computation on edges.
The edge function should be compatible with following signature:
(node_reprs, node_reprs, edge_reprs) -> edge_reprs
It computes the new edge representations (the same concept as messages)
using the representations of the source node, target node and the edge
itself. All node_reprs and edge_reprs are dictionaries.
"""Register global edge update function.
Parameters
----------
edge_func : callable
Message function on the edge.
edges : str, pair of nodes, pair of containers, pair of tensors
The edges for which the message function is registered. Default is
registering for all the edges. Registering for multiple edges is
supported.
batchable : bool
Whether the provided message function allows batch computing.
Examples
--------
Register for all edges.
>>> g.register_edge_func(efunc)
Register for a specific edge.
>>> g.register_edge_func(efunc, (u, v))
Register for multiple edges.
>>> u = [u1, u2, u3, ...]
>>> v = [v1, v2, v3, ...]
>>> g.register_edge_func(mfunc, (u, v))
"""
self._internal_register_edge(__EFUNC__, edge_func, edges, batchable)
self._edge_func = (edge_func, batchable)
def register_reduce_func(self,
reduce_func,
nodes='all',
batchable=False):
"""Register message reduce function on incoming edges.
The reduce function should be compatible with following signature:
edge_reprs -> reduced_edge_repr
It computes the reduced edge representations using the representations
of the in-coming edges (the same concept as messages).
The reduce function can be any of the pre-defined functions ('sum',
'max'). If built-in function is used, computation will be performed
efficiently (using generic-SPMV kernels).
"""Register global message reduce function.
Parameters
----------
reduce_func : str or callable
Reduce function on incoming edges.
nodes : str, node, container or tensor
The nodes for which the reduce function is registered. Default is
registering for all the nodes. Registering for multiple nodes is
supported.
batchable : bool
Whether the provided reduce function allows batch computing.
Examples
--------
Register for all nodes.
>>> g.register_reduce_func(rfunc)
Register for a specific node.
>>> g.register_reduce_func(rfunc, u) # TODO Not implemented
Register for multiple nodes.
>>> u = [u1, u2, u3, ...]
>>> g.register_reduce_func(rfunc, u)
"""
if isinstance(reduce_func, str):
# built-in reduce func
if reduce_func == 'sum':
reduce_func = F.reduce_sum
elif reduce_func == 'max':
reduce_func = F.reduce_max
else:
raise NotImplementedError(
"Built-in function %s not implemented" % reduce_func)
self._internal_register_node(__RFUNC__, reduce_func, nodes, batchable)
self._reduce_func = (reduce_func, batchable)
def register_update_func(self,
update_func,
nodes='all',
batchable=False):
"""Register computation on nodes.
The update function should be compatible with following signature:
(node_reprs, reduced_edge_repr) -> node_reprs
It computes the new node representations using the representations
of the in-coming edges (the same concept as messages) and the node
itself. All node_reprs and edge_reprs are dictionaries.
"""Register global node update function.
Parameters
----------
update_func : callable
Update function on the node.
nodes : str, node, container or tensor
The nodes for which the update function is registered. Default is
registering for all the nodes. Registering for multiple nodes is
supported.
batchable : bool
Whether the provided update function allows batch computing.
name : str
The name of the function.
Examples
--------
Register for all nodes.
>>> g.register_update_func(ufunc)
Register for a specific node.
>>> g.register_update_func(ufunc, u) # TODO Not implemented
Register for multiple nodes.
>>> u = [u1, u2, u3, ...]
>>> g.register_update_func(ufunc, u)
"""
self._internal_register_node(__UFUNC__, update_func, nodes, batchable)
def register_readout_func(self, readout_func):
"""Register computation on the whole graph.
The readout_func should be compatible with following signature:
(node_reprs, edge_reprs) -> any
It takes the representations of selected nodes and edges and
returns readout values.
NOTE: readout function can be implemented outside of DGLGraph.
One can simple get the node/edge reprs of the graph and perform
arbitrary computation.
Parameters
----------
readout_func : callable
The readout function.
See Also
--------
readout
"""
self._glb_func[__READOUT__] = readout_func
self._update_func = (update_func, batchable)
def readout(self,
nodes='all',
edges='all',
**kwargs):
readout_func,
nodes=ALL,
edges=ALL):
"""Trigger the readout function on the specified nodes/edges.
Parameters
----------
readout_func : callable
Readout function.
nodes : str, node, container or tensor
The nodes to get reprs from.
edges : str, pair of nodes, pair of containers or pair of tensors
The edges to get reprs from.
kwargs : keyword arguments, optional
Arguments for the readout function.
"""
nodes = self._nodes_or_all(nodes)
edges = self._edges_or_all(edges)
assert __READOUT__ in self._glb_func, \
"Readout function has not been registered."
# TODO(minjie): tensorize following loop.
nstates = [self.nodes[n] for n in nodes]
estates = [self.edges[e] for e in edges]
return self._glb_func[__READOUT__](nstates, estates, **kwargs)
return readout_func(nstates, estates)
def sendto(self, u, v):
def sendto(self, u, v, message_func=None, batchable=False):
"""Trigger the message function on edge u->v
The message function should be compatible with following signature:
(node_reprs, edge_reprs) -> message
It computes the representation of a message using the
representations of the source node, and the edge u->v.
All node_reprs and edge_reprs are dictionaries.
The message function can be any of the pre-defined functions
('from_src').
Parameters
----------
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
message_func : str or callable
The message function.
batchable : bool
Whether the function allows batched computation.
"""
self._internal_trigger_edges(u, v, __MFUNC__)
if message_func is None:
message_func, batchable = self._message_func
assert message_func is not None
if batchable:
self._batch_sendto(u, v, message_func)
else:
self._nonbatch_sendto(u, v, message_func)
def update_edge(self, u, v):
def _nonbatch_sendto(self, u, v, message_func):
f_msg = _get_message_func(message_func)
for uu, vv in utils.edge_iter(u, v):
ret = f_msg(_get_repr(self.nodes[uu]),
_get_repr(self.edges[uu, vv]))
self.edges[uu, vv][__MSG__] = ret
def _batch_sendto(self, u, v, message_func):
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
edge_id = self.cached_graph.get_edge_id(u, v)
self.msg_graph.add_edges(u, v)
if len(u) != len(v) and len(u) == 1:
u = F.broadcast_to(u, v)
src_reprs = _get_repr(self._node_frame.select_rows(u))
edge_reprs = _get_repr(self._edge_frame.select_rows(edge_id))
msgs = message_func(src_reprs, edge_reprs)
if isinstance(msgs, dict):
self._msg_frame.append(msgs)
else:
self._msg_frame.append({__MSG__ : msgs})
def update_edge(self, u, v, edge_func=None, batchable=False):
"""Update representation on edge u->v
The edge function should be compatible with following signature:
(node_reprs, node_reprs, edge_reprs) -> edge_reprs
It computes the new edge representations using the representations
of the source node, target node and the edge itself.
All node_reprs and edge_reprs are dictionaries.
Parameters
----------
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
edge_func : str or callable
The update function.
batchable : bool
Whether the function allows batched computation.
"""
self._internal_trigger_edges(u, v, __EFUNC__)
if edge_func is None:
edge_func, batchable = self._edge_func
assert edge_func is not None
if batchable:
self._batch_update_edge(u, v, edge_func)
else:
self._nonbatch_update_edge(u, v, edge_func)
def recv(self, u):
def _nonbatch_update_edge(self, u, v, edge_func):
for uu, vv in utils.edge_iter(u, v):
ret = edge_func(_get_repr(self.nodes[uu]),
_get_repr(self.nodes[vv]),
_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)
def _batch_update_edge(self, u, v, edge_func):
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
edge_id = self.cached_graph.get_edge_id(u, v)
if len(u) != len(v) and len(u) == 1:
u = F.broadcast_to(u, v)
elif len(u) != len(v) and len(v) == 1:
v = F.broadcast_to(v, u)
src_reprs = _get_repr(self._node_frame.select_rows(u))
dst_reprs = _get_repr(self._node_frame.select_rows(v))
edge_reprs = _get_repr(self._edge_frame.select_rows(edge_id))
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
_batch_set_repr(self._edge_frame, edge_id, new_edge_reprs)
def recv(self,
u,
reduce_func=None,
update_func=None,
batchable=False):
"""Receive in-coming messages and update representation on node u.
It computes the new node state using the messages sent from the predecessors
......@@ -315,22 +385,51 @@ class DGLGraph(DiGraph):
will be skipped and a None type will be provided as the reduced messages for
the update function.
The reduce function should be compatible with following signature:
(node_reprs, batched_messages) -> reduced_messages
It computes the reduced edge representations using the representations
of the in-coming edges (the same concept as messages).
The reduce function can be any of the pre-defined functions ('sum',
'max'). If built-in function is used, computation will be performed
efficiently (using generic-SPMV kernels).
The update function should be compatible with following signature:
(node_reprs, reduced_messages) -> node_reprs
It computes the new node representations using the representations
of the in-coming edges (the same concept as messages) and the node
itself. All node_reprs and edge_reprs are dictionaries.
Parameters
----------
u : node, container or tensor
The node to be updated.
reduce_func : str or callable
The reduce function.
update_func : str or callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
u_is_container = isinstance(u, list)
u_is_tensor = isinstance(u, Tensor)
rfunc = self._glb_func.get(__RFUNC__)
ufunc = self._glb_func.get(__UFUNC__)
# TODO(minjie): tensorize the loop.
if reduce_func is None:
reduce_func, batchable = self._reduce_func
if update_func is None:
update_func, batchable = self._update_func
assert reduce_func is not None
assert update_func is not None
if batchable:
self._batch_recv(u, reduce_func, update_func)
else:
self._nonbatch_recv(u, reduce_func, update_func)
def _nonbatch_recv(self, u, reduce_func, update_func):
f_reduce = _get_reduce_func(reduce_func)
f_update = update_func
for i, uu in enumerate(utils.node_iter(u)):
# TODO(minjie): tensorize the message batching
# reduce phase
f_reduce = self.nodes[uu].get(__RFUNC__, rfunc)
assert f_reduce is not None, \
"Reduce function not registered for node %s" % uu
msgs_batch = [self.edges[vv, uu].pop(__MSG__)
for vv in self.pred[uu] if __MSG__ in self.edges[vv, uu]]
if len(msgs_batch) == 0:
......@@ -338,15 +437,60 @@ class DGLGraph(DiGraph):
elif len(msgs_batch) == 1:
msgs_reduced = msgs_batch[0]
else:
msgs_reduced = f_reduce(msgs_batch)
msgs_reduced = f_reduce(_get_repr(self.nodes[uu]), msgs_batch)
# update phase
f_update = self.nodes[uu].get(__UFUNC__, ufunc)
assert f_update is not None, \
"Update function not registered for node %s" % uu
ret = f_update(self._get_repr(self.nodes[uu]), msgs_reduced)
self._set_repr(self.nodes[uu], ret)
def update_by_edge(self, u, v):
ret = f_update(_get_repr(self.nodes[uu]), msgs_reduced)
_set_repr(self.nodes[uu], ret)
def _batch_recv(self, v, reduce_func, update_func):
# sanity checks
v = utils.convert_to_id_tensor(v)
f_reduce = _get_reduce_func(reduce_func)
f_update = update_func
# degree bucketing
degrees, v_buckets = scheduler.degree_bucketing(self.msg_graph, v)
reduced_msgs = []
for deg, v_bkt in zip(degrees, v_buckets):
bkt_len = len(v_bkt)
uu, vv = self.msg_graph.in_edges(v_bkt)
in_msg_ids = self.msg_graph.get_edge_id(uu, vv)
# The in_msgs represents the rows selected. Since our storage
# is column-based, it will only be materialized when user
# tries to get the column (e.g. when user called `msgs['h']`)
in_msgs = self._msg_frame.select_rows(in_msg_ids)
# Reshape the column tensor to (B, Deg, ...).
def _reshape_fn(msg):
msg_shape = F.shape(msg)
new_shape = (bkt_len, deg) + msg_shape[1:]
return F.reshape(msg, new_shape)
if len(in_msgs) == 1 and __MSG__ in in_msgs:
reshaped_in_msgs = _reshape_fn(in_msgs[__MSG__])
else:
reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)
dst_reprs = _get_repr(self._node_frame.select_rows(v_bkt))
reduced_msgs.append(f_reduce(dst_reprs, reshaped_in_msgs))
# TODO: clear partial messages
self.clear_messages()
# Read the node states in the degree-bucketing order.
reordered_v = F.pack(v_buckets)
reordered_ns = _get_repr(self._node_frame.select_rows(reordered_v))
# Pack all reduced msgs together
if isinstance(reduced_msgs, dict):
all_reduced_msgs = {key : F.pack(val) for key, val in reduced_msgs.items()}
else:
all_reduced_msgs = F.pack(reduced_msgs)
new_ns = f_update(reordered_ns, all_reduced_msgs)
_batch_set_repr(self._node_frame, reordered_v, new_ns)
def update_by_edge(self,
u, v,
message_func=None,
reduce_func=None,
update_func=None,
batchable=False):
"""Trigger the message function on u->v and update v.
Parameters
......@@ -355,53 +499,183 @@ class DGLGraph(DiGraph):
The source node(s).
v : node, container or tensor
The destination node(s).
message_func : str or callable
The message function.
reduce_func : str or callable
The reduce function.
update_func : str or callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
self.sendto(u, v)
# TODO(minjie): tensorize the following loops.
if message_func is None:
message_func, batchable = self._message_func
if reduce_func is None:
reduce_func, batchable = self._reduce_func
if update_func is None:
update_func, batchable = self._update_func
assert message_func is not None
assert reduce_func is not None
assert update_func is not None
if batchable:
self._batch_update_by_edge(
u, v, message_func, reduce_func, update_func)
else:
self._nonbatch_update_by_edge(
u, v, message_func, reduce_func, update_func)
def _nonbatch_update_by_edge(
self,
u, v,
message_func=None,
reduce_func=None,
update_func=None):
self._nonbatch_sendto(u, v, message_func)
dst = set()
for uu, vv in utils.edge_iter(u, v):
dst.add(vv)
self.recv(list(dst))
def update_to(self, u):
self._nonbatch_recv(list(dst), reduce_func, update_func)
def _batch_update_by_edge(
self,
u, v,
message_func=None,
reduce_func=None,
update_func=None):
if message_func == 'from_src' and reduce_func == 'sum':
# Specialized to generic-SPMV
raise NotImplementedError('SPVM specialization')
else:
self._batch_sendto(u, v, message_func)
unique_v = F.unique(v)
self._batch_recv(unique_v, reduce_func, update_func)
def update_to(self,
v,
message_func=None,
reduce_func=None,
update_func=None,
batchable=False):
"""Pull messages from the node's predecessors and then update it.
Parameters
----------
u : node, container or tensor
v : node, container or tensor
The node to be updated.
message_func : str or callable
The message function.
reduce_func : str or callable
The reduce function.
update_func : str or callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
# TODO(minjie): tensorize the following code.
for uu in utils.node_iter(u):
assert uu in self.nodes
preds = list(self.pred[uu])
self.sendto(preds, uu)
self.recv(uu)
def update_from(self, u):
if message_func is None:
message_func, batchable = self._message_func
if reduce_func is None:
reduce_func, batchable = self._reduce_func
if update_func is None:
update_func, batchable = self._update_func
assert message_func is not None
assert reduce_func is not None
assert update_func is not None
if batchable:
uu, vv = self.cached_graph.in_edges(v)
self.update_by_edge(uu, vv, message_func,
reduce_func, update_func, batchable)
else:
for vv in utils.node_iter(v):
assert vv in self.nodes
uu = list(self.pred[vv])
self.sendto(uu, vv, message_func, batchable)
self.recv(vv, reduce_func, update_func, batchable)
def update_from(self,
u,
message_func=None,
reduce_func=None,
update_func=None,
batchable=False):
"""Send message from the node to its successors and update them.
Parameters
----------
u : node, container or tensor
The node that sends out messages.
message_func : str or callable
The message function.
reduce_func : str or callable
The reduce function.
update_func : str or callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
# TODO(minjie): tensorize the following code.
for uu in utils.node_iter(u):
assert uu in self.nodes
for v in self.succ[uu]:
self.update_by_edge(uu, v)
def update_all(self):
if message_func is None:
message_func, batchable = self._message_func
if reduce_func is None:
reduce_func, batchable = self._reduce_func
if update_func is None:
update_func, batchable = self._update_func
assert message_func is not None
assert reduce_func is not None
assert update_func is not None
if batchable:
uu, vv = self.cached_graph.out_edges(u)
self.update_by_edge(uu, vv, message_func,
reduce_func, update_func, batchable)
else:
for uu in utils.node_iter(u):
assert uu in self.nodes
for v in self.succ[uu]:
self.update_by_edge(uu, v,
message_func, reduce_func, update_func, batchable)
def update_all(self,
message_func=None,
reduce_func=None,
update_func=None,
batchable=False):
"""Send messages through all the edges and update all nodes.
"""
# TODO(minjie): tensorize the following code.
u = [uu for uu, _ in self.edges]
v = [vv for _, vv in self.edges]
self.sendto(u, v)
self.recv(list(self.nodes()))
def propagate(self, iterator='bfs', **kwargs):
Parameters
----------
message_func : str or callable
The message function.
reduce_func : str or callable
The reduce function.
update_func : str or callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
if message_func is None:
message_func, batchable = self._message_func
if reduce_func is None:
reduce_func, batchable = self._reduce_func
if update_func is None:
update_func, batchable = self._update_func
assert message_func is not None
assert reduce_func is not None
assert update_func is not None
if batchable:
u, v = self.cached_graph.edges()
self._batch_update_by_edge(u, v,
message_func, reduce_func, update_func)
else:
u = [uu for uu, _ in self.edges]
v = [vv for _, vv in self.edges]
self._nonbatch_sendto(u, v, message_func)
self._nonbatch_recv(list(self.nodes()), reduce_func, update_func)
def propagate(self,
message_func=None,
reduce_func=None,
update_func=None,
batchable=False,
iterator='bfs',
**kwargs):
"""Propagate messages and update nodes using iterator.
A convenient function for passing messages and updating
......@@ -413,6 +687,14 @@ class DGLGraph(DiGraph):
Parameters
----------
message_func : str or callable
The message function.
reduce_func : str or callable
The reduce function.
update_func : str or callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
iterator : str or generator of steps.
The iterator of the graph.
kwargs : keyword arguments, optional
......@@ -424,7 +706,8 @@ class DGLGraph(DiGraph):
else:
# NOTE: the iteration can return multiple edges at each step.
for u, v in iterator:
self.update_by_edge(u, v)
self.update_by_edge(u, v,
message_func, reduce_func, update_func, batchable)
def draw(self):
"""Plot the graph using dot."""
......@@ -433,50 +716,69 @@ class DGLGraph(DiGraph):
pos = graphviz_layout(self, prog='dot')
nx.draw(self, pos, with_labels=True)
def _nodes_or_all(self, nodes='all'):
return self.nodes() if nodes == 'all' else nodes
def _edges_or_all(self, edges='all'):
return self.edges() if edges == 'all' else edges
def _get_repr(self, states):
if len(states) == 1 and __REPR__ in states:
return states[__REPR__]
else:
return states
def _set_repr(self, states, val):
if isinstance(val, dict):
states.update(val)
else:
states[__REPR__] = val
def _internal_register_node(self, name, func, nodes, batchable):
# TODO(minjie): handle batchable
# TODO(minjie): group nodes based on their registered func
if nodes == 'all':
self._glb_func[name] = func
@property
def cached_graph(self):
# TODO: dirty flag when mutated
if self._cached_graph is None:
self._cached_graph = create_cached_graph(self)
return self._cached_graph
@property
def msg_graph(self):
# TODO: dirty flag when mutated
if self._msg_graph is None:
self._msg_graph = CachedGraph()
self._msg_graph.add_nodes(self.number_of_nodes())
return self._msg_graph
def clear_messages(self):
if self._msg_graph is not None:
self._msg_graph = CachedGraph()
self._msg_graph.add_nodes(self.number_of_nodes())
self._msg_frame.clear()
def _nodes_or_all(self, nodes):
return self.nodes() if nodes == ALL else nodes
def _edges_or_all(self, edges):
return self.edges() if edges == ALL else edges
def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict:
return attr_dict[__REPR__]
else:
return attr_dict
def _set_repr(attr_dict, attr):
if isinstance(attr, dict):
attr_dict.update(attr)
else:
attr_dict[__REPR__] = attr
def _batch_set_repr(frame, rows, attr):
if isinstance(attr, dict):
frame.update_rows(rows, attr)
else:
frame.update_rows(rows, {__REPR__ : attr})
def _get_reduce_func(reduce_func):
if isinstance(reduce_func, str):
# built-in reduce func
if reduce_func == 'sum':
return builtin.reduce_sum
elif reduce_func == 'max':
return builtin.reduce_max
else:
for n in nodes:
self.nodes[n][name] = func
def _internal_register_edge(self, name, func, edges, batchable):
# TODO(minjie): handle batchable
# TODO(minjie): group edges based on their registered func
if edges == 'all':
self._glb_func[name] = func
raise ValueError(
"Unknown built-in reduce function: %s" % reduce_func)
return reduce_func
def _get_message_func(message_func):
if isinstance(message_func, str):
# built-in message func
if message_func == 'from_src':
return builtin.message_from_src
else:
for e in edges:
self.edges[e][name] = func
def _internal_trigger_edges(self, u, v, name):
# TODO(minjie): tensorize the loop.
efunc = self._glb_func.get(name)
for uu, vv in utils.edge_iter(u, v):
f_edge = self.edges[uu, vv].get(name, efunc)
assert f_edge is not None, \
"edge function \"%s\" not registered for edge (%s->%s)" % (name, uu, vv)
ret = f_edge(self._get_repr(self.nodes[uu]),
self._get_repr(self.nodes[vv]),
self._get_repr(self.edges[uu, vv]))
self._set_repr(self.edges[uu, vv], ret)
raise ValueError(
"Unknown built-in message function: %s" % message_func)
return message_func
"""Schedule policies for graph computation."""
from __future__ import absolute_import
import dgl.backend as F
def degree_bucketing(cached_graph, v):
degrees = cached_graph.in_degrees(v)
unique_degrees = list(F.asnumpy(F.unique(degrees)))
v_bkt = []
for deg in unique_degrees:
idx = F.squeeze(F.nonzero(F.eq_scalar(degrees, deg)), 1)
v_bkt.append(v[idx])
return unique_degrees, v_bkt
"""Utility module."""
from __future__ import absolute_import
from collections import Mapping
import dgl.backend as F
from dgl.backend import Tensor
from dgl.backend import Tensor, SparseTensor
def is_id_tensor(u):
return isinstance(u, Tensor) and F.isinteger(u) and len(F.shape(u)) == 1
def is_id_container(u):
return isinstance(u, list)
def node_iter(n):
n_is_container = isinstance(n, list)
n_is_tensor = isinstance(n, Tensor)
if n_is_tensor:
n = F.asnumpy(n)
n_is_tensor = False
n_is_container = True
if n_is_container:
if is_id_tensor(n):
n = list(F.asnumpy(n))
if is_id_container(n):
for nn in n:
yield nn
else:
yield n
def edge_iter(u, v):
u_is_container = isinstance(u, list)
v_is_container = isinstance(v, list)
u_is_tensor = isinstance(u, Tensor)
v_is_tensor = isinstance(v, Tensor)
u_is_container = is_id_container(u)
v_is_container = is_id_container(v)
u_is_tensor = is_id_tensor(u)
v_is_tensor = is_id_tensor(v)
if u_is_tensor:
u = F.asnumpy(u)
u_is_tensor = False
......@@ -41,3 +47,43 @@ def edge_iter(u, v):
yield u, vv
else:
yield u, v
def homogeneous(x_list, type_x=None):
type_x = type_x if type_x else type(x_list[0])
return all(type(x) == type_x for x in x_list)
def convert_to_id_tensor(x):
if is_id_container(x):
assert homogeneous(x, int)
return F.tensor(x)
elif is_id_tensor(x):
return x
elif isinstance(x, int):
x = F.tensor([x])
return x
else:
raise TypeError('Error node: %s' % str(x))
return None
class LazyDict(Mapping):
"""A readonly dictionary that does not materialize the storage."""
def __init__(self, fn, keys):
self._fn = fn
self._keys = keys
def keys(self):
return self._keys
def __getitem__(self, key):
assert key in self._keys
return self._fn(key)
def __contains__(self, key):
return key in self._keys
def __iter__(self):
for key in self._keys:
yield key, self._fn(key)
def __len__(self):
return len(self._keys)
......@@ -17,6 +17,7 @@ setuptools.setup(
'numpy>=1.14.0',
'scipy>=1.1.0',
'networkx>=2.1',
'python-igraph>=0.7.0',
],
data_files=[('', ['VERSION'])],
url='https://github.com/jermainewang/dgl-1')
import torch as th
from dgl.graph import DGLGraph
D = 32
def update_func(hv, accum):
assert hv.shape == accum.shape
return hv + accum
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
# TODO: use internal interface to set data.
col = th.randn(10, D)
g._node_frame['h'] = col
return g
def test_spmv_specialize():
g = generate_graph()
g.register_message_func('from_src', batchable=True)
g.register_reduce_func('sum', batchable=True)
g.register_update_func(update_func, batchable=True)
g.update_all()
if __name__ == '__main__':
test_spmv_specialize()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment