Unverified Commit 22167f72 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Refactor] Enable new kernel in all message passing APIs (#1953)

* WIP: frame refactor

* new frame

* simple update_all builtin

* move all subgraph routines into the same file

* sddmm & spmm schedule; node & edge udf

* degree bucketing

* some tricky 0deg corner cases

* bug in frame append

* merge test_hetero_basics and test_basics

* some code rearange

* fix test_heterograph

* add mean spmm

* enable all builtin combinations

* pass gpu test

* pass pytorch tests

* wip

* fix some pt debugging codes

* fix bug in mxnet backward

* pass all mxnet utests

* passed tf tests

* docstring

* lint

* lint

* fix broadcasting bugs

* add warning and clamp for mean reducer

* add test for zero-degree mean

* address comments

* lint

* small fix
parent 5d5436ba
......@@ -26,6 +26,7 @@ from .convert import *
from .generators import *
from .heterograph import DGLHeteroGraph
from .heterograph import DGLHeteroGraph as DGLGraph # pylint: disable=reimported
from .subgraph import *
from .traversal import *
from .transform import *
from .propagate import *
......
This diff is collapsed.
......@@ -12,7 +12,7 @@ import dgl
from ..base import ALL, NID, EID, is_all, DGLError, dgl_warning
from .. import backend as F
from .. import init
from ..frame import FrameRef, Frame, Scheme, sync_frame_initializer
from .frame import FrameRef, Frame, Scheme, sync_frame_initializer
from .. import graph_index
from .runtime import ir, scheduler, Runtime, GraphAdapter
from .. import utils
......
......@@ -5,7 +5,7 @@ from .._ffi.object import register_object, ObjectBase
from .._ffi.function import _init_api
from ..base import ALL, is_all, DGLError, dgl_warning
from .. import backend as F
from ..frame import Frame, FrameRef
from .frame import Frame, FrameRef
from .graph import DGLBaseGraph
from ..graph_index import transform_ids
from .runtime import ir, scheduler, Runtime
......
......@@ -5,7 +5,7 @@ from __future__ import absolute_import
from abc import abstractmethod
from .... import backend as F
from ....frame import FrameRef, Frame
from ...frame import FrameRef, Frame
from .... import utils
from .program import get_current_prog
......
......@@ -5,7 +5,7 @@ from ... import utils
from ..._ffi.function import _init_api
from ...base import DGLError
from ... import backend as F
from ...frame import frame_like, FrameRef
from ..frame import frame_like, FrameRef
from ...function.base import BuiltinFunction
from ..udf import EdgeBatch, NodeBatch
from ... import ndarray as nd
......
......@@ -971,58 +971,6 @@ def pack_padded_tensor(input, lengths):
"""
pass
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
"""Computes the sum along segments of a tensor.
Equivalent to tf.unsorted_segment_sum, but seg_id is required to be a
1D tensor.
Parameters
----------
input : Tensor
The input tensor
seg_id : 1D Tensor
The segment IDs whose values are between 0 and n_segs - 1. Should
have the same length as input.
n_segs : int
Number of distinct segments
dim : int
Dimension to sum on
Returns
-------
Tensor
The result
"""
pass
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
"""Computes the mean along segments of a tensor.
Equivalent to tf.unsorted_segment_mean, but seg_id is required to be a
1D tensor.
Note that segments never appeared in seg_id will have results of 0.
Parameters
----------
input : Tensor
The input tensor
seg_id : 1D Tensor
The segment IDs whose values are between 0 and n_segs - 1. Should
have the same length as input.
n_segs : int
Number of distinct segments
dim : int
Dimension to average on
Returns
-------
Tensor
The result
"""
pass
def boolean_mask(input, mask):
"""Selects elements in x according to the given mask from the first
dimension.
......@@ -1089,6 +1037,26 @@ def clone(input):
"""
pass
def clamp(data, min_val, max_val):
"""Clamp all elements in :attr:`input` into the range [min_val, max_val]
and return a resulting tensor.
Parameters
----------
data : Tensor
Input tensor
min_val : Scalar
Min value.
max_val : Scalar
Max value.
Returns
-------
Tensor
The result.
"""
pass
###############################################################################
# Tensor functions used *only* on index tensor
# ----------------
......
......@@ -176,10 +176,17 @@ class GSpMM(mx.autograd.Function):
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
func = GSpMM(gidx, op, reduce_op)
ctx = to_backend_ctx(gidx.ctx)
# XXX(minjie): There is a bug in MXNet's autograd system when one of the inputs
# does not require gradient. Although it still invokes the backward function,
# it does not set the gradient value to the correct buffer, resulting all the
# input gradients to be zero. Fix this by enforcing all the inputs to require
# gradients.
if lhs_data is None:
lhs_data = nd.zeros((1,), ctx=ctx)
lhs_data.attach_grad()
if rhs_data is None:
rhs_data = nd.zeros((1,), ctx=ctx)
rhs_data.attach_grad()
return func(lhs_data, rhs_data)
......
......@@ -304,35 +304,6 @@ def pack_padded_tensor(input, lengths):
index = nd.array(index, ctx=ctx)
return gather_row(input.reshape(batch_size * max_len, -1), index)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
# TODO: support other dimensions
assert dim == 0, 'MXNet only supports segment sum on first dimension'
# Use SPMV to simulate segment sum
ctx = input.context
n_inputs = input.shape[0]
input_shape_suffix = input.shape[1:]
input = input.reshape(n_inputs, -1)
n_range = nd.arange(n_inputs, dtype='int64').as_in_context(input.context)
w_nnz = nd.ones(n_inputs).as_in_context(input.context)
w_nid = nd.stack(seg_id, n_range, axis=0)
w = nd.sparse.csr_matrix((w_nnz, (seg_id, n_range)), (n_segs, n_inputs))
w = w.as_in_context(input.context)
y = nd.dot(w, input)
y = nd.reshape(y, (n_segs,) + input_shape_suffix)
return y
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
# TODO: support other dimensions
assert dim == 0, 'MXNet only supports segment mean on first dimension'
n_ones = nd.ones_like(seg_id).astype(input.dtype)
w = unsorted_1d_segment_sum(n_ones, seg_id, n_segs, 0)
w = nd.clip(w, a_min=1, a_max=np.inf)
y = unsorted_1d_segment_sum(input, seg_id, n_segs, dim)
y = y / w.reshape((-1,) + (1,) * (y.ndim - 1))
return y
def boolean_mask(input, mask):
return mx.contrib.nd.boolean_mask(input, mask)
......@@ -348,6 +319,9 @@ def logical_and(input1, input2):
def clone(input):
return input.copy()
def clamp(data, min_val, max_val):
return nd.clip(data, min_val, max_val)
def unique(input):
# TODO: fallback to numpy is unfortunate
tmp = input.asnumpy()
......
......@@ -106,7 +106,6 @@ class GSpMM(th.autograd.Function):
else: # max/min
dY = th.zeros((Y.shape[0],) + dZ.shape[1:],
dtype=Y.dtype, device=Y.device)
print(X.shape, dZ.shape)
if op in ['mul', 'div']:
grad = _expand(X, dZ.shape[1:]).gather(
0, argX.long()) * dZ
......
......@@ -245,19 +245,6 @@ def pack_padded_tensor(input, lengths):
index = th.tensor(index).to(device)
return gather_row(input.view(batch_size * max_len, -1), index)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
y = th.zeros(n_segs, *input.shape[1:]).to(input)
seg_id = seg_id.view((-1,) + (1,) * (input.dim() - 1)).expand_as(input)
y = y.scatter_add_(dim, seg_id, input)
return y
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
w = unsorted_1d_segment_sum(th.ones_like(seg_id), seg_id, n_segs, 0).to(input)
w = w.clamp(min=1) # remove 0 entries
y = unsorted_1d_segment_sum(input, seg_id, n_segs, dim)
y = y / w.view((-1,) + (1,) * (y.dim() - 1))
return y
def boolean_mask(input, mask):
if 'bool' not in str(mask.dtype):
mask = th.tensor(mask, dtype=th.bool)
......@@ -275,6 +262,9 @@ def logical_and(input1, input2):
def clone(input):
return input.clone()
def clamp(data, min_val, max_val):
return th.clamp(data, min_val, max_val)
def unique(input):
if input.dtype == th.bool:
input = input.type(th.int8)
......
......@@ -283,7 +283,11 @@ def narrow_row(x, start, stop):
def scatter_row(data, row_index, value):
row_index = tf.expand_dims(row_index, 1)
return tf.tensor_scatter_nd_update(data, row_index, value)
# XXX(minjie): Normally, the copy_to here is unnecessary. However, TF has this
# notorious legacy issue that int32 type data is always on CPU, which will
# crash the program since DGL requires feature data to be on the same device
# as graph structure.
return copy_to(tf.tensor_scatter_nd_update(data, row_index, value), data.device)
def index_add_inplace(data, row_idx, value):
......@@ -366,18 +370,6 @@ def pack_padded_tensor(input, lengths):
return tf.concat(out_list, axis=0)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
assert dim == 0 # Why we need dim for 1d?
return tf.math.unsorted_segment_sum(input, seg_id, n_segs)
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
assert dim == 0 # Why we need dim for 1d?
return tf.math.unsorted_segment_mean(input, seg_id, n_segs)
# TODO: TF has unsorted_segment_max, which can accelerate _max_on on batched graph
def boolean_mask(input, mask):
return tf.boolean_mask(input, mask)
......@@ -396,6 +388,9 @@ def clone(input):
# TF tensor is always immutable so returning the input is safe.
return input
def clamp(data, min_val, max_val):
return tf.clip_by_value(data, min_val, max_val)
def unique(input):
return tf.unique(input).y
......
"""Implementation for core graph computation."""
# pylint: disable=not-callable
from .base import DGLError, is_all, NID, EID, ALL
from . import backend as F
from . import function as fn
from .frame import Frame
from .udf import NodeBatch, EdgeBatch
from . import ops
def is_builtin(func):
"""Return true if the function is a DGL builtin function."""
return isinstance(func, fn.BuiltinFunction)
def invoke_node_udf(graph, nid, ntype, func, *, ndata=None, orig_nid=None):
"""Invoke user-defined node function on the given nodes.
Parameters
----------
graph : DGLGraph
The input graph.
eid : Tensor
The IDs of the nodes to invoke UDF on.
ntype : str
Node type.
func : callable
The user-defined function.
ndata : dict[str, Tensor], optional
If provided, apply the UDF on this ndata instead of the ndata of the graph.
orig_nid : Tensor, optional
Original node IDs. Useful if the input graph is an extracted subgraph.
Returns
-------
dict[str, Tensor]
Results from running the UDF.
"""
ntid = graph.get_ntype_id(ntype)
if ndata is None:
if is_all(nid):
ndata = graph._node_frames[ntid]
nid = graph.nodes(ntype=ntype)
else:
ndata = graph._node_frames[ntid].subframe(nid)
nbatch = NodeBatch(graph, nid if orig_nid is None else orig_nid, ntype, ndata)
return func(nbatch)
def invoke_edge_udf(graph, eid, etype, func, *, orig_eid=None):
"""Invoke user-defined edge function on the given edges.
Parameters
----------
graph : DGLGraph
The input graph.
eid : Tensor
The IDs of the edges to invoke UDF on.
etype : (str, str, str)
Edge type.
func : callable
The user-defined function.
orig_eid : Tensor, optional
Original edge IDs. Useful if the input graph is an extracted subgraph.
Returns
-------
dict[str, Tensor]
Results from running the UDF.
"""
etid = graph.get_etype_id(etype)
stid, dtid = graph._graph.metagraph.find_edge(etid)
if is_all(eid):
u, v, eid = graph.edges(form='all')
edata = graph._edge_frames[etid]
else:
u, v = graph.find_edges(eid)
edata = graph._edge_frames[etid].subframe(eid)
srcdata = graph._node_frames[stid].subframe(u)
dstdata = graph._node_frames[dtid].subframe(v)
ebatch = EdgeBatch(graph, eid if orig_eid is None else orig_eid,
etype, srcdata, edata, dstdata)
return func(ebatch)
def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None):
"""Invoke user-defined reduce function on all the nodes in the graph.
It analyzes the graph, groups nodes by their degrees and applies the UDF on each
group -- a strategy called *degree-bucketing*.
Parameters
----------
graph : DGLGraph
The input graph.
func : callable
The user-defined function.
msgdata : dict[str, Tensor]
Message data.
orig_nid : Tensor, optional
Original node IDs. Useful if the input graph is an extracted subgraph.
Returns
-------
dict[str, Tensor]
Results from running the UDF.
"""
degs = graph.in_degrees()
nodes = graph.dstnodes()
if orig_nid is None:
orig_nid = nodes
ntype = graph.dsttypes[0]
ntid = graph.get_ntype_id_from_dst(ntype)
dstdata = graph._node_frames[ntid]
msgdata = Frame(msgdata)
# degree bucketing
unique_degs, bucketor = _bucketing(degs)
bkt_rsts = []
bkt_nodes = []
for deg, node_bkt, orig_nid_bkt in zip(unique_degs, bucketor(nodes), bucketor(orig_nid)):
if deg == 0:
# skip reduce function for zero-degree nodes
continue
bkt_nodes.append(node_bkt)
ndata_bkt = dstdata.subframe(node_bkt)
eid_bkt = graph.in_edges(node_bkt, form='eid')
assert len(eid_bkt) == deg * len(node_bkt)
msgdata_bkt = msgdata.subframe(eid_bkt)
# reshape all msg tensors to (num_nodes_bkt, degree, feat_size)
maildata = {}
for k, msg in msgdata_bkt.items():
newshape = (len(node_bkt), deg) + F.shape(msg)[1:]
maildata[k] = F.reshape(msg, newshape)
# invoke udf
nbatch = NodeBatch(graph, orig_nid_bkt, ntype, ndata_bkt, msgs=maildata)
bkt_rsts.append(func(nbatch))
# prepare a result frame
retf = Frame(num_rows=len(nodes))
retf._initializers = dstdata._initializers
retf._default_initializer = dstdata._default_initializer
# merge bucket results and write to the result frame
if len(bkt_rsts) != 0: # if all the nodes have zero degree, no need to merge results.
merged_rst = {}
for k in bkt_rsts[0].keys():
merged_rst[k] = F.cat([rst[k] for rst in bkt_rsts], dim=0)
merged_nodes = F.cat(bkt_nodes, dim=0)
retf.update_row(merged_nodes, merged_rst)
return retf
def _bucketing(val):
"""Internal function to create groups on the values.
Parameters
----------
val : Tensor
Value tensor.
Returns
-------
unique_val : Tensor
Unique values.
bucketor : callable[Tensor -> list[Tensor]]
A bucketing function that splits the given tensor data as the same
way of how the :attr:`val` tensor is grouped.
"""
sorted_val, idx = F.sort_1d(val)
unique_val = F.asnumpy(F.unique(sorted_val))
bkt_idx = []
for v in unique_val:
eqidx = F.nonzero_1d(F.equal(sorted_val, v))
bkt_idx.append(F.gather_row(idx, eqidx))
def bucketor(data):
bkts = [F.gather_row(data, idx) for idx in bkt_idx]
return bkts
return unique_val, bucketor
def invoke_gsddmm(graph, func):
"""Invoke g-SDDMM computation on the graph.
Parameters
----------
graph : DGLGraph
The input graph.
func : dgl.function.BaseMessageFunction
Built-in message function.
Returns
-------
dict[str, Tensor]
Results from the g-SDDMM computation.
"""
alldata = [graph.srcdata, graph.dstdata, graph.edata]
if isinstance(func, fn.BinaryMessageFunction):
x = alldata[func.lhs][func.lhs_field]
y = alldata[func.rhs][func.rhs_field]
op = getattr(ops, func.name)
z = op(graph, x, y)
else:
x = alldata[func.target][func.in_field]
op = getattr(ops, func.name)
z = op(graph, x)
return {func.out_field : z}
def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None):
"""Invoke g-SPMM computation on the graph.
Parameters
----------
graph : DGLGraph
The input graph.
mfunc : dgl.function.BaseMessageFunction
Built-in message function.
rfunc : dgl.function.BaseReduceFunction
Built-in reduce function.
srcdata : dict[str, Tensor], optional
Source node feature data. If not provided, it use ``graph.srcdata``.
dstdata : dict[str, Tensor], optional
Destination node feature data. If not provided, it use ``graph.dstdata``.
edata : dict[str, Tensor], optional
Edge feature data. If not provided, it use ``graph.edata``.
Returns
-------
dict[str, Tensor]
Results from the g-SPMM computation.
"""
# sanity check
if mfunc.out_field != rfunc.msg_field:
raise DGLError('Invalid message ({}) and reduce ({}) function pairs.'
' The output field of the message function must be equal to the'
' message field of the reduce function.'.format(mfunc, rfunc))
if edata is None:
edata = graph.edata
if srcdata is None:
srcdata = graph.srcdata
if dstdata is None:
dstdata = graph.dstdata
alldata = [srcdata, dstdata, edata]
if isinstance(mfunc, fn.BinaryMessageFunction):
x = alldata[mfunc.lhs][mfunc.lhs_field]
y = alldata[mfunc.rhs][mfunc.rhs_field]
op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name))
z = op(graph, x, y)
else:
x = alldata[mfunc.target][mfunc.in_field]
op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name))
z = op(graph, x)
return {rfunc.out_field : z}
def message_passing(g, mfunc, rfunc, afunc):
"""Invoke message passing computation on the whole graph.
Parameters
----------
g : DGLGraph
The input graph.
mfunc : callable or dgl.function.BuiltinFunction
Message function.
rfunc : callable or dgl.function.BuiltinFunction
Reduce function.
afunc : callable or dgl.function.BuiltinFunction
Apply function.
Returns
-------
dict[str, Tensor]
Results from the message passing computation.
"""
if g.number_of_edges() == 0:
# No message passing is triggered.
ndata = {}
elif (is_builtin(mfunc) and is_builtin(rfunc) and
getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name), None) is not None):
# invoke fused message passing
ndata = invoke_gspmm(g, mfunc, rfunc)
else:
# invoke message passing in two separate steps
# message phase
if is_builtin(mfunc):
msgdata = invoke_gsddmm(g, mfunc)
else:
orig_eid = g.edata.get(EID, None)
msgdata = invoke_edge_udf(g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid)
# reduce phase
if is_builtin(rfunc):
msg = rfunc.msg_field
ndata = invoke_gspmm(g, fn.copy_e(msg, msg), rfunc, edata=msgdata)
else:
orig_nid = g.dstdata.get(NID, None)
ndata = invoke_udf_reduce(g, rfunc, msgdata, orig_nid=orig_nid)
# apply phase
if afunc is not None:
for k, v in g.dstdata.items(): # include original node features
if k not in ndata:
ndata[k] = v
orig_nid = g.dstdata.get(NID, None)
ndata = invoke_node_udf(g, ALL, g.dsttypes[0], afunc, ndata=ndata, orig_nid=orig_nid)
return ndata
"""For HeteroGraph Serialization"""
from __future__ import absolute_import
from ..heterograph import DGLHeteroGraph
from ..frame import Frame, FrameRef
from ..frame import Frame
from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api
from .. import backend as F
......@@ -51,10 +51,10 @@ class HeteroGraphData(ObjectBase):
eframes = []
for ntid, ntensor in enumerate(ntensor_list):
ndict = {ntensor[i]: F.zerocopy_from_dgl_ndarray(ntensor[i+1]) for i in range(0, len(ntensor), 2)}
nframes.append(FrameRef(Frame(ndict, num_rows=gidx.number_of_nodes(ntid))))
nframes.append(Frame(ndict, num_rows=gidx.number_of_nodes(ntid)))
for etid, etensor in enumerate(etensor_list):
edict = {etensor[i]: F.zerocopy_from_dgl_ndarray(etensor[i+1]) for i in range(0, len(etensor), 2)}
eframes.append(FrameRef(Frame(edict, num_rows=gidx.number_of_edges(etid))))
eframes.append(Frame(edict, num_rows=gidx.number_of_edges(etid)))
return DGLHeteroGraph(gidx, ntype_names, etype_names, nframes, eframes)
......@@ -3,7 +3,7 @@ from collections import namedtuple
from .rpc import Request, Response, send_requests_to_machine, recv_responses
from ..sampling import sample_neighbors as local_sample_neighbors
from ..transform import in_subgraph as local_in_subgraph
from ..subgraph import in_subgraph as local_in_subgraph
from .rpc import register_service
from ..convert import graph
from ..base import NID, EID
......
This diff is collapsed.
......@@ -15,9 +15,9 @@ class TargetCode(object):
EDGE = 2
CODE2STR = {
0: "src",
1: "dst",
2: "edge",
0: "u",
1: "v",
2: "e",
}
......
......@@ -9,7 +9,8 @@ from .._deprecate.runtime import ir
from .._deprecate.runtime.ir import var
__all__ = ["src_mul_edge", "copy_src", "copy_edge", "copy_u", "copy_e"]
__all__ = ["src_mul_edge", "copy_src", "copy_edge", "copy_u", "copy_e",
"BinaryMessageFunction", "CopyMessageFunction"]
class MessageFunction(BuiltinFunction):
......
......@@ -87,7 +87,7 @@ __all__ = []
def _register_builtin_reduce_func():
"""Register builtin reduce functions"""
for reduce_op in ["max", "min", "sum", "mean", "prod"]:
for reduce_op in ["max", "min", "sum", "mean"]:
builtin = _gen_reduce_builtin(reduce_op)
setattr(sys.modules[__name__], reduce_op, builtin)
__all__.append(reduce_op)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment