Unverified Commit 9c135fd5 authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

Merge pull request #4 from jermainewang/master

Sync with latest commit
parents 9d3f299d 00add9f2
"""Columnar storage for graph attributes.""" """Columnar storage for DGLGraph."""
from __future__ import absolute_import from __future__ import absolute_import
from collections import MutableMapping from collections import MutableMapping
...@@ -6,178 +6,598 @@ import numpy as np ...@@ -6,178 +6,598 @@ import numpy as np
from . import backend as F from . import backend as F
from .backend import Tensor from .backend import Tensor
from .base import DGLError, dgl_warning
from . import utils from . import utils
class Scheme(object):
"""The column scheme.
Parameters
----------
shape : tuple of int
The feature shape.
dtype : TVMType
The feature data type.
"""
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
def __repr__(self):
return '{shape=%s, dtype=%s}' % (repr(self.shape), repr(self.dtype))
def __eq__(self, other):
return self.shape == other.shape and self.dtype == other.dtype
def __ne__(self, other):
return not self.__eq__(other)
@staticmethod
def infer_scheme(tensor):
"""Infer the scheme of the given tensor."""
return Scheme(tuple(F.shape(tensor)[1:]), F.get_tvmtype(tensor))
class Column(object):
"""A column is a compact store of features of multiple nodes/edges.
Currently, we use one dense tensor to batch all the feature tensors
together (along the first dimension).
Parameters
----------
data : Tensor
The initial data of the column.
scheme : Scheme, optional
The scheme of the column. Will be inferred if not provided.
"""
def __init__(self, data, scheme=None):
self.data = data
self.scheme = scheme if scheme else Scheme.infer_scheme(data)
def __len__(self):
"""The column length."""
return F.shape(self.data)[0]
def __getitem__(self, idx):
"""Return the feature data given the index.
Parameters
----------
idx : utils.Index
The index.
Returns
-------
Tensor
The feature data
"""
user_idx = idx.tousertensor(F.get_context(self.data))
return F.gather_row(self.data, user_idx)
def __setitem__(self, idx, feats):
"""Update the feature data given the index.
The update is performed out-placely so it can be used in autograd mode.
For inplace write, please use ``update``.
Parameters
----------
idx : utils.Index
The index.
feats : Tensor
The new features.
"""
self.update(idx, feats, inplace=False)
def update(self, idx, feats, inplace):
"""Update the feature data given the index.
Parameters
----------
idx : utils.Index
The index.
feats : Tensor
The new features.
inplace : bool
If true, use inplace write.
"""
feat_scheme = Scheme.infer_scheme(feats)
if feat_scheme != self.scheme:
raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme))
user_idx = idx.tousertensor(F.get_context(self.data))
if inplace:
# TODO(minjie): do not use [] operator directly
self.data[user_idx] = feats
else:
self.data = F.scatter_row(self.data, user_idx, feats)
@staticmethod
def create(data):
"""Create a new column using the given data."""
if isinstance(data, Column):
return Column(data.data)
else:
return Column(data)
class Frame(MutableMapping): class Frame(MutableMapping):
"""The columnar storage for node/edge features.
The frame is a dictionary from feature fields to feature columns.
All columns should have the same number of rows (i.e. the same first dimension).
Parameters
----------
data : dict-like, optional
The frame data in dictionary. If the provided data is another frame,
this frame will NOT share columns with the given frame. So any out-place
update on one will not reflect to the other. The inplace update will
be seen by both. This follows the semantic of python's container.
"""
def __init__(self, data=None): def __init__(self, data=None):
if data is None: if data is None:
self._columns = dict() self._columns = dict()
self._num_rows = 0 self._num_rows = 0
else: else:
self._columns = dict(data) # Note that we always create a new column for the given data.
self._num_rows = F.shape(list(data.values())[0])[0] # This avoids two frames accidentally sharing the same column.
for k, v in data.items(): self._columns = {k : Column.create(v) for k, v in data.items()}
assert F.shape(v)[0] == self._num_rows if len(self._columns) != 0:
self._num_rows = len(next(iter(self._columns.values())))
else:
self._num_rows = 0
# sanity check
for name, col in self._columns.items():
if len(col) != self._num_rows:
raise DGLError('Expected all columns to have same # rows (%d), '
'got %d on %r.' % (self._num_rows, len(col), name))
# Initializer for empty values. Initializer is a callable.
# If is none, then a warning will be raised
# in the first call and zero initializer will be used later.
self._initializer = None
def set_initializer(self, initializer):
"""Set the initializer for empty values.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self._initializer = initializer
@property
def initializer(self):
"""Return the initializer of this frame."""
return self._initializer
@property @property
def schemes(self): def schemes(self):
return set(self._columns.keys()) """Return a dictionary of column name to column schemes."""
return {k : col.scheme for k, col in self._columns.items()}
@property @property
def num_columns(self): def num_columns(self):
"""Return the number of columns in this frame."""
return len(self._columns) return len(self._columns)
@property @property
def num_rows(self): def num_rows(self):
"""Return the number of rows in this frame."""
return self._num_rows return self._num_rows
def __contains__(self, key): def __contains__(self, name):
return key in self._columns """Return true if the given column name exists."""
return name in self._columns
def __getitem__(self, key):
# get column def __getitem__(self, name):
return self._columns[key] """Return the column of the given name.
def __setitem__(self, key, val): Parameters
# set column ----------
self.add_column(key, val) name : str
The column name.
def __delitem__(self, key):
# delete column Returns
del self._columns[key] -------
Column
The column.
"""
return self._columns[name]
def __setitem__(self, name, data):
"""Update the whole column.
Parameters
----------
name : str
The column name.
col : Column or data convertible to Column
The column data.
"""
self.update_column(name, data)
def __delitem__(self, name):
"""Delete the whole column.
Parameters
----------
name : str
The column name.
"""
del self._columns[name]
if len(self._columns) == 0: if len(self._columns) == 0:
self._num_rows = 0 self._num_rows = 0
def add_column(self, name, col): def add_column(self, name, scheme, ctx):
"""Add a new column to the frame.
The frame will be initialized by the initializer.
Parameters
----------
name : str
The column name.
scheme : Scheme
The column scheme.
ctx : TVMContext
The column context.
"""
if name in self:
dgl_warning('Column "%s" already exists. Ignore adding this column again.' % name)
return
if self.num_rows == 0:
raise DGLError('Cannot add column "%s" using column schemes because'
' number of rows is unknown. Make sure there is at least'
' one column in the frame so number of rows can be inferred.')
if self.initializer is None:
dgl_warning('Initializer is not set. Use zero initializer instead.'
' To suppress this warning, use `set_initializer` to'
' explicitly specify which initializer to use.')
# TODO(minjie): handle data type
self.set_initializer(lambda shape, dtype : F.zeros(shape))
# TODO(minjie): directly init data on the targer device.
init_data = self.initializer((self.num_rows,) + scheme.shape, scheme.dtype)
init_data = F.to_context(init_data, ctx)
self._columns[name] = Column(init_data, scheme)
def update_column(self, name, data):
"""Add or replace the column with the given name and data.
Parameters
----------
name : str
The column name.
data : Column or data convertible to Column
The column data.
"""
col = Column.create(data)
if self.num_columns == 0: if self.num_columns == 0:
self._num_rows = F.shape(col)[0] self._num_rows = len(col)
else: elif len(col) != self._num_rows:
assert F.shape(col)[0] == self._num_rows raise DGLError('Expected data to have %d rows, got %d.' %
(self._num_rows, len(col)))
self._columns[name] = col self._columns[name] = col
def append(self, other): def append(self, other):
"""Append another frame's data into this frame.
If the current frame is empty, it will just use the columns of the
given frame. Otherwise, the given data should contain all the
column keys of this frame.
Parameters
----------
other : Frame or dict-like
The frame data to be appended.
"""
if not isinstance(other, Frame):
other = Frame(other)
if len(self._columns) == 0: if len(self._columns) == 0:
for key, col in other.items(): for key, col in other.items():
self._columns[key] = col self._columns[key] = col
self._num_rows = other.num_rows
else: else:
for key, col in other.items(): for key, col in other.items():
self._columns[key] = F.pack([self[key], col]) sch = self._columns[key].scheme
# TODO(minjie): sanity check for num_rows other_sch = col.scheme
if len(self._columns) != 0: if sch != other_sch:
self._num_rows = F.shape(list(self._columns.values())[0])[0] raise DGLError("Cannot append column of scheme %s to column of scheme %s."
% (other_scheme, sch))
self._columns[key].data = F.pack(
[self._columns[key].data, col.data])
self._num_rows += other.num_rows
def clear(self): def clear(self):
"""Clear this frame. Remove all the columns."""
self._columns = {} self._columns = {}
self._num_rows = 0 self._num_rows = 0
def __iter__(self): def __iter__(self):
"""Return an iterator of columns."""
return iter(self._columns) return iter(self._columns)
def __len__(self): def __len__(self):
"""Return the number of columns."""
return self.num_columns return self.num_columns
def keys(self):
"""Return the keys."""
return self._columns.keys()
class FrameRef(MutableMapping): class FrameRef(MutableMapping):
"""Frame reference """Reference object to a frame on a subset of rows.
Parameters Parameters
---------- ----------
frame : dgl.frame.Frame frame : Frame, optional
The underlying frame. The underlying frame. If not given, the reference will point to a
index : iterable of int new empty frame.
The rows that are referenced in the underlying frame. index : iterable of int, optional
The rows that are referenced in the underlying frame. If not given,
the whole frame is referenced. The index should be distinct (no
duplication is allowed).
""" """
def __init__(self, frame=None, index=None): def __init__(self, frame=None, index=None):
self._frame = frame if frame is not None else Frame() self._frame = frame if frame is not None else Frame()
if index is None: if index is None:
self._index_data = slice(0, self._frame.num_rows) self._index_data = slice(0, self._frame.num_rows)
else: else:
# check no duplication # TODO(minjie): check no duplication
assert len(index) == len(np.unique(index))
self._index_data = index self._index_data = index
self._index = None self._index = None
@property @property
def schemes(self): def schemes(self):
"""Return the frame schemes.
Returns
-------
dict of str to Scheme
The frame schemes.
"""
return self._frame.schemes return self._frame.schemes
@property @property
def num_columns(self): def num_columns(self):
"""Return the number of columns in the referred frame."""
return self._frame.num_columns return self._frame.num_columns
@property @property
def num_rows(self): def num_rows(self):
"""Return the number of rows referred."""
if isinstance(self._index_data, slice): if isinstance(self._index_data, slice):
# NOTE: we are assuming that the index is a slice ONLY IF
# index=None during construction.
# As such, start is always 0, and step is always 1.
return self._index_data.stop return self._index_data.stop
else: else:
return len(self._index_data) return len(self._index_data)
def __contains__(self, key): def set_initializer(self, initializer):
return key in self._frame """Set the initializer for empty values.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self._frame.set_initializer(initializer)
def index(self):
"""Return the index object.
Returns
-------
utils.Index
The index.
"""
if self._index is None:
if self.is_contiguous():
self._index = utils.toindex(
F.arange(self._index_data.stop, dtype=F.int64))
else:
self._index = utils.toindex(self._index_data)
return self._index
def __contains__(self, name):
"""Return whether the column name exists."""
return name in self._frame
def __iter__(self):
"""Return the iterator of the columns."""
return iter(self._frame)
def __len__(self):
"""Return the number of columns."""
return self.num_columns
def keys(self):
"""Return the keys."""
return self._frame.keys()
def __getitem__(self, key): def __getitem__(self, key):
"""Get data from the frame.
If the provided key is string, the corresponding column data will be returned.
If the provided key is an index, the corresponding rows will be selected. The
returned rows are saved in a lazy dictionary so only the real selection happens
when the explicit column name is provided.
Examples (using pytorch)
------------------------
>>> # create a frame of two columns and five rows
>>> f = Frame({'c1' : torch.zeros([5, 2]), 'c2' : torch.ones([5, 2])})
>>> fr = FrameRef(f)
>>> # select the row 1 and 2, the returned `rows` is a lazy dictionary.
>>> rows = fr[Index([1, 2])]
>>> rows['c1'] # only select rows for 'c1' column; 'c2' column is not sliced.
Parameters
----------
key : str or utils.Index
The key.
Returns
-------
Tensor or lazy dict or tensors
Depends on whether it is a column selection or row selection.
"""
if isinstance(key, str): if isinstance(key, str):
return self.get_column(key) return self.select_column(key)
else: else:
return self.select_rows(key) return self.select_rows(key)
def select_rows(self, query): def select_column(self, name):
rowids = self._getrowid(query) """Return the column of the given name.
def _lazy_select(key):
idx = rowids.tousertensor(F.get_context(self._frame[key])) If only part of the rows are referenced, the fetching the whole column will
return F.gather_row(self._frame[key], idx) also slice out the referenced rows.
return utils.LazyDict(_lazy_select, keys=self.schemes)
Parameters
----------
name : str
The column name.
def get_column(self, name): Returns
-------
Tensor
The column data.
"""
col = self._frame[name] col = self._frame[name]
if self.is_span_whole_column(): if self.is_span_whole_column():
return col return col.data
else: else:
idx = self.index().tousertensor(F.get_context(col)) return col[self.index()]
return F.gather_row(col, idx)
def select_rows(self, query):
"""Return the rows given the query.
Parameters
----------
query : utils.Index
The rows to be selected.
Returns
-------
utils.LazyDict
The lazy dictionary from str to the selected data.
"""
rowids = self._getrowid(query)
return utils.LazyDict(lambda key: self._frame[key][rowids], keys=self.keys())
def __setitem__(self, key, val): def __setitem__(self, key, val):
"""Update the data in the frame.
If the provided key is string, the corresponding column data will be updated.
The provided value should be one tensor that have the same scheme and length
as the column.
If the provided key is an index, the corresponding rows will be updated. The
value provided should be a dictionary of string to the data of each column.
All updates are performed out-placely to be work with autograd. For inplace
update, use ``update_column`` or ``update_rows``.
Parameters
----------
key : str or utils.Index
The key.
val : Tensor or dict of tensors
The value.
"""
if isinstance(key, str): if isinstance(key, str):
self.add_column(key, val) self.update_column(key, val, inplace=False)
else: else:
self.update_rows(key, val) self.update_rows(key, val, inplace=False)
def add_column(self, name, col, inplace=False): def update_column(self, name, data, inplace):
shp = F.shape(col) """Update the column.
If this frameref spans the whole column of the underlying frame, this is
equivalent to update the column of the frame.
If this frameref only points to part of the rows, then update the column
here will correspond to update part of the column in the frame. Raise error
if the given column name does not exist.
Parameters
----------
name : str
The column name.
data : Tensor
The update data.
inplace : bool
True if the update is performed inplacely.
"""
if self.is_span_whole_column(): if self.is_span_whole_column():
col = Column.create(data)
if self.num_columns == 0: if self.num_columns == 0:
self._index_data = slice(0, shp[0]) # the frame is empty
self._index_data = slice(0, len(col))
self._clear_cache() self._clear_cache()
assert shp[0] == self.num_rows
self._frame[name] = col self._frame[name] = col
else: else:
colctx = F.get_context(col) if name not in self._frame:
if name in self._frame: feat_shape = F.shape(data)[1:]
fcol = self._frame[name] feat_dtype = F.get_tvmtype(data)
else: ctx = F.get_context(data)
fcol = F.zeros((self._frame.num_rows,) + shp[1:]) self._frame.add_column(name, Scheme(feat_shape, feat_dtype), ctx)
fcol = F.to_context(fcol, colctx) #raise DGLError('Cannot update column. Column "%s" does not exist.'
idx = self.index().tousertensor(colctx) # ' Did you forget to init the column using `set_n_repr`'
if inplace: # ' or `set_e_repr`?' % name)
self._frame[name] = fcol fcol = self._frame[name]
self._frame[name][idx] = col fcol.update(self.index(), data, inplace)
else:
newfcol = F.scatter_row(fcol, idx, col) def update_rows(self, query, data, inplace):
self._frame[name] = newfcol """Update the rows.
def update_rows(self, query, other, inplace=False): If the provided data has new column, it will be added to the frame.
See Also
--------
``update_column``
Parameters
----------
query : utils.Index
The rows to be updated.
data : dict-like
The row data.
inplace : bool
True if the update is performed inplacely.
"""
rowids = self._getrowid(query) rowids = self._getrowid(query)
for key, col in other.items(): for key, col in data.items():
if key not in self: if key not in self:
# add new column # add new column
tmpref = FrameRef(self._frame, rowids) tmpref = FrameRef(self._frame, rowids)
tmpref.add_column(key, col, inplace) tmpref.update_column(key, col, inplace)
idx = rowids.tousertensor(F.get_context(self._frame[key])) #raise DGLError('Cannot update rows. Column "%s" does not exist.'
if inplace: # ' Did you forget to init the column using `set_n_repr`'
self._frame[key][idx] = col # ' or `set_e_repr`?' % key)
else: else:
self._frame[key] = F.scatter_row(self._frame[key], idx, col) self._frame[key].update(rowids, col, inplace)
def __delitem__(self, key): def __delitem__(self, key):
"""Delete data in the frame.
If the provided key is a string, the corresponding column will be deleted.
If the provided key is an index object, the corresponding rows will be deleted.
Please note that "deleted" rows are not really deleted, but simply removed
in the reference. As a result, if two FrameRefs point to the same Frame, deleting
from one ref will not relect on the other. By contrast, deleting columns is real.
Parameters
----------
key : str or utils.Index
The key.
"""
if isinstance(key, str): if isinstance(key, str):
del self._frame[key] del self._frame[key]
if len(self._frame) == 0: if len(self._frame) == 0:
...@@ -186,7 +606,18 @@ class FrameRef(MutableMapping): ...@@ -186,7 +606,18 @@ class FrameRef(MutableMapping):
self.delete_rows(key) self.delete_rows(key)
def delete_rows(self, query): def delete_rows(self, query):
query = F.asnumpy(query) """Delete rows.
Please note that "deleted" rows are not really deleted, but simply removed
in the reference. As a result, if two FrameRefs point to the same Frame, deleting
from one ref will not relect on the other. By contrast, deleting columns is real.
Parameters
----------
query : utils.Index
The rows to be deleted.
"""
query = query.tolist()
if isinstance(self._index_data, slice): if isinstance(self._index_data, slice):
self._index_data = list(range(self._index_data.start, self._index_data.stop)) self._index_data = list(range(self._index_data.start, self._index_data.stop))
arr = np.array(self._index_data, dtype=np.int32) arr = np.array(self._index_data, dtype=np.int32)
...@@ -194,6 +625,13 @@ class FrameRef(MutableMapping): ...@@ -194,6 +625,13 @@ class FrameRef(MutableMapping):
self._clear_cache() self._clear_cache()
def append(self, other): def append(self, other):
"""Append another frame into this one.
Parameters
----------
other : dict of str to tensor
The data to be appended.
"""
span_whole = self.is_span_whole_column() span_whole = self.is_span_whole_column()
contiguous = self.is_contiguous() contiguous = self.is_contiguous()
old_nrows = self._frame.num_rows old_nrows = self._frame.num_rows
...@@ -208,24 +646,23 @@ class FrameRef(MutableMapping): ...@@ -208,24 +646,23 @@ class FrameRef(MutableMapping):
self._clear_cache() self._clear_cache()
def clear(self): def clear(self):
"""Clear the frame."""
self._frame.clear() self._frame.clear()
self._index_data = slice(0, 0) self._index_data = slice(0, 0)
self._clear_cache() self._clear_cache()
def __iter__(self):
return iter(self._frame)
def __len__(self):
return self.num_columns
def is_contiguous(self): def is_contiguous(self):
# NOTE: this check could have false negative """Return whether this refers to a contiguous range of rows."""
# NOTE: this check could have false negatives and false positives
# (step other than 1)
return isinstance(self._index_data, slice) return isinstance(self._index_data, slice)
def is_span_whole_column(self): def is_span_whole_column(self):
"""Return whether this refers to all the rows."""
return self.is_contiguous() and self.num_rows == self._frame.num_rows return self.is_contiguous() and self.num_rows == self._frame.num_rows
def _getrowid(self, query): def _getrowid(self, query):
"""Internal function to convert from the local row ids to the row ids of the frame."""
if self.is_contiguous(): if self.is_contiguous():
# shortcut for identical mapping # shortcut for identical mapping
return query return query
...@@ -233,16 +670,8 @@ class FrameRef(MutableMapping): ...@@ -233,16 +670,8 @@ class FrameRef(MutableMapping):
idxtensor = self.index().tousertensor() idxtensor = self.index().tousertensor()
return utils.toindex(F.gather_row(idxtensor, query.tousertensor())) return utils.toindex(F.gather_row(idxtensor, query.tousertensor()))
def index(self):
if self._index is None:
if self.is_contiguous():
self._index = utils.toindex(
F.arange(self._index_data.stop, dtype=F.int64))
else:
self._index = utils.toindex(self._index_data)
return self._index
def _clear_cache(self): def _clear_cache(self):
"""Internal function to clear the cached object."""
self._index_tensor = None self._index_tensor = None
def merge_frames(frames, indices, max_index, reduce_func): def merge_frames(frames, indices, max_index, reduce_func):
...@@ -267,6 +696,8 @@ def merge_frames(frames, indices, max_index, reduce_func): ...@@ -267,6 +696,8 @@ def merge_frames(frames, indices, max_index, reduce_func):
merged : FrameRef merged : FrameRef
The merged frame. The merged frame.
""" """
# TODO(minjie)
assert False, 'Buggy code, disabled for now.'
assert reduce_func == 'sum' assert reduce_func == 'sum'
assert len(frames) > 0 assert len(frames) > 0
schemes = frames[0].schemes schemes = frames[0].schemes
......
...@@ -4,17 +4,25 @@ from __future__ import absolute_import ...@@ -4,17 +4,25 @@ from __future__ import absolute_import
import operator import operator
import dgl.backend as F import dgl.backend as F
__all__ = ["MessageFunction", "src_mul_edge", "copy_src", "copy_edge"] __all__ = ["src_mul_edge", "copy_src", "copy_edge"]
class MessageFunction(object): class MessageFunction(object):
"""Base builtin message function class."""
def __call__(self, src, edge): def __call__(self, src, edge):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise NotImplementedError raise NotImplementedError
def name(self): def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError raise NotImplementedError
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
"""Return whether the SPMV optimization is supported."""
raise NotImplementedError raise NotImplementedError
...@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction): ...@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction):
def __init__(self, fn_list): def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)): if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list] fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
# cannot perform check for udf
if isinstance(fn, MessageFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple message is ambiguous")
self.fn_list = fn_list self.fn_list = fn_list
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
...@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction): ...@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction):
if ret is None: if ret is None:
ret = msg ret = msg
else: else:
try: # ret and msg must be dict
# ret and msg must be dict ret.update(msg)
ret.update(msg)
except:
raise RuntimeError("Must specify out field for multiple message")
return ret return ret
def name(self): def name(self):
...@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction): ...@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction):
def _is_spmv_supported_node_feat(g, field): def _is_spmv_supported_node_feat(g, field):
if field is None: """Return whether the node feature shape supports SPMV optimization.
feat = g.get_n_repr()
else: Only scalar and vector features are supported currently.
feat = g.get_n_repr()[field] """
feat = g.get_n_repr()[field]
shape = F.shape(feat) shape = F.shape(feat)
return len(shape) == 1 or len(shape) == 2 return len(shape) == 1 or len(shape) == 2
def _is_spmv_supported_edge_feat(g, field): def _is_spmv_supported_edge_feat(g, field):
# check shape, only scalar edge feature can be optimized at the moment """Return whether the edge feature shape supports SPMV optimization.
if field is None:
feat = g.get_e_repr() Only scalar feature is supported currently.
else: """
feat = g.get_e_repr()[field] feat = g.get_e_repr()[field]
shape = F.shape(feat) shape = F.shape(feat)
return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1) return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1)
class SrcMulEdgeMessageFunction(MessageFunction): class SrcMulEdgeMessageFunction(MessageFunction):
def __init__(self, mul_op, src_field=None, edge_field=None, out_field=None): def __init__(self, mul_op, src_field, edge_field, out_field):
self.mul_op = mul_op self.mul_op = mul_op
self.src_field = src_field self.src_field = src_field
self.edge_field = edge_field self.edge_field = edge_field
...@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction): ...@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction):
and _is_spmv_supported_edge_feat(g, self.edge_field) and _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.src_field is not None: ret = self.mul_op(src[self.src_field], edge[self.edge_field])
src = src[self.src_field] return {self.out_field : ret}
if self.edge_field is not None:
edge = edge[self.edge_field]
ret = self.mul_op(src, edge)
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self): def name(self):
return "src_mul_edge" return "src_mul_edge"
class CopySrcMessageFunction(MessageFunction): class CopySrcMessageFunction(MessageFunction):
def __init__(self, src_field=None, out_field=None): def __init__(self, src_field, out_field):
self.src_field = src_field self.src_field = src_field
self.out_field = out_field self.out_field = out_field
...@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction): ...@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction):
return _is_spmv_supported_node_feat(g, self.src_field) return _is_spmv_supported_node_feat(g, self.src_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.src_field is not None: return {self.out_field : src[self.src_field]}
ret = src[self.src_field]
else:
ret = src
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self): def name(self):
return "copy_src" return "copy_src"
...@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction): ...@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction):
return "copy_edge" return "copy_edge"
def src_mul_edge(src=None, edge=None, out=None): def src_mul_edge(src, edge, out):
"""TODO(minjie): docstring """ """Builtin message function that computes message by multiplying source node features
with edge features.
Parameters
----------
src : str
The source feature name.
edge : str
The edge feature name.
out : str
The output message name.
"""
return SrcMulEdgeMessageFunction(operator.mul, src, edge, out) return SrcMulEdgeMessageFunction(operator.mul, src, edge, out)
def copy_src(src=None, out=None): def copy_src(src, out):
"""TODO(minjie): docstring """ """Builtin message function that computes message using source node feature.
Parameters
----------
src : str
The source feature name.
out : str
The output message name.
"""
return CopySrcMessageFunction(src, out) return CopySrcMessageFunction(src, out)
def copy_edge(edge=None, out=None): def copy_edge(edge, out):
"""TODO(minjie): docstring """ """Builtin message function that computes message using edge feature.
Parameters
----------
edge : str
The edge feature name.
out : str
The output message name.
"""
return CopyEdgeMessageFunction(edge, out) return CopyEdgeMessageFunction(edge, out)
...@@ -3,27 +3,30 @@ from __future__ import absolute_import ...@@ -3,27 +3,30 @@ from __future__ import absolute_import
from .. import backend as F from .. import backend as F
__all__ = ["ReduceFunction", "sum", "max"] __all__ = ["sum", "max"]
class ReduceFunction(object): class ReduceFunction(object):
"""Base builtin reduce function class."""
def __call__(self, node, msgs): def __call__(self, node, msgs):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise NotImplementedError raise NotImplementedError
def name(self): def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError raise NotImplementedError
def is_spmv_supported(self): def is_spmv_supported(self):
"""Return whether the SPMV optimization is supported."""
raise NotImplementedError raise NotImplementedError
class BundledReduceFunction(ReduceFunction): class BundledReduceFunction(ReduceFunction):
def __init__(self, fn_list): def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)): if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list] fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
if isinstance(fn, ReduceFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple reduce is ambiguous")
self.fn_list = fn_list self.fn_list = fn_list
def is_spmv_supported(self): def is_spmv_supported(self):
...@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction): ...@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction):
if ret is None: if ret is None:
ret = rpr ret = rpr
else: else:
try: # ret and rpr must be dict
# ret and rpr must be dict ret.update(rpr)
ret.update(rpr)
except:
raise RuntimeError("Must specify out field for multiple reudce")
return ret return ret
def name(self): def name(self):
return "bundled" return "bundled"
class ReducerFunctionTemplate(ReduceFunction): class ReducerFunctionTemplate(ReduceFunction):
def __init__(self, name, batch_op, nonbatch_op, msg_field=None, out_field=None): def __init__(self, name, op, msg_field, out_field):
self.name = name self.name = name
self.batch_op = batch_op self.op = op
self.nonbatch_op = nonbatch_op
self.msg_field = msg_field self.msg_field = msg_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self): def is_spmv_supported(self):
# TODO: support max # NOTE: only sum is supported right now.
return self.name == "sum" return self.name == "sum"
def __call__(self, node, msgs): def __call__(self, node, msgs):
if isinstance(msgs, list): return {self.out_field : self.op(msgs[self.msg_field], 1)}
if self.msg_field is None:
ret = self.nonbatch_op(msgs)
else:
ret = self.nonbatch_op([msg[self.msg_field] for msg in msgs])
else:
if self.msg_field is None:
ret = self.batch_op(msgs, 1)
else:
ret = self.batch_op(msgs[self.msg_field], 1)
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self): def name(self):
return self.name return self.name
_python_sum = sum def sum(msg, out):
def sum(msgs=None, out=None): """Builtin reduce function that aggregates messages by sum.
return ReducerFunctionTemplate("sum", F.sum, _python_sum, msgs, out)
Parameters
----------
msg : str
The message name.
out : str
The output node feature name.
"""
return ReducerFunctionTemplate("sum", F.sum, msg, out)
def max(msg, out):
"""Builtin reduce function that aggregates messages by max.
_python_max = max Parameters
def max(msgs=None, out=None): ----------
return ReducerFunctionTemplate("max", F.max, _python_max, msgs, out) msg : str
The message name.
out : str
The output node feature name.
"""
return ReducerFunctionTemplate("max", F.max, msg, out)
"""Package for graph generators"""
from __future__ import absolute_import
from .line import *
"""Line graph generator."""
from __future__ import absolute_import
import networkx as nx
import numpy as np
from .. import backend as F
from ..graph import DGLGraph
from ..frame import FrameRef
def line_graph(G, no_backtracking=False):
"""Create the line graph that shares the underlying features.
The node features of the result line graph will share the edge features
of the given graph.
Parameters
----------
G : DGLGraph
The input graph.
no_backtracking : bool
Whether the backtracking edges are included in the line graph.
If i~j and j~i are two edges in original graph G, then
(i,j)~(j,i) and (j,i)~(i,j) are the "backtracking" edges on
the line graph.
"""
L = nx.DiGraph()
for eid, from_node in enumerate(G.edge_list):
L.add_node(from_node)
for to_node in G.edges(from_node[1]):
if no_backtracking and to_node[1] == from_node[0]:
continue
L.add_edge(from_node, to_node)
relabel_map = {}
for i, e in enumerate(G.edge_list):
relabel_map[e] = i
nx.relabel.relabel_nodes(L, relabel_map, copy=False)
return DGLGraph(L, node_frame=G._edge_frame)
...@@ -6,7 +6,7 @@ import networkx as nx ...@@ -6,7 +6,7 @@ import networkx as nx
import numpy as np import numpy as np
import dgl import dgl
from .base import ALL, is_all, __MSG__, __REPR__ from .base import ALL, is_all, DGLError, dgl_warning
from . import backend as F from . import backend as F
from .backend import Tensor from .backend import Tensor
from .frame import FrameRef, merge_frames from .frame import FrameRef, merge_frames
...@@ -22,7 +22,6 @@ class DGLGraph(object): ...@@ -22,7 +22,6 @@ class DGLGraph(object):
"""Base graph class specialized for neural networks on graphs. """Base graph class specialized for neural networks on graphs.
TODO(minjie): document of batching semantics TODO(minjie): document of batching semantics
TODO(minjie): document of __REPR__ semantics
Parameters Parameters
---------- ----------
...@@ -448,7 +447,9 @@ class DGLGraph(object): ...@@ -448,7 +447,9 @@ class DGLGraph(object):
The nx graph The nx graph
""" """
nx_graph = self._graph.to_networkx() nx_graph = self._graph.to_networkx()
#TODO: attributes #TODO(minjie): attributes
dgl_warning('to_networkx currently does not support converting'
' node/edge features automatically.')
return nx_graph return nx_graph
def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None): def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
...@@ -504,70 +505,95 @@ class DGLGraph(object): ...@@ -504,70 +505,95 @@ class DGLGraph(object):
self._msg_graph.add_nodes(self._graph.number_of_nodes()) self._msg_graph.add_nodes(self._graph.number_of_nodes())
def node_attr_schemes(self): def node_attr_schemes(self):
"""Return the node attribute schemes. """Return the node feature schemes.
Returns Returns
------- -------
iterable dict of str to schemes
The set of attribute names The schemes of node feature columns.
""" """
return self._node_frame.schemes return self._node_frame.schemes
def edge_attr_schemes(self): def edge_attr_schemes(self):
"""Return the edge attribute schemes. """Return the edge feature schemes.
Returns Returns
------- -------
iterable dict of str to schemes
The set of attribute names The schemes of edge feature columns.
""" """
return self._edge_frame.schemes return self._edge_frame.schemes
def set_n_initializer(self, initializer):
"""Set the initializer for empty node features.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self._node_frame.set_initializer(initializer)
def set_e_initializer(self, initializer):
"""Set the initializer for empty edge features.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self._edge_frame.set_initializer(initializer)
def set_n_repr(self, hu, u=ALL, inplace=False): def set_n_repr(self, hu, u=ALL, inplace=False):
"""Set node(s) representation. """Set node(s) representation.
To set multiple node representations at once, pass `u` with a tensor or `hu` is a dictionary from the feature name to feature tensor. Each tensor
a supported container of node ids. In this case, `hu` must be a tensor is of shape (B, D1, D2, ...), where B is the number of nodes to be updated,
of shape (B, D1, D2, ...), where B is the number of the nodes and and (D1, D2, ...) be the shape of the node representation tensor. The
(D1, D2, ...) is the shape of the node representation tensor. length of the given node ids must match B (i.e, len(u) == B).
Dictionary type is also supported for `hu`. In this case, each item All update will be done out-placely to work with autograd unless the inplace
will be treated as separate attribute of the nodes. flag is true.
Parameters Parameters
---------- ----------
hu : tensor or dict of tensor hu : dict of tensor
Node representation. Node representation.
u : node, container or tensor u : node, container or tensor
The node(s). The node(s).
inplace : bool
True if the update is done inplacely
""" """
# sanity check # sanity check
if not utils.is_dict_like(hu):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(hu))
if is_all(u): if is_all(u):
num_nodes = self.number_of_nodes() num_nodes = self.number_of_nodes()
else: else:
u = utils.toindex(u) u = utils.toindex(u)
num_nodes = len(u) num_nodes = len(u)
if utils.is_dict_like(hu): for key, val in hu.items():
for key, val in hu.items(): nfeats = F.shape(val)[0]
assert F.shape(val)[0] == num_nodes if nfeats != num_nodes:
else: raise DGLError('Expect number of features to match number of nodes (len(u)).'
assert F.shape(hu)[0] == num_nodes ' Got %d and %d instead.' % (nfeats, num_nodes))
# set # set
if is_all(u): if is_all(u):
if utils.is_dict_like(hu): for key, val in hu.items():
for key, val in hu.items(): self._node_frame[key] = val
self._node_frame[key] = val
else:
self._node_frame[__REPR__] = hu
else: else:
if utils.is_dict_like(hu): self._node_frame.update_rows(u, hu, inplace=inplace)
self._node_frame.update_rows(u, hu, inplace=inplace)
else:
self._node_frame.update_rows(u, {__REPR__ : hu}, inplace=inplace)
def get_n_repr(self, u=ALL): def get_n_repr(self, u=ALL):
"""Get node(s) representation. """Get node(s) representation.
The returned feature tensor batches multiple node features on the first dimension.
Parameters Parameters
---------- ----------
u : node, container or tensor u : node, container or tensor
...@@ -576,23 +602,17 @@ class DGLGraph(object): ...@@ -576,23 +602,17 @@ class DGLGraph(object):
Returns Returns
------- -------
dict dict
Representation dict Representation dict from feature name to feature tensor.
""" """
if len(self.node_attr_schemes()) == 0: if len(self.node_attr_schemes()) == 0:
return dict() return dict()
if is_all(u): if is_all(u):
if len(self._node_frame) == 1 and __REPR__ in self._node_frame: return dict(self._node_frame)
return self._node_frame[__REPR__]
else:
return dict(self._node_frame)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
if len(self._node_frame) == 1 and __REPR__ in self._node_frame: return self._node_frame.select_rows(u)
return self._node_frame.select_rows(u)[__REPR__]
else:
return self._node_frame.select_rows(u)
def pop_n_repr(self, key=__REPR__): def pop_n_repr(self, key):
"""Get and remove the specified node repr. """Get and remove the specified node repr.
Parameters Parameters
...@@ -607,71 +627,83 @@ class DGLGraph(object): ...@@ -607,71 +627,83 @@ class DGLGraph(object):
""" """
return self._node_frame.pop(key) return self._node_frame.pop(key)
def set_e_repr(self, h_uv, u=ALL, v=ALL): def set_e_repr(self, he, u=ALL, v=ALL, inplace=False):
"""Set edge(s) representation. """Set edge(s) representation.
To set multiple edge representations at once, pass `u` and `v` with tensors or `he` is a dictionary from the feature name to feature tensor. Each tensor
supported containers of node ids. In this case, `h_uv` must be a tensor is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
of shape (B, D1, D2, ...), where B is the number of the edges and and (D1, D2, ...) be the shape of the edge representation tensor.
(D1, D2, ...) is the shape of the edge representation tensor.
Dictionary type is also supported for `h_uv`. In this case, each item All update will be done out-placely to work with autograd unless the inplace
will be treated as separate attribute of the edges. flag is true.
Parameters Parameters
---------- ----------
h_uv : tensor or dict of tensor he : tensor or dict of tensor
Edge representation. Edge representation.
u : node, container or tensor u : node, container or tensor
The source node(s). The source node(s).
v : node, container or tensor v : node, container or tensor
The destination node(s). The destination node(s).
inplace : bool
True if the update is done inplacely
""" """
# sanity check # sanity check
if not utils.is_dict_like(he):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(he))
u_is_all = is_all(u) u_is_all = is_all(u)
v_is_all = is_all(v) v_is_all = is_all(v)
assert u_is_all == v_is_all assert u_is_all == v_is_all
if u_is_all: if u_is_all:
self.set_e_repr_by_id(h_uv, eid=ALL) self.set_e_repr_by_id(he, eid=ALL, inplace=inplace)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
_, _, eid = self._graph.edge_ids(u, v) _, _, eid = self._graph.edge_ids(u, v)
self.set_e_repr_by_id(h_uv, eid=eid) self.set_e_repr_by_id(he, eid=eid, inplace=inplace)
def set_e_repr_by_id(self, h_uv, eid=ALL): def set_e_repr_by_id(self, he, eid=ALL, inplace=False):
"""Set edge(s) representation by edge id. """Set edge(s) representation by edge id.
`he` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
and (D1, D2, ...) be the shape of the edge representation tensor.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters Parameters
---------- ----------
h_uv : tensor or dict of tensor he : tensor or dict of tensor
Edge representation. Edge representation.
eid : int, container or tensor eid : int, container or tensor
The edge id(s). The edge id(s).
inplace : bool
True if the update is done inplacely
""" """
# sanity check # sanity check
if not utils.is_dict_like(he):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(he))
if is_all(eid): if is_all(eid):
num_edges = self.number_of_edges() num_edges = self.number_of_edges()
else: else:
eid = utils.toindex(eid) eid = utils.toindex(eid)
num_edges = len(eid) num_edges = len(eid)
if utils.is_dict_like(h_uv): for key, val in he.items():
for key, val in h_uv.items(): nfeats = F.shape(val)[0]
assert F.shape(val)[0] == num_edges if nfeats != num_edges:
else: raise DGLError('Expect number of features to match number of edges.'
assert F.shape(h_uv)[0] == num_edges ' Got %d and %d instead.' % (nfeats, num_edges))
# set # set
if is_all(eid): if is_all(eid):
if utils.is_dict_like(h_uv): # update column
for key, val in h_uv.items(): for key, val in he.items():
self._edge_frame[key] = val self._edge_frame[key] = val
else:
self._edge_frame[__REPR__] = h_uv
else: else:
if utils.is_dict_like(h_uv): # update row
self._edge_frame[eid] = h_uv self._edge_frame.update_rows(eid, he, inplace=inplace)
else:
self._edge_frame[eid] = {__REPR__ : h_uv}
def get_e_repr(self, u=ALL, v=ALL): def get_e_repr(self, u=ALL, v=ALL):
"""Get node(s) representation. """Get node(s) representation.
...@@ -701,7 +733,7 @@ class DGLGraph(object): ...@@ -701,7 +733,7 @@ class DGLGraph(object):
_, _, eid = self._graph.edge_ids(u, v) _, _, eid = self._graph.edge_ids(u, v)
return self.get_e_repr_by_id(eid=eid) return self.get_e_repr_by_id(eid=eid)
def pop_e_repr(self, key=__REPR__): def pop_e_repr(self, key):
"""Get and remove the specified edge repr. """Get and remove the specified edge repr.
Parameters Parameters
...@@ -727,21 +759,15 @@ class DGLGraph(object): ...@@ -727,21 +759,15 @@ class DGLGraph(object):
Returns Returns
------- -------
dict dict
Representation dict Representation dict from feature name to feature tensor.
""" """
if len(self.edge_attr_schemes()) == 0: if len(self.edge_attr_schemes()) == 0:
return dict() return dict()
if is_all(eid): if is_all(eid):
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: return dict(self._edge_frame)
return self._edge_frame[__REPR__]
else:
return dict(self._edge_frame)
else: else:
eid = utils.toindex(eid) eid = utils.toindex(eid)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: return self._edge_frame.select_rows(eid)
return self._edge_frame.select_rows(eid)[__REPR__]
else:
return self._edge_frame.select_rows(eid)
def register_edge_func(self, edge_func): def register_edge_func(self, edge_func):
"""Register global edge update function. """Register global edge update function.
...@@ -793,12 +819,14 @@ class DGLGraph(object): ...@@ -793,12 +819,14 @@ class DGLGraph(object):
""" """
self._apply_edge_func = apply_edge_func self._apply_edge_func = apply_edge_func
def apply_nodes(self, v, apply_node_func="default"): def apply_nodes(self, v=ALL, apply_node_func="default"):
"""Apply the function on node representations. """Apply the function on node representations.
Applying a None function will be ignored.
Parameters Parameters
---------- ----------
v : int, iterable of int, tensor v : int, iterable of int, tensor, optional
The node id(s). The node id(s).
apply_node_func : callable apply_node_func : callable
The apply node function. The apply node function.
...@@ -827,7 +855,7 @@ class DGLGraph(object): ...@@ -827,7 +855,7 @@ class DGLGraph(object):
# merge current node_repr with reduce output # merge current node_repr with reduce output
curr_repr = utils.HybridDict(reduce_accum, curr_repr) curr_repr = utils.HybridDict(reduce_accum, curr_repr)
new_repr = apply_node_func(curr_repr) new_repr = apply_node_func(curr_repr)
if reduce_accum is not None and utils.is_dict_like(new_repr) : if reduce_accum is not None:
# merge new node_repr with reduce output # merge new node_repr with reduce output
reduce_accum.update(new_repr) reduce_accum.update(new_repr)
new_repr = reduce_accum new_repr = reduce_accum
...@@ -836,6 +864,8 @@ class DGLGraph(object): ...@@ -836,6 +864,8 @@ class DGLGraph(object):
def apply_edges(self, u=None, v=None, apply_edge_func="default", eid=None): def apply_edges(self, u=None, v=None, apply_edge_func="default", eid=None):
"""Apply the function on edge representations. """Apply the function on edge representations.
Applying a None function will be ignored.
Parameters Parameters
---------- ----------
u : optional, int, iterable of int, tensor u : optional, int, iterable of int, tensor
...@@ -852,7 +882,6 @@ class DGLGraph(object): ...@@ -852,7 +882,6 @@ class DGLGraph(object):
if not apply_edge_func: if not apply_edge_func:
# Skip none function call. # Skip none function call.
return return
if eid is None: if eid is None:
new_repr = apply_edge_func(self.get_e_repr(u, v)) new_repr = apply_edge_func(self.get_e_repr(u, v))
self.set_e_repr(new_repr, u, v) self.set_e_repr(new_repr, u, v)
...@@ -873,9 +902,8 @@ class DGLGraph(object): ...@@ -873,9 +902,8 @@ class DGLGraph(object):
The message function can be any of the pre-defined functions The message function can be any of the pre-defined functions
('from_src'). ('from_src').
Currently, we require the message functions of consecutive send's and Currently, we require the message functions of consecutive send's to
send_on's to return the same keys. Otherwise the behavior will be return the same keys. Otherwise the behavior will be undefined.
undefined.
Parameters Parameters
---------- ----------
...@@ -922,7 +950,11 @@ class DGLGraph(object): ...@@ -922,7 +950,11 @@ class DGLGraph(object):
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid) edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs) msgs = message_func(src_reprs, edge_reprs)
self._msg_graph.add_edges(u, v)
self._msg_frame.append(msgs)
# TODO(minjie): Fix these codes in next PR.
"""
new_uv = [] new_uv = []
msg_target_rows = [] msg_target_rows = []
msg_update_rows = [] msg_update_rows = []
...@@ -945,8 +977,8 @@ class DGLGraph(object): ...@@ -945,8 +977,8 @@ class DGLGraph(object):
self._msg_frame.update_rows( self._msg_frame.update_rows(
msg_target_rows, msg_target_rows,
{k: F.gather_row(msgs[k], msg_update_rows.tousertensor()) {k: F.gather_row(msgs[k], msg_update_rows.tousertensor())
for k in msgs} for k in msgs},
) inplace=False)
if len(msg_append_rows) > 0: if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv) new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u) new_u = utils.toindex(new_u)
...@@ -954,14 +986,13 @@ class DGLGraph(object): ...@@ -954,14 +986,13 @@ class DGLGraph(object):
self._msg_graph.add_edges(new_u, new_v) self._msg_graph.add_edges(new_u, new_v)
self._msg_frame.append( self._msg_frame.append(
{k: F.gather_row(msgs[k], msg_append_rows.tousertensor()) {k: F.gather_row(msgs[k], msg_append_rows.tousertensor())
for k in msgs} for k in msgs})
)
else: else:
if len(msg_target_rows) > 0: if len(msg_target_rows) > 0:
self._msg_frame.update_rows( self._msg_frame.update_rows(
msg_target_rows, msg_target_rows,
{__MSG__: F.gather_row(msgs, msg_update_rows.tousertensor())} {__MSG__: F.gather_row(msgs, msg_update_rows.tousertensor())},
) inplace=False)
if len(msg_append_rows) > 0: if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv) new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u) new_u = utils.toindex(new_u)
...@@ -970,6 +1001,7 @@ class DGLGraph(object): ...@@ -970,6 +1001,7 @@ class DGLGraph(object):
self._msg_frame.append( self._msg_frame.append(
{__MSG__: F.gather_row(msgs, msg_append_rows.tousertensor())} {__MSG__: F.gather_row(msgs, msg_append_rows.tousertensor())}
) )
"""
def update_edge(self, u=ALL, v=ALL, edge_func="default", eid=None): def update_edge(self, u=ALL, v=ALL, edge_func="default", eid=None):
"""Update representation on edge u->v """Update representation on edge u->v
...@@ -1013,7 +1045,6 @@ class DGLGraph(object): ...@@ -1013,7 +1045,6 @@ class DGLGraph(object):
v = utils.toindex(v) v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v) u, v = utils.edge_broadcasting(u, v)
_, _, eid = self._graph.edge_ids(u, v) _, _, eid = self._graph.edge_ids(u, v)
# call the UDF # call the UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v) dst_reprs = self.get_n_repr(v)
...@@ -1100,25 +1131,19 @@ class DGLGraph(object): ...@@ -1100,25 +1131,19 @@ class DGLGraph(object):
msg_shape = F.shape(msg) msg_shape = F.shape(msg)
new_shape = (bkt_len, deg) + msg_shape[1:] new_shape = (bkt_len, deg) + msg_shape[1:]
return F.reshape(msg, new_shape) return F.reshape(msg, new_shape)
if len(in_msgs) == 1 and __MSG__ in in_msgs: reshaped_in_msgs = utils.LazyDict(
reshaped_in_msgs = _reshape_fn(in_msgs[__MSG__]) lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)
else:
reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)
reordered_v.append(v_bkt.tousertensor()) reordered_v.append(v_bkt.tousertensor())
new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs)) new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs))
# TODO: clear partial messages # TODO(minjie): clear partial messages
self.reset_messages() self.reset_messages()
# Pack all reducer results together # Pack all reducer results together
reordered_v = F.pack(reordered_v) reordered_v = F.pack(reordered_v)
if utils.is_dict_like(new_reprs[0]): keys = new_reprs[0].keys()
keys = new_reprs[0].keys() new_reprs = {key : F.pack([repr[key] for repr in new_reprs])
new_reprs = {key : F.pack([repr[key] for repr in new_reprs]) for key in keys}
for key in keys}
else:
new_reprs = {__REPR__ : F.pack(new_reprs)}
if v_is_all and not has_zero_degree: if v_is_all and not has_zero_degree:
# First do reorder and then replace the whole column. # First do reorder and then replace the whole column.
...@@ -1189,15 +1214,13 @@ class DGLGraph(object): ...@@ -1189,15 +1214,13 @@ class DGLGraph(object):
if executor: if executor:
new_reprs = executor.run() new_reprs = executor.run()
if not utils.is_dict_like(new_reprs):
new_reprs = {__REPR__: new_reprs}
unique_v = executor.recv_nodes unique_v = executor.recv_nodes
self._apply_nodes(unique_v, apply_node_func, reduce_accum=new_reprs) self._apply_nodes(unique_v, apply_node_func, reduce_accum=new_reprs)
elif eid is not None: elif eid is not None:
_, v, _ = self._graph.find_edges(eid) _, v, _ = self._graph.find_edges(eid)
unique_v = utils.toindex(F.unique(v.tousertensor())) unique_v = utils.toindex(F.unique(v.tousertensor()))
# TODO: replace with the new DegreeBucketingScheduler # TODO(quan): replace with the new DegreeBucketingScheduler
self.send(eid=eid, message_func=message_func) self.send(eid=eid, message_func=message_func)
self.recv(unique_v, reduce_func, apply_node_func) self.recv(unique_v, reduce_func, apply_node_func)
else: else:
...@@ -1213,10 +1236,7 @@ class DGLGraph(object): ...@@ -1213,10 +1236,7 @@ class DGLGraph(object):
edge_reprs = self.get_e_repr(u, v) edge_reprs = self.get_e_repr(u, v)
msgs = message_func(src_reprs, edge_reprs) msgs = message_func(src_reprs, edge_reprs)
msg_frame = FrameRef() msg_frame = FrameRef()
if utils.is_dict_like(msgs): msg_frame.append(msgs)
msg_frame.append(msgs)
else:
msg_frame.append({__MSG__: msgs})
# recv with degree bucketing # recv with degree bucketing
executor = scheduler.get_recv_executor(graph=self, executor = scheduler.get_recv_executor(graph=self,
...@@ -1305,8 +1325,6 @@ class DGLGraph(object): ...@@ -1305,8 +1325,6 @@ class DGLGraph(object):
"update_all", self, message_func=message_func, reduce_func=reduce_func) "update_all", self, message_func=message_func, reduce_func=reduce_func)
if executor: if executor:
new_reprs = executor.run() new_reprs = executor.run()
if not utils.is_dict_like(new_reprs):
new_reprs = {__REPR__: new_reprs}
self._apply_nodes(ALL, apply_node_func, reduce_accum=new_reprs) self._apply_nodes(ALL, apply_node_func, reduce_accum=new_reprs)
else: else:
self.send(ALL, ALL, message_func) self.send(ALL, ALL, message_func)
...@@ -1339,7 +1357,7 @@ class DGLGraph(object): ...@@ -1339,7 +1357,7 @@ class DGLGraph(object):
Arguments for pre-defined iterators. Arguments for pre-defined iterators.
""" """
if isinstance(traverser, str): if isinstance(traverser, str):
# TODO Call pre-defined routine to unroll the computation. # TODO(minjie): Call pre-defined routine to unroll the computation.
raise RuntimeError('Not implemented.') raise RuntimeError('Not implemented.')
else: else:
# NOTE: the iteration can return multiple edges at each step. # NOTE: the iteration can return multiple edges at each step.
......
...@@ -3,7 +3,7 @@ from __future__ import absolute_import ...@@ -3,7 +3,7 @@ from __future__ import absolute_import
import ctypes import ctypes
import numpy as np import numpy as np
import networkx as nx import networkx as nx
import scipy.sparse as sp import scipy
from ._ffi.base import c_array from ._ffi.base import c_array
from ._ffi.function import _init_api from ._ffi.function import _init_api
...@@ -600,30 +600,59 @@ class GraphIndex(object): ...@@ -600,30 +600,59 @@ class GraphIndex(object):
return GraphIndex(handle) return GraphIndex(handle)
class SubgraphIndex(GraphIndex): class SubgraphIndex(GraphIndex):
def __init__(self, handle, parent, induced_nodes, induced_edges): """Graph index for subgraph.
super().__init__(handle)
Parameters
----------
handle : GraphIndexHandle
The capi handle.
paranet : GraphIndex
The parent graph index.
induced_nodes : utils.Index
The parent node ids in this subgraph.
induced_edges : utils.Index
The parent edge ids in this subgraph.
"""
def __init__(self, handle, parent, induced_nodes, induced_edges):
super(SubgraphIndex, self).__init__(handle)
self._parent = parent self._parent = parent
self._induced_nodes = induced_nodes self._induced_nodes = induced_nodes
self._induced_edges = induced_edges self._induced_edges = induced_edges
def add_nodes(self, num): def add_nodes(self, num):
"""Add nodes. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.') raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v): def add_edge(self, u, v):
"""Add edges. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.') raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v): def add_edges(self, u, v):
"""Add edges. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.') raise RuntimeError('Readonly graph. Mutation is not allowed.')
@property
def induced_edges(self):
return self._induced_edges
@property @property
def induced_nodes(self): def induced_nodes(self):
"""Return parent node ids.
Returns
-------
utils.Index
The parent node ids.
"""
return self._induced_nodes return self._induced_nodes
@property
def induced_edges(self):
"""Return parent edge ids.
Returns
-------
utils.Index
The parent edge ids.
"""
return self._induced_edges
def disjoint_union(graphs): def disjoint_union(graphs):
"""Return a disjoint union of the input graphs. """Return a disjoint union of the input graphs.
...@@ -697,8 +726,25 @@ def create_graph_index(graph_data=None, multigraph=False): ...@@ -697,8 +726,25 @@ def create_graph_index(graph_data=None, multigraph=False):
handle = _CAPI_DGLGraphCreate(multigraph) handle = _CAPI_DGLGraphCreate(multigraph)
gi = GraphIndex(handle) gi = GraphIndex(handle)
if graph_data is not None:
if graph_data is None:
return gi
# scipy format
if isinstance(graph_data, scipy.sparse.spmatrix):
try:
gi.from_scipy_sparse_matrix(graph_data)
return gi
except:
raise Exception('Graph data is not a valid scipy sparse matrix.')
# networkx - any format
try:
gi.from_networkx(graph_data) gi.from_networkx(graph_data)
except:
raise Exception('Error while creating graph from input of type "%s".'
% type(graph_data))
return gi return gi
_init_api("dgl.graph_index") _init_api("dgl.graph_index")
...@@ -3,7 +3,7 @@ from __future__ import absolute_import ...@@ -3,7 +3,7 @@ from __future__ import absolute_import
import numpy as np import numpy as np
from .base import ALL, __MSG__, __REPR__ from .base import ALL, DGLError
from . import backend as F from . import backend as F
from .function import message as fmsg from .function import message as fmsg
from .function import reducer as fred from .function import reducer as fred
...@@ -111,7 +111,15 @@ def light_degree_bucketing_for_graph(graph): ...@@ -111,7 +111,15 @@ def light_degree_bucketing_for_graph(graph):
class Executor(object): class Executor(object):
"""Base class for executing graph computation."""
def run(self): def run(self):
"""Run this executor.
This should return the new node features.
TODO(minjie): extend this to support computation on edges.
"""
raise NotImplementedError raise NotImplementedError
class SPMVOperator(Executor): class SPMVOperator(Executor):
...@@ -126,10 +134,7 @@ class SPMVOperator(Executor): ...@@ -126,10 +134,7 @@ class SPMVOperator(Executor):
def run(self): def run(self):
# get src col # get src col
if self.src_field is None: srccol = self.node_repr[self.src_field]
srccol = self.node_repr
else:
srccol = self.node_repr[self.src_field]
ctx = F.get_context(srccol) ctx = F.get_context(srccol)
# build adjmat # build adjmat
...@@ -142,10 +147,7 @@ class SPMVOperator(Executor): ...@@ -142,10 +147,7 @@ class SPMVOperator(Executor):
dstcol = F.squeeze(dstcol) dstcol = F.squeeze(dstcol)
else: else:
dstcol = F.spmm(adjmat, srccol) dstcol = F.spmm(adjmat, srccol)
if self.dst_field is None: return {self.dst_field : dstcol}
return dstcol
else:
return {self.dst_field : dstcol}
# FIXME: refactorize in scheduler/executor redesign # FIXME: refactorize in scheduler/executor redesign
...@@ -180,20 +182,14 @@ class DegreeBucketingExecutor(Executor): ...@@ -180,20 +182,14 @@ class DegreeBucketingExecutor(Executor):
msg_shape = F.shape(msg) msg_shape = F.shape(msg)
new_shape = (len(vv), deg) + msg_shape[1:] new_shape = (len(vv), deg) + msg_shape[1:]
return F.reshape(msg, new_shape) return F.reshape(msg, new_shape)
if len(in_msgs) == 1 and __MSG__ in in_msgs: reshaped_in_msgs = utils.LazyDict(
reshaped_in_msgs = _reshape_fn(in_msgs[__MSG__]) lambda key: _reshape_fn(in_msgs[key]), self.msg_frame.schemes)
else:
reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self.msg_frame.schemes)
new_reprs.append(self.rfunc(dst_reprs, reshaped_in_msgs)) new_reprs.append(self.rfunc(dst_reprs, reshaped_in_msgs))
# Pack all reducer results together # Pack all reducer results together
if utils.is_dict_like(new_reprs[0]): keys = new_reprs[0].keys()
keys = new_reprs[0].keys() new_reprs = {key : F.pack([repr[key] for repr in new_reprs])
new_reprs = {key : F.pack([repr[key] for repr in new_reprs]) for key in keys}
for key in keys}
else:
new_reprs = {__REPR__ : F.pack(new_reprs)}
return new_reprs return new_reprs
...@@ -249,12 +245,6 @@ class UpdateAllExecutor(BasicExecutor): ...@@ -249,12 +245,6 @@ class UpdateAllExecutor(BasicExecutor):
self._graph_shape = None self._graph_shape = None
self._recv_nodes = None self._recv_nodes = None
@property
def graph_idx(self):
if self._graph_idx is None:
self._graph_idx = self.g._graph.adjacency_matrix()
return self._graph_idx
@property @property
def graph_shape(self): def graph_shape(self):
if self._graph_shape is None: if self._graph_shape is None:
...@@ -280,16 +270,13 @@ class UpdateAllExecutor(BasicExecutor): ...@@ -280,16 +270,13 @@ class UpdateAllExecutor(BasicExecutor):
def _adj_build_fn(self, edge_field, ctx, use_edge_feat): def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
if use_edge_feat: if use_edge_feat:
if edge_field is None: dat = self.edge_repr[edge_field]
dat = self.edge_repr
else:
dat = self.edge_repr[edge_field]
dat = F.squeeze(dat) dat = F.squeeze(dat)
# TODO(minjie): should not directly use _indices # TODO(minjie): should not directly use _indices
idx = self.graph_idx.get(ctx)._indices() idx = self.g.adjacency_matrix(ctx)._indices()
adjmat = F.sparse_tensor(idx, dat, self.graph_shape) adjmat = F.sparse_tensor(idx, dat, self.graph_shape)
else: else:
adjmat = self.graph_idx.get(ctx) adjmat = self.g.adjacency_matrix(ctx)
return adjmat return adjmat
...@@ -351,10 +338,7 @@ class SendRecvExecutor(BasicExecutor): ...@@ -351,10 +338,7 @@ class SendRecvExecutor(BasicExecutor):
def _adj_build_fn(self, edge_field, ctx, use_edge_feat): def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
if use_edge_feat: if use_edge_feat:
if edge_field is None: dat = self.edge_repr[edge_field]
dat = self.edge_repr
else:
dat = self.edge_repr[edge_field]
dat = F.squeeze(dat) dat = F.squeeze(dat)
else: else:
dat = F.ones((len(self.u), )) dat = F.ones((len(self.u), ))
...@@ -386,9 +370,8 @@ class BundledExecutor(BasicExecutor): ...@@ -386,9 +370,8 @@ class BundledExecutor(BasicExecutor):
func_pairs = [] func_pairs = []
for rfn in rfunc.fn_list: for rfn in rfunc.fn_list:
mfn = out2mfunc.get(rfn.msg_field, None) mfn = out2mfunc.get(rfn.msg_field, None)
# field check if mfn is None:
assert mfn is not None, \ raise DGLError('Cannot find message field "%s".' % rfn.msg_field)
"cannot find message func for reduce func in-field {}".format(rfn.msg_field)
func_pairs.append((mfn, rfn)) func_pairs.append((mfn, rfn))
return func_pairs return func_pairs
...@@ -409,7 +392,6 @@ class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor): ...@@ -409,7 +392,6 @@ class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor):
self._init_state() self._init_state()
BundledExecutor.__init__(self, graph, mfunc, rfunc) BundledExecutor.__init__(self, graph, mfunc, rfunc)
class BundledSendRecvExecutor(BundledExecutor, SendRecvExecutor): class BundledSendRecvExecutor(BundledExecutor, SendRecvExecutor):
def __init__(self, graph, src, dst, mfunc, rfunc): def __init__(self, graph, src, dst, mfunc, rfunc):
self._init_state(src, dst) self._init_state(src, dst)
......
/*!
* Copyright (c) 2018 by Contributors
* \file c_runtime_api.cc
* \brief DGL C API common implementations
*/
#include "c_api_common.h" #include "c_api_common.h"
using tvm::runtime::TVMArgs; using tvm::runtime::TVMArgs;
...@@ -29,5 +34,5 @@ PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) { ...@@ -29,5 +34,5 @@ PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) {
return PackedFunc(body); return PackedFunc(body);
} }
} // namespace dgl } // namespace dgl
// DGL C API common util functions /*!
* Copyright (c) 2018 by Contributors
* \file c_api_common.h
* \brief DGL C API common util functions
*/
#ifndef DGL_C_API_COMMON_H_ #ifndef DGL_C_API_COMMON_H_
#define DGL_C_API_COMMON_H_ #define DGL_C_API_COMMON_H_
...@@ -12,12 +16,20 @@ namespace dgl { ...@@ -12,12 +16,20 @@ namespace dgl {
// Graph handler type // Graph handler type
typedef void* GraphHandle; typedef void* GraphHandle;
// Convert the given DLTensor to a temporary DLManagedTensor that does not own memory. /*!
DLManagedTensor* CreateTmpDLManagedTensor(const tvm::runtime::TVMArgValue& arg); * \brief Convert the given DLTensor to DLManagedTensor.
*
* Return a temporary DLManagedTensor that does not own memory.
*/
DLManagedTensor* CreateTmpDLManagedTensor(
const tvm::runtime::TVMArgValue& arg);
// Convert a vector of NDArray to PackedFunc /*!
tvm::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<tvm::runtime::NDArray>& vec); * \brief Convert a vector of NDArray to PackedFunc.
*/
tvm::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(
const std::vector<tvm::runtime::NDArray>& vec);
} // namespace dgl } // namespace dgl
#endif // DGL_C_API_COMMON_H_ #endif // DGL_C_API_COMMON_H_
// Graph class implementation /*!
* Copyright (c) 2018 by Contributors
* \file graph/graph.cc
* \brief DGL graph index implementation
*/
#include <dgl/graph.h>
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <unordered_map>
#include <set> #include <set>
#include <functional> #include <functional>
#include <dgl/graph.h>
namespace dgl { namespace dgl {
namespace { namespace {
...@@ -193,9 +197,9 @@ Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const { ...@@ -193,9 +197,9 @@ Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
const auto& succ = adjlist_[src_id].succ; const auto& succ = adjlist_[src_id].succ;
for (size_t k = 0; k < succ.size(); ++k) { for (size_t k = 0; k < succ.size(); ++k) {
if (succ[k] == dst_id) { if (succ[k] == dst_id) {
src.push_back(src_id); src.push_back(src_id);
dst.push_back(dst_id); dst.push_back(dst_id);
eid.push_back(adjlist_[src_id].edge_id[k]); eid.push_back(adjlist_[src_id].edge_id[k]);
} }
} }
} }
...@@ -351,7 +355,7 @@ Graph::EdgeArray Graph::Edges(bool sorted) const { ...@@ -351,7 +355,7 @@ Graph::EdgeArray Graph::Edges(bool sorted) const {
return std::get<0>(t1) < std::get<0>(t2) return std::get<0>(t1) < std::get<0>(t2)
|| (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2)); || (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2));
}); });
// make return arrays // make return arrays
int64_t* src_ptr = static_cast<int64_t*>(src->data); int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data); int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
...@@ -461,7 +465,8 @@ Subgraph Graph::EdgeSubgraph(IdArray eids) const { ...@@ -461,7 +465,8 @@ Subgraph Graph::EdgeSubgraph(IdArray eids) const {
rst.graph.AddEdge(oldv2newv[src_id], oldv2newv[dst_id]); rst.graph.AddEdge(oldv2newv[src_id], oldv2newv[dst_id]);
} }
rst.induced_vertices = IdArray::Empty({static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx); rst.induced_vertices = IdArray::Empty(
{static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data)); std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data));
return rst; return rst;
......
/*!
* Copyright (c) 2018 by Contributors
* \file graph/graph.cc
* \brief DGL graph index APIs
*/
#include <dgl/graph.h> #include <dgl/graph.h>
#include <dgl/graph_op.h> #include <dgl/graph_op.h>
#include "../c_api_common.h" #include "../c_api_common.h"
......
// Graph operation implementation /*!
* Copyright (c) 2018 by Contributors
* \file graph/graph.cc
* \brief Graph operation implementation
*/
#include <dgl/graph_op.h> #include <dgl/graph_op.h>
#include <algorithm> #include <algorithm>
namespace dgl { namespace dgl {
Graph GraphOp::LineGraph(const Graph* g, bool backtracking){ Graph GraphOp::LineGraph(const Graph* g, bool backtracking) {
typedef std::pair<dgl_id_t, dgl_id_t> entry; typedef std::pair<dgl_id_t, dgl_id_t> entry;
typedef std::map<dgl_id_t, std::vector<entry>> csm; // Compressed Sparse Matrix typedef std::map<dgl_id_t, std::vector<entry>> csm; // Compressed Sparse Matrix
csm adj; csm adj;
std::vector<entry> vec; std::vector<entry> vec;
...@@ -67,7 +71,7 @@ std::vector<Graph> GraphOp::DisjointPartitionByNum(const Graph* graph, int64_t n ...@@ -67,7 +71,7 @@ std::vector<Graph> GraphOp::DisjointPartitionByNum(const Graph* graph, int64_t n
std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num); std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num);
return DisjointPartitionBySizes(graph, sizes); return DisjointPartitionBySizes(graph, sizes);
} }
std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray sizes) { std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray sizes) {
const int64_t len = sizes->shape[0]; const int64_t len = sizes->shape[0];
const int64_t* sizes_data = static_cast<int64_t*>(sizes->data); const int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
...@@ -117,32 +121,6 @@ std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray ...@@ -117,32 +121,6 @@ std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray
node_offset += sizes_data[i]; node_offset += sizes_data[i];
edge_offset += num_edges; edge_offset += num_edges;
} }
/*for (int64_t i = 0; i < len; ++i) {
rst[i].AddVertices(sizes_data[i]);
}
for (dgl_id_t eid = 0; eid < graph->num_edges_; ++eid) {
const dgl_id_t src = graph->all_edges_src_[eid];
const dgl_id_t dst = graph->all_edges_dst_[eid];
size_t src_select = 0, dst_select = 0;
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > src) {
src_select = i;
break;
}
}
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > dst) {
dst_select = i;
break;
}
}
if (src_select != dst_select) {
// the edge is ignored if across two partitions
continue;
}
const int64_t offset = cumsum[src_select - 1];
rst[src_select - 1].AddEdge(src - offset, dst - offset);
}*/
return rst; return rst;
} }
......
# C API and runtime
Borrowed and adapted from TVM project.
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* \file file_util.h * \file file_util.h
* \brief Minimum file manipulation util for runtime. * \brief Minimum file manipulation util for runtime.
*/ */
#ifndef TVM_RUNTIME_FILE_UTIL_H_ #ifndef DGL_RUNTIME_FILE_UTIL_H_
#define TVM_RUNTIME_FILE_UTIL_H_ #define DGL_RUNTIME_FILE_UTIL_H_
#include <string> #include <string>
#include "meta_data.h" #include "meta_data.h"
...@@ -73,4 +73,4 @@ void LoadMetaDataFromFile( ...@@ -73,4 +73,4 @@ void LoadMetaDataFromFile(
std::unordered_map<std::string, FunctionInfo>* fmap); std::unordered_map<std::string, FunctionInfo>* fmap);
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_FILE_UTIL_H_ #endif // DGL_RUNTIME_FILE_UTIL_H_
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* \file meta_data.h * \file meta_data.h
* \brief Meta data related utilities * \brief Meta data related utilities
*/ */
#ifndef TVM_RUNTIME_META_DATA_H_ #ifndef DGL_RUNTIME_META_DATA_H_
#define TVM_RUNTIME_META_DATA_H_ #define DGL_RUNTIME_META_DATA_H_
#include <dmlc/json.h> #include <dmlc/json.h>
#include <dmlc/io.h> #include <dmlc/io.h>
...@@ -33,4 +33,4 @@ struct FunctionInfo { ...@@ -33,4 +33,4 @@ struct FunctionInfo {
namespace dmlc { namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::FunctionInfo, true); DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::FunctionInfo, true);
} // namespace dmlc } // namespace dmlc
#endif // TVM_RUNTIME_META_DATA_H_ #endif // DGL_RUNTIME_META_DATA_H_
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* \file module_util.h * \file module_util.h
* \brief Helper utilities for module building * \brief Helper utilities for module building
*/ */
#ifndef TVM_RUNTIME_MODULE_UTIL_H_ #ifndef DGL_RUNTIME_MODULE_UTIL_H_
#define TVM_RUNTIME_MODULE_UTIL_H_ #define DGL_RUNTIME_MODULE_UTIL_H_
#include <dgl/runtime/module.h> #include <dgl/runtime/module.h>
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
...@@ -58,4 +58,4 @@ void InitContextFunctions(FLookup flookup) { ...@@ -58,4 +58,4 @@ void InitContextFunctions(FLookup flookup) {
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_MODULE_UTIL_H_ #endif // DGL_RUNTIME_MODULE_UTIL_H_
...@@ -10,8 +10,8 @@ ...@@ -10,8 +10,8 @@
* union_32bit args[N], int num_args); * union_32bit args[N], int num_args);
* - Pack buffer by address, pack rest parameter into 32bit union buffer. * - Pack buffer by address, pack rest parameter into 32bit union buffer.
*/ */
#ifndef TVM_RUNTIME_PACK_ARGS_H_ #ifndef DGL_RUNTIME_PACK_ARGS_H_
#define TVM_RUNTIME_PACK_ARGS_H_ #define DGL_RUNTIME_PACK_ARGS_H_
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include <vector> #include <vector>
...@@ -307,4 +307,4 @@ inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types) ...@@ -307,4 +307,4 @@ inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types)
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_PACK_ARGS_H_ #endif // DGL_RUNTIME_PACK_ARGS_H_
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* \file runtime_base.h * \file runtime_base.h
* \brief Base of all C APIs * \brief Base of all C APIs
*/ */
#ifndef TVM_RUNTIME_RUNTIME_BASE_H_ #ifndef DGL_RUNTIME_RUNTIME_BASE_H_
#define TVM_RUNTIME_RUNTIME_BASE_H_ #define DGL_RUNTIME_RUNTIME_BASE_H_
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include <stdexcept> #include <stdexcept>
...@@ -31,4 +31,4 @@ inline int TVMAPIHandleException(const std::runtime_error &e) { ...@@ -31,4 +31,4 @@ inline int TVMAPIHandleException(const std::runtime_error &e) {
return -1; return -1;
} }
#endif // TVM_RUNTIME_RUNTIME_BASE_H_ #endif // DGL_RUNTIME_RUNTIME_BASE_H_
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* \file thread_storage_scope.h * \file thread_storage_scope.h
* \brief Extract thread axis configuration from TVMArgs. * \brief Extract thread axis configuration from TVMArgs.
*/ */
#ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #ifndef DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_
#define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #define DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <string> #include <string>
...@@ -204,4 +204,4 @@ struct hash<::tvm::runtime::StorageScope> { ...@@ -204,4 +204,4 @@ struct hash<::tvm::runtime::StorageScope> {
} }
}; };
} // namespace std } // namespace std
#endif // TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #endif // DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment