Unverified Commit c7c0fd0e authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Accessing node/edge types in UDFs. (#1243)

* typed udf

* docstrings
parent b133abb8
...@@ -3564,3 +3564,8 @@ class AdaptedDGLGraph(GraphAdapter): ...@@ -3564,3 +3564,8 @@ class AdaptedDGLGraph(GraphAdapter):
def bits_needed(self): def bits_needed(self):
return self.graph._graph.bits_needed() return self.graph._graph.bits_needed()
@property
def canonical_etype(self):
"""Canonical edge type (None for homogeneous graph)"""
return (None, None, None)
...@@ -2276,7 +2276,7 @@ class DGLHeteroGraph(object): ...@@ -2276,7 +2276,7 @@ class DGLHeteroGraph(object):
v_ntype = utils.toindex(v) v_ntype = utils.toindex(v)
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_apply_nodes(v_ntype, func, self._node_frames[ntid], scheduler.schedule_apply_nodes(v_ntype, func, self._node_frames[ntid],
inplace=inplace) inplace=inplace, ntype=self._ntypes[ntid])
Runtime.run(prog) Runtime.run(prog)
def apply_edges(self, func, edges=ALL, etype=None, inplace=False): def apply_edges(self, func, edges=ALL, etype=None, inplace=False):
...@@ -3496,7 +3496,7 @@ class DGLHeteroGraph(object): ...@@ -3496,7 +3496,7 @@ class DGLHeteroGraph(object):
v = utils.toindex(nodes) v = utils.toindex(nodes)
n_repr = self._get_n_repr(ntid, v) n_repr = self._get_n_repr(ntid, v)
nbatch = NodeBatch(v, n_repr) nbatch = NodeBatch(v, n_repr, ntype=self.ntypes[ntid])
n_mask = F.copy_to(predicate(nbatch), F.cpu()) n_mask = F.copy_to(predicate(nbatch), F.cpu())
if is_all(nodes): if is_all(nodes):
...@@ -3557,7 +3557,8 @@ class DGLHeteroGraph(object): ...@@ -3557,7 +3557,8 @@ class DGLHeteroGraph(object):
src_data = self._get_n_repr(stid, u) src_data = self._get_n_repr(stid, u)
edge_data = self._get_e_repr(etid, eid) edge_data = self._get_e_repr(etid, eid)
dst_data = self._get_n_repr(dtid, v) dst_data = self._get_n_repr(dtid, v)
ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data) ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data,
canonical_etype=self.canonical_etypes[etid])
e_mask = F.copy_to(predicate(ebatch), F.cpu()) e_mask = F.copy_to(predicate(ebatch), F.cpu())
if is_all(edges): if is_all(edges):
...@@ -3988,3 +3989,8 @@ class AdaptedHeteroGraph(GraphAdapter): ...@@ -3988,3 +3989,8 @@ class AdaptedHeteroGraph(GraphAdapter):
def bits_needed(self): def bits_needed(self):
return self.graph._graph.bits_needed(self.etid) return self.graph._graph.bits_needed(self.etid)
@property
def canonical_etype(self):
"""Canonical edge type."""
return self.graph.canonical_etypes[self.etid]
...@@ -1017,6 +1017,12 @@ class NodeFlow(DGLBaseGraph): ...@@ -1017,6 +1017,12 @@ class NodeFlow(DGLBaseGraph):
self.block_compute(i, message_func, reduce_func, apply_node_func, self.block_compute(i, message_func, reduce_func, apply_node_func,
inplace=inplace) inplace=inplace)
@property
def canonical_etype(self):
"""Return canonical edge type to be compatible with GraphAdapter
"""
return (None, None, None)
def _copy_to_like(arr1, arr2): def _copy_to_like(arr1, arr2):
return F.copy_to(arr1, F.context(arr2)) return F.copy_to(arr1, F.context(arr2))
......
...@@ -16,7 +16,8 @@ def gen_degree_bucketing_schedule( ...@@ -16,7 +16,8 @@ def gen_degree_bucketing_schedule(
recv_nodes, recv_nodes,
var_nf, var_nf,
var_mf, var_mf,
var_out): var_out,
ntype=None):
"""Create degree bucketing schedule. """Create degree bucketing schedule.
The messages will be divided by their receivers into buckets. Each bucket The messages will be divided by their receivers into buckets. Each bucket
...@@ -44,6 +45,9 @@ def gen_degree_bucketing_schedule( ...@@ -44,6 +45,9 @@ def gen_degree_bucketing_schedule(
The variable for message frame. The variable for message frame.
var_out : var.FEAT_DICT var_out : var.FEAT_DICT
The variable for output feature dicts. The variable for output feature dicts.
ntype : str, optional
The node type, if running on a heterograph.
If None, assuming it's running on a homogeneous graph.
""" """
buckets = _degree_bucketing_schedule(message_ids, dst_nodes, recv_nodes) buckets = _degree_bucketing_schedule(message_ids, dst_nodes, recv_nodes)
# generate schedule # generate schedule
...@@ -53,7 +57,7 @@ def gen_degree_bucketing_schedule( ...@@ -53,7 +57,7 @@ def gen_degree_bucketing_schedule(
fd_list = [] fd_list = []
for deg, vbkt, mid in zip(degs, buckets, msg_ids): for deg, vbkt, mid in zip(degs, buckets, msg_ids):
# create per-bkt rfunc # create per-bkt rfunc
rfunc = _create_per_bkt_rfunc(reduce_udf, deg, vbkt) rfunc = _create_per_bkt_rfunc(reduce_udf, deg, vbkt, ntype=ntype)
# vars # vars
vbkt = var.IDX(vbkt) vbkt = var.IDX(vbkt)
mid = var.IDX(mid) mid = var.IDX(mid)
...@@ -141,7 +145,7 @@ def _process_node_buckets(buckets): ...@@ -141,7 +145,7 @@ def _process_node_buckets(buckets):
return v, degs, dsts, msg_ids, zero_deg_nodes return v, degs, dsts, msg_ids, zero_deg_nodes
def _create_per_bkt_rfunc(reduce_udf, deg, vbkt): def _create_per_bkt_rfunc(reduce_udf, deg, vbkt, ntype=None):
"""Internal function to generate the per degree bucket node UDF.""" """Internal function to generate the per degree bucket node UDF."""
def _rfunc_wrapper(node_data, mail_data): def _rfunc_wrapper(node_data, mail_data):
def _reshaped_getter(key): def _reshaped_getter(key):
...@@ -149,7 +153,7 @@ def _create_per_bkt_rfunc(reduce_udf, deg, vbkt): ...@@ -149,7 +153,7 @@ def _create_per_bkt_rfunc(reduce_udf, deg, vbkt):
new_shape = (len(vbkt), deg) + F.shape(msg)[1:] new_shape = (len(vbkt), deg) + F.shape(msg)[1:]
return F.reshape(msg, new_shape) return F.reshape(msg, new_shape)
reshaped_mail_data = utils.LazyDict(_reshaped_getter, mail_data.keys()) reshaped_mail_data = utils.LazyDict(_reshaped_getter, mail_data.keys())
nbatch = NodeBatch(vbkt, node_data, reshaped_mail_data) nbatch = NodeBatch(vbkt, node_data, reshaped_mail_data, ntype=ntype)
return reduce_udf(nbatch) return reduce_udf(nbatch)
return _rfunc_wrapper return _rfunc_wrapper
...@@ -160,7 +164,8 @@ def gen_group_apply_edge_schedule( ...@@ -160,7 +164,8 @@ def gen_group_apply_edge_schedule(
var_src_nf, var_src_nf,
var_dst_nf, var_dst_nf,
var_ef, var_ef,
var_out): var_out,
canonical_etype=(None, None, None)):
"""Create degree bucketing schedule for group_apply_edge """Create degree bucketing schedule for group_apply_edge
Edges will be grouped by either its source node or destination node Edges will be grouped by either its source node or destination node
...@@ -189,6 +194,9 @@ def gen_group_apply_edge_schedule( ...@@ -189,6 +194,9 @@ def gen_group_apply_edge_schedule(
The variable for edge frame. The variable for edge frame.
var_out : var.FEAT_DICT var_out : var.FEAT_DICT
The variable for output feature dicts. The variable for output feature dicts.
canonical_etype : tuple[str, str, str], optional
Canonical edge type if running on a heterograph.
Default: (None, None, None), if running on a homogeneous graph.
""" """
if group_by == "src": if group_by == "src":
buckets = _degree_bucketing_for_edge_grouping(u, v, eid) buckets = _degree_bucketing_for_edge_grouping(u, v, eid)
...@@ -204,7 +212,8 @@ def gen_group_apply_edge_schedule( ...@@ -204,7 +212,8 @@ def gen_group_apply_edge_schedule(
for deg, u_bkt, v_bkt, eid_bkt in zip(degs, uids, vids, eids): for deg, u_bkt, v_bkt, eid_bkt in zip(degs, uids, vids, eids):
# create per-bkt efunc # create per-bkt efunc
_efunc = var.FUNC(_create_per_bkt_efunc(apply_func, deg, _efunc = var.FUNC(_create_per_bkt_efunc(apply_func, deg,
u_bkt, v_bkt, eid_bkt)) u_bkt, v_bkt, eid_bkt,
canonical_etype=canonical_etype))
# vars # vars
var_u = var.IDX(u_bkt) var_u = var.IDX(u_bkt)
var_v = var.IDX(v_bkt) var_v = var.IDX(v_bkt)
...@@ -274,7 +283,7 @@ def _process_edge_buckets(buckets): ...@@ -274,7 +283,7 @@ def _process_edge_buckets(buckets):
eids = split(eids) eids = split(eids)
return degs, uids, vids, eids return degs, uids, vids, eids
def _create_per_bkt_efunc(apply_func, deg, u, v, eid): def _create_per_bkt_efunc(apply_func, deg, u, v, eid, canonical_etype=(None, None, None)):
"""Internal function to generate the per degree bucket edge UDF.""" """Internal function to generate the per degree bucket edge UDF."""
batch_size = len(u) // deg batch_size = len(u) // deg
def _efunc_wrapper(src_data, edge_data, dst_data): def _efunc_wrapper(src_data, edge_data, dst_data):
...@@ -297,7 +306,8 @@ def _create_per_bkt_efunc(apply_func, deg, u, v, eid): ...@@ -297,7 +306,8 @@ def _create_per_bkt_efunc(apply_func, deg, u, v, eid):
reshaped_dst_data = utils.LazyDict(_reshape_func(dst_data), reshaped_dst_data = utils.LazyDict(_reshape_func(dst_data),
dst_data.keys()) dst_data.keys())
ebatch = EdgeBatch((u, v, eid), reshaped_src_data, ebatch = EdgeBatch((u, v, eid), reshaped_src_data,
reshaped_edge_data, reshaped_dst_data) reshaped_edge_data, reshaped_dst_data,
canonical_etype=canonical_etype)
return {k: _reshape_back(v) for k, v in apply_func(ebatch).items()} return {k: _reshape_back(v) for k, v in apply_func(ebatch).items()}
return _efunc_wrapper return _efunc_wrapper
......
...@@ -104,7 +104,7 @@ def schedule_recv(graph, ...@@ -104,7 +104,7 @@ def schedule_recv(graph,
# 2) no send has been called # 2) no send has been called
if apply_func is not None: if apply_func is not None:
schedule_apply_nodes(recv_nodes, apply_func, graph.dstframe, schedule_apply_nodes(recv_nodes, apply_func, graph.dstframe,
inplace, outframe) inplace, outframe, ntype=graph.canonical_etype[-1])
else: else:
var_dst_nf = var.FEAT_DICT(graph.dstframe, 'dst_nf') var_dst_nf = var.FEAT_DICT(graph.dstframe, 'dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf') var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
...@@ -117,7 +117,8 @@ def schedule_recv(graph, ...@@ -117,7 +117,8 @@ def schedule_recv(graph,
recv_nodes) recv_nodes)
# apply # apply
final_feat = _apply_with_accum(var_recv_nodes, var_dst_nf, final_feat = _apply_with_accum(var_recv_nodes, var_dst_nf,
reduced_feat, apply_func) reduced_feat, apply_func,
ntype=graph.canonical_etype[-1])
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_out_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_INPLACE_(var_out_nf, var_recv_nodes, final_feat)
else: else:
...@@ -182,10 +183,11 @@ def schedule_snr(graph, ...@@ -182,10 +183,11 @@ def schedule_snr(graph,
var_reduce_nodes=var_recv_nodes, var_reduce_nodes=var_recv_nodes,
uv_getter=uv_getter, uv_getter=uv_getter,
adj_creator=adj_creator, adj_creator=adj_creator,
out_map_creator=out_map_creator) out_map_creator=out_map_creator,
canonical_etype=graph.canonical_etype)
# generate apply schedule # generate apply schedule
final_feat = _apply_with_accum(var_recv_nodes, var_dst_nf, reduced_feat, final_feat = _apply_with_accum(var_recv_nodes, var_dst_nf, reduced_feat,
apply_func) apply_func, ntype=graph.canonical_etype[-1])
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_out_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_INPLACE_(var_out_nf, var_recv_nodes, final_feat)
else: else:
...@@ -216,7 +218,8 @@ def schedule_update_all(graph, ...@@ -216,7 +218,8 @@ def schedule_update_all(graph,
if apply_func is not None: if apply_func is not None:
nodes = utils.toindex(slice(0, graph.num_dst())) nodes = utils.toindex(slice(0, graph.num_dst()))
schedule_apply_nodes(nodes, apply_func, graph.dstframe, schedule_apply_nodes(nodes, apply_func, graph.dstframe,
inplace=False, outframe=outframe) inplace=False, outframe=outframe,
ntype=graph.canonical_etype[-1])
else: else:
eid = utils.toindex(slice(0, graph.num_edges())) # ALL eid = utils.toindex(slice(0, graph.num_edges())) # ALL
recv_nodes = utils.toindex(slice(0, graph.num_dst())) # ALL recv_nodes = utils.toindex(slice(0, graph.num_dst())) # ALL
...@@ -240,17 +243,20 @@ def schedule_update_all(graph, ...@@ -240,17 +243,20 @@ def schedule_update_all(graph,
var_reduce_nodes=var_recv_nodes, var_reduce_nodes=var_recv_nodes,
uv_getter=uv_getter, uv_getter=uv_getter,
adj_creator=adj_creator, adj_creator=adj_creator,
out_map_creator=out_map_creator) out_map_creator=out_map_creator,
canonical_etype=graph.canonical_etype)
# generate optional apply # generate optional apply
final_feat = _apply_with_accum(var_recv_nodes, var_dst_nf, final_feat = _apply_with_accum(var_recv_nodes, var_dst_nf,
reduced_feat, apply_func) reduced_feat, apply_func,
ntype=graph.canonical_etype[-1])
ir.WRITE_DICT_(var_out_nf, final_feat) ir.WRITE_DICT_(var_out_nf, final_feat)
def schedule_apply_nodes(v, def schedule_apply_nodes(v,
apply_func, apply_func,
node_frame, node_frame,
inplace, inplace,
outframe=None): outframe=None,
ntype=None):
"""Get apply nodes schedule """Get apply nodes schedule
Parameters Parameters
...@@ -265,6 +271,9 @@ def schedule_apply_nodes(v, ...@@ -265,6 +271,9 @@ def schedule_apply_nodes(v,
If True, the update will be done in place If True, the update will be done in place
outframe : FrameRef, optional outframe : FrameRef, optional
The storage to write output data. If None, use the given node_frame. The storage to write output data. If None, use the given node_frame.
ntype : str, optional
The node type, if running on a heterograph.
If None, assuming it's running on a homogeneous graph.
Returns Returns
------- -------
...@@ -275,7 +284,7 @@ def schedule_apply_nodes(v, ...@@ -275,7 +284,7 @@ def schedule_apply_nodes(v,
var_out_nf = var_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf') var_out_nf = var_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
v_nf = ir.READ_ROW(var_nf, var_v) v_nf = ir.READ_ROW(var_nf, var_v)
def _afunc_wrapper(node_data): def _afunc_wrapper(node_data):
nbatch = NodeBatch(v, node_data) nbatch = NodeBatch(v, node_data, ntype=ntype)
return apply_func(nbatch) return apply_func(nbatch)
afunc = var.FUNC(_afunc_wrapper) afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf) applied_feat = ir.NODE_UDF(afunc, v_nf)
...@@ -472,7 +481,8 @@ def schedule_pull(graph, ...@@ -472,7 +481,8 @@ def schedule_pull(graph,
if len(eid) == 0: if len(eid) == 0:
# All the nodes are 0deg; downgrades to apply. # All the nodes are 0deg; downgrades to apply.
if apply_func is not None: if apply_func is not None:
schedule_apply_nodes(pull_nodes, apply_func, graph.dstframe, inplace, outframe) schedule_apply_nodes(pull_nodes, apply_func, graph.dstframe, inplace,
outframe, ntype=graph.canonical_etype[-1])
else: else:
pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor())) pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor()))
pull_nodes = utils.toindex(pull_nodes) pull_nodes = utils.toindex(pull_nodes)
...@@ -492,10 +502,12 @@ def schedule_pull(graph, ...@@ -492,10 +502,12 @@ def schedule_pull(graph,
graph.dstframe, graph.edgeframe, graph.dstframe, graph.edgeframe,
message_func, reduce_func, var_eid, message_func, reduce_func, var_eid,
var_pull_nodes, uv_getter, adj_creator, var_pull_nodes, uv_getter, adj_creator,
out_map_creator) out_map_creator,
canonical_etype=graph.canonical_etype)
# generate optional apply # generate optional apply
final_feat = _apply_with_accum(var_pull_nodes, var_dst_nf, final_feat = _apply_with_accum(var_pull_nodes, var_dst_nf,
reduced_feat, apply_func) reduced_feat, apply_func,
ntype=graph.canonical_etype[-1])
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_out_nf, var_pull_nodes, final_feat) ir.WRITE_ROW_INPLACE_(var_out_nf, var_pull_nodes, final_feat)
else: else:
...@@ -535,7 +547,8 @@ def schedule_group_apply_edge(graph, ...@@ -535,7 +547,8 @@ def schedule_group_apply_edge(graph,
var_out_ef = var_ef if outframe is None else var.FEAT_DICT(outframe, name='out_ef') var_out_ef = var_ef if outframe is None else var.FEAT_DICT(outframe, name='out_ef')
var_out = var.FEAT_DICT(name='new_ef') var_out = var.FEAT_DICT(name='new_ef')
db.gen_group_apply_edge_schedule(apply_func, u, v, eid, group_by, db.gen_group_apply_edge_schedule(apply_func, u, v, eid, group_by,
var_src_nf, var_dst_nf, var_ef, var_out) var_src_nf, var_dst_nf, var_ef, var_out,
canonical_etype=graph.canonical_etype)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_out_ef, var_eid, var_out) ir.WRITE_ROW_INPLACE_(var_out_ef, var_eid, var_out)
...@@ -700,7 +713,7 @@ def _standardize_func_usage(func, func_name): ...@@ -700,7 +713,7 @@ def _standardize_func_usage(func, func_name):
' Got: %s' % (func_name, str(func))) ' Got: %s' % (func_name, str(func)))
return func return func
def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func): def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func, ntype=None):
"""Apply with accumulated features. """Apply with accumulated features.
Paramters Paramters
...@@ -713,6 +726,9 @@ def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func): ...@@ -713,6 +726,9 @@ def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func):
The accumulated features. The accumulated features.
apply_func : callable, None apply_func : callable, None
The apply function. The apply function.
ntype : str, optional
The node type, if running on a heterograph.
If None, assuming it's running on a homogeneous graph.
""" """
if apply_func: if apply_func:
# To avoid writing reduced features back to node frame and reading # To avoid writing reduced features back to node frame and reading
...@@ -722,7 +738,7 @@ def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func): ...@@ -722,7 +738,7 @@ def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func):
v_nf = ir.UPDATE_DICT(v_nf, var_accum) v_nf = ir.UPDATE_DICT(v_nf, var_accum)
def _afunc_wrapper(node_data): def _afunc_wrapper(node_data):
nbatch = NodeBatch(var_nodes.data, node_data) nbatch = NodeBatch(var_nodes.data, node_data, ntype=ntype)
return apply_func(nbatch) return apply_func(nbatch)
afunc = var.FUNC(_afunc_wrapper) afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf) applied_feat = ir.NODE_UDF(afunc, v_nf)
...@@ -778,7 +794,8 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes): ...@@ -778,7 +794,8 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
else: else:
# gen degree bucketing schedule for UDF recv # gen degree bucketing schedule for UDF recv
db.gen_degree_bucketing_schedule(rfunc, eid, dst, recv_nodes, db.gen_degree_bucketing_schedule(rfunc, eid, dst, recv_nodes,
var_dst_nf, var_msg, var_out) var_dst_nf, var_msg, var_out,
ntype=graph.canonical_etype[-1])
return var_out return var_out
def _gen_send_reduce( def _gen_send_reduce(
...@@ -791,7 +808,8 @@ def _gen_send_reduce( ...@@ -791,7 +808,8 @@ def _gen_send_reduce(
var_reduce_nodes, var_reduce_nodes,
uv_getter, uv_getter,
adj_creator, adj_creator,
out_map_creator): out_map_creator,
canonical_etype=(None, None, None)):
"""Generate send and reduce schedule. """Generate send and reduce schedule.
The function generates symbolic program for computing The function generates symbolic program for computing
...@@ -832,9 +850,12 @@ def _gen_send_reduce( ...@@ -832,9 +850,12 @@ def _gen_send_reduce(
adj_creator : callable adj_creator : callable
Function that returns the adjmat, edge order of csr matrix, and Function that returns the adjmat, edge order of csr matrix, and
bit-width. bit-width.
out_map_creator: callable out_map_creator : callable
A function that returns a mapping from reduce_nodes to relabeled A function that returns a mapping from reduce_nodes to relabeled
consecutive ids consecutive ids
canonical_etype : tuple[str, str, str], optional
Canonical edge type if running on a heterograph.
Default: (None, None, None), if running on a homogeneous graph.
Returns Returns
------- -------
...@@ -917,7 +938,7 @@ def _gen_send_reduce( ...@@ -917,7 +938,7 @@ def _gen_send_reduce(
else: else:
# generate UDF send schedule # generate UDF send schedule
var_mf = _gen_udf_send(var_src_nf, var_dst_nf, var_ef, var_u, var_mf = _gen_udf_send(var_src_nf, var_dst_nf, var_ef, var_u,
var_v, var_eid, mfunc) var_v, var_eid, mfunc, canonical_etype=canonical_etype)
# 6. Generate reduce # 6. Generate reduce
if rfunc_is_list: if rfunc_is_list:
...@@ -935,17 +956,18 @@ def _gen_send_reduce( ...@@ -935,17 +956,18 @@ def _gen_send_reduce(
mid = utils.toindex(slice(0, len(var_v.data))) mid = utils.toindex(slice(0, len(var_v.data)))
db.gen_degree_bucketing_schedule(rfunc, mid, var_v.data, db.gen_degree_bucketing_schedule(rfunc, mid, var_v.data,
reduce_nodes, var_dst_nf, var_mf, reduce_nodes, var_dst_nf, var_mf,
var_out) var_out, ntype=canonical_etype[-1])
return var_out return var_out
def _gen_udf_send(var_src_nf, var_dst_nf, var_ef, u, v, eid, mfunc): def _gen_udf_send(var_src_nf, var_dst_nf, var_ef, u, v, eid, mfunc,
canonical_etype=(None, None, None)):
"""Internal function to generate send schedule for UDF message function.""" """Internal function to generate send schedule for UDF message function."""
fdsrc = ir.READ_ROW(var_src_nf, u) fdsrc = ir.READ_ROW(var_src_nf, u)
fddst = ir.READ_ROW(var_dst_nf, v) fddst = ir.READ_ROW(var_dst_nf, v)
fdedge = ir.READ_ROW(var_ef, eid) fdedge = ir.READ_ROW(var_ef, eid)
def _mfunc_wrapper(src_data, edge_data, dst_data): def _mfunc_wrapper(src_data, edge_data, dst_data):
ebatch = EdgeBatch((u.data, v.data, eid.data), ebatch = EdgeBatch((u.data, v.data, eid.data),
src_data, edge_data, dst_data) src_data, edge_data, dst_data, canonical_etype=canonical_etype)
return mfunc(ebatch) return mfunc(ebatch)
_mfunc_wrapper = var.FUNC(_mfunc_wrapper) _mfunc_wrapper = var.FUNC(_mfunc_wrapper)
msg = ir.EDGE_UDF(_mfunc_wrapper, fdsrc, fdedge, fddst) msg = ir.EDGE_UDF(_mfunc_wrapper, fdsrc, fdedge, fddst)
...@@ -982,7 +1004,8 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef): ...@@ -982,7 +1004,8 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
else: else:
# UDF send # UDF send
var_out = _gen_udf_send(var_src_nf, var_dst_nf, var_ef, var_u, var_out = _gen_udf_send(var_src_nf, var_dst_nf, var_ef, var_u,
var_v, var_eid, mfunc) var_v, var_eid, mfunc,
canonical_etype=graph.canonical_etype)
return var_out return var_out
def _build_idx_map(idx, nbits): def _build_idx_map(idx, nbits):
......
...@@ -17,12 +17,17 @@ class EdgeBatch(object): ...@@ -17,12 +17,17 @@ class EdgeBatch(object):
dst_data : dict of tensors dst_data : dict of tensors
The dst node features, in the form of ``dict`` The dst node features, in the form of ``dict``
with ``str`` keys and ``tensor`` values with ``str`` keys and ``tensor`` values
canonical_etype : tuple of (str, str, str), optional
Canonical edge type of the edge batch, if UDF is
running on a heterograph.
""" """
def __init__(self, edges, src_data, edge_data, dst_data): def __init__(self, edges, src_data, edge_data, dst_data,
canonical_etype=(None, None, None)):
self._edges = edges self._edges = edges
self._src_data = src_data self._src_data = src_data
self._edge_data = edge_data self._edge_data = edge_data
self._dst_data = dst_data self._dst_data = dst_data
self._canonical_etype = canonical_etype
@property @property
def src(self): def src(self):
...@@ -89,6 +94,12 @@ class EdgeBatch(object): ...@@ -89,6 +94,12 @@ class EdgeBatch(object):
""" """
return self.batch_size() return self.batch_size()
@property
def canonical_etype(self):
"""Return the canonical edge type (i.e. triplet of source, edge, and
destination node type) for this edge batch, if available."""
return self._canonical_etype
class NodeBatch(object): class NodeBatch(object):
"""The class that can represent a batch of nodes. """The class that can represent a batch of nodes.
...@@ -102,11 +113,15 @@ class NodeBatch(object): ...@@ -102,11 +113,15 @@ class NodeBatch(object):
msgs : dict, optional msgs : dict, optional
The messages, , in the form of ``dict`` The messages, , in the form of ``dict``
with ``str`` keys and ``tensor`` values with ``str`` keys and ``tensor`` values
ntype : str, optional
The node type of this node batch, if running
on a heterograph.
""" """
def __init__(self, nodes, data, msgs=None): def __init__(self, nodes, data, msgs=None, ntype=None):
self._nodes = nodes self._nodes = nodes
self._data = data self._data = data
self._msgs = msgs self._msgs = msgs
self._ntype = ntype
@property @property
def data(self): def data(self):
...@@ -160,3 +175,8 @@ class NodeBatch(object): ...@@ -160,3 +175,8 @@ class NodeBatch(object):
int int
""" """
return self.batch_size() return self.batch_size()
@property
def ntype(self):
"""Return the node type of this node batch, if available."""
return self._ntype
...@@ -1412,6 +1412,63 @@ def test_compact(): ...@@ -1412,6 +1412,63 @@ def test_compact():
_check(g2, new_g2, induced_nodes) _check(g2, new_g2, induced_nodes)
def test_types_in_function():
def mfunc1(edges):
assert edges.canonical_etype == ('user', 'follow', 'user')
return {}
def rfunc1(nodes):
assert nodes.ntype == 'user'
return {}
def filter_nodes1(nodes):
assert nodes.ntype == 'user'
return F.zeros((3,))
def filter_edges1(edges):
assert edges.canonical_etype == ('user', 'follow', 'user')
return F.zeros((2,))
def mfunc2(edges):
assert edges.canonical_etype == ('user', 'plays', 'game')
return {}
def rfunc2(nodes):
assert nodes.ntype == 'game'
return {}
def filter_nodes2(nodes):
assert nodes.ntype == 'game'
return F.zeros((3,))
def filter_edges2(edges):
assert edges.canonical_etype == ('user', 'plays', 'game')
return F.zeros((2,))
g = dgl.graph([(0, 1), (1, 2)], 'user', 'follow')
g.apply_nodes(rfunc1)
g.apply_edges(mfunc1)
g.update_all(mfunc1, rfunc1)
g.send_and_recv([0, 1], mfunc1, rfunc1)
g.send([0, 1], mfunc1)
g.recv([1, 2], rfunc1)
g.push([0], mfunc1, rfunc1)
g.pull([1], mfunc1, rfunc1)
g.filter_nodes(filter_nodes1)
g.filter_edges(filter_edges1)
g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
g.apply_nodes(rfunc2, ntype='game')
g.apply_edges(mfunc2)
g.update_all(mfunc2, rfunc2)
g.send_and_recv([0, 1], mfunc2, rfunc2)
g.send([0, 1], mfunc2)
g.recv([1, 2], rfunc2)
g.push([0], mfunc2, rfunc2)
g.pull([1], mfunc2, rfunc2)
g.filter_nodes(filter_nodes2, ntype='game')
g.filter_edges(filter_edges2)
if __name__ == '__main__': if __name__ == '__main__':
test_create() test_create()
test_query() test_query()
...@@ -1433,3 +1490,4 @@ if __name__ == '__main__': ...@@ -1433,3 +1490,4 @@ if __name__ == '__main__':
test_backward() test_backward()
test_empty_heterograph() test_empty_heterograph()
test_compact() test_compact()
test_types_in_function()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment