Unverified Commit 4bd4d6e3 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Lint] Pylint (#330)

* fix lint for graph_index.py

* pylint for base.py

* pylint for batched_graph.py

* pylint for frame.py; simplify and fix bugs in frame when index is slice type

* pylint for graph.py

* pylint for immutable_graph_index.py

* pylint for init.py

* pylint for rest files in root package

* pylint for _ffi package

* pylint for function package

* pylint for runtime package

* pylint for runtime.ir package

* add pylint to ci

* fix mx tests

* fix lint errors

* fix ci

* fix as requested

* fix lint
parent 1e50cd2e
# One has to manually import dgl.data; fixes #125
#from . import data
"""DGL root package."""
from . import function
from . import nn
from . import contrib
......@@ -12,7 +11,6 @@ from .base import ALL
from .backend import load_backend
from .batched_graph import *
from .graph import DGLGraph
from .subgraph import DGLSubGraph
from .traversal import *
from .propagate import *
from .udf import NodeBatch, EdgeBatch
"""Namespace for internal apis."""
......@@ -26,8 +26,7 @@ else:
class DGLError(Exception):
"""Error thrown by DGL function"""
pass
pass # pylint: disable=unnecessary-pass
def _load_lib():
"""Load libary by searching possible path."""
......
......@@ -51,7 +51,7 @@ class Function(_FunctionBase):
dgl.register_func: How to register global function.
dgl.get_global_func: How to get global function.
"""
pass
pass # pylint: disable=unnecessary-pass
class ModuleBase(object):
......
"""Common runtime ctypes."""
# pylint: disable=invalid-name
# pylint: disable=invalid-name, super-init-not-called
from __future__ import absolute_import
import ctypes
......
......@@ -349,6 +349,31 @@ class ImmutableGraphIndex(object):
self.__init__(mx.nd.sparse.csr_matrix((edge_ids, (dst, src)), shape=(size, size)).astype(np.int64),
mx.nd.sparse.csr_matrix((edge_ids, (src, dst)), shape=(size, size)).astype(np.int64))
def from_edge_list(self, elist):
"""Convert from an edge list.
Paramters
---------
elist : list
List of (u, v) edge tuple.
"""
src, dst = zip(*elist)
src = np.array(src)
dst = np.array(dst)
num_nodes = max(src.max(), dst.max()) + 1
min_nodes = min(src.min(), dst.min())
if min_nodes != 0:
raise DGLError('Invalid edge list. Nodes must start from 0.')
edge_ids = mx.nd.arange(0, len(src), step=1, repeat=1, dtype=np.int32)
src = mx.nd.array(src, dtype=np.int64)
dst = mx.nd.array(dst, dtype=np.int64)
# TODO we can't generate a csr_matrix with np.int64 directly.
in_csr = mx.nd.sparse.csr_matrix((edge_ids, (dst, src)),
shape=(num_nodes, num_nodes)).astype(np.int64)
out_csr = mx.nd.sparse.csr_matrix((edge_ids, (src, dst)),
shape=(num_nodes, num_nodes)).astype(np.int64)
self.__init__(in_csr, out_csr)
def create_immutable_graph_index(in_csr=None, out_csr=None):
""" Create an empty backend-specific immutable graph index.
......
......@@ -3,12 +3,15 @@ from __future__ import absolute_import
import warnings
from ._ffi.base import DGLError
from ._ffi.base import DGLError # pylint: disable=unused-import
# A special argument for selecting all nodes/edges.
# A special symbol for selecting all nodes or edges.
ALL = "__ALL__"
def is_all(arg):
"""Return true if the argument is a special symbol for all nodes or edges."""
return isinstance(arg, str) and arg == ALL
dgl_warning = warnings.warn
def dgl_warning(msg):
"""Print out warning messages."""
warnings.warn(msg)
"""Classes and functions for batching multiple graphs together."""
from __future__ import absolute_import
from collections.abc import Iterable
import numpy as np
from collections import Iterable
from .base import ALL, is_all
from .base import ALL, is_all, DGLError
from .frame import FrameRef, Frame
from .graph import DGLGraph
from . import graph_index as gi
......@@ -152,8 +152,7 @@ class BatchedDGLGraph(DGLGraph):
elif is_all(attrs):
attrs = set()
# Check if at least a graph has mode items and associated features.
for i in range(len(graph_list)):
g = graph_list[i]
for i, g in enumerate(graph_list):
g_num_items, g_attrs = _get_num_item_and_attr_types(g, mode)
if g_num_items > 0 and len(g_attrs) > 0:
attrs = g_attrs
......@@ -161,13 +160,13 @@ class BatchedDGLGraph(DGLGraph):
break
# Check if all the graphs with mode items have the same associated features.
if len(attrs) > 0:
for i in range(len(graph_list)):
for i, g in enumerate(graph_list):
g = graph_list[i]
g_num_items, g_attrs = _get_num_item_and_attr_types(g, mode)
if g_attrs != attrs and g_num_items > 0:
raise ValueError('Expect graph {} and {} to have the same {} '
'attributes when {}_attrs=ALL, got {} and '
'{}'.format(ref_g_index, i, mode, mode, attrs, g_attrs))
raise ValueError('Expect graph {0} and {1} to have the same {2} '
'attributes when {2}_attrs=ALL, got {3} and {4}.'
.format(ref_g_index, i, mode, attrs, g_attrs))
return attrs
elif isinstance(attrs, str):
return [attrs]
......@@ -200,25 +199,24 @@ class BatchedDGLGraph(DGLGraph):
for key in edge_attrs}
batched_edge_frame = FrameRef(Frame(cols))
super(BatchedDGLGraph, self).__init__(
graph_data=batched_index,
node_frame=batched_node_frame,
edge_frame=batched_edge_frame)
super(BatchedDGLGraph, self).__init__(graph_data=batched_index,
node_frame=batched_node_frame,
edge_frame=batched_edge_frame)
# extra members
self._batch_size = 0
self._batch_num_nodes = []
self._batch_num_edges = []
for gr in graph_list:
if isinstance(gr, BatchedDGLGraph):
for grh in graph_list:
if isinstance(grh, BatchedDGLGraph):
# handle the input is again a batched graph.
self._batch_size += gr._batch_size
self._batch_num_nodes += gr._batch_num_nodes
self._batch_num_edges += gr._batch_num_edges
self._batch_size += grh._batch_size
self._batch_num_nodes += grh._batch_num_nodes
self._batch_num_edges += grh._batch_num_edges
else:
self._batch_size += 1
self._batch_num_nodes.append(gr.number_of_nodes())
self._batch_num_edges.append(gr.number_of_edges())
self._batch_num_nodes.append(grh.number_of_nodes())
self._batch_num_edges.append(grh.number_of_edges())
@property
def batch_size(self):
......@@ -251,33 +249,33 @@ class BatchedDGLGraph(DGLGraph):
return self._batch_num_edges
# override APIs
def add_nodes(self, num, reprs=None):
def add_nodes(self, num, data=None):
"""Add nodes. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v, reprs=None):
def add_edge(self, u, v, data=None):
"""Add one edge. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v, reprs=None):
def add_edges(self, u, v, data=None):
"""Add many edges. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
raise DGLError('Readonly graph. Mutation is not allowed.')
# new APIs
def __getitem__(self, idx):
"""Slice the batch and return the batch of graphs specified by the idx."""
# TODO
pass
raise NotImplementedError
def __setitem__(self, idx, val):
"""Set the value of the slice. The graph size cannot be changed."""
# TODO
pass
raise NotImplementedError
def split(graph_batch, num_or_size_splits):
def split(graph_batch, num_or_size_splits): # pylint: disable=unused-argument
"""Split the batch."""
# TODO(minjie): could follow torch.split syntax
pass
raise NotImplementedError
def unbatch(graph):
"""Return the list of graphs in this batch.
......@@ -308,18 +306,18 @@ def unbatch(graph):
"""
assert isinstance(graph, BatchedDGLGraph)
bsize = graph.batch_size
bn = graph.batch_num_nodes
be = graph.batch_num_edges
pttns = gi.disjoint_partition(graph._graph, utils.toindex(bn))
bnn = graph.batch_num_nodes
bne = graph.batch_num_edges
pttns = gi.disjoint_partition(graph._graph, utils.toindex(bnn))
# split the frames
node_frames = [FrameRef(Frame(num_rows=n)) for n in bn]
edge_frames = [FrameRef(Frame(num_rows=n)) for n in be]
node_frames = [FrameRef(Frame(num_rows=n)) for n in bnn]
edge_frames = [FrameRef(Frame(num_rows=n)) for n in bne]
for attr, col in graph._node_frame.items():
col_splits = F.split(col, bn, dim=0)
col_splits = F.split(col, bnn, dim=0)
for i in range(bsize):
node_frames[i][attr] = col_splits[i]
for attr, col in graph._edge_frame.items():
col_splits = F.split(col, be, dim=0)
col_splits = F.split(col, bne, dim=0)
for i in range(bsize):
edge_frames[i][attr] = col_splits[i]
return [DGLGraph(graph_data=pttns[i],
......@@ -355,47 +353,63 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
return BatchedDGLGraph(graph_list, node_attrs, edge_attrs)
_readout_on_attrs = {
'nodes': ('ndata', 'batch_num_nodes', 'number_of_nodes'),
'edges': ('edata', 'batch_num_edges', 'number_of_edges'),
}
READOUT_ON_ATTRS = {
'nodes': ('ndata', 'batch_num_nodes', 'number_of_nodes'),
'edges': ('edata', 'batch_num_edges', 'number_of_edges'),
}
def _sum_on(graph, on, input, weight):
data_attr, batch_num_objs_attr, num_objs_attr = _readout_on_attrs[on]
def _sum_on(graph, typestr, feat, weight):
"""Internal function to sum node or edge features.
Parameters
----------
graph : DGLGraph
The graph.
typestr : str
'nodes' or 'edges'
feat : str
The feature field name.
weight : str
The weight field name.
Returns
-------
Tensor
The (weighted) summed node or edge features.
"""
data_attr, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
data = getattr(graph, data_attr)
input = data[input]
feat = data[feat]
if weight is not None:
weight = data[weight]
weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(input) - 1))
input = weight * input
weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(feat) - 1))
feat = weight * feat
if isinstance(graph, BatchedDGLGraph):
n_graphs = graph.batch_size
batch_num_objs = getattr(graph, batch_num_objs_attr)
n_objs = getattr(graph, num_objs_attr)()
seg_id = F.zerocopy_from_numpy(
np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.copy_to(seg_id, F.context(input))
y = F.unsorted_1d_segment_sum(input, seg_id, n_graphs, 0)
seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.copy_to(seg_id, F.context(feat))
y = F.unsorted_1d_segment_sum(feat, seg_id, n_graphs, 0)
return y
else:
return F.sum(input, 0)
return F.sum(feat, 0)
def sum_nodes(graph, input, weight=None):
"""Sums all the values of node field :attr:`input` in :attr:`graph`, optionally
def sum_nodes(graph, feat, weight=None):
"""Sums all the values of node field :attr:`feat` in :attr:`graph`, optionally
multiplies the field by a scalar node field :attr:`weight`.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph
input : str
The input field
graph : DGLGraph.
The graph.
feat : str
The feature field.
weight : str, optional
The weight field. If None, no weighting will be performed,
otherwise, weight each node feature with field :attr:`input`.
otherwise, weight each node feature with field :attr:`feat`.
for summation. The weight feature associated in the :attr:`graph`
should be a tensor of shape ``[graph.number_of_nodes(), 1]``.
......@@ -450,21 +464,21 @@ def sum_nodes(graph, input, weight=None):
sum_edges
mean_edges
"""
return _sum_on(graph, 'nodes', input, weight)
return _sum_on(graph, 'nodes', feat, weight)
def sum_edges(graph, input, weight=None):
"""Sums all the values of edge field :attr:`input` in :attr:`graph`,
def sum_edges(graph, feat, weight=None):
"""Sums all the values of edge field :attr:`feat` in :attr:`graph`,
optionally multiplies the field by a scalar edge field :attr:`weight`.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph
input : str
The input field
graph : DGLGraph
The graph.
feat : str
The feature field.
weight : str, optional
The weight field. If None, no weighting will be performed,
otherwise, weight each edge feature with field :attr:`input`.
otherwise, weight each edge feature with field :attr:`feat`.
for summation. The weight feature associated in the :attr:`graph`
should be a tensor of shape ``[graph.number_of_edges(), 1]``.
......@@ -521,54 +535,70 @@ def sum_edges(graph, input, weight=None):
mean_nodes
mean_edges
"""
return _sum_on(graph, 'edges', input, weight)
return _sum_on(graph, 'edges', feat, weight)
def _mean_on(graph, on, input, weight):
data_attr, batch_num_objs_attr, num_objs_attr = _readout_on_attrs[on]
def _mean_on(graph, typestr, feat, weight):
"""Internal function to sum node or edge features.
Parameters
----------
graph : DGLGraph
The graph.
typestr : str
'nodes' or 'edges'
feat : str
The feature field name.
weight : str
The weight field name.
Returns
-------
Tensor
The (weighted) summed node or edge features.
"""
data_attr, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
data = getattr(graph, data_attr)
input = data[input]
feat = data[feat]
if weight is not None:
weight = data[weight]
weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(input) - 1))
input = weight * input
weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(feat) - 1))
feat = weight * feat
if isinstance(graph, BatchedDGLGraph):
n_graphs = graph.batch_size
batch_num_objs = getattr(graph, batch_num_objs_attr)
n_objs = getattr(graph, num_objs_attr)()
seg_id = F.zerocopy_from_numpy(
np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.copy_to(seg_id, F.context(input))
seg_id = F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.copy_to(seg_id, F.context(feat))
if weight is not None:
w = F.unsorted_1d_segment_sum(weight, seg_id, n_graphs, 0)
y = F.unsorted_1d_segment_sum(input, seg_id, n_graphs, 0)
y = F.unsorted_1d_segment_sum(feat, seg_id, n_graphs, 0)
y = y / w
else:
y = F.unsorted_1d_segment_mean(input, seg_id, n_graphs, 0)
y = F.unsorted_1d_segment_mean(feat, seg_id, n_graphs, 0)
return y
else:
if weight is None:
return F.mean(input, 0)
return F.mean(feat, 0)
else:
y = F.sum(input, 0) / F.sum(weight, 0)
y = F.sum(feat, 0) / F.sum(weight, 0)
return y
def mean_nodes(graph, input, weight=None):
"""Averages all the values of node field :attr:`input` in :attr:`graph`,
def mean_nodes(graph, feat, weight=None):
"""Averages all the values of node field :attr:`feat` in :attr:`graph`,
optionally multiplies the field by a scalar node field :attr:`weight`.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph
input : str
The input field
The graph.
feat : str
The feature field.
weight : str, optional
The weight field. If None, no weighting will be performed,
otherwise, weight each node feature with field :attr:`input`.
otherwise, weight each node feature with field :attr:`feat`.
for calculating mean. The weight feature associated in the :attr:`graph`
should be a tensor of shape ``[graph.number_of_nodes(), 1]``.
......@@ -623,21 +653,21 @@ def mean_nodes(graph, input, weight=None):
sum_edges
mean_edges
"""
return _mean_on(graph, 'nodes', input, weight)
return _mean_on(graph, 'nodes', feat, weight)
def mean_edges(graph, input, weight=None):
"""Averages all the values of edge field :attr:`input` in :attr:`graph`,
def mean_edges(graph, feat, weight=None):
"""Averages all the values of edge field :attr:`feat` in :attr:`graph`,
optionally multiplies the field by a scalar edge field :attr:`weight`.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph
input : str
The input field
graph : DGLGraph
The graph.
feat : str
The feature field.
weight : optional, str
The weight field. If None, no weighting will be performed,
otherwise, weight each edge feature with field :attr:`input`.
otherwise, weight each edge feature with field :attr:`feat`.
for calculating mean. The weight feature associated in the :attr:`graph`
should be a tensor of shape ``[graph.number_of_edges(), 1]``.
......@@ -694,4 +724,4 @@ def mean_edges(graph, input, weight=None):
mean_nodes
sum_edges
"""
return _mean_on(graph, 'edges', input, weight)
return _mean_on(graph, 'edges', feat, weight)
"""Columnar storage for DGLGraph."""
from __future__ import absolute_import
from collections import MutableMapping, namedtuple
from collections import namedtuple
from collections.abc import MutableMapping
import sys
import numpy as np
......@@ -39,6 +40,18 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
return cls(shape, dtype)
def infer_scheme(tensor):
"""Infer column scheme from the given tensor data.
Paramters
---------
tensor : Tensor
The tensor data.
Returns
-------
Scheme
The column scheme.
"""
return Scheme(tuple(F.shape(tensor)[1:]), F.dtype(tensor))
class Column(object):
......@@ -64,6 +77,7 @@ class Column(object):
@property
def shape(self):
"""Return the scheme shape (feature shape) of this column."""
return self.scheme.shape
def __getitem__(self, idx):
......@@ -71,7 +85,7 @@ class Column(object):
Parameters
----------
idx : slice or utils.Index
idx : utils.Index
The index.
Returns
......@@ -79,8 +93,9 @@ class Column(object):
Tensor
The feature data
"""
if isinstance(idx, slice):
return self.data[idx]
if idx.slice_data() is not None:
slc = idx.slice_data()
return F.narrow_row(self.data, slc.start, slc.stop)
else:
user_idx = idx.tousertensor(F.context(self.data))
return F.gather_row(self.data, user_idx)
......@@ -105,7 +120,7 @@ class Column(object):
Parameters
----------
idx : utils.Index or slice
idx : utils.Index
The index.
feats : Tensor
The new features.
......@@ -115,22 +130,21 @@ class Column(object):
feat_scheme = infer_scheme(feats)
if feat_scheme != self.scheme:
raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme))
if isinstance(idx, utils.Index):
idx = idx.tousertensor(F.context(self.data))
% (feat_scheme, self.scheme))
if inplace:
idx = idx.tousertensor(F.context(self.data))
F.scatter_row_inplace(self.data, idx, feats)
elif idx.slice_data() is not None:
# for contiguous indices narrow+concat is usually faster than scatter row
slc = idx.slice_data()
part1 = F.narrow_row(self.data, 0, slc.start)
part2 = feats
part3 = F.narrow_row(self.data, slc.stop, len(self))
self.data = F.cat([part1, part2, part3], dim=0)
else:
if isinstance(idx, slice):
# for contiguous indices pack is usually faster than scatter row
part1 = F.narrow_row(self.data, 0, idx.start)
part2 = feats
part3 = F.narrow_row(self.data, idx.stop, len(self))
self.data = F.cat([part1, part2, part3], dim=0)
else:
self.data = F.scatter_row(self.data, idx, feats)
idx = idx.tousertensor(F.context(self.data))
self.data = F.scatter_row(self.data, idx, feats)
def extend(self, feats, feat_scheme=None):
"""Extend the feature data.
......@@ -143,11 +157,11 @@ class Column(object):
The scheme
"""
if feat_scheme is None:
feat_scheme = Scheme.infer_scheme(feats)
feat_scheme = infer_scheme(feats)
if feat_scheme != self.scheme:
raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme))
% (feat_scheme, self.scheme))
feats = F.copy_to(feats, F.context(self.data))
self.data = F.cat([self.data, feats], dim=0)
......@@ -314,9 +328,9 @@ class Frame(MutableMapping):
return
if self.get_initializer(name) is None:
self._warn_and_set_initializer()
init_data = self.get_initializer(name)(
(self.num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(0, self.num_rows))
initializer = self.get_initializer(name)
init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(0, self.num_rows))
self._columns[name] = Column(init_data, scheme)
def add_rows(self, num_rows):
......@@ -336,9 +350,9 @@ class Frame(MutableMapping):
ctx = F.context(col.data)
if self.get_initializer(key) is None:
self._warn_and_set_initializer()
new_data = self.get_initializer(key)(
(num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(self._num_rows, self._num_rows + num_rows))
initializer = self.get_initializer(key)
new_data = initializer((num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(self._num_rows, self._num_rows + num_rows))
feat_placeholders[key] = new_data
self._append(Frame(feat_placeholders))
self._num_rows += num_rows
......@@ -368,17 +382,17 @@ class Frame(MutableMapping):
else:
# pad columns that are not provided in the other frame with initial values
for key, col in self.items():
if key not in other:
scheme = col.scheme
ctx = F.context(col.data)
if self.get_initializer(key) is None:
self._warn_and_set_initializer()
new_data = self.get_initializer(key)(
(other.num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(self._num_rows,
self._num_rows + other.num_rows)
)
other[key] = new_data
if key in other:
continue
scheme = col.scheme
ctx = F.context(col.data)
if self.get_initializer(key) is None:
self._warn_and_set_initializer()
initializer = self.get_initializer(key)
new_data = initializer((other.num_rows,) + scheme.shape,
scheme.dtype, ctx,
slice(self._num_rows, self._num_rows + other.num_rows))
other[key] = new_data
# append other to self
for key, col in other.items():
if key not in self._columns:
......@@ -428,23 +442,19 @@ class FrameRef(MutableMapping):
frame : Frame, optional
The underlying frame. If not given, the reference will point to a
new empty frame.
index : iterable, slice, or int, optional
index : utils.Index, optional
The rows that are referenced in the underlying frame. If not given,
the whole frame is referenced. The index should be distinct (no
duplication is allowed).
Note that if a slice is given, the step must be None.
"""
def __init__(self, frame=None, index=None):
self._frame = frame if frame is not None else Frame()
# TODO(minjie): check no duplication
assert index is None or isinstance(index, utils.Index)
if index is None:
# _index_data can be either a slice or an iterable
self._index_data = slice(0, self._frame.num_rows)
self._index = utils.toindex(slice(0, self._frame.num_rows))
else:
# TODO(minjie): check no duplication
self._index_data = index
self._index = None
self._index_or_slice = None
self._index = index
@property
def schemes(self):
......@@ -465,11 +475,7 @@ class FrameRef(MutableMapping):
@property
def num_rows(self):
"""Return the number of rows referred."""
if isinstance(self._index_data, slice):
# NOTE: we always assume that slice.step is None
return self._index_data.stop - self._index_data.start
else:
return len(self._index_data)
return len(self._index)
def set_initializer(self, initializer, column=None):
"""Set the initializer for empty values.
......@@ -500,38 +506,6 @@ class FrameRef(MutableMapping):
"""
return self._frame.get_initializer(column)
def index(self):
"""Return the index object.
Returns
-------
utils.Index
The index.
"""
if self._index is None:
if self.is_contiguous():
self._index = utils.toindex(
F.arange(self._index_data.start,
self._index_data.stop))
else:
self._index = utils.toindex(self._index_data)
return self._index
def index_or_slice(self):
"""Returns the index object or the slice
Returns
-------
utils.Index or slice
The index or slice
"""
if self._index_or_slice is None:
if self.is_contiguous():
self._index_or_slice = self._index_data
else:
self._index_or_slice = utils.toindex(self._index_data)
return self._index_or_slice
def __contains__(self, name):
"""Return whether the column name exists."""
return name in self._frame
......@@ -567,7 +541,7 @@ class FrameRef(MutableMapping):
Parameters
----------
key : str or utils.Index or slice
key : str or utils.Index
The key.
Returns
......@@ -575,12 +549,11 @@ class FrameRef(MutableMapping):
Tensor or lazy dict or tensors
Depends on whether it is a column selection or row selection.
"""
if not isinstance(key, (str, utils.Index)):
raise DGLError('Argument "key" must be either str or utils.Index type.')
if isinstance(key, str):
return self.select_column(key)
elif isinstance(key, slice) and key == slice(0, self.num_rows):
# shortcut for selecting all the rows
return self
elif isinstance(key, utils.Index) and key.is_slice(0, self.num_rows):
elif key.is_slice(0, self.num_rows):
# shortcut for selecting all the rows
return self
else:
......@@ -606,7 +579,7 @@ class FrameRef(MutableMapping):
if self.is_span_whole_column():
return col.data
else:
return col[self.index_or_slice()]
return col[self._index]
def select_rows(self, query):
"""Return the rows given the query.
......@@ -625,9 +598,22 @@ class FrameRef(MutableMapping):
return utils.LazyDict(lambda key: self._frame[key][rows], keys=self.keys())
def __setitem__(self, key, val):
self.set_item_inplace(key, val, inplace=False)
"""Update the data in the frame. The update is done out-of-place.
Parameters
----------
key : str or utils.Index
The key.
val : Tensor or dict of tensors
The value.
def set_item_inplace(self, key, val, inplace):
See Also
--------
update
"""
self.update_data(key, val, inplace=False)
def update_data(self, key, val, inplace):
"""Update the data in the frame.
If the provided key is string, the corresponding column data will be updated.
......@@ -649,14 +635,14 @@ class FrameRef(MutableMapping):
inplace: bool
If True, update will be done in place
"""
if not isinstance(key, (str, utils.Index)):
raise DGLError('Argument "key" must be either str or utils.Index type.')
if isinstance(key, str):
self.update_column(key, val, inplace=inplace)
elif isinstance(key, slice) and key == slice(0, self.num_rows):
elif key.is_slice(0, self.num_rows):
# shortcut for updating all the rows
return self.update(val)
elif isinstance(key, utils.Index) and key.is_slice(0, self.num_rows):
# shortcut for selecting all the rows
return self.update(val)
for colname, col in val.items():
self.update_column(colname, col, inplace=inplace)
else:
self.update_rows(key, val, inplace=inplace)
......@@ -683,15 +669,14 @@ class FrameRef(MutableMapping):
col = Column.create(data)
if self.num_columns == 0:
# the frame is empty
self._index_data = slice(0, len(col))
self._clear_cache()
self._index = utils.toindex(slice(0, len(col)))
self._frame[name] = col
else:
if name not in self._frame:
ctx = F.context(data)
self._frame.add_column(name, infer_scheme(data), ctx)
fcol = self._frame[name]
fcol.update(self.index_or_slice(), data, inplace)
fcol.update(self._index, data, inplace)
def add_rows(self, num_rows):
"""Add blank rows to the underlying frame.
......@@ -700,7 +685,7 @@ class FrameRef(MutableMapping):
initializers.
Note: only available for FrameRef that spans the whole column. The row
span will extend to new rows. Other FrameRefs referencing the same
span will extend to new rows. Other FrameRefs referencing the same
frame will not be affected.
Parameters
......@@ -711,10 +696,14 @@ class FrameRef(MutableMapping):
if not self.is_span_whole_column():
raise RuntimeError('FrameRef not spanning whole column.')
self._frame.add_rows(num_rows)
if self.is_contiguous():
self._index_data = slice(0, self._index_data.stop + num_rows)
if self._index.slice_data() is not None:
# the index is a slice
slc = self._index.slice_data()
self._index = utils.toindex(slice(slc.start, slc.stop + num_rows))
else:
self._index_data.extend(range(self.num_rows, self.num_rows + num_rows))
selfidxdata = self._index.tousertensor()
newdata = F.arange(self.num_rows, self.num_rows + num_rows)
self._index = utils.toindex(F.cat([selfidxdata, newdata], dim=0))
def update_rows(self, query, data, inplace):
"""Update the rows.
......@@ -759,6 +748,8 @@ class FrameRef(MutableMapping):
key : str or utils.Index
The key.
"""
if not isinstance(key, (str, utils.Index)):
raise DGLError('Argument "key" must be either str or utils.Index type.')
if isinstance(key, str):
del self._frame[key]
else:
......@@ -769,22 +760,16 @@ class FrameRef(MutableMapping):
Please note that "deleted" rows are not really deleted, but simply removed
in the reference. As a result, if two FrameRefs point to the same Frame, deleting
from one ref will not relect on the other. By contrast, deleting columns is real.
from one ref will not reflect on the other. By contrast, deleting columns is real.
Parameters
----------
query : utils.Index or slice
query : utils.Index
The rows to be deleted.
"""
if isinstance(query, slice):
query = range(query.start, query.stop)
else:
query = query.tonumpy()
if isinstance(self._index_data, slice):
self._index_data = range(self._index_data.start, self._index_data.stop)
self._index_data = list(np.delete(self._index_data, query))
self._clear_cache()
query = query.tonumpy()
index = self._index.tonumpy()
self._index = utils.toindex(np.delete(index, query))
def append(self, other):
"""Append another frame into this one.
......@@ -794,59 +779,50 @@ class FrameRef(MutableMapping):
other : dict of str to tensor
The data to be appended.
"""
span_whole = self.is_span_whole_column()
contiguous = self.is_contiguous()
old_nrows = self._frame.num_rows
self._frame.append(other)
new_nrows = self._frame.num_rows
# update index
if span_whole:
self._index_data = slice(0, self._frame.num_rows)
elif contiguous:
if self._index_data.stop == old_nrows:
new_idx = slice(self._index_data.start, self._frame.num_rows)
else:
new_idx = list(range(self._index_data.start, self._index_data.stop))
new_idx.extend(range(old_nrows, self._frame.num_rows))
self._index_data = new_idx
self._clear_cache()
if (self._index.slice_data() is not None
and self._index.slice_data().stop == old_nrows):
# Self index is a slice and index.stop is equal to the size of the
# underlying frame. Can still use a slice for the new index.
oldstart = self._index.slice_data().start
self._index = utils.toindex(slice(oldstart, new_nrows))
else:
# convert it to user tensor and concat
selfidxdata = self._index.tousertensor()
newdata = F.arange(old_nrows, new_nrows)
self._index = utils.toindex(F.cat([selfidxdata, newdata], dim=0))
def clear(self):
"""Clear the frame."""
self._frame.clear()
self._index_data = slice(0, 0)
self._clear_cache()
self._index = utils.toindex(slice(0, 0))
def is_contiguous(self):
"""Return whether this refers to a contiguous range of rows."""
# NOTE: this check could have false negatives
# NOTE: we always assume that slice.step is None
return isinstance(self._index_data, slice)
return self._index.slice_data() is not None
def is_span_whole_column(self):
"""Return whether this refers to all the rows."""
return self.is_contiguous() and self.num_rows == self._frame.num_rows
def _getrows(self, query):
"""Internal function to convert from the local row ids to the row ids of the frame."""
if self.is_contiguous():
start = self._index_data.start
if start == 0:
# shortcut for identical mapping
return query
elif isinstance(query, slice):
return slice(query.start + start, query.stop + start)
else:
query = query.tousertensor()
return utils.toindex(query + start)
else:
idxtensor = self.index().tousertensor()
query = query.tousertensor()
return utils.toindex(F.gather_row(idxtensor, query))
def _clear_cache(self):
"""Internal function to clear the cached object."""
self._index = None
self._index_or_slice = None
"""Internal function to convert from the local row ids to the row ids of the frame.
Parameters
----------
query : utils.Index
The query index.
Returns
-------
utils.Index
The actual index to the underlying frame.
"""
return self._index.get_items(query)
def frame_like(other, num_rows):
"""Create a new frame that has the same scheme as the given one.
......
"""DGL builtin functors"""
# pylint: disable=redefined-builtin
from __future__ import absolute_import
from .message import *
......
"""Built-in function base class"""
from __future__ import absolute_import
__all__ = ['BuiltinFunction', 'BundledFunction']
class BuiltinFunction(object):
"""Base builtin function class."""
def __call__(self):
"""Regular computation of this builtin function
This will be used when optimization is not available.
"""
raise NotImplementedError
@property
def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError
class BundledFunction(object):
"""A utility class that bundles multiple functions.
Parameters
----------
fn_list : list of callable
The function list.
"""
def __init__(self, fn_list):
self.fn_list = fn_list
def __call__(self, *args, **kwargs):
"""Regular computation of this builtin function
This will be used when optimization is not available and should
ONLY be called by DGL framework.
"""
ret = {}
for fn in self.fn_list:
ret.update(fn(*args, **kwargs))
......@@ -28,4 +34,5 @@ class BundledFunction(object):
@property
def name(self):
"""Return the name."""
return "bundled"
"""Built-in message function."""
from __future__ import absolute_import
from .base import BuiltinFunction
import operator
import dgl.backend as F
from .base import BuiltinFunction
from .. import backend as F
__all__ = ["src_mul_edge", "copy_src", "copy_edge"]
......@@ -12,9 +13,10 @@ class MessageFunction(BuiltinFunction):
"""Base builtin message function class."""
def __call__(self, edges):
"""Regular computation of this builtin.
"""Regular computation of this builtin function
This will be used when optimization is not available.
This will be used when optimization is not available and should
ONLY be called by DGL framework.
"""
raise NotImplementedError
......@@ -29,9 +31,9 @@ class MessageFunction(BuiltinFunction):
@property
def use_edge_feature(self):
"""Return true if the message function uses edge feature data."""
raise NotImplementedError
def _is_spmv_supported_edge_feat(g, field):
"""Return whether the edge feature shape supports SPMV optimization.
......@@ -43,6 +45,12 @@ def _is_spmv_supported_edge_feat(g, field):
class SrcMulEdgeMessageFunction(MessageFunction):
"""Class for the src_mul_edge builtin message function.
See Also
--------
src_mul_edge
"""
def __init__(self, mul_op, src_field, edge_field, out_field):
self.mul_op = mul_op
self.src_field = src_field
......@@ -50,9 +58,26 @@ class SrcMulEdgeMessageFunction(MessageFunction):
self.out_field = out_field
def is_spmv_supported(self, g):
"""Return true if this supports SPMV optimization.
Parameters
----------
g : DGLGraph
The graph.
Returns
-------
bool
True if this supports SPMV optimization.
"""
return _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, edges):
"""Regular computation of this builtin function
This will be used when optimization is not available and should
ONLY be called by DGL framework.
"""
sdata = edges.src[self.src_field]
edata = edges.data[self.edge_field]
# Due to the different broadcasting semantics of different backends,
......@@ -71,17 +96,41 @@ class SrcMulEdgeMessageFunction(MessageFunction):
@property
def use_edge_feature(self):
"""Return true if the message function uses edge feature data."""
return True
class CopySrcMessageFunction(MessageFunction):
"""Class for the copy_src builtin message function.
See Also
--------
copy_src
"""
def __init__(self, src_field, out_field):
self.src_field = src_field
self.out_field = out_field
def is_spmv_supported(self, g):
"""Return true if this supports SPMV optimization.
Parameters
----------
g : DGLGraph
The graph.
Returns
-------
bool
True if this supports SPMV optimization.
"""
return True
def __call__(self, edges):
"""Regular computation of this builtin function
This will be used when optimization is not available and should
ONLY be called by DGL framework.
"""
return {self.out_field : edges.src[self.src_field]}
@property
......@@ -90,19 +139,43 @@ class CopySrcMessageFunction(MessageFunction):
@property
def use_edge_feature(self):
"""Return true if the message function uses edge feature data."""
return False
class CopyEdgeMessageFunction(MessageFunction):
"""Class for the copy_edge builtin message function.
See Also
--------
copy_edge
"""
def __init__(self, edge_field=None, out_field=None):
self.edge_field = edge_field
self.out_field = out_field
def is_spmv_supported(self, g):
"""Return true if this supports SPMV optimization.
Parameters
----------
g : DGLGraph
The graph.
Returns
-------
bool
True if this supports SPMV optimization.
"""
# TODO: support this with e2v spmv
return False
# return _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, edges):
"""Regular computation of this builtin function
This will be used when optimization is not available and should
ONLY be called by DGL framework.
"""
return {self.out_field : edges.data[self.edge_field]}
@property
......@@ -111,9 +184,9 @@ class CopyEdgeMessageFunction(MessageFunction):
@property
def use_edge_feature(self):
"""Return true if the message function uses edge feature data."""
return True
def src_mul_edge(src, edge, out):
"""Builtin message function that computes message by multiplying source
node features with edge features.
......
"""Built-in reducer function."""
# pylint: disable=redefined-builtin
from __future__ import absolute_import
from .. import backend as F
......@@ -10,9 +11,10 @@ class ReduceFunction(BuiltinFunction):
"""Base builtin reduce function class."""
def __call__(self, nodes):
"""Regular computation of this builtin.
"""Regular computation of this builtin function
This will be used when optimization is not available.
This will be used when optimization is not available and should
ONLY be called by DGL framework.
"""
raise NotImplementedError
......@@ -29,18 +31,19 @@ class ReduceFunction(BuiltinFunction):
class SimpleReduceFunction(ReduceFunction):
"""Builtin reduce function that aggregates a single field into another
single field."""
def __init__(self, name, op, msg_field, out_field):
def __init__(self, name, reduce_op, msg_field, out_field):
self._name = name
self.op = op
self.reduce_op = reduce_op
self.msg_field = msg_field
self.out_field = out_field
def is_spmv_supported(self):
"""Return whether the SPMV optimization is supported."""
# NOTE: only sum is supported right now.
return self._name == "sum"
def __call__(self, nodes):
return {self.out_field : self.op(nodes.mailbox[self.msg_field], 1)}
return {self.out_field : self.reduce_op(nodes.mailbox[self.msg_field], 1)}
@property
def name(self):
......
"""Base graph class specialized for neural networks on graphs."""
from __future__ import absolute_import
import networkx as nx
import numpy as np
from collections import defaultdict
import dgl
from .base import ALL, is_all, DGLError, dgl_warning
from .base import ALL, is_all, DGLError
from . import backend as F
from . import init
from .frame import FrameRef, Frame
from .graph_index import GraphIndex, create_graph_index
from .graph_index import create_graph_index
from .runtime import ir, scheduler, Runtime
from . import subgraph
from . import utils
from .view import NodeView, EdgeView
from .udf import NodeBatch, EdgeBatch
__all__ = ['DGLGraph']
class DGLGraph(object):
......@@ -177,7 +175,6 @@ class DGLGraph(object):
multigraph=False,
readonly=False):
# graph
self._readonly=readonly
self._graph = create_graph_index(graph_data, multigraph, readonly)
# node and edge frame
if node_frame is None:
......@@ -194,7 +191,7 @@ class DGLGraph(object):
# message frame
self._msg_frame = FrameRef(Frame(num_rows=self.number_of_edges()))
# set initializer for message frame
self._msg_frame.set_initializer(dgl.init.zero_initializer)
self._msg_frame.set_initializer(init.zero_initializer)
# registered functions
self._message_func = None
self._reduce_func = None
......@@ -916,7 +913,7 @@ class DGLGraph(object):
else:
raise DGLError('Invalid form:', form)
def all_edges(self, form='uv', sorted=False):
def all_edges(self, form='uv', return_sorted=False):
"""Return all the edges.
Parameters
......@@ -927,7 +924,7 @@ class DGLGraph(object):
- 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default
- 'eid' : one eid tensor
sorted : bool
return_sorted : bool
True if the returned edges are sorted by their src and dst ids.
Returns
......@@ -954,7 +951,7 @@ class DGLGraph(object):
>>> G.all_edges('all')
(tensor([0, 0, 1]), tensor([1, 2, 2]), tensor([0, 1, 2]))
"""
src, dst, eid = self._graph.edges(sorted)
src, dst, eid = self._graph.edges(return_sorted)
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
......@@ -1021,7 +1018,7 @@ class DGLGraph(object):
in_degree
"""
if is_all(v):
v = utils.toindex(slice(0, self.number_of_nodes()))
v = utils.toindex(slice(0, self.number_of_nodes()))
else:
v = utils.toindex(v)
return self._graph.in_degrees(v).tousertensor()
......@@ -1083,7 +1080,7 @@ class DGLGraph(object):
out_degree
"""
if is_all(v):
v = utils.toindex(slice(0, self.number_of_nodes()))
v = utils.toindex(slice(0, self.number_of_nodes()))
else:
v = utils.toindex(v)
return self._graph.out_degrees(v).tousertensor()
......@@ -1121,13 +1118,13 @@ class DGLGraph(object):
nx_graph = self._graph.to_networkx()
if node_attrs is not None:
for nid, attr in nx_graph.nodes(data=True):
nf = self.get_n_repr(nid)
attr.update({key: nf[key].squeeze(0) for key in node_attrs})
feat_dict = self.get_n_repr(nid)
attr.update({key: feat_dict[key].squeeze(0) for key in node_attrs})
if edge_attrs is not None:
for u, v, attr in nx_graph.edges(data=True):
for _, _, attr in nx_graph.edges(data=True):
eid = attr['id']
ef = self.get_e_repr(eid)
attr.update({key: ef[key].squeeze(0) for key in edge_attrs})
feat_dict = self.get_e_repr(eid)
attr.update({key: feat_dict[key].squeeze(0) for key in edge_attrs})
return nx_graph
def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
......@@ -1208,12 +1205,12 @@ class DGLGraph(object):
for attr in edge_attrs:
self._edge_frame[attr] = _batcher(attr_dict[attr])
def from_scipy_sparse_matrix(self, a):
def from_scipy_sparse_matrix(self, spmat):
""" Convert from scipy sparse matrix.
Parameters
----------
a : scipy sparse matrix
spmat : scipy sparse matrix
The graph's adjacency matrix
Examples
......@@ -1227,7 +1224,7 @@ class DGLGraph(object):
>>> g.from_scipy_sparse_matrix(a)
"""
self.clear()
self._graph.from_scipy_sparse_matrix(a)
self._graph.from_scipy_sparse_matrix(spmat)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_index = utils.zero_index(self.number_of_edges())
......@@ -1502,10 +1499,10 @@ class DGLGraph(object):
"""
return self.edges[:].data
def set_n_repr(self, hu, u=ALL, inplace=False):
def set_n_repr(self, data, u=ALL, inplace=False):
"""Set node(s) representation.
`hu` is a dictionary from the feature name to feature tensor. Each tensor
`data` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of nodes to be updated,
and (D1, D2, ...) be the shape of the node representation tensor. The
length of the given node ids must match B (i.e, len(u) == B).
......@@ -1515,7 +1512,7 @@ class DGLGraph(object):
Parameters
----------
hu : dict of tensor
data : dict of tensor
Node representation.
u : node, container or tensor
The node(s).
......@@ -1523,25 +1520,25 @@ class DGLGraph(object):
If True, update will be done in place, but autograd will break.
"""
# sanity check
if not utils.is_dict_like(hu):
if not utils.is_dict_like(data):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(hu))
' Got "%s" instead.' % type(data))
if is_all(u):
num_nodes = self.number_of_nodes()
else:
u = utils.toindex(u)
num_nodes = len(u)
for key, val in hu.items():
for key, val in data.items():
nfeats = F.shape(val)[0]
if nfeats != num_nodes:
raise DGLError('Expect number of features to match number of nodes (len(u)).'
' Got %d and %d instead.' % (nfeats, num_nodes))
# set
if is_all(u):
for key, val in hu.items():
for key, val in data.items():
self._node_frame[key] = val
else:
self._node_frame.update_rows(u, hu, inplace=inplace)
self._node_frame.update_rows(u, data, inplace=inplace)
def get_n_repr(self, u=ALL):
"""Get node(s) representation.
......@@ -1581,10 +1578,10 @@ class DGLGraph(object):
"""
return self._node_frame.pop(key)
def set_e_repr(self, he, edges=ALL, inplace=False):
def set_e_repr(self, data, edges=ALL, inplace=False):
"""Set edge(s) representation.
`he` is a dictionary from the feature name to feature tensor. Each tensor
`data` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
and (D1, D2, ...) be the shape of the edge representation tensor.
......@@ -1593,7 +1590,7 @@ class DGLGraph(object):
Parameters
----------
he : tensor or dict of tensor
data : tensor or dict of tensor
Edge representation.
edges : edges
Edges can be a pair of endpoint nodes (u, v), or a
......@@ -1614,16 +1611,16 @@ class DGLGraph(object):
eid = utils.toindex(edges)
# sanity check
if not utils.is_dict_like(he):
if not utils.is_dict_like(data):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(he))
' Got "%s" instead.' % type(data))
if is_all(eid):
num_edges = self.number_of_edges()
else:
eid = utils.toindex(eid)
num_edges = len(eid)
for key, val in he.items():
for key, val in data.items():
nfeats = F.shape(val)[0]
if nfeats != num_edges:
raise DGLError('Expect number of features to match number of edges.'
......@@ -1631,11 +1628,11 @@ class DGLGraph(object):
# set
if is_all(eid):
# update column
for key, val in he.items():
for key, val in data.items():
self._edge_frame[key] = val
else:
# update row
self._edge_frame.update_rows(eid, he, inplace=inplace)
self._edge_frame.update_rows(eid, data, inplace=inplace)
def get_e_repr(self, edges=ALL):
"""Get node(s) representation.
......@@ -2491,8 +2488,7 @@ class DGLGraph(object):
prop_edges
"""
for node_frontier in nodes_generator:
self.pull(node_frontier,
message_func, reduce_func, apply_node_func)
self.pull(node_frontier, message_func, reduce_func, apply_node_func)
def prop_edges(self,
edges_generator,
......@@ -2573,8 +2569,7 @@ class DGLGraph(object):
prop_nodes
"""
for edge_frontier in edges_generator:
self.send_and_recv(edge_frontier,
message_func, reduce_func, apply_node_func)
self.send_and_recv(edge_frontier, message_func, reduce_func, apply_node_func)
def subgraph(self, nodes):
"""Return the subgraph induced on given nodes.
......@@ -2621,7 +2616,7 @@ class DGLGraph(object):
"""
induced_nodes = utils.toindex(nodes)
sgi = self._graph.node_subgraph(induced_nodes)
return dgl.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
return subgraph.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
def subgraphs(self, nodes):
"""Return a list of subgraphs, each induced in the corresponding given
......@@ -2648,8 +2643,8 @@ class DGLGraph(object):
"""
induced_nodes = [utils.toindex(n) for n in nodes]
sgis = self._graph.node_subgraphs(induced_nodes)
return [dgl.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges,
sgi) for sgi in sgis]
return [subgraph.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
for sgi in sgis]
def edge_subgraph(self, edges):
"""Return the subgraph induced on given edges.
......@@ -2695,7 +2690,7 @@ class DGLGraph(object):
"""
induced_edges = utils.toindex(edges)
sgi = self._graph.edge_subgraph(induced_edges)
return dgl.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
return subgraph.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
def adjacency_matrix(self, transpose=False, ctx=F.cpu()):
"""Return the adjacency matrix representation of this graph.
......@@ -2720,7 +2715,7 @@ class DGLGraph(object):
"""
return self._graph.adjacency_matrix(transpose, ctx)[0]
def incidence_matrix(self, type, ctx=F.cpu()):
def incidence_matrix(self, typestr, ctx=F.cpu()):
"""Return the incidence matrix representation of this graph.
An incidence matrix is an n x m sparse matrix, where n is
......@@ -2750,7 +2745,7 @@ class DGLGraph(object):
Parameters
----------
type : str
typestr : str
Can be either ``in``, ``out`` or ``both``
ctx : context, optional (default=cpu)
The context of returned incidence matrix.
......@@ -2760,7 +2755,7 @@ class DGLGraph(object):
SparseTensor
The incidence matrix.
"""
return self._graph.incidence_matrix(type, ctx)[0]
return self._graph.incidence_matrix(typestr, ctx)[0]
def line_graph(self, backtracking=True, shared=False):
"""Return the line graph of this graph.
......@@ -2833,8 +2828,8 @@ class DGLGraph(object):
v = utils.toindex(nodes)
n_repr = self.get_n_repr(v)
nb = NodeBatch(self, v, n_repr)
n_mask = predicate(nb)
nbatch = NodeBatch(self, v, n_repr)
n_mask = predicate(nbatch)
if is_all(nodes):
return F.nonzero_1d(n_mask)
......@@ -2906,10 +2901,8 @@ class DGLGraph(object):
src_data = self.get_n_repr(u)
edge_data = self.get_e_repr(eid)
dst_data = self.get_n_repr(v)
eb = EdgeBatch(self, (u, v, eid),
src_data, edge_data, dst_data)
e_mask = predicate(eb)
ebatch = EdgeBatch(self, (u, v, eid), src_data, edge_data, dst_data)
e_mask = predicate(ebatch)
if is_all(edges):
return F.nonzero_1d(e_mask)
......@@ -2918,7 +2911,9 @@ class DGLGraph(object):
return edges[e_mask]
def __repr__(self):
s = 'DGLGraph with {node} nodes and {edge} edges.\nNode data: {ndata}\nEdge data: {edata}'
return s.format(node=self.number_of_nodes(), edge=self.number_of_edges(),
ndata=str(self.node_attr_schemes()),
edata=str(self.edge_attr_schemes()))
ret = ('DGLGraph(num_nodes={node}, num_edges={edge},\n'
' ndata_schemes={ndata}\n'
' edata_schemes={edata})')
return ret.format(node=self.number_of_nodes(), edge=self.number_of_edges(),
ndata=str(self.node_attr_schemes()),
edata=str(self.edge_attr_schemes()))
"""Module for graph index class definition."""
from __future__ import absolute_import
import ctypes
......@@ -7,7 +8,7 @@ import scipy
from ._ffi.base import c_array
from ._ffi.function import _init_api
from .base import DGLError, is_all
from .base import DGLError
from . import backend as F
from . import utils
from .immutable_graph_index import create_immutable_graph_index
......@@ -58,7 +59,7 @@ class GraphIndex(object):
num : int
Number of nodes to be added.
"""
_CAPI_DGLGraphAddVertices(self._handle, num);
_CAPI_DGLGraphAddVertices(self._handle, num)
self.clear_cache()
def add_edge(self, u, v):
......@@ -71,7 +72,7 @@ class GraphIndex(object):
v : int
The dst node.
"""
_CAPI_DGLGraphAddEdge(self._handle, u, v);
_CAPI_DGLGraphAddEdge(self._handle, u, v)
self.clear_cache()
def add_edges(self, u, v):
......@@ -366,12 +367,12 @@ class GraphIndex(object):
return src, dst, eid
@utils.cached_member(cache='_cache', prefix='edges')
def edges(self, sorted=False):
def edges(self, return_sorted=False):
"""Return all the edges
Parameters
----------
sorted : bool
return_sorted : bool
True if the returned edges are sorted by their src and dst ids.
Returns
......@@ -383,9 +384,9 @@ class GraphIndex(object):
utils.Index
The edge ids.
"""
key = 'edges_s%d' % sorted
key = 'edges_s%d' % return_sorted
if key not in self._cache:
edge_array = _CAPI_DGLGraphEdges(self._handle, sorted)
edge_array = _CAPI_DGLGraphEdges(self._handle, return_sorted)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
......@@ -505,7 +506,6 @@ class GraphIndex(object):
"""
e_array = e.todgltensor()
rst = _CAPI_DGLGraphEdgeSubgraph(self._handle, e_array)
gi = GraphIndex(rst(0))
induced_nodes = utils.toindex(rst(1))
return SubgraphIndex(rst(0), self, induced_nodes, e)
......@@ -555,7 +555,7 @@ class GraphIndex(object):
return adj, shuffle_idx
@utils.cached_member(cache='_cache', prefix='inc')
def incidence_matrix(self, type, ctx):
def incidence_matrix(self, typestr, ctx):
"""Return the incidence matrix representation of this graph.
An incidence matrix is an n x m sparse matrix, where n is
......@@ -577,7 +577,7 @@ class GraphIndex(object):
Parameters
----------
type : str
typestr : str
Can be either "in", "out" or "both"
ctx : context
The context of returned incidence matrix.
......@@ -596,21 +596,21 @@ class GraphIndex(object):
eid = eid.tousertensor(ctx) # the index of the ctx will be cached
n = self.number_of_nodes()
m = self.number_of_edges()
if type == 'in':
if typestr == 'in':
row = F.unsqueeze(dst, 0)
col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif type == 'out':
elif typestr == 'out':
row = F.unsqueeze(src, 0)
col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif type == 'both':
elif typestr == 'both':
# create index
row = F.unsqueeze(F.cat([src, dst], dim=0), 0)
col = F.unsqueeze(F.cat([eid, eid], dim=0), 0)
......@@ -625,7 +625,7 @@ class GraphIndex(object):
dat = F.cat([x, y], dim=0)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
else:
raise DGLError('Invalid incidence matrix type: %s' % str(type))
raise DGLError('Invalid incidence matrix type: %s' % str(typestr))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return inc, shuffle_idx
......@@ -642,8 +642,8 @@ class GraphIndex(object):
src, dst, eid = self.edges()
ret = nx.MultiDiGraph() if self.is_multigraph() else nx.DiGraph()
ret.add_nodes_from(range(self.number_of_nodes()))
for u, v, id in zip(src, dst, eid):
ret.add_edge(u, v, id=id)
for u, v, e in zip(src, dst, eid):
ret.add_edge(u, v, id=e)
return ret
def from_networkx(self, nx_graph):
......@@ -661,7 +661,7 @@ class GraphIndex(object):
if not isinstance(nx_graph, nx.Graph):
nx_graph = (nx.MultiDiGraph(nx_graph) if self.is_multigraph()
else nx.DiGraph(nx_graph))
else nx.DiGraph(nx_graph))
else:
nx_graph = nx_graph.to_directed()
......@@ -797,11 +797,11 @@ class SubgraphIndex(GraphIndex):
def __getstate__(self):
raise NotImplementedError(
"SubgraphIndex pickling is not supported yet.")
"SubgraphIndex pickling is not supported yet.")
def __setstate__(self, state):
raise NotImplementedError(
"SubgraphIndex unpickling is not supported yet.")
"SubgraphIndex unpickling is not supported yet.")
def map_to_subgraph_nid(subgraph, parent_nids):
"""Map parent node Ids to the subgraph node Ids.
......@@ -820,7 +820,7 @@ def map_to_subgraph_nid(subgraph, parent_nids):
Node Ids in the subgraph.
"""
return utils.toindex(_CAPI_DGLMapSubgraphNID(subgraph.induced_nodes.todgltensor(),
parent_nids.todgltensor()))
parent_nids.todgltensor()))
def disjoint_union(graphs):
"""Return a disjoint union of the input graphs.
......@@ -868,12 +868,12 @@ def disjoint_partition(graph, num_or_size_splits):
"""
if isinstance(num_or_size_splits, utils.Index):
rst = _CAPI_DGLDisjointPartitionBySizes(
graph._handle,
num_or_size_splits.todgltensor())
graph._handle,
num_or_size_splits.todgltensor())
else:
rst = _CAPI_DGLDisjointPartitionByNum(
graph._handle,
int(num_or_size_splits))
graph._handle,
int(num_or_size_splits))
graphs = []
for val in rst.asnumpy():
handle = ctypes.cast(int(val), ctypes.c_void_p)
......@@ -891,46 +891,41 @@ def create_graph_index(graph_data=None, multigraph=False, readonly=False):
Whether the graph is multigraph (default is False)
"""
if isinstance(graph_data, GraphIndex):
# FIXME(minjie): this return is not correct for mutable graph index
return graph_data
if readonly and graph_data is not None:
try:
gi = create_immutable_graph_index(graph_data)
except:
gi = None
# If we can't create an immutable graph index, we'll have to fall back.
if gi is not None:
return gi
if readonly:
return create_immutable_graph_index(graph_data)
handle = _CAPI_DGLGraphCreate(multigraph)
gi = GraphIndex(handle)
gidx = GraphIndex(handle)
if graph_data is None:
return gi
return gidx
# edge list
if isinstance(graph_data, (list, tuple)):
try:
gi.from_edge_list(graph_data)
return gi
except:
gidx.from_edge_list(graph_data)
return gidx
except Exception: # pylint: disable=broad-except
raise DGLError('Graph data is not a valid edge list.')
# scipy format
if isinstance(graph_data, scipy.sparse.spmatrix):
try:
gi.from_scipy_sparse_matrix(graph_data)
return gi
except:
gidx.from_scipy_sparse_matrix(graph_data)
return gidx
except Exception: # pylint: disable=broad-except
raise DGLError('Graph data is not a valid scipy sparse matrix.')
# networkx - any format
try:
gi.from_networkx(graph_data)
except:
gidx.from_networkx(graph_data)
except Exception: # pylint: disable=broad-except
raise DGLError('Error while creating graph from input of type "%s".'
% type(graph_data))
% type(graph_data))
return gi
return gidx
_init_api("dgl.graph_index")
"""Module for immutable graph index.
NOTE: this is currently a temporary solution.
"""
# pylint: disable=abstract-method,unused-argument
from __future__ import absolute_import
import ctypes
import numpy as np
import networkx as nx
import scipy.sparse as sp
......@@ -8,7 +13,7 @@ import scipy.sparse as sp
from ._ffi.function import _init_api
from . import backend as F
from . import utils
from .base import ALL, is_all, DGLError
from .base import DGLError
class ImmutableGraphIndex(object):
"""Graph index object on immutable graphs.
......@@ -27,7 +32,7 @@ class ImmutableGraphIndex(object):
def add_nodes(self, num):
"""Add nodes.
Parameters
----------
num : int
......@@ -37,7 +42,7 @@ class ImmutableGraphIndex(object):
def add_edge(self, u, v):
"""Add one edge.
Parameters
----------
u : int
......@@ -49,7 +54,7 @@ class ImmutableGraphIndex(object):
def add_edges(self, u, v):
"""Add many edges.
Parameters
----------
u : utils.Index
......@@ -229,8 +234,8 @@ class ImmutableGraphIndex(object):
"""
u = F.tensor([u], dtype=F.int64)
v = F.tensor([v], dtype=F.int64)
_, _, id = self._sparse.edge_ids(u, v)
return utils.toindex(id)
_, _, eid = self._sparse.edge_ids(u, v)
return utils.toindex(eid)
def edge_ids(self, u, v):
"""Return the edge ids.
......@@ -282,7 +287,7 @@ class ImmutableGraphIndex(object):
----------
v : utils.Index
The node(s).
Returns
-------
utils.Index
......@@ -305,7 +310,7 @@ class ImmutableGraphIndex(object):
----------
v : utils.Index
The node(s).
Returns
-------
utils.Index
......@@ -321,14 +326,14 @@ class ImmutableGraphIndex(object):
src = _CAPI_DGLExpandIds(v.todgltensor(), off.todgltensor())
return utils.toindex(src), utils.toindex(dst), utils.toindex(edges)
def edges(self, sorted=False):
def edges(self, return_sorted=False):
"""Return all the edges
Parameters
----------
sorted : bool
return_sorted : bool
True if the returned edges are sorted by their src and dst ids.
Returns
-------
utils.Index
......@@ -340,7 +345,7 @@ class ImmutableGraphIndex(object):
"""
if "all_edges" in self._cache:
return self._cache["all_edges"]
src, dst, edges = self._sparse.edges(sorted)
src, dst, edges = self._sparse.edges(return_sorted)
self._cache["all_edges"] = (utils.toindex(src), utils.toindex(dst), utils.toindex(edges))
return self._cache["all_edges"]
......@@ -440,8 +445,8 @@ class ImmutableGraphIndex(object):
The subgraph index.
"""
v = v.tousertensor()
gi, induced_n, induced_e = self._sparse.node_subgraph(v)
return ImmutableSubgraphIndex(gi, self, induced_n, induced_e)
gidx, induced_n, induced_e = self._sparse.node_subgraph(v)
return ImmutableSubgraphIndex(gidx, self, induced_n, induced_e)
def node_subgraphs(self, vs_arr):
"""Return the induced node subgraphs.
......@@ -458,8 +463,8 @@ class ImmutableGraphIndex(object):
"""
vs_arr = [v.tousertensor() for v in vs_arr]
gis, induced_nodes, induced_edges = self._sparse.node_subgraphs(vs_arr)
return [ImmutableSubgraphIndex(gi, self, induced_n,
induced_e) for gi, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)]
return [ImmutableSubgraphIndex(gidx, self, induced_n, induced_e)
for gidx, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)]
def edge_subgraph(self, e):
"""Return the induced edge subgraph.
......@@ -478,6 +483,7 @@ class ImmutableGraphIndex(object):
def neighbor_sampling(self, seed_ids, expand_factor, num_hops, neighbor_type,
node_prob, max_subgraph_size):
"""Neighborhood sampling"""
if len(seed_ids) == 0:
return []
seed_ids = [v.tousertensor() for v in seed_ids]
......@@ -486,8 +492,8 @@ class ImmutableGraphIndex(object):
node_prob,
max_subgraph_size)
induced_nodes = [utils.toindex(v) for v in induced_nodes]
return [ImmutableSubgraphIndex(gi, self, induced_n,
induced_e) for gi, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)]
return [ImmutableSubgraphIndex(gidx, self, induced_n, induced_e)
for gidx, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)]
def adjacency_matrix(self, transpose=False, ctx=F.cpu()):
"""Return the adjacency matrix representation of this graph.
......@@ -511,12 +517,9 @@ class ImmutableGraphIndex(object):
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
def get_adj(ctx):
new_mat = self._sparse.adjacency_matrix(transpose)
return F.copy_to(new_mat, ctx)
return self._sparse.adjacency_matrix(transpose, ctx), None
def incidence_matrix(self, type, ctx):
def incidence_matrix(self, typestr, ctx):
"""Return the incidence matrix representation of this graph.
An incidence matrix is an n x m sparse matrix, where n is
......@@ -538,7 +541,7 @@ class ImmutableGraphIndex(object):
Parameters
----------
type : str
typestr : str
Can be either "in", "out" or "both"
ctx : context
The context of returned incidence matrix.
......@@ -565,8 +568,8 @@ class ImmutableGraphIndex(object):
"""
src, dst, eid = self.edges()
ret = nx.DiGraph()
for u, v, id in zip(src, dst, eid):
ret.add_edge(u, v, id=id)
for u, v, e in zip(src, dst, eid):
ret.add_edge(u, v, id=e)
return ret
def from_networkx(self, nx_graph):
......@@ -574,7 +577,7 @@ class ImmutableGraphIndex(object):
If 'id' edge attribute exists, the edge will be added follows
the edge id order. Otherwise, order is undefined.
Parameters
----------
nx_graph : networkx.DiGraph
......@@ -582,7 +585,7 @@ class ImmutableGraphIndex(object):
"""
if not isinstance(nx_graph, nx.Graph):
nx_graph = (nx.MultiDiGraph(nx_graph) if self.is_multigraph()
else nx.DiGraph(nx_graph))
else nx.DiGraph(nx_graph))
else:
nx_graph = nx_graph.to_directed()
......@@ -626,8 +629,8 @@ class ImmutableGraphIndex(object):
----------
adj : scipy sparse matrix
"""
assert isinstance(adj, sp.csr_matrix) or isinstance(adj, sp.coo_matrix), \
"The input matrix has to be a SciPy sparse matrix."
if not isinstance(adj, (sp.csr_matrix, sp.coo_matrix)):
raise DGLError("The input matrix has to be a SciPy sparse matrix.")
out_mat = adj.tocoo()
self._sparse.from_coo_matrix(out_mat)
......@@ -639,23 +642,7 @@ class ImmutableGraphIndex(object):
elist : list
List of (u, v) edge tuple.
"""
self.clear()
src, dst = zip(*elist)
src = np.array(src)
dst = np.array(dst)
num_nodes = max(src.max(), dst.max()) + 1
min_nodes = min(src.min(), dst.min())
if min_nodes != 0:
raise DGLError('Invalid edge list. Nodes must start from 0.')
edge_ids = mx.nd.arange(0, len(src), step=1, repeat=1, dtype=np.int32)
src = mx.nd.array(src, dtype=np.int64)
dst = mx.nd.array(dst, dtype=np.int64)
# TODO we can't generate a csr_matrix with np.int64 directly.
in_csr = mx.nd.sparse.csr_matrix((edge_ids, (dst, src)),
shape=(num_nodes, num_nodes)).astype(np.int64)
out_csr = mx.nd.sparse.csr_matrix((edge_ids, (src, dst)),
shape=(num_nodes, num_nodes)).astype(np.int64)
self.__init__(in_csr, out_csr)
self._sparse.from_edge_list(elist)
def line_graph(self, backtracking=True):
"""Return the line graph of this graph.
......@@ -778,35 +765,35 @@ def create_immutable_graph_index(graph_data=None):
# If graph_data is None, we return an empty graph index.
# If we can't create a graph index, we'll use the code below to handle the graph.
return ImmutableGraphIndex(F.create_immutable_graph_index(graph_data))
except:
except Exception: # pylint: disable=broad-except
pass
# Let's create an empty graph index first.
gi = ImmutableGraphIndex(F.create_immutable_graph_index())
gidx = ImmutableGraphIndex(F.create_immutable_graph_index())
# edge list
if isinstance(graph_data, (list, tuple)):
try:
gi.from_edge_list(graph_data)
return gi
except:
gidx.from_edge_list(graph_data)
return gidx
except Exception: # pylint: disable=broad-except
raise DGLError('Graph data is not a valid edge list.')
# scipy format
if isinstance(graph_data, sp.spmatrix):
try:
gi.from_scipy_sparse_matrix(graph_data)
return gi
except:
gidx.from_scipy_sparse_matrix(graph_data)
return gidx
except Exception: # pylint: disable=broad-except
raise DGLError('Graph data is not a valid scipy sparse matrix.')
# networkx - any format
try:
gi.from_networkx(graph_data)
except:
gidx.from_networkx(graph_data)
except Exception: # pylint: disable=broad-except
raise DGLError('Error while creating graph from input of type "%s".'
% type(graph_data))
% type(graph_data))
return gi
return gidx
_init_api("dgl.immutable_graph_index")
......@@ -5,7 +5,7 @@ from . import backend as F
__all__ = ['base_initializer', 'zero_initializer']
def base_initializer(shape, dtype, ctx, range):
def base_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-argument
"""The function signature for feature initializer.
Any customized feature initializer should follow this signature (see
......@@ -20,7 +20,7 @@ def base_initializer(shape, dtype, ctx, range):
The data type of the returned features.
ctx : context object
The device context of the returned features.
range : slice
id_range : slice
The start id and the end id of the features to be initialized.
The id could be node or edge id depending on the scenario.
Note that the step is always None.
......@@ -32,7 +32,7 @@ def base_initializer(shape, dtype, ctx, range):
>>> import torch
>>> import dgl
>>> def initializer(shape, dtype, ctx, range):
>>> def initializer(shape, dtype, ctx, id_range):
>>> return torch.ones(shape, dtype=dtype, device=ctx)
>>> g = dgl.DGLGraph()
>>> g.set_n_initializer(initializer)
......@@ -44,7 +44,7 @@ def base_initializer(shape, dtype, ctx, range):
"""
raise NotImplementedError
def zero_initializer(shape, dtype, ctx, range):
def zero_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-argument
"""Zero feature initializer
Examples
......
......@@ -56,7 +56,7 @@ def prop_edges(graph,
def prop_nodes_bfs(graph,
source,
reversed=False,
reverse=False,
message_func='default',
reduce_func='default',
apply_node_func='default'):
......@@ -68,7 +68,7 @@ def prop_nodes_bfs(graph,
The graph object.
source : list, tensor of nodes
Source nodes.
reversed : bool, optional
reverse : bool, optional
If true, traverse following the in-edge direction.
message_func : callable, optional
The message function.
......@@ -81,11 +81,11 @@ def prop_nodes_bfs(graph,
--------
dgl.traversal.bfs_nodes_generator
"""
nodes_gen = trv.bfs_nodes_generator(graph, source, reversed)
nodes_gen = trv.bfs_nodes_generator(graph, source, reverse)
prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
def prop_nodes_topo(graph,
reversed=False,
reverse=False,
message_func='default',
reduce_func='default',
apply_node_func='default'):
......@@ -95,7 +95,7 @@ def prop_nodes_topo(graph,
----------
graph : DGLGraph
The graph object.
reversed : bool, optional
reverse : bool, optional
If true, traverse following the in-edge direction.
message_func : callable, optional
The message function.
......@@ -108,12 +108,12 @@ def prop_nodes_topo(graph,
--------
dgl.traversal.topological_nodes_generator
"""
nodes_gen = trv.topological_nodes_generator(graph, reversed)
nodes_gen = trv.topological_nodes_generator(graph, reverse)
prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
def prop_edges_dfs(graph,
source,
reversed=False,
reverse=False,
has_reverse_edge=False,
has_nontree_edge=False,
message_func='default',
......@@ -127,7 +127,7 @@ def prop_edges_dfs(graph,
The graph object.
source : list, tensor of nodes
Source nodes.
reversed : bool, optional
reverse : bool, optional
If true, traverse following the in-edge direction.
message_func : callable, optional
The message function.
......@@ -141,6 +141,6 @@ def prop_edges_dfs(graph,
dgl.traversal.dfs_labeled_edges_generator
"""
edges_gen = trv.dfs_labeled_edges_generator(
graph, source, reversed, has_reverse_edge, has_nontree_edge,
return_labels=False)
graph, source, reverse, has_reverse_edge, has_nontree_edge,
return_labels=False)
prop_edges(graph, edges_gen, message_func, reduce_func, apply_node_func)
"""DGL Runtime"""
"""Package for DGL scheduler and runtime."""
from __future__ import absolute_import
from . import scheduler
......
"""Module for degree bucketing schedulers"""
"""Module for degree bucketing schedulers."""
from __future__ import absolute_import
from .._ffi.function import _init_api
from ..base import is_all, ALL
from ..base import is_all
from .. import backend as F
from ..immutable_graph_index import ImmutableGraphIndex
from ..udf import EdgeBatch, NodeBatch
from ..udf import NodeBatch
from .. import utils
from . import ir
from .ir import var as var
from .ir import var
def gen_degree_bucketing_schedule(
graph,
......@@ -52,23 +51,23 @@ def gen_degree_bucketing_schedule(
"""
buckets = _degree_bucketing_schedule(message_ids, dst_nodes, recv_nodes)
# generate schedule
unique_dst, degs, buckets, msg_ids, zero_deg_nodes = buckets
_, degs, buckets, msg_ids, zero_deg_nodes = buckets
# loop over each bucket
idx_list = []
fd_list = []
for deg, vb, mid in zip(degs, buckets, msg_ids):
for deg, vbkt, mid in zip(degs, buckets, msg_ids):
# create per-bkt rfunc
rfunc = _create_per_bkt_rfunc(graph, reduce_udf, deg, vb)
rfunc = _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt)
# vars
vb = var.IDX(vb)
vbkt = var.IDX(vbkt)
mid = var.IDX(mid)
rfunc = var.FUNC(rfunc)
# recv on each bucket
fdvb = ir.READ_ROW(var_nf, vb)
fdvb = ir.READ_ROW(var_nf, vbkt)
fdmail = ir.READ_ROW(var_mf, mid)
fdvb = ir.NODE_UDF(rfunc, fdvb, fdmail, ret=fdvb) # reuse var
# save for merge
idx_list.append(vb)
idx_list.append(vbkt)
fd_list.append(fdvb)
if zero_deg_nodes is not None:
# NOTE: there must be at least one non-zero-deg node; otherwise,
......@@ -178,15 +177,16 @@ def _process_buckets(buckets):
return v, degs, dsts, msg_ids, zero_deg_nodes
def _create_per_bkt_rfunc(graph, reduce_udf, deg, vb):
def _create_per_bkt_rfunc(graph, reduce_udf, deg, vbkt):
"""Internal function to generate the per degree bucket node UDF."""
def _rfunc_wrapper(node_data, mail_data):
def _reshaped_getter(key):
msg = mail_data[key]
new_shape = (len(vb), deg) + F.shape(msg)[1:]
new_shape = (len(vbkt), deg) + F.shape(msg)[1:]
return F.reshape(msg, new_shape)
reshaped_mail_data = utils.LazyDict(_reshaped_getter, mail_data.keys())
nb = NodeBatch(graph, vb, node_data, reshaped_mail_data)
return reduce_udf(nb)
nbatch = NodeBatch(graph, vbkt, node_data, reshaped_mail_data)
return reduce_udf(nbatch)
return _rfunc_wrapper
_init_api("dgl.runtime.degree_bucketing")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment