"git@developer.sourcefind.cn:OpenDAS/deepspeed.git" did not exist on "dd03cff29f0bf5416afa03bc9be323be903d51aa"
Unverified Commit 22167f72 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Refactor] Enable new kernel in all message passing APIs (#1953)

* WIP: frame refactor

* new frame

* simple update_all builtin

* move all subgraph routines into the same file

* sddmm & spmm schedule; node & edge udf

* degree bucketing

* some tricky 0deg corner cases

* bug in frame append

* merge test_hetero_basics and test_basics

* some code rearange

* fix test_heterograph

* add mean spmm

* enable all builtin combinations

* pass gpu test

* pass pytorch tests

* wip

* fix some pt debugging codes

* fix bug in mxnet backward

* pass all mxnet utests

* passed tf tests

* docstring

* lint

* lint

* fix broadcasting bugs

* add warning and clamp for mean reducer

* add test for zero-degree mean

* address comments

* lint

* small fix
parent 5d5436ba
...@@ -26,6 +26,7 @@ from .convert import * ...@@ -26,6 +26,7 @@ from .convert import *
from .generators import * from .generators import *
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
from .heterograph import DGLHeteroGraph as DGLGraph # pylint: disable=reimported from .heterograph import DGLHeteroGraph as DGLGraph # pylint: disable=reimported
from .subgraph import *
from .traversal import * from .traversal import *
from .transform import * from .transform import *
from .propagate import * from .propagate import *
......
This diff is collapsed.
...@@ -12,7 +12,7 @@ import dgl ...@@ -12,7 +12,7 @@ import dgl
from ..base import ALL, NID, EID, is_all, DGLError, dgl_warning from ..base import ALL, NID, EID, is_all, DGLError, dgl_warning
from .. import backend as F from .. import backend as F
from .. import init from .. import init
from ..frame import FrameRef, Frame, Scheme, sync_frame_initializer from .frame import FrameRef, Frame, Scheme, sync_frame_initializer
from .. import graph_index from .. import graph_index
from .runtime import ir, scheduler, Runtime, GraphAdapter from .runtime import ir, scheduler, Runtime, GraphAdapter
from .. import utils from .. import utils
......
...@@ -5,7 +5,7 @@ from .._ffi.object import register_object, ObjectBase ...@@ -5,7 +5,7 @@ from .._ffi.object import register_object, ObjectBase
from .._ffi.function import _init_api from .._ffi.function import _init_api
from ..base import ALL, is_all, DGLError, dgl_warning from ..base import ALL, is_all, DGLError, dgl_warning
from .. import backend as F from .. import backend as F
from ..frame import Frame, FrameRef from .frame import Frame, FrameRef
from .graph import DGLBaseGraph from .graph import DGLBaseGraph
from ..graph_index import transform_ids from ..graph_index import transform_ids
from .runtime import ir, scheduler, Runtime from .runtime import ir, scheduler, Runtime
......
...@@ -5,7 +5,7 @@ from __future__ import absolute_import ...@@ -5,7 +5,7 @@ from __future__ import absolute_import
from abc import abstractmethod from abc import abstractmethod
from .... import backend as F from .... import backend as F
from ....frame import FrameRef, Frame from ...frame import FrameRef, Frame
from .... import utils from .... import utils
from .program import get_current_prog from .program import get_current_prog
......
...@@ -5,7 +5,7 @@ from ... import utils ...@@ -5,7 +5,7 @@ from ... import utils
from ..._ffi.function import _init_api from ..._ffi.function import _init_api
from ...base import DGLError from ...base import DGLError
from ... import backend as F from ... import backend as F
from ...frame import frame_like, FrameRef from ..frame import frame_like, FrameRef
from ...function.base import BuiltinFunction from ...function.base import BuiltinFunction
from ..udf import EdgeBatch, NodeBatch from ..udf import EdgeBatch, NodeBatch
from ... import ndarray as nd from ... import ndarray as nd
......
...@@ -971,58 +971,6 @@ def pack_padded_tensor(input, lengths): ...@@ -971,58 +971,6 @@ def pack_padded_tensor(input, lengths):
""" """
pass pass
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
"""Computes the sum along segments of a tensor.
Equivalent to tf.unsorted_segment_sum, but seg_id is required to be a
1D tensor.
Parameters
----------
input : Tensor
The input tensor
seg_id : 1D Tensor
The segment IDs whose values are between 0 and n_segs - 1. Should
have the same length as input.
n_segs : int
Number of distinct segments
dim : int
Dimension to sum on
Returns
-------
Tensor
The result
"""
pass
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
"""Computes the mean along segments of a tensor.
Equivalent to tf.unsorted_segment_mean, but seg_id is required to be a
1D tensor.
Note that segments never appeared in seg_id will have results of 0.
Parameters
----------
input : Tensor
The input tensor
seg_id : 1D Tensor
The segment IDs whose values are between 0 and n_segs - 1. Should
have the same length as input.
n_segs : int
Number of distinct segments
dim : int
Dimension to average on
Returns
-------
Tensor
The result
"""
pass
def boolean_mask(input, mask): def boolean_mask(input, mask):
"""Selects elements in x according to the given mask from the first """Selects elements in x according to the given mask from the first
dimension. dimension.
...@@ -1089,6 +1037,26 @@ def clone(input): ...@@ -1089,6 +1037,26 @@ def clone(input):
""" """
pass pass
def clamp(data, min_val, max_val):
"""Clamp all elements in :attr:`input` into the range [min_val, max_val]
and return a resulting tensor.
Parameters
----------
data : Tensor
Input tensor
min_val : Scalar
Min value.
max_val : Scalar
Max value.
Returns
-------
Tensor
The result.
"""
pass
############################################################################### ###############################################################################
# Tensor functions used *only* on index tensor # Tensor functions used *only* on index tensor
# ---------------- # ----------------
......
...@@ -176,10 +176,17 @@ class GSpMM(mx.autograd.Function): ...@@ -176,10 +176,17 @@ class GSpMM(mx.autograd.Function):
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
func = GSpMM(gidx, op, reduce_op) func = GSpMM(gidx, op, reduce_op)
ctx = to_backend_ctx(gidx.ctx) ctx = to_backend_ctx(gidx.ctx)
# XXX(minjie): There is a bug in MXNet's autograd system when one of the inputs
# does not require gradient. Although it still invokes the backward function,
# it does not set the gradient value to the correct buffer, resulting all the
# input gradients to be zero. Fix this by enforcing all the inputs to require
# gradients.
if lhs_data is None: if lhs_data is None:
lhs_data = nd.zeros((1,), ctx=ctx) lhs_data = nd.zeros((1,), ctx=ctx)
lhs_data.attach_grad()
if rhs_data is None: if rhs_data is None:
rhs_data = nd.zeros((1,), ctx=ctx) rhs_data = nd.zeros((1,), ctx=ctx)
rhs_data.attach_grad()
return func(lhs_data, rhs_data) return func(lhs_data, rhs_data)
......
...@@ -304,35 +304,6 @@ def pack_padded_tensor(input, lengths): ...@@ -304,35 +304,6 @@ def pack_padded_tensor(input, lengths):
index = nd.array(index, ctx=ctx) index = nd.array(index, ctx=ctx)
return gather_row(input.reshape(batch_size * max_len, -1), index) return gather_row(input.reshape(batch_size * max_len, -1), index)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
# TODO: support other dimensions
assert dim == 0, 'MXNet only supports segment sum on first dimension'
# Use SPMV to simulate segment sum
ctx = input.context
n_inputs = input.shape[0]
input_shape_suffix = input.shape[1:]
input = input.reshape(n_inputs, -1)
n_range = nd.arange(n_inputs, dtype='int64').as_in_context(input.context)
w_nnz = nd.ones(n_inputs).as_in_context(input.context)
w_nid = nd.stack(seg_id, n_range, axis=0)
w = nd.sparse.csr_matrix((w_nnz, (seg_id, n_range)), (n_segs, n_inputs))
w = w.as_in_context(input.context)
y = nd.dot(w, input)
y = nd.reshape(y, (n_segs,) + input_shape_suffix)
return y
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
# TODO: support other dimensions
assert dim == 0, 'MXNet only supports segment mean on first dimension'
n_ones = nd.ones_like(seg_id).astype(input.dtype)
w = unsorted_1d_segment_sum(n_ones, seg_id, n_segs, 0)
w = nd.clip(w, a_min=1, a_max=np.inf)
y = unsorted_1d_segment_sum(input, seg_id, n_segs, dim)
y = y / w.reshape((-1,) + (1,) * (y.ndim - 1))
return y
def boolean_mask(input, mask): def boolean_mask(input, mask):
return mx.contrib.nd.boolean_mask(input, mask) return mx.contrib.nd.boolean_mask(input, mask)
...@@ -348,6 +319,9 @@ def logical_and(input1, input2): ...@@ -348,6 +319,9 @@ def logical_and(input1, input2):
def clone(input): def clone(input):
return input.copy() return input.copy()
def clamp(data, min_val, max_val):
return nd.clip(data, min_val, max_val)
def unique(input): def unique(input):
# TODO: fallback to numpy is unfortunate # TODO: fallback to numpy is unfortunate
tmp = input.asnumpy() tmp = input.asnumpy()
......
...@@ -106,7 +106,6 @@ class GSpMM(th.autograd.Function): ...@@ -106,7 +106,6 @@ class GSpMM(th.autograd.Function):
else: # max/min else: # max/min
dY = th.zeros((Y.shape[0],) + dZ.shape[1:], dY = th.zeros((Y.shape[0],) + dZ.shape[1:],
dtype=Y.dtype, device=Y.device) dtype=Y.dtype, device=Y.device)
print(X.shape, dZ.shape)
if op in ['mul', 'div']: if op in ['mul', 'div']:
grad = _expand(X, dZ.shape[1:]).gather( grad = _expand(X, dZ.shape[1:]).gather(
0, argX.long()) * dZ 0, argX.long()) * dZ
......
...@@ -245,19 +245,6 @@ def pack_padded_tensor(input, lengths): ...@@ -245,19 +245,6 @@ def pack_padded_tensor(input, lengths):
index = th.tensor(index).to(device) index = th.tensor(index).to(device)
return gather_row(input.view(batch_size * max_len, -1), index) return gather_row(input.view(batch_size * max_len, -1), index)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
y = th.zeros(n_segs, *input.shape[1:]).to(input)
seg_id = seg_id.view((-1,) + (1,) * (input.dim() - 1)).expand_as(input)
y = y.scatter_add_(dim, seg_id, input)
return y
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
w = unsorted_1d_segment_sum(th.ones_like(seg_id), seg_id, n_segs, 0).to(input)
w = w.clamp(min=1) # remove 0 entries
y = unsorted_1d_segment_sum(input, seg_id, n_segs, dim)
y = y / w.view((-1,) + (1,) * (y.dim() - 1))
return y
def boolean_mask(input, mask): def boolean_mask(input, mask):
if 'bool' not in str(mask.dtype): if 'bool' not in str(mask.dtype):
mask = th.tensor(mask, dtype=th.bool) mask = th.tensor(mask, dtype=th.bool)
...@@ -275,6 +262,9 @@ def logical_and(input1, input2): ...@@ -275,6 +262,9 @@ def logical_and(input1, input2):
def clone(input): def clone(input):
return input.clone() return input.clone()
def clamp(data, min_val, max_val):
return th.clamp(data, min_val, max_val)
def unique(input): def unique(input):
if input.dtype == th.bool: if input.dtype == th.bool:
input = input.type(th.int8) input = input.type(th.int8)
......
...@@ -283,7 +283,11 @@ def narrow_row(x, start, stop): ...@@ -283,7 +283,11 @@ def narrow_row(x, start, stop):
def scatter_row(data, row_index, value): def scatter_row(data, row_index, value):
row_index = tf.expand_dims(row_index, 1) row_index = tf.expand_dims(row_index, 1)
return tf.tensor_scatter_nd_update(data, row_index, value) # XXX(minjie): Normally, the copy_to here is unnecessary. However, TF has this
# notorious legacy issue that int32 type data is always on CPU, which will
# crash the program since DGL requires feature data to be on the same device
# as graph structure.
return copy_to(tf.tensor_scatter_nd_update(data, row_index, value), data.device)
def index_add_inplace(data, row_idx, value): def index_add_inplace(data, row_idx, value):
...@@ -366,18 +370,6 @@ def pack_padded_tensor(input, lengths): ...@@ -366,18 +370,6 @@ def pack_padded_tensor(input, lengths):
return tf.concat(out_list, axis=0) return tf.concat(out_list, axis=0)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
assert dim == 0 # Why we need dim for 1d?
return tf.math.unsorted_segment_sum(input, seg_id, n_segs)
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
assert dim == 0 # Why we need dim for 1d?
return tf.math.unsorted_segment_mean(input, seg_id, n_segs)
# TODO: TF has unsorted_segment_max, which can accelerate _max_on on batched graph
def boolean_mask(input, mask): def boolean_mask(input, mask):
return tf.boolean_mask(input, mask) return tf.boolean_mask(input, mask)
...@@ -396,6 +388,9 @@ def clone(input): ...@@ -396,6 +388,9 @@ def clone(input):
# TF tensor is always immutable so returning the input is safe. # TF tensor is always immutable so returning the input is safe.
return input return input
def clamp(data, min_val, max_val):
return tf.clip_by_value(data, min_val, max_val)
def unique(input): def unique(input):
return tf.unique(input).y return tf.unique(input).y
......
"""Implementation for core graph computation."""
# pylint: disable=not-callable
from .base import DGLError, is_all, NID, EID, ALL
from . import backend as F
from . import function as fn
from .frame import Frame
from .udf import NodeBatch, EdgeBatch
from . import ops
def is_builtin(func):
"""Return true if the function is a DGL builtin function."""
return isinstance(func, fn.BuiltinFunction)
def invoke_node_udf(graph, nid, ntype, func, *, ndata=None, orig_nid=None):
"""Invoke user-defined node function on the given nodes.
Parameters
----------
graph : DGLGraph
The input graph.
eid : Tensor
The IDs of the nodes to invoke UDF on.
ntype : str
Node type.
func : callable
The user-defined function.
ndata : dict[str, Tensor], optional
If provided, apply the UDF on this ndata instead of the ndata of the graph.
orig_nid : Tensor, optional
Original node IDs. Useful if the input graph is an extracted subgraph.
Returns
-------
dict[str, Tensor]
Results from running the UDF.
"""
ntid = graph.get_ntype_id(ntype)
if ndata is None:
if is_all(nid):
ndata = graph._node_frames[ntid]
nid = graph.nodes(ntype=ntype)
else:
ndata = graph._node_frames[ntid].subframe(nid)
nbatch = NodeBatch(graph, nid if orig_nid is None else orig_nid, ntype, ndata)
return func(nbatch)
def invoke_edge_udf(graph, eid, etype, func, *, orig_eid=None):
"""Invoke user-defined edge function on the given edges.
Parameters
----------
graph : DGLGraph
The input graph.
eid : Tensor
The IDs of the edges to invoke UDF on.
etype : (str, str, str)
Edge type.
func : callable
The user-defined function.
orig_eid : Tensor, optional
Original edge IDs. Useful if the input graph is an extracted subgraph.
Returns
-------
dict[str, Tensor]
Results from running the UDF.
"""
etid = graph.get_etype_id(etype)
stid, dtid = graph._graph.metagraph.find_edge(etid)
if is_all(eid):
u, v, eid = graph.edges(form='all')
edata = graph._edge_frames[etid]
else:
u, v = graph.find_edges(eid)
edata = graph._edge_frames[etid].subframe(eid)
srcdata = graph._node_frames[stid].subframe(u)
dstdata = graph._node_frames[dtid].subframe(v)
ebatch = EdgeBatch(graph, eid if orig_eid is None else orig_eid,
etype, srcdata, edata, dstdata)
return func(ebatch)
def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None):
"""Invoke user-defined reduce function on all the nodes in the graph.
It analyzes the graph, groups nodes by their degrees and applies the UDF on each
group -- a strategy called *degree-bucketing*.
Parameters
----------
graph : DGLGraph
The input graph.
func : callable
The user-defined function.
msgdata : dict[str, Tensor]
Message data.
orig_nid : Tensor, optional
Original node IDs. Useful if the input graph is an extracted subgraph.
Returns
-------
dict[str, Tensor]
Results from running the UDF.
"""
degs = graph.in_degrees()
nodes = graph.dstnodes()
if orig_nid is None:
orig_nid = nodes
ntype = graph.dsttypes[0]
ntid = graph.get_ntype_id_from_dst(ntype)
dstdata = graph._node_frames[ntid]
msgdata = Frame(msgdata)
# degree bucketing
unique_degs, bucketor = _bucketing(degs)
bkt_rsts = []
bkt_nodes = []
for deg, node_bkt, orig_nid_bkt in zip(unique_degs, bucketor(nodes), bucketor(orig_nid)):
if deg == 0:
# skip reduce function for zero-degree nodes
continue
bkt_nodes.append(node_bkt)
ndata_bkt = dstdata.subframe(node_bkt)
eid_bkt = graph.in_edges(node_bkt, form='eid')
assert len(eid_bkt) == deg * len(node_bkt)
msgdata_bkt = msgdata.subframe(eid_bkt)
# reshape all msg tensors to (num_nodes_bkt, degree, feat_size)
maildata = {}
for k, msg in msgdata_bkt.items():
newshape = (len(node_bkt), deg) + F.shape(msg)[1:]
maildata[k] = F.reshape(msg, newshape)
# invoke udf
nbatch = NodeBatch(graph, orig_nid_bkt, ntype, ndata_bkt, msgs=maildata)
bkt_rsts.append(func(nbatch))
# prepare a result frame
retf = Frame(num_rows=len(nodes))
retf._initializers = dstdata._initializers
retf._default_initializer = dstdata._default_initializer
# merge bucket results and write to the result frame
if len(bkt_rsts) != 0: # if all the nodes have zero degree, no need to merge results.
merged_rst = {}
for k in bkt_rsts[0].keys():
merged_rst[k] = F.cat([rst[k] for rst in bkt_rsts], dim=0)
merged_nodes = F.cat(bkt_nodes, dim=0)
retf.update_row(merged_nodes, merged_rst)
return retf
def _bucketing(val):
"""Internal function to create groups on the values.
Parameters
----------
val : Tensor
Value tensor.
Returns
-------
unique_val : Tensor
Unique values.
bucketor : callable[Tensor -> list[Tensor]]
A bucketing function that splits the given tensor data as the same
way of how the :attr:`val` tensor is grouped.
"""
sorted_val, idx = F.sort_1d(val)
unique_val = F.asnumpy(F.unique(sorted_val))
bkt_idx = []
for v in unique_val:
eqidx = F.nonzero_1d(F.equal(sorted_val, v))
bkt_idx.append(F.gather_row(idx, eqidx))
def bucketor(data):
bkts = [F.gather_row(data, idx) for idx in bkt_idx]
return bkts
return unique_val, bucketor
def invoke_gsddmm(graph, func):
"""Invoke g-SDDMM computation on the graph.
Parameters
----------
graph : DGLGraph
The input graph.
func : dgl.function.BaseMessageFunction
Built-in message function.
Returns
-------
dict[str, Tensor]
Results from the g-SDDMM computation.
"""
alldata = [graph.srcdata, graph.dstdata, graph.edata]
if isinstance(func, fn.BinaryMessageFunction):
x = alldata[func.lhs][func.lhs_field]
y = alldata[func.rhs][func.rhs_field]
op = getattr(ops, func.name)
z = op(graph, x, y)
else:
x = alldata[func.target][func.in_field]
op = getattr(ops, func.name)
z = op(graph, x)
return {func.out_field : z}
def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None):
"""Invoke g-SPMM computation on the graph.
Parameters
----------
graph : DGLGraph
The input graph.
mfunc : dgl.function.BaseMessageFunction
Built-in message function.
rfunc : dgl.function.BaseReduceFunction
Built-in reduce function.
srcdata : dict[str, Tensor], optional
Source node feature data. If not provided, it use ``graph.srcdata``.
dstdata : dict[str, Tensor], optional
Destination node feature data. If not provided, it use ``graph.dstdata``.
edata : dict[str, Tensor], optional
Edge feature data. If not provided, it use ``graph.edata``.
Returns
-------
dict[str, Tensor]
Results from the g-SPMM computation.
"""
# sanity check
if mfunc.out_field != rfunc.msg_field:
raise DGLError('Invalid message ({}) and reduce ({}) function pairs.'
' The output field of the message function must be equal to the'
' message field of the reduce function.'.format(mfunc, rfunc))
if edata is None:
edata = graph.edata
if srcdata is None:
srcdata = graph.srcdata
if dstdata is None:
dstdata = graph.dstdata
alldata = [srcdata, dstdata, edata]
if isinstance(mfunc, fn.BinaryMessageFunction):
x = alldata[mfunc.lhs][mfunc.lhs_field]
y = alldata[mfunc.rhs][mfunc.rhs_field]
op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name))
z = op(graph, x, y)
else:
x = alldata[mfunc.target][mfunc.in_field]
op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name))
z = op(graph, x)
return {rfunc.out_field : z}
def message_passing(g, mfunc, rfunc, afunc):
"""Invoke message passing computation on the whole graph.
Parameters
----------
g : DGLGraph
The input graph.
mfunc : callable or dgl.function.BuiltinFunction
Message function.
rfunc : callable or dgl.function.BuiltinFunction
Reduce function.
afunc : callable or dgl.function.BuiltinFunction
Apply function.
Returns
-------
dict[str, Tensor]
Results from the message passing computation.
"""
if g.number_of_edges() == 0:
# No message passing is triggered.
ndata = {}
elif (is_builtin(mfunc) and is_builtin(rfunc) and
getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name), None) is not None):
# invoke fused message passing
ndata = invoke_gspmm(g, mfunc, rfunc)
else:
# invoke message passing in two separate steps
# message phase
if is_builtin(mfunc):
msgdata = invoke_gsddmm(g, mfunc)
else:
orig_eid = g.edata.get(EID, None)
msgdata = invoke_edge_udf(g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid)
# reduce phase
if is_builtin(rfunc):
msg = rfunc.msg_field
ndata = invoke_gspmm(g, fn.copy_e(msg, msg), rfunc, edata=msgdata)
else:
orig_nid = g.dstdata.get(NID, None)
ndata = invoke_udf_reduce(g, rfunc, msgdata, orig_nid=orig_nid)
# apply phase
if afunc is not None:
for k, v in g.dstdata.items(): # include original node features
if k not in ndata:
ndata[k] = v
orig_nid = g.dstdata.get(NID, None)
ndata = invoke_node_udf(g, ALL, g.dsttypes[0], afunc, ndata=ndata, orig_nid=orig_nid)
return ndata
"""For HeteroGraph Serialization""" """For HeteroGraph Serialization"""
from __future__ import absolute_import from __future__ import absolute_import
from ..heterograph import DGLHeteroGraph from ..heterograph import DGLHeteroGraph
from ..frame import Frame, FrameRef from ..frame import Frame
from .._ffi.object import ObjectBase, register_object from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
...@@ -51,10 +51,10 @@ class HeteroGraphData(ObjectBase): ...@@ -51,10 +51,10 @@ class HeteroGraphData(ObjectBase):
eframes = [] eframes = []
for ntid, ntensor in enumerate(ntensor_list): for ntid, ntensor in enumerate(ntensor_list):
ndict = {ntensor[i]: F.zerocopy_from_dgl_ndarray(ntensor[i+1]) for i in range(0, len(ntensor), 2)} ndict = {ntensor[i]: F.zerocopy_from_dgl_ndarray(ntensor[i+1]) for i in range(0, len(ntensor), 2)}
nframes.append(FrameRef(Frame(ndict, num_rows=gidx.number_of_nodes(ntid)))) nframes.append(Frame(ndict, num_rows=gidx.number_of_nodes(ntid)))
for etid, etensor in enumerate(etensor_list): for etid, etensor in enumerate(etensor_list):
edict = {etensor[i]: F.zerocopy_from_dgl_ndarray(etensor[i+1]) for i in range(0, len(etensor), 2)} edict = {etensor[i]: F.zerocopy_from_dgl_ndarray(etensor[i+1]) for i in range(0, len(etensor), 2)}
eframes.append(FrameRef(Frame(edict, num_rows=gidx.number_of_edges(etid)))) eframes.append(Frame(edict, num_rows=gidx.number_of_edges(etid)))
return DGLHeteroGraph(gidx, ntype_names, etype_names, nframes, eframes) return DGLHeteroGraph(gidx, ntype_names, etype_names, nframes, eframes)
...@@ -3,7 +3,7 @@ from collections import namedtuple ...@@ -3,7 +3,7 @@ from collections import namedtuple
from .rpc import Request, Response, send_requests_to_machine, recv_responses from .rpc import Request, Response, send_requests_to_machine, recv_responses
from ..sampling import sample_neighbors as local_sample_neighbors from ..sampling import sample_neighbors as local_sample_neighbors
from ..transform import in_subgraph as local_in_subgraph from ..subgraph import in_subgraph as local_in_subgraph
from .rpc import register_service from .rpc import register_service
from ..convert import graph from ..convert import graph
from ..base import NID, EID from ..base import NID, EID
......
This diff is collapsed.
...@@ -15,9 +15,9 @@ class TargetCode(object): ...@@ -15,9 +15,9 @@ class TargetCode(object):
EDGE = 2 EDGE = 2
CODE2STR = { CODE2STR = {
0: "src", 0: "u",
1: "dst", 1: "v",
2: "edge", 2: "e",
} }
......
...@@ -9,7 +9,8 @@ from .._deprecate.runtime import ir ...@@ -9,7 +9,8 @@ from .._deprecate.runtime import ir
from .._deprecate.runtime.ir import var from .._deprecate.runtime.ir import var
__all__ = ["src_mul_edge", "copy_src", "copy_edge", "copy_u", "copy_e"] __all__ = ["src_mul_edge", "copy_src", "copy_edge", "copy_u", "copy_e",
"BinaryMessageFunction", "CopyMessageFunction"]
class MessageFunction(BuiltinFunction): class MessageFunction(BuiltinFunction):
......
...@@ -87,7 +87,7 @@ __all__ = [] ...@@ -87,7 +87,7 @@ __all__ = []
def _register_builtin_reduce_func(): def _register_builtin_reduce_func():
"""Register builtin reduce functions""" """Register builtin reduce functions"""
for reduce_op in ["max", "min", "sum", "mean", "prod"]: for reduce_op in ["max", "min", "sum", "mean"]:
builtin = _gen_reduce_builtin(reduce_op) builtin = _gen_reduce_builtin(reduce_op)
setattr(sys.modules[__name__], reduce_op, builtin) setattr(sys.modules[__name__], reduce_op, builtin)
__all__.append(reduce_op) __all__.append(reduce_op)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment