Unverified Commit 4bd4d6e3 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Lint] Pylint (#330)

* fix lint for graph_index.py

* pylint for base.py

* pylint for batched_graph.py

* pylint for frame.py; simplify and fix bugs in frame when index is slice type

* pylint for graph.py

* pylint for immutable_graph_index.py

* pylint for init.py

* pylint for rest files in root package

* pylint for _ffi package

* pylint for function package

* pylint for runtime package

* pylint for runtime.ir package

* add pylint to ci

* fix mx tests

* fix lint errors

* fix ci

* fix as requested

* fix lint
parent 1e50cd2e
# One has to manually import dgl.data; fixes #125 """DGL root package."""
#from . import data
from . import function from . import function
from . import nn from . import nn
from . import contrib from . import contrib
...@@ -12,7 +11,6 @@ from .base import ALL ...@@ -12,7 +11,6 @@ from .base import ALL
from .backend import load_backend from .backend import load_backend
from .batched_graph import * from .batched_graph import *
from .graph import DGLGraph from .graph import DGLGraph
from .subgraph import DGLSubGraph
from .traversal import * from .traversal import *
from .propagate import * from .propagate import *
from .udf import NodeBatch, EdgeBatch from .udf import NodeBatch, EdgeBatch
"""Namespace for internal apis."""
...@@ -26,8 +26,7 @@ else: ...@@ -26,8 +26,7 @@ else:
class DGLError(Exception): class DGLError(Exception):
"""Error thrown by DGL function""" """Error thrown by DGL function"""
pass pass # pylint: disable=unnecessary-pass
def _load_lib(): def _load_lib():
"""Load libary by searching possible path.""" """Load libary by searching possible path."""
......
...@@ -51,7 +51,7 @@ class Function(_FunctionBase): ...@@ -51,7 +51,7 @@ class Function(_FunctionBase):
dgl.register_func: How to register global function. dgl.register_func: How to register global function.
dgl.get_global_func: How to get global function. dgl.get_global_func: How to get global function.
""" """
pass pass # pylint: disable=unnecessary-pass
class ModuleBase(object): class ModuleBase(object):
......
"""Common runtime ctypes.""" """Common runtime ctypes."""
# pylint: disable=invalid-name # pylint: disable=invalid-name, super-init-not-called
from __future__ import absolute_import from __future__ import absolute_import
import ctypes import ctypes
......
...@@ -349,6 +349,31 @@ class ImmutableGraphIndex(object): ...@@ -349,6 +349,31 @@ class ImmutableGraphIndex(object):
self.__init__(mx.nd.sparse.csr_matrix((edge_ids, (dst, src)), shape=(size, size)).astype(np.int64), self.__init__(mx.nd.sparse.csr_matrix((edge_ids, (dst, src)), shape=(size, size)).astype(np.int64),
mx.nd.sparse.csr_matrix((edge_ids, (src, dst)), shape=(size, size)).astype(np.int64)) mx.nd.sparse.csr_matrix((edge_ids, (src, dst)), shape=(size, size)).astype(np.int64))
def from_edge_list(self, elist):
"""Convert from an edge list.
Paramters
---------
elist : list
List of (u, v) edge tuple.
"""
src, dst = zip(*elist)
src = np.array(src)
dst = np.array(dst)
num_nodes = max(src.max(), dst.max()) + 1
min_nodes = min(src.min(), dst.min())
if min_nodes != 0:
raise DGLError('Invalid edge list. Nodes must start from 0.')
edge_ids = mx.nd.arange(0, len(src), step=1, repeat=1, dtype=np.int32)
src = mx.nd.array(src, dtype=np.int64)
dst = mx.nd.array(dst, dtype=np.int64)
# TODO we can't generate a csr_matrix with np.int64 directly.
in_csr = mx.nd.sparse.csr_matrix((edge_ids, (dst, src)),
shape=(num_nodes, num_nodes)).astype(np.int64)
out_csr = mx.nd.sparse.csr_matrix((edge_ids, (src, dst)),
shape=(num_nodes, num_nodes)).astype(np.int64)
self.__init__(in_csr, out_csr)
def create_immutable_graph_index(in_csr=None, out_csr=None): def create_immutable_graph_index(in_csr=None, out_csr=None):
""" Create an empty backend-specific immutable graph index. """ Create an empty backend-specific immutable graph index.
......
...@@ -3,12 +3,15 @@ from __future__ import absolute_import ...@@ -3,12 +3,15 @@ from __future__ import absolute_import
import warnings import warnings
from ._ffi.base import DGLError from ._ffi.base import DGLError # pylint: disable=unused-import
# A special argument for selecting all nodes/edges. # A special symbol for selecting all nodes or edges.
ALL = "__ALL__" ALL = "__ALL__"
def is_all(arg): def is_all(arg):
"""Return true if the argument is a special symbol for all nodes or edges."""
return isinstance(arg, str) and arg == ALL return isinstance(arg, str) and arg == ALL
dgl_warning = warnings.warn def dgl_warning(msg):
"""Print out warning messages."""
warnings.warn(msg)
"""Classes and functions for batching multiple graphs together.""" """Classes and functions for batching multiple graphs together."""
from __future__ import absolute_import from __future__ import absolute_import
from collections.abc import Iterable
import numpy as np import numpy as np
from collections import Iterable
from .base import ALL, is_all from .base import ALL, is_all, DGLError
from .frame import FrameRef, Frame from .frame import FrameRef, Frame
from .graph import DGLGraph from .graph import DGLGraph
from . import graph_index as gi from . import graph_index as gi
...@@ -152,8 +152,7 @@ class BatchedDGLGraph(DGLGraph): ...@@ -152,8 +152,7 @@ class BatchedDGLGraph(DGLGraph):
elif is_all(attrs): elif is_all(attrs):
attrs = set() attrs = set()
# Check if at least a graph has mode items and associated features. # Check if at least a graph has mode items and associated features.
for i in range(len(graph_list)): for i, g in enumerate(graph_list):
g = graph_list[i]
g_num_items, g_attrs = _get_num_item_and_attr_types(g, mode) g_num_items, g_attrs = _get_num_item_and_attr_types(g, mode)
if g_num_items > 0 and len(g_attrs) > 0: if g_num_items > 0 and len(g_attrs) > 0:
attrs = g_attrs attrs = g_attrs
...@@ -161,13 +160,13 @@ class BatchedDGLGraph(DGLGraph): ...@@ -161,13 +160,13 @@ class BatchedDGLGraph(DGLGraph):
break break
# Check if all the graphs with mode items have the same associated features. # Check if all the graphs with mode items have the same associated features.
if len(attrs) > 0: if len(attrs) > 0:
for i in range(len(graph_list)): for i, g in enumerate(graph_list):
g = graph_list[i] g = graph_list[i]
g_num_items, g_attrs = _get_num_item_and_attr_types(g, mode) g_num_items, g_attrs = _get_num_item_and_attr_types(g, mode)
if g_attrs != attrs and g_num_items > 0: if g_attrs != attrs and g_num_items > 0:
raise ValueError('Expect graph {} and {} to have the same {} ' raise ValueError('Expect graph {0} and {1} to have the same {2} '
'attributes when {}_attrs=ALL, got {} and ' 'attributes when {2}_attrs=ALL, got {3} and {4}.'
'{}'.format(ref_g_index, i, mode, mode, attrs, g_attrs)) .format(ref_g_index, i, mode, attrs, g_attrs))
return attrs return attrs
elif isinstance(attrs, str): elif isinstance(attrs, str):
return [attrs] return [attrs]
...@@ -200,8 +199,7 @@ class BatchedDGLGraph(DGLGraph): ...@@ -200,8 +199,7 @@ class BatchedDGLGraph(DGLGraph):
for key in edge_attrs} for key in edge_attrs}
batched_edge_frame = FrameRef(Frame(cols)) batched_edge_frame = FrameRef(Frame(cols))
super(BatchedDGLGraph, self).__init__( super(BatchedDGLGraph, self).__init__(graph_data=batched_index,
graph_data=batched_index,
node_frame=batched_node_frame, node_frame=batched_node_frame,
edge_frame=batched_edge_frame) edge_frame=batched_edge_frame)
...@@ -209,16 +207,16 @@ class BatchedDGLGraph(DGLGraph): ...@@ -209,16 +207,16 @@ class BatchedDGLGraph(DGLGraph):
self._batch_size = 0 self._batch_size = 0
self._batch_num_nodes = [] self._batch_num_nodes = []
self._batch_num_edges = [] self._batch_num_edges = []
for gr in graph_list: for grh in graph_list:
if isinstance(gr, BatchedDGLGraph): if isinstance(grh, BatchedDGLGraph):
# handle the input is again a batched graph. # handle the input is again a batched graph.
self._batch_size += gr._batch_size self._batch_size += grh._batch_size
self._batch_num_nodes += gr._batch_num_nodes self._batch_num_nodes += grh._batch_num_nodes
self._batch_num_edges += gr._batch_num_edges self._batch_num_edges += grh._batch_num_edges
else: else:
self._batch_size += 1 self._batch_size += 1
self._batch_num_nodes.append(gr.number_of_nodes()) self._batch_num_nodes.append(grh.number_of_nodes())
self._batch_num_edges.append(gr.number_of_edges()) self._batch_num_edges.append(grh.number_of_edges())
@property @property
def batch_size(self): def batch_size(self):
...@@ -251,33 +249,33 @@ class BatchedDGLGraph(DGLGraph): ...@@ -251,33 +249,33 @@ class BatchedDGLGraph(DGLGraph):
return self._batch_num_edges return self._batch_num_edges
# override APIs # override APIs
def add_nodes(self, num, reprs=None): def add_nodes(self, num, data=None):
"""Add nodes. Disabled because BatchedDGLGraph is read-only.""" """Add nodes. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.') raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v, reprs=None): def add_edge(self, u, v, data=None):
"""Add one edge. Disabled because BatchedDGLGraph is read-only.""" """Add one edge. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.') raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v, reprs=None): def add_edges(self, u, v, data=None):
"""Add many edges. Disabled because BatchedDGLGraph is read-only.""" """Add many edges. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.') raise DGLError('Readonly graph. Mutation is not allowed.')
# new APIs # new APIs
def __getitem__(self, idx): def __getitem__(self, idx):
"""Slice the batch and return the batch of graphs specified by the idx.""" """Slice the batch and return the batch of graphs specified by the idx."""
# TODO # TODO
pass raise NotImplementedError
def __setitem__(self, idx, val): def __setitem__(self, idx, val):
"""Set the value of the slice. The graph size cannot be changed.""" """Set the value of the slice. The graph size cannot be changed."""
# TODO # TODO
pass raise NotImplementedError
def split(graph_batch, num_or_size_splits): def split(graph_batch, num_or_size_splits): # pylint: disable=unused-argument
"""Split the batch.""" """Split the batch."""
# TODO(minjie): could follow torch.split syntax # TODO(minjie): could follow torch.split syntax
pass raise NotImplementedError
def unbatch(graph): def unbatch(graph):
"""Return the list of graphs in this batch. """Return the list of graphs in this batch.
...@@ -308,18 +306,18 @@ def unbatch(graph): ...@@ -308,18 +306,18 @@ def unbatch(graph):
""" """
assert isinstance(graph, BatchedDGLGraph) assert isinstance(graph, BatchedDGLGraph)
bsize = graph.batch_size bsize = graph.batch_size
bn = graph.batch_num_nodes bnn = graph.batch_num_nodes
be = graph.batch_num_edges bne = graph.batch_num_edges
pttns = gi.disjoint_partition(graph._graph, utils.toindex(bn)) pttns = gi.disjoint_partition(graph._graph, utils.toindex(bnn))
# split the frames # split the frames
node_frames = [FrameRef(Frame(num_rows=n)) for n in bn] node_frames = [FrameRef(Frame(num_rows=n)) for n in bnn]
edge_frames = [FrameRef(Frame(num_rows=n)) for n in be] edge_frames = [FrameRef(Frame(num_rows=n)) for n in bne]
for attr, col in graph._node_frame.items(): for attr, col in graph._node_frame.items():
col_splits = F.split(col, bn, dim=0) col_splits = F.split(col, bnn, dim=0)
for i in range(bsize): for i in range(bsize):
node_frames[i][attr] = col_splits[i] node_frames[i][attr] = col_splits[i]
for attr, col in graph._edge_frame.items(): for attr, col in graph._edge_frame.items():
col_splits = F.split(col, be, dim=0) col_splits = F.split(col, bne, dim=0)
for i in range(bsize): for i in range(bsize):
edge_frames[i][attr] = col_splits[i] edge_frames[i][attr] = col_splits[i]
return [DGLGraph(graph_data=pttns[i], return [DGLGraph(graph_data=pttns[i],
...@@ -355,47 +353,63 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL): ...@@ -355,47 +353,63 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
return BatchedDGLGraph(graph_list, node_attrs, edge_attrs) return BatchedDGLGraph(graph_list, node_attrs, edge_attrs)
_readout_on_attrs = { READOUT_ON_ATTRS = {
'nodes': ('ndata', 'batch_num_nodes', 'number_of_nodes'), 'nodes': ('ndata', 'batch_num_nodes', 'number_of_nodes'),
'edges': ('edata', 'batch_num_edges', 'number_of_edges'), 'edges': ('edata', 'batch_num_edges', 'number_of_edges'),
} }
def _sum_on(graph, typestr, feat, weight):
"""Internal function to sum node or edge features.
Parameters
----------
graph : DGLGraph
The graph.
typestr : str
'nodes' or 'edges'
feat : str
The feature field name.
weight : str
The weight field name.
def _sum_on(graph, on, input, weight): Returns
data_attr, batch_num_objs_attr, num_objs_attr = _readout_on_attrs[on] -------
Tensor
The (weighted) summed node or edge features.
"""
data_attr, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
data = getattr(graph, data_attr) data = getattr(graph, data_attr)
input = data[input] feat = data[feat]
if weight is not None: if weight is not None:
weight = data[weight] weight = data[weight]
weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(input) - 1)) weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(feat) - 1))
input = weight * input feat = weight * feat
if isinstance(graph, BatchedDGLGraph): if isinstance(graph, BatchedDGLGraph):
n_graphs = graph.batch_size n_graphs = graph.batch_size
batch_num_objs = getattr(graph, batch_num_objs_attr) batch_num_objs = getattr(graph, batch_num_objs_attr)
n_objs = getattr(graph, num_objs_attr)()
seg_id = F.zerocopy_from_numpy( seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
np.arange(n_graphs, dtype='int64').repeat(batch_num_objs)) seg_id = F.copy_to(seg_id, F.context(feat))
seg_id = F.copy_to(seg_id, F.context(input)) y = F.unsorted_1d_segment_sum(feat, seg_id, n_graphs, 0)
y = F.unsorted_1d_segment_sum(input, seg_id, n_graphs, 0)
return y return y
else: else:
return F.sum(input, 0) return F.sum(feat, 0)
def sum_nodes(graph, input, weight=None): def sum_nodes(graph, feat, weight=None):
"""Sums all the values of node field :attr:`input` in :attr:`graph`, optionally """Sums all the values of node field :attr:`feat` in :attr:`graph`, optionally
multiplies the field by a scalar node field :attr:`weight`. multiplies the field by a scalar node field :attr:`weight`.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph.
The graph The graph.
input : str feat : str
The input field The feature field.
weight : str, optional weight : str, optional
The weight field. If None, no weighting will be performed, The weight field. If None, no weighting will be performed,
otherwise, weight each node feature with field :attr:`input`. otherwise, weight each node feature with field :attr:`feat`.
for summation. The weight feature associated in the :attr:`graph` for summation. The weight feature associated in the :attr:`graph`
should be a tensor of shape ``[graph.number_of_nodes(), 1]``. should be a tensor of shape ``[graph.number_of_nodes(), 1]``.
...@@ -450,21 +464,21 @@ def sum_nodes(graph, input, weight=None): ...@@ -450,21 +464,21 @@ def sum_nodes(graph, input, weight=None):
sum_edges sum_edges
mean_edges mean_edges
""" """
return _sum_on(graph, 'nodes', input, weight) return _sum_on(graph, 'nodes', feat, weight)
def sum_edges(graph, input, weight=None): def sum_edges(graph, feat, weight=None):
"""Sums all the values of edge field :attr:`input` in :attr:`graph`, """Sums all the values of edge field :attr:`feat` in :attr:`graph`,
optionally multiplies the field by a scalar edge field :attr:`weight`. optionally multiplies the field by a scalar edge field :attr:`weight`.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph The graph.
input : str feat : str
The input field The feature field.
weight : str, optional weight : str, optional
The weight field. If None, no weighting will be performed, The weight field. If None, no weighting will be performed,
otherwise, weight each edge feature with field :attr:`input`. otherwise, weight each edge feature with field :attr:`feat`.
for summation. The weight feature associated in the :attr:`graph` for summation. The weight feature associated in the :attr:`graph`
should be a tensor of shape ``[graph.number_of_edges(), 1]``. should be a tensor of shape ``[graph.number_of_edges(), 1]``.
...@@ -521,54 +535,70 @@ def sum_edges(graph, input, weight=None): ...@@ -521,54 +535,70 @@ def sum_edges(graph, input, weight=None):
mean_nodes mean_nodes
mean_edges mean_edges
""" """
return _sum_on(graph, 'edges', input, weight) return _sum_on(graph, 'edges', feat, weight)
def _mean_on(graph, typestr, feat, weight):
"""Internal function to sum node or edge features.
Parameters
----------
graph : DGLGraph
The graph.
typestr : str
'nodes' or 'edges'
feat : str
The feature field name.
weight : str
The weight field name.
def _mean_on(graph, on, input, weight): Returns
data_attr, batch_num_objs_attr, num_objs_attr = _readout_on_attrs[on] -------
Tensor
The (weighted) summed node or edge features.
"""
data_attr, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
data = getattr(graph, data_attr) data = getattr(graph, data_attr)
input = data[input] feat = data[feat]
if weight is not None: if weight is not None:
weight = data[weight] weight = data[weight]
weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(input) - 1)) weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(feat) - 1))
input = weight * input feat = weight * feat
if isinstance(graph, BatchedDGLGraph): if isinstance(graph, BatchedDGLGraph):
n_graphs = graph.batch_size n_graphs = graph.batch_size
batch_num_objs = getattr(graph, batch_num_objs_attr) batch_num_objs = getattr(graph, batch_num_objs_attr)
n_objs = getattr(graph, num_objs_attr)()
seg_id = F.zerocopy_from_numpy( seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
np.arange(n_graphs, dtype='int64').repeat(batch_num_objs)) seg_id = F.copy_to(seg_id, F.context(feat))
seg_id = F.copy_to(seg_id, F.context(input))
if weight is not None: if weight is not None:
w = F.unsorted_1d_segment_sum(weight, seg_id, n_graphs, 0) w = F.unsorted_1d_segment_sum(weight, seg_id, n_graphs, 0)
y = F.unsorted_1d_segment_sum(input, seg_id, n_graphs, 0) y = F.unsorted_1d_segment_sum(feat, seg_id, n_graphs, 0)
y = y / w y = y / w
else: else:
y = F.unsorted_1d_segment_mean(input, seg_id, n_graphs, 0) y = F.unsorted_1d_segment_mean(feat, seg_id, n_graphs, 0)
return y return y
else: else:
if weight is None: if weight is None:
return F.mean(input, 0) return F.mean(feat, 0)
else: else:
y = F.sum(input, 0) / F.sum(weight, 0) y = F.sum(feat, 0) / F.sum(weight, 0)
return y return y
def mean_nodes(graph, input, weight=None): def mean_nodes(graph, feat, weight=None):
"""Averages all the values of node field :attr:`input` in :attr:`graph`, """Averages all the values of node field :attr:`feat` in :attr:`graph`,
optionally multiplies the field by a scalar node field :attr:`weight`. optionally multiplies the field by a scalar node field :attr:`weight`.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph or BatchedDGLGraph
The graph The graph.
input : str feat : str
The input field The feature field.
weight : str, optional weight : str, optional
The weight field. If None, no weighting will be performed, The weight field. If None, no weighting will be performed,
otherwise, weight each node feature with field :attr:`input`. otherwise, weight each node feature with field :attr:`feat`.
for calculating mean. The weight feature associated in the :attr:`graph` for calculating mean. The weight feature associated in the :attr:`graph`
should be a tensor of shape ``[graph.number_of_nodes(), 1]``. should be a tensor of shape ``[graph.number_of_nodes(), 1]``.
...@@ -623,21 +653,21 @@ def mean_nodes(graph, input, weight=None): ...@@ -623,21 +653,21 @@ def mean_nodes(graph, input, weight=None):
sum_edges sum_edges
mean_edges mean_edges
""" """
return _mean_on(graph, 'nodes', input, weight) return _mean_on(graph, 'nodes', feat, weight)
def mean_edges(graph, input, weight=None): def mean_edges(graph, feat, weight=None):
"""Averages all the values of edge field :attr:`input` in :attr:`graph`, """Averages all the values of edge field :attr:`feat` in :attr:`graph`,
optionally multiplies the field by a scalar edge field :attr:`weight`. optionally multiplies the field by a scalar edge field :attr:`weight`.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph graph : DGLGraph
The graph The graph.
input : str feat : str
The input field The feature field.
weight : optional, str weight : optional, str
The weight field. If None, no weighting will be performed, The weight field. If None, no weighting will be performed,
otherwise, weight each edge feature with field :attr:`input`. otherwise, weight each edge feature with field :attr:`feat`.
for calculating mean. The weight feature associated in the :attr:`graph` for calculating mean. The weight feature associated in the :attr:`graph`
should be a tensor of shape ``[graph.number_of_edges(), 1]``. should be a tensor of shape ``[graph.number_of_edges(), 1]``.
...@@ -694,4 +724,4 @@ def mean_edges(graph, input, weight=None): ...@@ -694,4 +724,4 @@ def mean_edges(graph, input, weight=None):
mean_nodes mean_nodes
sum_edges sum_edges
""" """
return _mean_on(graph, 'edges', input, weight) return _mean_on(graph, 'edges', feat, weight)
"""Columnar storage for DGLGraph.""" """Columnar storage for DGLGraph."""
from __future__ import absolute_import from __future__ import absolute_import
from collections import MutableMapping, namedtuple from collections import namedtuple
from collections.abc import MutableMapping
import sys import sys
import numpy as np import numpy as np
...@@ -39,6 +40,18 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])): ...@@ -39,6 +40,18 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
return cls(shape, dtype) return cls(shape, dtype)
def infer_scheme(tensor): def infer_scheme(tensor):
"""Infer column scheme from the given tensor data.
Paramters
---------
tensor : Tensor
The tensor data.
Returns
-------
Scheme
The column scheme.
"""
return Scheme(tuple(F.shape(tensor)[1:]), F.dtype(tensor)) return Scheme(tuple(F.shape(tensor)[1:]), F.dtype(tensor))
class Column(object): class Column(object):
...@@ -64,6 +77,7 @@ class Column(object): ...@@ -64,6 +77,7 @@ class Column(object):
@property @property
def shape(self): def shape(self):
"""Return the scheme shape (feature shape) of this column."""
return self.scheme.shape return self.scheme.shape
def __getitem__(self, idx): def __getitem__(self, idx):
...@@ -71,7 +85,7 @@ class Column(object): ...@@ -71,7 +85,7 @@ class Column(object):
Parameters Parameters
---------- ----------
idx : slice or utils.Index idx : utils.Index
The index. The index.
Returns Returns
...@@ -79,8 +93,9 @@ class Column(object): ...@@ -79,8 +93,9 @@ class Column(object):
Tensor Tensor
The feature data The feature data
""" """
if isinstance(idx, slice): if idx.slice_data() is not None:
return self.data[idx] slc = idx.slice_data()
return F.narrow_row(self.data, slc.start, slc.stop)
else: else:
user_idx = idx.tousertensor(F.context(self.data)) user_idx = idx.tousertensor(F.context(self.data))
return F.gather_row(self.data, user_idx) return F.gather_row(self.data, user_idx)
...@@ -105,7 +120,7 @@ class Column(object): ...@@ -105,7 +120,7 @@ class Column(object):
Parameters Parameters
---------- ----------
idx : utils.Index or slice idx : utils.Index
The index. The index.
feats : Tensor feats : Tensor
The new features. The new features.
...@@ -117,19 +132,18 @@ class Column(object): ...@@ -117,19 +132,18 @@ class Column(object):
raise DGLError("Cannot update column of scheme %s using feature of scheme %s." raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme)) % (feat_scheme, self.scheme))
if isinstance(idx, utils.Index):
idx = idx.tousertensor(F.context(self.data))
if inplace: if inplace:
idx = idx.tousertensor(F.context(self.data))
F.scatter_row_inplace(self.data, idx, feats) F.scatter_row_inplace(self.data, idx, feats)
else: elif idx.slice_data() is not None:
if isinstance(idx, slice): # for contiguous indices narrow+concat is usually faster than scatter row
# for contiguous indices pack is usually faster than scatter row slc = idx.slice_data()
part1 = F.narrow_row(self.data, 0, idx.start) part1 = F.narrow_row(self.data, 0, slc.start)
part2 = feats part2 = feats
part3 = F.narrow_row(self.data, idx.stop, len(self)) part3 = F.narrow_row(self.data, slc.stop, len(self))
self.data = F.cat([part1, part2, part3], dim=0) self.data = F.cat([part1, part2, part3], dim=0)
else: else:
idx = idx.tousertensor(F.context(self.data))
self.data = F.scatter_row(self.data, idx, feats) self.data = F.scatter_row(self.data, idx, feats)
def extend(self, feats, feat_scheme=None): def extend(self, feats, feat_scheme=None):
...@@ -143,7 +157,7 @@ class Column(object): ...@@ -143,7 +157,7 @@ class Column(object):
The scheme The scheme
""" """
if feat_scheme is None: if feat_scheme is None:
feat_scheme = Scheme.infer_scheme(feats) feat_scheme = infer_scheme(feats)
if feat_scheme != self.scheme: if feat_scheme != self.scheme:
raise DGLError("Cannot update column of scheme %s using feature of scheme %s." raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
...@@ -314,8 +328,8 @@ class Frame(MutableMapping): ...@@ -314,8 +328,8 @@ class Frame(MutableMapping):
return return
if self.get_initializer(name) is None: if self.get_initializer(name) is None:
self._warn_and_set_initializer() self._warn_and_set_initializer()
init_data = self.get_initializer(name)( initializer = self.get_initializer(name)
(self.num_rows,) + scheme.shape, scheme.dtype, init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(0, self.num_rows)) ctx, slice(0, self.num_rows))
self._columns[name] = Column(init_data, scheme) self._columns[name] = Column(init_data, scheme)
...@@ -336,8 +350,8 @@ class Frame(MutableMapping): ...@@ -336,8 +350,8 @@ class Frame(MutableMapping):
ctx = F.context(col.data) ctx = F.context(col.data)
if self.get_initializer(key) is None: if self.get_initializer(key) is None:
self._warn_and_set_initializer() self._warn_and_set_initializer()
new_data = self.get_initializer(key)( initializer = self.get_initializer(key)
(num_rows,) + scheme.shape, scheme.dtype, new_data = initializer((num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(self._num_rows, self._num_rows + num_rows)) ctx, slice(self._num_rows, self._num_rows + num_rows))
feat_placeholders[key] = new_data feat_placeholders[key] = new_data
self._append(Frame(feat_placeholders)) self._append(Frame(feat_placeholders))
...@@ -368,16 +382,16 @@ class Frame(MutableMapping): ...@@ -368,16 +382,16 @@ class Frame(MutableMapping):
else: else:
# pad columns that are not provided in the other frame with initial values # pad columns that are not provided in the other frame with initial values
for key, col in self.items(): for key, col in self.items():
if key not in other: if key in other:
continue
scheme = col.scheme scheme = col.scheme
ctx = F.context(col.data) ctx = F.context(col.data)
if self.get_initializer(key) is None: if self.get_initializer(key) is None:
self._warn_and_set_initializer() self._warn_and_set_initializer()
new_data = self.get_initializer(key)( initializer = self.get_initializer(key)
(other.num_rows,) + scheme.shape, scheme.dtype, new_data = initializer((other.num_rows,) + scheme.shape,
ctx, slice(self._num_rows, scheme.dtype, ctx,
self._num_rows + other.num_rows) slice(self._num_rows, self._num_rows + other.num_rows))
)
other[key] = new_data other[key] = new_data
# append other to self # append other to self
for key, col in other.items(): for key, col in other.items():
...@@ -428,23 +442,19 @@ class FrameRef(MutableMapping): ...@@ -428,23 +442,19 @@ class FrameRef(MutableMapping):
frame : Frame, optional frame : Frame, optional
The underlying frame. If not given, the reference will point to a The underlying frame. If not given, the reference will point to a
new empty frame. new empty frame.
index : iterable, slice, or int, optional index : utils.Index, optional
The rows that are referenced in the underlying frame. If not given, The rows that are referenced in the underlying frame. If not given,
the whole frame is referenced. The index should be distinct (no the whole frame is referenced. The index should be distinct (no
duplication is allowed). duplication is allowed).
Note that if a slice is given, the step must be None.
""" """
def __init__(self, frame=None, index=None): def __init__(self, frame=None, index=None):
self._frame = frame if frame is not None else Frame() self._frame = frame if frame is not None else Frame()
# TODO(minjie): check no duplication
assert index is None or isinstance(index, utils.Index)
if index is None: if index is None:
# _index_data can be either a slice or an iterable self._index = utils.toindex(slice(0, self._frame.num_rows))
self._index_data = slice(0, self._frame.num_rows)
else: else:
# TODO(minjie): check no duplication self._index = index
self._index_data = index
self._index = None
self._index_or_slice = None
@property @property
def schemes(self): def schemes(self):
...@@ -465,11 +475,7 @@ class FrameRef(MutableMapping): ...@@ -465,11 +475,7 @@ class FrameRef(MutableMapping):
@property @property
def num_rows(self): def num_rows(self):
"""Return the number of rows referred.""" """Return the number of rows referred."""
if isinstance(self._index_data, slice): return len(self._index)
# NOTE: we always assume that slice.step is None
return self._index_data.stop - self._index_data.start
else:
return len(self._index_data)
def set_initializer(self, initializer, column=None): def set_initializer(self, initializer, column=None):
"""Set the initializer for empty values. """Set the initializer for empty values.
...@@ -500,38 +506,6 @@ class FrameRef(MutableMapping): ...@@ -500,38 +506,6 @@ class FrameRef(MutableMapping):
""" """
return self._frame.get_initializer(column) return self._frame.get_initializer(column)
def index(self):
"""Return the index object.
Returns
-------
utils.Index
The index.
"""
if self._index is None:
if self.is_contiguous():
self._index = utils.toindex(
F.arange(self._index_data.start,
self._index_data.stop))
else:
self._index = utils.toindex(self._index_data)
return self._index
def index_or_slice(self):
"""Returns the index object or the slice
Returns
-------
utils.Index or slice
The index or slice
"""
if self._index_or_slice is None:
if self.is_contiguous():
self._index_or_slice = self._index_data
else:
self._index_or_slice = utils.toindex(self._index_data)
return self._index_or_slice
def __contains__(self, name): def __contains__(self, name):
"""Return whether the column name exists.""" """Return whether the column name exists."""
return name in self._frame return name in self._frame
...@@ -567,7 +541,7 @@ class FrameRef(MutableMapping): ...@@ -567,7 +541,7 @@ class FrameRef(MutableMapping):
Parameters Parameters
---------- ----------
key : str or utils.Index or slice key : str or utils.Index
The key. The key.
Returns Returns
...@@ -575,12 +549,11 @@ class FrameRef(MutableMapping): ...@@ -575,12 +549,11 @@ class FrameRef(MutableMapping):
Tensor or lazy dict or tensors Tensor or lazy dict or tensors
Depends on whether it is a column selection or row selection. Depends on whether it is a column selection or row selection.
""" """
if not isinstance(key, (str, utils.Index)):
raise DGLError('Argument "key" must be either str or utils.Index type.')
if isinstance(key, str): if isinstance(key, str):
return self.select_column(key) return self.select_column(key)
elif isinstance(key, slice) and key == slice(0, self.num_rows): elif key.is_slice(0, self.num_rows):
# shortcut for selecting all the rows
return self
elif isinstance(key, utils.Index) and key.is_slice(0, self.num_rows):
# shortcut for selecting all the rows # shortcut for selecting all the rows
return self return self
else: else:
...@@ -606,7 +579,7 @@ class FrameRef(MutableMapping): ...@@ -606,7 +579,7 @@ class FrameRef(MutableMapping):
if self.is_span_whole_column(): if self.is_span_whole_column():
return col.data return col.data
else: else:
return col[self.index_or_slice()] return col[self._index]
def select_rows(self, query): def select_rows(self, query):
"""Return the rows given the query. """Return the rows given the query.
...@@ -625,9 +598,22 @@ class FrameRef(MutableMapping): ...@@ -625,9 +598,22 @@ class FrameRef(MutableMapping):
return utils.LazyDict(lambda key: self._frame[key][rows], keys=self.keys()) return utils.LazyDict(lambda key: self._frame[key][rows], keys=self.keys())
def __setitem__(self, key, val): def __setitem__(self, key, val):
self.set_item_inplace(key, val, inplace=False) """Update the data in the frame. The update is done out-of-place.
Parameters
----------
key : str or utils.Index
The key.
val : Tensor or dict of tensors
The value.
See Also
--------
update
"""
self.update_data(key, val, inplace=False)
def set_item_inplace(self, key, val, inplace): def update_data(self, key, val, inplace):
"""Update the data in the frame. """Update the data in the frame.
If the provided key is string, the corresponding column data will be updated. If the provided key is string, the corresponding column data will be updated.
...@@ -649,14 +635,14 @@ class FrameRef(MutableMapping): ...@@ -649,14 +635,14 @@ class FrameRef(MutableMapping):
inplace: bool inplace: bool
If True, update will be done in place If True, update will be done in place
""" """
if not isinstance(key, (str, utils.Index)):
raise DGLError('Argument "key" must be either str or utils.Index type.')
if isinstance(key, str): if isinstance(key, str):
self.update_column(key, val, inplace=inplace) self.update_column(key, val, inplace=inplace)
elif isinstance(key, slice) and key == slice(0, self.num_rows): elif key.is_slice(0, self.num_rows):
# shortcut for updating all the rows # shortcut for updating all the rows
return self.update(val) for colname, col in val.items():
elif isinstance(key, utils.Index) and key.is_slice(0, self.num_rows): self.update_column(colname, col, inplace=inplace)
# shortcut for selecting all the rows
return self.update(val)
else: else:
self.update_rows(key, val, inplace=inplace) self.update_rows(key, val, inplace=inplace)
...@@ -683,15 +669,14 @@ class FrameRef(MutableMapping): ...@@ -683,15 +669,14 @@ class FrameRef(MutableMapping):
col = Column.create(data) col = Column.create(data)
if self.num_columns == 0: if self.num_columns == 0:
# the frame is empty # the frame is empty
self._index_data = slice(0, len(col)) self._index = utils.toindex(slice(0, len(col)))
self._clear_cache()
self._frame[name] = col self._frame[name] = col
else: else:
if name not in self._frame: if name not in self._frame:
ctx = F.context(data) ctx = F.context(data)
self._frame.add_column(name, infer_scheme(data), ctx) self._frame.add_column(name, infer_scheme(data), ctx)
fcol = self._frame[name] fcol = self._frame[name]
fcol.update(self.index_or_slice(), data, inplace) fcol.update(self._index, data, inplace)
def add_rows(self, num_rows): def add_rows(self, num_rows):
"""Add blank rows to the underlying frame. """Add blank rows to the underlying frame.
...@@ -711,10 +696,14 @@ class FrameRef(MutableMapping): ...@@ -711,10 +696,14 @@ class FrameRef(MutableMapping):
if not self.is_span_whole_column(): if not self.is_span_whole_column():
raise RuntimeError('FrameRef not spanning whole column.') raise RuntimeError('FrameRef not spanning whole column.')
self._frame.add_rows(num_rows) self._frame.add_rows(num_rows)
if self.is_contiguous(): if self._index.slice_data() is not None:
self._index_data = slice(0, self._index_data.stop + num_rows) # the index is a slice
slc = self._index.slice_data()
self._index = utils.toindex(slice(slc.start, slc.stop + num_rows))
else: else:
self._index_data.extend(range(self.num_rows, self.num_rows + num_rows)) selfidxdata = self._index.tousertensor()
newdata = F.arange(self.num_rows, self.num_rows + num_rows)
self._index = utils.toindex(F.cat([selfidxdata, newdata], dim=0))
def update_rows(self, query, data, inplace): def update_rows(self, query, data, inplace):
"""Update the rows. """Update the rows.
...@@ -759,6 +748,8 @@ class FrameRef(MutableMapping): ...@@ -759,6 +748,8 @@ class FrameRef(MutableMapping):
key : str or utils.Index key : str or utils.Index
The key. The key.
""" """
if not isinstance(key, (str, utils.Index)):
raise DGLError('Argument "key" must be either str or utils.Index type.')
if isinstance(key, str): if isinstance(key, str):
del self._frame[key] del self._frame[key]
else: else:
...@@ -769,22 +760,16 @@ class FrameRef(MutableMapping): ...@@ -769,22 +760,16 @@ class FrameRef(MutableMapping):
Please note that "deleted" rows are not really deleted, but simply removed Please note that "deleted" rows are not really deleted, but simply removed
in the reference. As a result, if two FrameRefs point to the same Frame, deleting in the reference. As a result, if two FrameRefs point to the same Frame, deleting
from one ref will not relect on the other. By contrast, deleting columns is real. from one ref will not reflect on the other. By contrast, deleting columns is real.
Parameters Parameters
---------- ----------
query : utils.Index or slice query : utils.Index
The rows to be deleted. The rows to be deleted.
""" """
if isinstance(query, slice):
query = range(query.start, query.stop)
else:
query = query.tonumpy() query = query.tonumpy()
index = self._index.tonumpy()
if isinstance(self._index_data, slice): self._index = utils.toindex(np.delete(index, query))
self._index_data = range(self._index_data.start, self._index_data.stop)
self._index_data = list(np.delete(self._index_data, query))
self._clear_cache()
def append(self, other): def append(self, other):
"""Append another frame into this one. """Append another frame into this one.
...@@ -794,59 +779,50 @@ class FrameRef(MutableMapping): ...@@ -794,59 +779,50 @@ class FrameRef(MutableMapping):
other : dict of str to tensor other : dict of str to tensor
The data to be appended. The data to be appended.
""" """
span_whole = self.is_span_whole_column()
contiguous = self.is_contiguous()
old_nrows = self._frame.num_rows old_nrows = self._frame.num_rows
self._frame.append(other) self._frame.append(other)
new_nrows = self._frame.num_rows
# update index # update index
if span_whole: if (self._index.slice_data() is not None
self._index_data = slice(0, self._frame.num_rows) and self._index.slice_data().stop == old_nrows):
elif contiguous: # Self index is a slice and index.stop is equal to the size of the
if self._index_data.stop == old_nrows: # underlying frame. Can still use a slice for the new index.
new_idx = slice(self._index_data.start, self._frame.num_rows) oldstart = self._index.slice_data().start
self._index = utils.toindex(slice(oldstart, new_nrows))
else: else:
new_idx = list(range(self._index_data.start, self._index_data.stop)) # convert it to user tensor and concat
new_idx.extend(range(old_nrows, self._frame.num_rows)) selfidxdata = self._index.tousertensor()
self._index_data = new_idx newdata = F.arange(old_nrows, new_nrows)
self._clear_cache() self._index = utils.toindex(F.cat([selfidxdata, newdata], dim=0))
def clear(self): def clear(self):
"""Clear the frame.""" """Clear the frame."""
self._frame.clear() self._frame.clear()
self._index_data = slice(0, 0) self._index = utils.toindex(slice(0, 0))
self._clear_cache()
def is_contiguous(self): def is_contiguous(self):
"""Return whether this refers to a contiguous range of rows.""" """Return whether this refers to a contiguous range of rows."""
# NOTE: this check could have false negatives # NOTE: this check could have false negatives
# NOTE: we always assume that slice.step is None return self._index.slice_data() is not None
return isinstance(self._index_data, slice)
def is_span_whole_column(self): def is_span_whole_column(self):
"""Return whether this refers to all the rows.""" """Return whether this refers to all the rows."""
return self.is_contiguous() and self.num_rows == self._frame.num_rows return self.is_contiguous() and self.num_rows == self._frame.num_rows
def _getrows(self, query): def _getrows(self, query):
"""Internal function to convert from the local row ids to the row ids of the frame.""" """Internal function to convert from the local row ids to the row ids of the frame.
if self.is_contiguous():
start = self._index_data.start Parameters
if start == 0: ----------
# shortcut for identical mapping query : utils.Index
return query The query index.
elif isinstance(query, slice):
return slice(query.start + start, query.stop + start) Returns
else: -------
query = query.tousertensor() utils.Index
return utils.toindex(query + start) The actual index to the underlying frame.
else: """
idxtensor = self.index().tousertensor() return self._index.get_items(query)
query = query.tousertensor()
return utils.toindex(F.gather_row(idxtensor, query))
def _clear_cache(self):
"""Internal function to clear the cached object."""
self._index = None
self._index_or_slice = None
def frame_like(other, num_rows): def frame_like(other, num_rows):
"""Create a new frame that has the same scheme as the given one. """Create a new frame that has the same scheme as the given one.
......
"""DGL builtin functors""" """DGL builtin functors"""
# pylint: disable=redefined-builtin
from __future__ import absolute_import from __future__ import absolute_import
from .message import * from .message import *
......
"""Built-in function base class""" """Built-in function base class"""
from __future__ import absolute_import from __future__ import absolute_import
__all__ = ['BuiltinFunction', 'BundledFunction']
class BuiltinFunction(object): class BuiltinFunction(object):
"""Base builtin function class.""" """Base builtin function class."""
def __call__(self):
"""Regular computation of this builtin function
This will be used when optimization is not available.
"""
raise NotImplementedError
@property @property
def name(self): def name(self):
"""Return the name of this builtin function.""" """Return the name of this builtin function."""
raise NotImplementedError raise NotImplementedError
class BundledFunction(object): class BundledFunction(object):
"""A utility class that bundles multiple functions.
Parameters
----------
fn_list : list of callable
The function list.
"""
def __init__(self, fn_list): def __init__(self, fn_list):
self.fn_list = fn_list self.fn_list = fn_list
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
"""Regular computation of this builtin function
This will be used when optimization is not available and should
ONLY be called by DGL framework.
"""
ret = {} ret = {}
for fn in self.fn_list: for fn in self.fn_list:
ret.update(fn(*args, **kwargs)) ret.update(fn(*args, **kwargs))
...@@ -28,4 +34,5 @@ class BundledFunction(object): ...@@ -28,4 +34,5 @@ class BundledFunction(object):
@property @property
def name(self): def name(self):
"""Return the name."""
return "bundled" return "bundled"
"""Built-in message function.""" """Built-in message function."""
from __future__ import absolute_import from __future__ import absolute_import
from .base import BuiltinFunction
import operator import operator
import dgl.backend as F
from .base import BuiltinFunction
from .. import backend as F
__all__ = ["src_mul_edge", "copy_src", "copy_edge"] __all__ = ["src_mul_edge", "copy_src", "copy_edge"]
...@@ -12,9 +13,10 @@ class MessageFunction(BuiltinFunction): ...@@ -12,9 +13,10 @@ class MessageFunction(BuiltinFunction):
"""Base builtin message function class.""" """Base builtin message function class."""
def __call__(self, edges): def __call__(self, edges):
"""Regular computation of this builtin. """Regular computation of this builtin function
This will be used when optimization is not available. This will be used when optimization is not available and should
ONLY be called by DGL framework.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -29,9 +31,9 @@ class MessageFunction(BuiltinFunction): ...@@ -29,9 +31,9 @@ class MessageFunction(BuiltinFunction):
@property @property
def use_edge_feature(self): def use_edge_feature(self):
"""Return true if the message function uses edge feature data."""
raise NotImplementedError raise NotImplementedError
def _is_spmv_supported_edge_feat(g, field): def _is_spmv_supported_edge_feat(g, field):
"""Return whether the edge feature shape supports SPMV optimization. """Return whether the edge feature shape supports SPMV optimization.
...@@ -43,6 +45,12 @@ def _is_spmv_supported_edge_feat(g, field): ...@@ -43,6 +45,12 @@ def _is_spmv_supported_edge_feat(g, field):
class SrcMulEdgeMessageFunction(MessageFunction): class SrcMulEdgeMessageFunction(MessageFunction):
"""Class for the src_mul_edge builtin message function.
See Also
--------
src_mul_edge
"""
def __init__(self, mul_op, src_field, edge_field, out_field): def __init__(self, mul_op, src_field, edge_field, out_field):
self.mul_op = mul_op self.mul_op = mul_op
self.src_field = src_field self.src_field = src_field
...@@ -50,9 +58,26 @@ class SrcMulEdgeMessageFunction(MessageFunction): ...@@ -50,9 +58,26 @@ class SrcMulEdgeMessageFunction(MessageFunction):
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
"""Return true if this supports SPMV optimization.
Parameters
----------
g : DGLGraph
The graph.
Returns
-------
bool
True if this supports SPMV optimization.
"""
return _is_spmv_supported_edge_feat(g, self.edge_field) return _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, edges): def __call__(self, edges):
"""Regular computation of this builtin function
This will be used when optimization is not available and should
ONLY be called by DGL framework.
"""
sdata = edges.src[self.src_field] sdata = edges.src[self.src_field]
edata = edges.data[self.edge_field] edata = edges.data[self.edge_field]
# Due to the different broadcasting semantics of different backends, # Due to the different broadcasting semantics of different backends,
...@@ -71,17 +96,41 @@ class SrcMulEdgeMessageFunction(MessageFunction): ...@@ -71,17 +96,41 @@ class SrcMulEdgeMessageFunction(MessageFunction):
@property @property
def use_edge_feature(self): def use_edge_feature(self):
"""Return true if the message function uses edge feature data."""
return True return True
class CopySrcMessageFunction(MessageFunction): class CopySrcMessageFunction(MessageFunction):
"""Class for the copy_src builtin message function.
See Also
--------
copy_src
"""
def __init__(self, src_field, out_field): def __init__(self, src_field, out_field):
self.src_field = src_field self.src_field = src_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
"""Return true if this supports SPMV optimization.
Parameters
----------
g : DGLGraph
The graph.
Returns
-------
bool
True if this supports SPMV optimization.
"""
return True return True
def __call__(self, edges): def __call__(self, edges):
"""Regular computation of this builtin function
This will be used when optimization is not available and should
ONLY be called by DGL framework.
"""
return {self.out_field : edges.src[self.src_field]} return {self.out_field : edges.src[self.src_field]}
@property @property
...@@ -90,19 +139,43 @@ class CopySrcMessageFunction(MessageFunction): ...@@ -90,19 +139,43 @@ class CopySrcMessageFunction(MessageFunction):
@property @property
def use_edge_feature(self): def use_edge_feature(self):
"""Return true if the message function uses edge feature data."""
return False return False
class CopyEdgeMessageFunction(MessageFunction): class CopyEdgeMessageFunction(MessageFunction):
"""Class for the copy_edge builtin message function.
See Also
--------
copy_edge
"""
def __init__(self, edge_field=None, out_field=None): def __init__(self, edge_field=None, out_field=None):
self.edge_field = edge_field self.edge_field = edge_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
"""Return true if this supports SPMV optimization.
Parameters
----------
g : DGLGraph
The graph.
Returns
-------
bool
True if this supports SPMV optimization.
"""
# TODO: support this with e2v spmv # TODO: support this with e2v spmv
return False return False
# return _is_spmv_supported_edge_feat(g, self.edge_field) # return _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, edges): def __call__(self, edges):
"""Regular computation of this builtin function
This will be used when optimization is not available and should
ONLY be called by DGL framework.
"""
return {self.out_field : edges.data[self.edge_field]} return {self.out_field : edges.data[self.edge_field]}
@property @property
...@@ -111,9 +184,9 @@ class CopyEdgeMessageFunction(MessageFunction): ...@@ -111,9 +184,9 @@ class CopyEdgeMessageFunction(MessageFunction):
@property @property
def use_edge_feature(self): def use_edge_feature(self):
"""Return true if the message function uses edge feature data."""
return True return True
def src_mul_edge(src, edge, out): def src_mul_edge(src, edge, out):
"""Builtin message function that computes message by multiplying source """Builtin message function that computes message by multiplying source
node features with edge features. node features with edge features.
......
"""Built-in reducer function.""" """Built-in reducer function."""
# pylint: disable=redefined-builtin
from __future__ import absolute_import from __future__ import absolute_import
from .. import backend as F from .. import backend as F
...@@ -10,9 +11,10 @@ class ReduceFunction(BuiltinFunction): ...@@ -10,9 +11,10 @@ class ReduceFunction(BuiltinFunction):
"""Base builtin reduce function class.""" """Base builtin reduce function class."""
def __call__(self, nodes): def __call__(self, nodes):
"""Regular computation of this builtin. """Regular computation of this builtin function
This will be used when optimization is not available. This will be used when optimization is not available and should
ONLY be called by DGL framework.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -29,18 +31,19 @@ class ReduceFunction(BuiltinFunction): ...@@ -29,18 +31,19 @@ class ReduceFunction(BuiltinFunction):
class SimpleReduceFunction(ReduceFunction): class SimpleReduceFunction(ReduceFunction):
"""Builtin reduce function that aggregates a single field into another """Builtin reduce function that aggregates a single field into another
single field.""" single field."""
def __init__(self, name, op, msg_field, out_field): def __init__(self, name, reduce_op, msg_field, out_field):
self._name = name self._name = name
self.op = op self.reduce_op = reduce_op
self.msg_field = msg_field self.msg_field = msg_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self): def is_spmv_supported(self):
"""Return whether the SPMV optimization is supported."""
# NOTE: only sum is supported right now. # NOTE: only sum is supported right now.
return self._name == "sum" return self._name == "sum"
def __call__(self, nodes): def __call__(self, nodes):
return {self.out_field : self.op(nodes.mailbox[self.msg_field], 1)} return {self.out_field : self.reduce_op(nodes.mailbox[self.msg_field], 1)}
@property @property
def name(self): def name(self):
......
"""Base graph class specialized for neural networks on graphs.""" """Base graph class specialized for neural networks on graphs."""
from __future__ import absolute_import from __future__ import absolute_import
import networkx as nx
import numpy as np
from collections import defaultdict from collections import defaultdict
import dgl from .base import ALL, is_all, DGLError
from .base import ALL, is_all, DGLError, dgl_warning
from . import backend as F from . import backend as F
from . import init
from .frame import FrameRef, Frame from .frame import FrameRef, Frame
from .graph_index import GraphIndex, create_graph_index from .graph_index import create_graph_index
from .runtime import ir, scheduler, Runtime from .runtime import ir, scheduler, Runtime
from . import subgraph
from . import utils from . import utils
from .view import NodeView, EdgeView from .view import NodeView, EdgeView
from .udf import NodeBatch, EdgeBatch from .udf import NodeBatch, EdgeBatch
__all__ = ['DGLGraph'] __all__ = ['DGLGraph']
class DGLGraph(object): class DGLGraph(object):
...@@ -177,7 +175,6 @@ class DGLGraph(object): ...@@ -177,7 +175,6 @@ class DGLGraph(object):
multigraph=False, multigraph=False,
readonly=False): readonly=False):
# graph # graph
self._readonly=readonly
self._graph = create_graph_index(graph_data, multigraph, readonly) self._graph = create_graph_index(graph_data, multigraph, readonly)
# node and edge frame # node and edge frame
if node_frame is None: if node_frame is None:
...@@ -194,7 +191,7 @@ class DGLGraph(object): ...@@ -194,7 +191,7 @@ class DGLGraph(object):
# message frame # message frame
self._msg_frame = FrameRef(Frame(num_rows=self.number_of_edges())) self._msg_frame = FrameRef(Frame(num_rows=self.number_of_edges()))
# set initializer for message frame # set initializer for message frame
self._msg_frame.set_initializer(dgl.init.zero_initializer) self._msg_frame.set_initializer(init.zero_initializer)
# registered functions # registered functions
self._message_func = None self._message_func = None
self._reduce_func = None self._reduce_func = None
...@@ -916,7 +913,7 @@ class DGLGraph(object): ...@@ -916,7 +913,7 @@ class DGLGraph(object):
else: else:
raise DGLError('Invalid form:', form) raise DGLError('Invalid form:', form)
def all_edges(self, form='uv', sorted=False): def all_edges(self, form='uv', return_sorted=False):
"""Return all the edges. """Return all the edges.
Parameters Parameters
...@@ -927,7 +924,7 @@ class DGLGraph(object): ...@@ -927,7 +924,7 @@ class DGLGraph(object):
- 'all' : a tuple (u, v, eid) - 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default - 'uv' : a pair (u, v), default
- 'eid' : one eid tensor - 'eid' : one eid tensor
sorted : bool return_sorted : bool
True if the returned edges are sorted by their src and dst ids. True if the returned edges are sorted by their src and dst ids.
Returns Returns
...@@ -954,7 +951,7 @@ class DGLGraph(object): ...@@ -954,7 +951,7 @@ class DGLGraph(object):
>>> G.all_edges('all') >>> G.all_edges('all')
(tensor([0, 0, 1]), tensor([1, 2, 2]), tensor([0, 1, 2])) (tensor([0, 0, 1]), tensor([1, 2, 2]), tensor([0, 1, 2]))
""" """
src, dst, eid = self._graph.edges(sorted) src, dst, eid = self._graph.edges(return_sorted)
if form == 'all': if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor()) return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv': elif form == 'uv':
...@@ -1121,13 +1118,13 @@ class DGLGraph(object): ...@@ -1121,13 +1118,13 @@ class DGLGraph(object):
nx_graph = self._graph.to_networkx() nx_graph = self._graph.to_networkx()
if node_attrs is not None: if node_attrs is not None:
for nid, attr in nx_graph.nodes(data=True): for nid, attr in nx_graph.nodes(data=True):
nf = self.get_n_repr(nid) feat_dict = self.get_n_repr(nid)
attr.update({key: nf[key].squeeze(0) for key in node_attrs}) attr.update({key: feat_dict[key].squeeze(0) for key in node_attrs})
if edge_attrs is not None: if edge_attrs is not None:
for u, v, attr in nx_graph.edges(data=True): for _, _, attr in nx_graph.edges(data=True):
eid = attr['id'] eid = attr['id']
ef = self.get_e_repr(eid) feat_dict = self.get_e_repr(eid)
attr.update({key: ef[key].squeeze(0) for key in edge_attrs}) attr.update({key: feat_dict[key].squeeze(0) for key in edge_attrs})
return nx_graph return nx_graph
def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None): def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
...@@ -1208,12 +1205,12 @@ class DGLGraph(object): ...@@ -1208,12 +1205,12 @@ class DGLGraph(object):
for attr in edge_attrs: for attr in edge_attrs:
self._edge_frame[attr] = _batcher(attr_dict[attr]) self._edge_frame[attr] = _batcher(attr_dict[attr])
def from_scipy_sparse_matrix(self, a): def from_scipy_sparse_matrix(self, spmat):
""" Convert from scipy sparse matrix. """ Convert from scipy sparse matrix.
Parameters Parameters
---------- ----------
a : scipy sparse matrix spmat : scipy sparse matrix
The graph's adjacency matrix The graph's adjacency matrix
Examples Examples
...@@ -1227,7 +1224,7 @@ class DGLGraph(object): ...@@ -1227,7 +1224,7 @@ class DGLGraph(object):
>>> g.from_scipy_sparse_matrix(a) >>> g.from_scipy_sparse_matrix(a)
""" """
self.clear() self.clear()
self._graph.from_scipy_sparse_matrix(a) self._graph.from_scipy_sparse_matrix(spmat)
self._node_frame.add_rows(self.number_of_nodes()) self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges()) self._edge_frame.add_rows(self.number_of_edges())
self._msg_index = utils.zero_index(self.number_of_edges()) self._msg_index = utils.zero_index(self.number_of_edges())
...@@ -1502,10 +1499,10 @@ class DGLGraph(object): ...@@ -1502,10 +1499,10 @@ class DGLGraph(object):
""" """
return self.edges[:].data return self.edges[:].data
def set_n_repr(self, hu, u=ALL, inplace=False): def set_n_repr(self, data, u=ALL, inplace=False):
"""Set node(s) representation. """Set node(s) representation.
`hu` is a dictionary from the feature name to feature tensor. Each tensor `data` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of nodes to be updated, is of shape (B, D1, D2, ...), where B is the number of nodes to be updated,
and (D1, D2, ...) be the shape of the node representation tensor. The and (D1, D2, ...) be the shape of the node representation tensor. The
length of the given node ids must match B (i.e, len(u) == B). length of the given node ids must match B (i.e, len(u) == B).
...@@ -1515,7 +1512,7 @@ class DGLGraph(object): ...@@ -1515,7 +1512,7 @@ class DGLGraph(object):
Parameters Parameters
---------- ----------
hu : dict of tensor data : dict of tensor
Node representation. Node representation.
u : node, container or tensor u : node, container or tensor
The node(s). The node(s).
...@@ -1523,25 +1520,25 @@ class DGLGraph(object): ...@@ -1523,25 +1520,25 @@ class DGLGraph(object):
If True, update will be done in place, but autograd will break. If True, update will be done in place, but autograd will break.
""" """
# sanity check # sanity check
if not utils.is_dict_like(hu): if not utils.is_dict_like(data):
raise DGLError('Expect dictionary type for feature data.' raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(hu)) ' Got "%s" instead.' % type(data))
if is_all(u): if is_all(u):
num_nodes = self.number_of_nodes() num_nodes = self.number_of_nodes()
else: else:
u = utils.toindex(u) u = utils.toindex(u)
num_nodes = len(u) num_nodes = len(u)
for key, val in hu.items(): for key, val in data.items():
nfeats = F.shape(val)[0] nfeats = F.shape(val)[0]
if nfeats != num_nodes: if nfeats != num_nodes:
raise DGLError('Expect number of features to match number of nodes (len(u)).' raise DGLError('Expect number of features to match number of nodes (len(u)).'
' Got %d and %d instead.' % (nfeats, num_nodes)) ' Got %d and %d instead.' % (nfeats, num_nodes))
# set # set
if is_all(u): if is_all(u):
for key, val in hu.items(): for key, val in data.items():
self._node_frame[key] = val self._node_frame[key] = val
else: else:
self._node_frame.update_rows(u, hu, inplace=inplace) self._node_frame.update_rows(u, data, inplace=inplace)
def get_n_repr(self, u=ALL): def get_n_repr(self, u=ALL):
"""Get node(s) representation. """Get node(s) representation.
...@@ -1581,10 +1578,10 @@ class DGLGraph(object): ...@@ -1581,10 +1578,10 @@ class DGLGraph(object):
""" """
return self._node_frame.pop(key) return self._node_frame.pop(key)
def set_e_repr(self, he, edges=ALL, inplace=False): def set_e_repr(self, data, edges=ALL, inplace=False):
"""Set edge(s) representation. """Set edge(s) representation.
`he` is a dictionary from the feature name to feature tensor. Each tensor `data` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated, is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
and (D1, D2, ...) be the shape of the edge representation tensor. and (D1, D2, ...) be the shape of the edge representation tensor.
...@@ -1593,7 +1590,7 @@ class DGLGraph(object): ...@@ -1593,7 +1590,7 @@ class DGLGraph(object):
Parameters Parameters
---------- ----------
he : tensor or dict of tensor data : tensor or dict of tensor
Edge representation. Edge representation.
edges : edges edges : edges
Edges can be a pair of endpoint nodes (u, v), or a Edges can be a pair of endpoint nodes (u, v), or a
...@@ -1614,16 +1611,16 @@ class DGLGraph(object): ...@@ -1614,16 +1611,16 @@ class DGLGraph(object):
eid = utils.toindex(edges) eid = utils.toindex(edges)
# sanity check # sanity check
if not utils.is_dict_like(he): if not utils.is_dict_like(data):
raise DGLError('Expect dictionary type for feature data.' raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(he)) ' Got "%s" instead.' % type(data))
if is_all(eid): if is_all(eid):
num_edges = self.number_of_edges() num_edges = self.number_of_edges()
else: else:
eid = utils.toindex(eid) eid = utils.toindex(eid)
num_edges = len(eid) num_edges = len(eid)
for key, val in he.items(): for key, val in data.items():
nfeats = F.shape(val)[0] nfeats = F.shape(val)[0]
if nfeats != num_edges: if nfeats != num_edges:
raise DGLError('Expect number of features to match number of edges.' raise DGLError('Expect number of features to match number of edges.'
...@@ -1631,11 +1628,11 @@ class DGLGraph(object): ...@@ -1631,11 +1628,11 @@ class DGLGraph(object):
# set # set
if is_all(eid): if is_all(eid):
# update column # update column
for key, val in he.items(): for key, val in data.items():
self._edge_frame[key] = val self._edge_frame[key] = val
else: else:
# update row # update row
self._edge_frame.update_rows(eid, he, inplace=inplace) self._edge_frame.update_rows(eid, data, inplace=inplace)
def get_e_repr(self, edges=ALL): def get_e_repr(self, edges=ALL):
"""Get node(s) representation. """Get node(s) representation.
...@@ -2491,8 +2488,7 @@ class DGLGraph(object): ...@@ -2491,8 +2488,7 @@ class DGLGraph(object):
prop_edges prop_edges
""" """
for node_frontier in nodes_generator: for node_frontier in nodes_generator:
self.pull(node_frontier, self.pull(node_frontier, message_func, reduce_func, apply_node_func)
message_func, reduce_func, apply_node_func)
def prop_edges(self, def prop_edges(self,
edges_generator, edges_generator,
...@@ -2573,8 +2569,7 @@ class DGLGraph(object): ...@@ -2573,8 +2569,7 @@ class DGLGraph(object):
prop_nodes prop_nodes
""" """
for edge_frontier in edges_generator: for edge_frontier in edges_generator:
self.send_and_recv(edge_frontier, self.send_and_recv(edge_frontier, message_func, reduce_func, apply_node_func)
message_func, reduce_func, apply_node_func)
def subgraph(self, nodes): def subgraph(self, nodes):
"""Return the subgraph induced on given nodes. """Return the subgraph induced on given nodes.
...@@ -2621,7 +2616,7 @@ class DGLGraph(object): ...@@ -2621,7 +2616,7 @@ class DGLGraph(object):
""" """
induced_nodes = utils.toindex(nodes) induced_nodes = utils.toindex(nodes)
sgi = self._graph.node_subgraph(induced_nodes) sgi = self._graph.node_subgraph(induced_nodes)
return dgl.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi) return subgraph.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
def subgraphs(self, nodes): def subgraphs(self, nodes):
"""Return a list of subgraphs, each induced in the corresponding given """Return a list of subgraphs, each induced in the corresponding given
...@@ -2648,8 +2643,8 @@ class DGLGraph(object): ...@@ -2648,8 +2643,8 @@ class DGLGraph(object):
""" """
induced_nodes = [utils.toindex(n) for n in nodes] induced_nodes = [utils.toindex(n) for n in nodes]
sgis = self._graph.node_subgraphs(induced_nodes) sgis = self._graph.node_subgraphs(induced_nodes)
return [dgl.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, return [subgraph.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
sgi) for sgi in sgis] for sgi in sgis]
def edge_subgraph(self, edges): def edge_subgraph(self, edges):
"""Return the subgraph induced on given edges. """Return the subgraph induced on given edges.
...@@ -2695,7 +2690,7 @@ class DGLGraph(object): ...@@ -2695,7 +2690,7 @@ class DGLGraph(object):
""" """
induced_edges = utils.toindex(edges) induced_edges = utils.toindex(edges)
sgi = self._graph.edge_subgraph(induced_edges) sgi = self._graph.edge_subgraph(induced_edges)
return dgl.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi) return subgraph.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
def adjacency_matrix(self, transpose=False, ctx=F.cpu()): def adjacency_matrix(self, transpose=False, ctx=F.cpu()):
"""Return the adjacency matrix representation of this graph. """Return the adjacency matrix representation of this graph.
...@@ -2720,7 +2715,7 @@ class DGLGraph(object): ...@@ -2720,7 +2715,7 @@ class DGLGraph(object):
""" """
return self._graph.adjacency_matrix(transpose, ctx)[0] return self._graph.adjacency_matrix(transpose, ctx)[0]
def incidence_matrix(self, type, ctx=F.cpu()): def incidence_matrix(self, typestr, ctx=F.cpu()):
"""Return the incidence matrix representation of this graph. """Return the incidence matrix representation of this graph.
An incidence matrix is an n x m sparse matrix, where n is An incidence matrix is an n x m sparse matrix, where n is
...@@ -2750,7 +2745,7 @@ class DGLGraph(object): ...@@ -2750,7 +2745,7 @@ class DGLGraph(object):
Parameters Parameters
---------- ----------
type : str typestr : str
Can be either ``in``, ``out`` or ``both`` Can be either ``in``, ``out`` or ``both``
ctx : context, optional (default=cpu) ctx : context, optional (default=cpu)
The context of returned incidence matrix. The context of returned incidence matrix.
...@@ -2760,7 +2755,7 @@ class DGLGraph(object): ...@@ -2760,7 +2755,7 @@ class DGLGraph(object):
SparseTensor SparseTensor
The incidence matrix. The incidence matrix.
""" """
return self._graph.incidence_matrix(type, ctx)[0] return self._graph.incidence_matrix(typestr, ctx)[0]
def line_graph(self, backtracking=True, shared=False): def line_graph(self, backtracking=True, shared=False):
"""Return the line graph of this graph. """Return the line graph of this graph.
...@@ -2833,8 +2828,8 @@ class DGLGraph(object): ...@@ -2833,8 +2828,8 @@ class DGLGraph(object):
v = utils.toindex(nodes) v = utils.toindex(nodes)
n_repr = self.get_n_repr(v) n_repr = self.get_n_repr(v)
nb = NodeBatch(self, v, n_repr) nbatch = NodeBatch(self, v, n_repr)
n_mask = predicate(nb) n_mask = predicate(nbatch)
if is_all(nodes): if is_all(nodes):
return F.nonzero_1d(n_mask) return F.nonzero_1d(n_mask)
...@@ -2906,10 +2901,8 @@ class DGLGraph(object): ...@@ -2906,10 +2901,8 @@ class DGLGraph(object):
src_data = self.get_n_repr(u) src_data = self.get_n_repr(u)
edge_data = self.get_e_repr(eid) edge_data = self.get_e_repr(eid)
dst_data = self.get_n_repr(v) dst_data = self.get_n_repr(v)
eb = EdgeBatch(self, (u, v, eid), ebatch = EdgeBatch(self, (u, v, eid), src_data, edge_data, dst_data)
src_data, edge_data, dst_data) e_mask = predicate(ebatch)
e_mask = predicate(eb)
if is_all(edges): if is_all(edges):
return F.nonzero_1d(e_mask) return F.nonzero_1d(e_mask)
...@@ -2918,7 +2911,9 @@ class DGLGraph(object): ...@@ -2918,7 +2911,9 @@ class DGLGraph(object):
return edges[e_mask] return edges[e_mask]
def __repr__(self): def __repr__(self):
s = 'DGLGraph with {node} nodes and {edge} edges.\nNode data: {ndata}\nEdge data: {edata}' ret = ('DGLGraph(num_nodes={node}, num_edges={edge},\n'
return s.format(node=self.number_of_nodes(), edge=self.number_of_edges(), ' ndata_schemes={ndata}\n'
' edata_schemes={edata})')
return ret.format(node=self.number_of_nodes(), edge=self.number_of_edges(),
ndata=str(self.node_attr_schemes()), ndata=str(self.node_attr_schemes()),
edata=str(self.edge_attr_schemes())) edata=str(self.edge_attr_schemes()))
"""Module for graph index class definition."""
from __future__ import absolute_import from __future__ import absolute_import
import ctypes import ctypes
...@@ -7,7 +8,7 @@ import scipy ...@@ -7,7 +8,7 @@ import scipy
from ._ffi.base import c_array from ._ffi.base import c_array
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import DGLError, is_all from .base import DGLError
from . import backend as F from . import backend as F
from . import utils from . import utils
from .immutable_graph_index import create_immutable_graph_index from .immutable_graph_index import create_immutable_graph_index
...@@ -58,7 +59,7 @@ class GraphIndex(object): ...@@ -58,7 +59,7 @@ class GraphIndex(object):
num : int num : int
Number of nodes to be added. Number of nodes to be added.
""" """
_CAPI_DGLGraphAddVertices(self._handle, num); _CAPI_DGLGraphAddVertices(self._handle, num)
self.clear_cache() self.clear_cache()
def add_edge(self, u, v): def add_edge(self, u, v):
...@@ -71,7 +72,7 @@ class GraphIndex(object): ...@@ -71,7 +72,7 @@ class GraphIndex(object):
v : int v : int
The dst node. The dst node.
""" """
_CAPI_DGLGraphAddEdge(self._handle, u, v); _CAPI_DGLGraphAddEdge(self._handle, u, v)
self.clear_cache() self.clear_cache()
def add_edges(self, u, v): def add_edges(self, u, v):
...@@ -366,12 +367,12 @@ class GraphIndex(object): ...@@ -366,12 +367,12 @@ class GraphIndex(object):
return src, dst, eid return src, dst, eid
@utils.cached_member(cache='_cache', prefix='edges') @utils.cached_member(cache='_cache', prefix='edges')
def edges(self, sorted=False): def edges(self, return_sorted=False):
"""Return all the edges """Return all the edges
Parameters Parameters
---------- ----------
sorted : bool return_sorted : bool
True if the returned edges are sorted by their src and dst ids. True if the returned edges are sorted by their src and dst ids.
Returns Returns
...@@ -383,9 +384,9 @@ class GraphIndex(object): ...@@ -383,9 +384,9 @@ class GraphIndex(object):
utils.Index utils.Index
The edge ids. The edge ids.
""" """
key = 'edges_s%d' % sorted key = 'edges_s%d' % return_sorted
if key not in self._cache: if key not in self._cache:
edge_array = _CAPI_DGLGraphEdges(self._handle, sorted) edge_array = _CAPI_DGLGraphEdges(self._handle, return_sorted)
src = utils.toindex(edge_array(0)) src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1)) dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2)) eid = utils.toindex(edge_array(2))
...@@ -505,7 +506,6 @@ class GraphIndex(object): ...@@ -505,7 +506,6 @@ class GraphIndex(object):
""" """
e_array = e.todgltensor() e_array = e.todgltensor()
rst = _CAPI_DGLGraphEdgeSubgraph(self._handle, e_array) rst = _CAPI_DGLGraphEdgeSubgraph(self._handle, e_array)
gi = GraphIndex(rst(0))
induced_nodes = utils.toindex(rst(1)) induced_nodes = utils.toindex(rst(1))
return SubgraphIndex(rst(0), self, induced_nodes, e) return SubgraphIndex(rst(0), self, induced_nodes, e)
...@@ -555,7 +555,7 @@ class GraphIndex(object): ...@@ -555,7 +555,7 @@ class GraphIndex(object):
return adj, shuffle_idx return adj, shuffle_idx
@utils.cached_member(cache='_cache', prefix='inc') @utils.cached_member(cache='_cache', prefix='inc')
def incidence_matrix(self, type, ctx): def incidence_matrix(self, typestr, ctx):
"""Return the incidence matrix representation of this graph. """Return the incidence matrix representation of this graph.
An incidence matrix is an n x m sparse matrix, where n is An incidence matrix is an n x m sparse matrix, where n is
...@@ -577,7 +577,7 @@ class GraphIndex(object): ...@@ -577,7 +577,7 @@ class GraphIndex(object):
Parameters Parameters
---------- ----------
type : str typestr : str
Can be either "in", "out" or "both" Can be either "in", "out" or "both"
ctx : context ctx : context
The context of returned incidence matrix. The context of returned incidence matrix.
...@@ -596,21 +596,21 @@ class GraphIndex(object): ...@@ -596,21 +596,21 @@ class GraphIndex(object):
eid = eid.tousertensor(ctx) # the index of the ctx will be cached eid = eid.tousertensor(ctx) # the index of the ctx will be cached
n = self.number_of_nodes() n = self.number_of_nodes()
m = self.number_of_edges() m = self.number_of_edges()
if type == 'in': if typestr == 'in':
row = F.unsqueeze(dst, 0) row = F.unsqueeze(dst, 0)
col = F.unsqueeze(eid, 0) col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0) idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx) dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif type == 'out': elif typestr == 'out':
row = F.unsqueeze(src, 0) row = F.unsqueeze(src, 0)
col = F.unsqueeze(eid, 0) col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0) idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx) dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif type == 'both': elif typestr == 'both':
# create index # create index
row = F.unsqueeze(F.cat([src, dst], dim=0), 0) row = F.unsqueeze(F.cat([src, dst], dim=0), 0)
col = F.unsqueeze(F.cat([eid, eid], dim=0), 0) col = F.unsqueeze(F.cat([eid, eid], dim=0), 0)
...@@ -625,7 +625,7 @@ class GraphIndex(object): ...@@ -625,7 +625,7 @@ class GraphIndex(object):
dat = F.cat([x, y], dim=0) dat = F.cat([x, y], dim=0)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
else: else:
raise DGLError('Invalid incidence matrix type: %s' % str(type)) raise DGLError('Invalid incidence matrix type: %s' % str(typestr))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return inc, shuffle_idx return inc, shuffle_idx
...@@ -642,8 +642,8 @@ class GraphIndex(object): ...@@ -642,8 +642,8 @@ class GraphIndex(object):
src, dst, eid = self.edges() src, dst, eid = self.edges()
ret = nx.MultiDiGraph() if self.is_multigraph() else nx.DiGraph() ret = nx.MultiDiGraph() if self.is_multigraph() else nx.DiGraph()
ret.add_nodes_from(range(self.number_of_nodes())) ret.add_nodes_from(range(self.number_of_nodes()))
for u, v, id in zip(src, dst, eid): for u, v, e in zip(src, dst, eid):
ret.add_edge(u, v, id=id) ret.add_edge(u, v, id=e)
return ret return ret
def from_networkx(self, nx_graph): def from_networkx(self, nx_graph):
...@@ -891,46 +891,41 @@ def create_graph_index(graph_data=None, multigraph=False, readonly=False): ...@@ -891,46 +891,41 @@ def create_graph_index(graph_data=None, multigraph=False, readonly=False):
Whether the graph is multigraph (default is False) Whether the graph is multigraph (default is False)
""" """
if isinstance(graph_data, GraphIndex): if isinstance(graph_data, GraphIndex):
# FIXME(minjie): this return is not correct for mutable graph index
return graph_data return graph_data
if readonly and graph_data is not None: if readonly:
try: return create_immutable_graph_index(graph_data)
gi = create_immutable_graph_index(graph_data)
except:
gi = None
# If we can't create an immutable graph index, we'll have to fall back.
if gi is not None:
return gi
handle = _CAPI_DGLGraphCreate(multigraph) handle = _CAPI_DGLGraphCreate(multigraph)
gi = GraphIndex(handle) gidx = GraphIndex(handle)
if graph_data is None: if graph_data is None:
return gi return gidx
# edge list # edge list
if isinstance(graph_data, (list, tuple)): if isinstance(graph_data, (list, tuple)):
try: try:
gi.from_edge_list(graph_data) gidx.from_edge_list(graph_data)
return gi return gidx
except: except Exception: # pylint: disable=broad-except
raise DGLError('Graph data is not a valid edge list.') raise DGLError('Graph data is not a valid edge list.')
# scipy format # scipy format
if isinstance(graph_data, scipy.sparse.spmatrix): if isinstance(graph_data, scipy.sparse.spmatrix):
try: try:
gi.from_scipy_sparse_matrix(graph_data) gidx.from_scipy_sparse_matrix(graph_data)
return gi return gidx
except: except Exception: # pylint: disable=broad-except
raise DGLError('Graph data is not a valid scipy sparse matrix.') raise DGLError('Graph data is not a valid scipy sparse matrix.')
# networkx - any format # networkx - any format
try: try:
gi.from_networkx(graph_data) gidx.from_networkx(graph_data)
except: except Exception: # pylint: disable=broad-except
raise DGLError('Error while creating graph from input of type "%s".' raise DGLError('Error while creating graph from input of type "%s".'
% type(graph_data)) % type(graph_data))
return gi return gidx
_init_api("dgl.graph_index") _init_api("dgl.graph_index")
"""Module for immutable graph index.
NOTE: this is currently a temporary solution.
"""
# pylint: disable=abstract-method,unused-argument
from __future__ import absolute_import from __future__ import absolute_import
import ctypes
import numpy as np import numpy as np
import networkx as nx import networkx as nx
import scipy.sparse as sp import scipy.sparse as sp
...@@ -8,7 +13,7 @@ import scipy.sparse as sp ...@@ -8,7 +13,7 @@ import scipy.sparse as sp
from ._ffi.function import _init_api from ._ffi.function import _init_api
from . import backend as F from . import backend as F
from . import utils from . import utils
from .base import ALL, is_all, DGLError from .base import DGLError
class ImmutableGraphIndex(object): class ImmutableGraphIndex(object):
"""Graph index object on immutable graphs. """Graph index object on immutable graphs.
...@@ -229,8 +234,8 @@ class ImmutableGraphIndex(object): ...@@ -229,8 +234,8 @@ class ImmutableGraphIndex(object):
""" """
u = F.tensor([u], dtype=F.int64) u = F.tensor([u], dtype=F.int64)
v = F.tensor([v], dtype=F.int64) v = F.tensor([v], dtype=F.int64)
_, _, id = self._sparse.edge_ids(u, v) _, _, eid = self._sparse.edge_ids(u, v)
return utils.toindex(id) return utils.toindex(eid)
def edge_ids(self, u, v): def edge_ids(self, u, v):
"""Return the edge ids. """Return the edge ids.
...@@ -321,12 +326,12 @@ class ImmutableGraphIndex(object): ...@@ -321,12 +326,12 @@ class ImmutableGraphIndex(object):
src = _CAPI_DGLExpandIds(v.todgltensor(), off.todgltensor()) src = _CAPI_DGLExpandIds(v.todgltensor(), off.todgltensor())
return utils.toindex(src), utils.toindex(dst), utils.toindex(edges) return utils.toindex(src), utils.toindex(dst), utils.toindex(edges)
def edges(self, sorted=False): def edges(self, return_sorted=False):
"""Return all the edges """Return all the edges
Parameters Parameters
---------- ----------
sorted : bool return_sorted : bool
True if the returned edges are sorted by their src and dst ids. True if the returned edges are sorted by their src and dst ids.
Returns Returns
...@@ -340,7 +345,7 @@ class ImmutableGraphIndex(object): ...@@ -340,7 +345,7 @@ class ImmutableGraphIndex(object):
""" """
if "all_edges" in self._cache: if "all_edges" in self._cache:
return self._cache["all_edges"] return self._cache["all_edges"]
src, dst, edges = self._sparse.edges(sorted) src, dst, edges = self._sparse.edges(return_sorted)
self._cache["all_edges"] = (utils.toindex(src), utils.toindex(dst), utils.toindex(edges)) self._cache["all_edges"] = (utils.toindex(src), utils.toindex(dst), utils.toindex(edges))
return self._cache["all_edges"] return self._cache["all_edges"]
...@@ -440,8 +445,8 @@ class ImmutableGraphIndex(object): ...@@ -440,8 +445,8 @@ class ImmutableGraphIndex(object):
The subgraph index. The subgraph index.
""" """
v = v.tousertensor() v = v.tousertensor()
gi, induced_n, induced_e = self._sparse.node_subgraph(v) gidx, induced_n, induced_e = self._sparse.node_subgraph(v)
return ImmutableSubgraphIndex(gi, self, induced_n, induced_e) return ImmutableSubgraphIndex(gidx, self, induced_n, induced_e)
def node_subgraphs(self, vs_arr): def node_subgraphs(self, vs_arr):
"""Return the induced node subgraphs. """Return the induced node subgraphs.
...@@ -458,8 +463,8 @@ class ImmutableGraphIndex(object): ...@@ -458,8 +463,8 @@ class ImmutableGraphIndex(object):
""" """
vs_arr = [v.tousertensor() for v in vs_arr] vs_arr = [v.tousertensor() for v in vs_arr]
gis, induced_nodes, induced_edges = self._sparse.node_subgraphs(vs_arr) gis, induced_nodes, induced_edges = self._sparse.node_subgraphs(vs_arr)
return [ImmutableSubgraphIndex(gi, self, induced_n, return [ImmutableSubgraphIndex(gidx, self, induced_n, induced_e)
induced_e) for gi, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)] for gidx, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)]
def edge_subgraph(self, e): def edge_subgraph(self, e):
"""Return the induced edge subgraph. """Return the induced edge subgraph.
...@@ -478,6 +483,7 @@ class ImmutableGraphIndex(object): ...@@ -478,6 +483,7 @@ class ImmutableGraphIndex(object):
def neighbor_sampling(self, seed_ids, expand_factor, num_hops, neighbor_type, def neighbor_sampling(self, seed_ids, expand_factor, num_hops, neighbor_type,
node_prob, max_subgraph_size): node_prob, max_subgraph_size):
"""Neighborhood sampling"""
if len(seed_ids) == 0: if len(seed_ids) == 0:
return [] return []
seed_ids = [v.tousertensor() for v in seed_ids] seed_ids = [v.tousertensor() for v in seed_ids]
...@@ -486,8 +492,8 @@ class ImmutableGraphIndex(object): ...@@ -486,8 +492,8 @@ class ImmutableGraphIndex(object):
node_prob, node_prob,
max_subgraph_size) max_subgraph_size)
induced_nodes = [utils.toindex(v) for v in induced_nodes] induced_nodes = [utils.toindex(v) for v in induced_nodes]
return [ImmutableSubgraphIndex(gi, self, induced_n, return [ImmutableSubgraphIndex(gidx, self, induced_n, induced_e)
induced_e) for gi, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)] for gidx, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)]
def adjacency_matrix(self, transpose=False, ctx=F.cpu()): def adjacency_matrix(self, transpose=False, ctx=F.cpu()):
"""Return the adjacency matrix representation of this graph. """Return the adjacency matrix representation of this graph.
...@@ -511,12 +517,9 @@ class ImmutableGraphIndex(object): ...@@ -511,12 +517,9 @@ class ImmutableGraphIndex(object):
A index for data shuffling due to sparse format change. Return None A index for data shuffling due to sparse format change. Return None
if shuffle is not required. if shuffle is not required.
""" """
def get_adj(ctx):
new_mat = self._sparse.adjacency_matrix(transpose)
return F.copy_to(new_mat, ctx)
return self._sparse.adjacency_matrix(transpose, ctx), None return self._sparse.adjacency_matrix(transpose, ctx), None
def incidence_matrix(self, type, ctx): def incidence_matrix(self, typestr, ctx):
"""Return the incidence matrix representation of this graph. """Return the incidence matrix representation of this graph.
An incidence matrix is an n x m sparse matrix, where n is An incidence matrix is an n x m sparse matrix, where n is
...@@ -538,7 +541,7 @@ class ImmutableGraphIndex(object): ...@@ -538,7 +541,7 @@ class ImmutableGraphIndex(object):
Parameters Parameters
---------- ----------
type : str typestr : str
Can be either "in", "out" or "both" Can be either "in", "out" or "both"
ctx : context ctx : context
The context of returned incidence matrix. The context of returned incidence matrix.
...@@ -565,8 +568,8 @@ class ImmutableGraphIndex(object): ...@@ -565,8 +568,8 @@ class ImmutableGraphIndex(object):
""" """
src, dst, eid = self.edges() src, dst, eid = self.edges()
ret = nx.DiGraph() ret = nx.DiGraph()
for u, v, id in zip(src, dst, eid): for u, v, e in zip(src, dst, eid):
ret.add_edge(u, v, id=id) ret.add_edge(u, v, id=e)
return ret return ret
def from_networkx(self, nx_graph): def from_networkx(self, nx_graph):
...@@ -626,8 +629,8 @@ class ImmutableGraphIndex(object): ...@@ -626,8 +629,8 @@ class ImmutableGraphIndex(object):
---------- ----------
adj : scipy sparse matrix adj : scipy sparse matrix
""" """
assert isinstance(adj, sp.csr_matrix) or isinstance(adj, sp.coo_matrix), \ if not isinstance(adj, (sp.csr_matrix, sp.coo_matrix)):
"The input matrix has to be a SciPy sparse matrix." raise DGLError("The input matrix has to be a SciPy sparse matrix.")
out_mat = adj.tocoo() out_mat = adj.tocoo()
self._sparse.from_coo_matrix(out_mat) self._sparse.from_coo_matrix(out_mat)
...@@ -639,23 +642,7 @@ class ImmutableGraphIndex(object): ...@@ -639,23 +642,7 @@ class ImmutableGraphIndex(object):
elist : list elist : list
List of (u, v) edge tuple. List of (u, v) edge tuple.
""" """
self.clear() self._sparse.from_edge_list(elist)
src, dst = zip(*elist)
src = np.array(src)
dst = np.array(dst)
num_nodes = max(src.max(), dst.max()) + 1
min_nodes = min(src.min(), dst.min())
if min_nodes != 0:
raise DGLError('Invalid edge list. Nodes must start from 0.')
edge_ids = mx.nd.arange(0, len(src), step=1, repeat=1, dtype=np.int32)
src = mx.nd.array(src, dtype=np.int64)
dst = mx.nd.array(dst, dtype=np.int64)
# TODO we can't generate a csr_matrix with np.int64 directly.
in_csr = mx.nd.sparse.csr_matrix((edge_ids, (dst, src)),
shape=(num_nodes, num_nodes)).astype(np.int64)
out_csr = mx.nd.sparse.csr_matrix((edge_ids, (src, dst)),
shape=(num_nodes, num_nodes)).astype(np.int64)
self.__init__(in_csr, out_csr)
def line_graph(self, backtracking=True): def line_graph(self, backtracking=True):
"""Return the line graph of this graph. """Return the line graph of this graph.
...@@ -778,35 +765,35 @@ def create_immutable_graph_index(graph_data=None): ...@@ -778,35 +765,35 @@ def create_immutable_graph_index(graph_data=None):
# If graph_data is None, we return an empty graph index. # If graph_data is None, we return an empty graph index.
# If we can't create a graph index, we'll use the code below to handle the graph. # If we can't create a graph index, we'll use the code below to handle the graph.
return ImmutableGraphIndex(F.create_immutable_graph_index(graph_data)) return ImmutableGraphIndex(F.create_immutable_graph_index(graph_data))
except: except Exception: # pylint: disable=broad-except
pass pass
# Let's create an empty graph index first. # Let's create an empty graph index first.
gi = ImmutableGraphIndex(F.create_immutable_graph_index()) gidx = ImmutableGraphIndex(F.create_immutable_graph_index())
# edge list # edge list
if isinstance(graph_data, (list, tuple)): if isinstance(graph_data, (list, tuple)):
try: try:
gi.from_edge_list(graph_data) gidx.from_edge_list(graph_data)
return gi return gidx
except: except Exception: # pylint: disable=broad-except
raise DGLError('Graph data is not a valid edge list.') raise DGLError('Graph data is not a valid edge list.')
# scipy format # scipy format
if isinstance(graph_data, sp.spmatrix): if isinstance(graph_data, sp.spmatrix):
try: try:
gi.from_scipy_sparse_matrix(graph_data) gidx.from_scipy_sparse_matrix(graph_data)
return gi return gidx
except: except Exception: # pylint: disable=broad-except
raise DGLError('Graph data is not a valid scipy sparse matrix.') raise DGLError('Graph data is not a valid scipy sparse matrix.')
# networkx - any format # networkx - any format
try: try:
gi.from_networkx(graph_data) gidx.from_networkx(graph_data)
except: except Exception: # pylint: disable=broad-except
raise DGLError('Error while creating graph from input of type "%s".' raise DGLError('Error while creating graph from input of type "%s".'
% type(graph_data)) % type(graph_data))
return gi return gidx
_init_api("dgl.immutable_graph_index") _init_api("dgl.immutable_graph_index")
...@@ -5,7 +5,7 @@ from . import backend as F ...@@ -5,7 +5,7 @@ from . import backend as F
__all__ = ['base_initializer', 'zero_initializer'] __all__ = ['base_initializer', 'zero_initializer']
def base_initializer(shape, dtype, ctx, range): def base_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-argument
"""The function signature for feature initializer. """The function signature for feature initializer.
Any customized feature initializer should follow this signature (see Any customized feature initializer should follow this signature (see
...@@ -20,7 +20,7 @@ def base_initializer(shape, dtype, ctx, range): ...@@ -20,7 +20,7 @@ def base_initializer(shape, dtype, ctx, range):
The data type of the returned features. The data type of the returned features.
ctx : context object ctx : context object
The device context of the returned features. The device context of the returned features.
range : slice id_range : slice
The start id and the end id of the features to be initialized. The start id and the end id of the features to be initialized.
The id could be node or edge id depending on the scenario. The id could be node or edge id depending on the scenario.
Note that the step is always None. Note that the step is always None.
...@@ -32,7 +32,7 @@ def base_initializer(shape, dtype, ctx, range): ...@@ -32,7 +32,7 @@ def base_initializer(shape, dtype, ctx, range):
>>> import torch >>> import torch
>>> import dgl >>> import dgl
>>> def initializer(shape, dtype, ctx, range): >>> def initializer(shape, dtype, ctx, id_range):
>>> return torch.ones(shape, dtype=dtype, device=ctx) >>> return torch.ones(shape, dtype=dtype, device=ctx)
>>> g = dgl.DGLGraph() >>> g = dgl.DGLGraph()
>>> g.set_n_initializer(initializer) >>> g.set_n_initializer(initializer)
...@@ -44,7 +44,7 @@ def base_initializer(shape, dtype, ctx, range): ...@@ -44,7 +44,7 @@ def base_initializer(shape, dtype, ctx, range):
""" """
raise NotImplementedError raise NotImplementedError
def zero_initializer(shape, dtype, ctx, range): def zero_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-argument
"""Zero feature initializer """Zero feature initializer
Examples Examples
......
...@@ -56,7 +56,7 @@ def prop_edges(graph, ...@@ -56,7 +56,7 @@ def prop_edges(graph,
def prop_nodes_bfs(graph, def prop_nodes_bfs(graph,
source, source,
reversed=False, reverse=False,
message_func='default', message_func='default',
reduce_func='default', reduce_func='default',
apply_node_func='default'): apply_node_func='default'):
...@@ -68,7 +68,7 @@ def prop_nodes_bfs(graph, ...@@ -68,7 +68,7 @@ def prop_nodes_bfs(graph,
The graph object. The graph object.
source : list, tensor of nodes source : list, tensor of nodes
Source nodes. Source nodes.
reversed : bool, optional reverse : bool, optional
If true, traverse following the in-edge direction. If true, traverse following the in-edge direction.
message_func : callable, optional message_func : callable, optional
The message function. The message function.
...@@ -81,11 +81,11 @@ def prop_nodes_bfs(graph, ...@@ -81,11 +81,11 @@ def prop_nodes_bfs(graph,
-------- --------
dgl.traversal.bfs_nodes_generator dgl.traversal.bfs_nodes_generator
""" """
nodes_gen = trv.bfs_nodes_generator(graph, source, reversed) nodes_gen = trv.bfs_nodes_generator(graph, source, reverse)
prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func) prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
def prop_nodes_topo(graph, def prop_nodes_topo(graph,
reversed=False, reverse=False,
message_func='default', message_func='default',
reduce_func='default', reduce_func='default',
apply_node_func='default'): apply_node_func='default'):
...@@ -95,7 +95,7 @@ def prop_nodes_topo(graph, ...@@ -95,7 +95,7 @@ def prop_nodes_topo(graph,
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph object. The graph object.
reversed : bool, optional reverse : bool, optional
If true, traverse following the in-edge direction. If true, traverse following the in-edge direction.
message_func : callable, optional message_func : callable, optional
The message function. The message function.
...@@ -108,12 +108,12 @@ def prop_nodes_topo(graph, ...@@ -108,12 +108,12 @@ def prop_nodes_topo(graph,
-------- --------
dgl.traversal.topological_nodes_generator dgl.traversal.topological_nodes_generator
""" """
nodes_gen = trv.topological_nodes_generator(graph, reversed) nodes_gen = trv.topological_nodes_generator(graph, reverse)
prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func) prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
def prop_edges_dfs(graph, def prop_edges_dfs(graph,
source, source,
reversed=False, reverse=False,
has_reverse_edge=False, has_reverse_edge=False,
has_nontree_edge=False, has_nontree_edge=False,
message_func='default', message_func='default',
...@@ -127,7 +127,7 @@ def prop_edges_dfs(graph, ...@@ -127,7 +127,7 @@ def prop_edges_dfs(graph,
The graph object. The graph object.
source : list, tensor of nodes source : list, tensor of nodes
Source nodes. Source nodes.
reversed : bool, optional reverse : bool, optional
If true, traverse following the in-edge direction. If true, traverse following the in-edge direction.
message_func : callable, optional message_func : callable, optional
The message function. The message function.
...@@ -141,6 +141,6 @@ def prop_edges_dfs(graph, ...@@ -141,6 +141,6 @@ def prop_edges_dfs(graph,
dgl.traversal.dfs_labeled_edges_generator dgl.traversal.dfs_labeled_edges_generator
""" """
edges_gen = trv.dfs_labeled_edges_generator( edges_gen = trv.dfs_labeled_edges_generator(
graph, source, reversed, has_reverse_edge, has_nontree_edge, graph, source, reverse, has_reverse_edge, has_nontree_edge,
return_labels=False) return_labels=False)
prop_edges(graph, edges_gen, message_func, reduce_func, apply_node_func) prop_edges(graph, edges_gen, message_func, reduce_func, apply_node_func)
"""DGL Runtime""" """Package for DGL scheduler and runtime."""
from __future__ import absolute_import from __future__ import absolute_import
from . import scheduler from . import scheduler
......
"""Module for degree bucketing schedulers""" """Module for degree bucketing schedulers."""
from __future__ import absolute_import from __future__ import absolute_import
from .._ffi.function import _init_api from .._ffi.function import _init_api
from ..base import is_all, ALL from ..base import is_all
from .. import backend as F from .. import backend as F
from ..immutable_graph_index import ImmutableGraphIndex from ..udf import NodeBatch
from ..udf import EdgeBatch, NodeBatch
from .. import utils from .. import utils
from . import ir from . import ir
from .ir import var as var from .ir import var
def gen_degree_bucketing_schedule( def gen_degree_bucketing_schedule(
graph, graph,
...@@ -52,23 +51,23 @@ def gen_degree_bucketing_schedule( ...@@ -52,23 +51,23 @@ def gen_degree_bucketing_schedule(
""" """
buckets = _degree_bucketing_schedule(message_ids, dst_nodes, recv_nodes) buckets = _degree_bucketing_schedule(message_ids, dst_nodes, recv_nodes)
# generate schedule # generate schedule
unique_dst, degs, buckets, msg_ids, zero_deg_nodes = buckets _, degs, buckets, msg_ids, zero_deg_nodes = buckets
# loop over each bucket # loop over each bucket
idx_list = [] idx_list = []
fd_list = [] fd_list = []
for deg, vb, mid in zip(degs, buckets, msg_ids): for deg, vbkt, mid in zip(degs, buckets, msg_ids):
# create per-bkt rfunc # create per-bkt rfunc
rfunc = _create_per_bkt_rfunc(graph, reduce_udf, deg, vb) rfunc = _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt)
# vars # vars
vb = var.IDX(vb) vbkt = var.IDX(vbkt)
mid = var.IDX(mid) mid = var.IDX(mid)
rfunc = var.FUNC(rfunc) rfunc = var.FUNC(rfunc)
# recv on each bucket # recv on each bucket
fdvb = ir.READ_ROW(var_nf, vb) fdvb = ir.READ_ROW(var_nf, vbkt)
fdmail = ir.READ_ROW(var_mf, mid) fdmail = ir.READ_ROW(var_mf, mid)
fdvb = ir.NODE_UDF(rfunc, fdvb, fdmail, ret=fdvb) # reuse var fdvb = ir.NODE_UDF(rfunc, fdvb, fdmail, ret=fdvb) # reuse var
# save for merge # save for merge
idx_list.append(vb) idx_list.append(vbkt)
fd_list.append(fdvb) fd_list.append(fdvb)
if zero_deg_nodes is not None: if zero_deg_nodes is not None:
# NOTE: there must be at least one non-zero-deg node; otherwise, # NOTE: there must be at least one non-zero-deg node; otherwise,
...@@ -178,15 +177,16 @@ def _process_buckets(buckets): ...@@ -178,15 +177,16 @@ def _process_buckets(buckets):
return v, degs, dsts, msg_ids, zero_deg_nodes return v, degs, dsts, msg_ids, zero_deg_nodes
def _create_per_bkt_rfunc(graph, reduce_udf, deg, vb): def _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt):
"""Internal function to generate the per degree bucket node UDF."""
def _rfunc_wrapper(node_data, mail_data): def _rfunc_wrapper(node_data, mail_data):
def _reshaped_getter(key): def _reshaped_getter(key):
msg = mail_data[key] msg = mail_data[key]
new_shape = (len(vb), deg) + F.shape(msg)[1:] new_shape = (len(vbkt), deg) + F.shape(msg)[1:]
return F.reshape(msg, new_shape) return F.reshape(msg, new_shape)
reshaped_mail_data = utils.LazyDict(_reshaped_getter, mail_data.keys()) reshaped_mail_data = utils.LazyDict(_reshaped_getter, mail_data.keys())
nb = NodeBatch(graph, vb, node_data, reshaped_mail_data) nbatch = NodeBatch(graph, vbkt, node_data, reshaped_mail_data)
return reduce_udf(nb) return reduce_udf(nbatch)
return _rfunc_wrapper return _rfunc_wrapper
_init_api("dgl.runtime.degree_bucketing") _init_api("dgl.runtime.degree_bucketing")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment