Unverified Commit c7c0fd0e authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Accessing node/edge types in UDFs. (#1243)

* typed udf

* docstrings
parent b133abb8
......@@ -3564,3 +3564,8 @@ class AdaptedDGLGraph(GraphAdapter):
def bits_needed(self):
return self.graph._graph.bits_needed()
@property
def canonical_etype(self):
"""Canonical edge type (None for homogeneous graph)"""
return (None, None, None)
......@@ -2276,7 +2276,7 @@ class DGLHeteroGraph(object):
v_ntype = utils.toindex(v)
with ir.prog() as prog:
scheduler.schedule_apply_nodes(v_ntype, func, self._node_frames[ntid],
inplace=inplace)
inplace=inplace, ntype=self._ntypes[ntid])
Runtime.run(prog)
def apply_edges(self, func, edges=ALL, etype=None, inplace=False):
......@@ -3496,7 +3496,7 @@ class DGLHeteroGraph(object):
v = utils.toindex(nodes)
n_repr = self._get_n_repr(ntid, v)
nbatch = NodeBatch(v, n_repr)
nbatch = NodeBatch(v, n_repr, ntype=self.ntypes[ntid])
n_mask = F.copy_to(predicate(nbatch), F.cpu())
if is_all(nodes):
......@@ -3557,7 +3557,8 @@ class DGLHeteroGraph(object):
src_data = self._get_n_repr(stid, u)
edge_data = self._get_e_repr(etid, eid)
dst_data = self._get_n_repr(dtid, v)
ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data)
ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data,
canonical_etype=self.canonical_etypes[etid])
e_mask = F.copy_to(predicate(ebatch), F.cpu())
if is_all(edges):
......@@ -3988,3 +3989,8 @@ class AdaptedHeteroGraph(GraphAdapter):
def bits_needed(self):
return self.graph._graph.bits_needed(self.etid)
@property
def canonical_etype(self):
"""Canonical edge type."""
return self.graph.canonical_etypes[self.etid]
......@@ -1017,6 +1017,12 @@ class NodeFlow(DGLBaseGraph):
self.block_compute(i, message_func, reduce_func, apply_node_func,
inplace=inplace)
@property
def canonical_etype(self):
"""Return canonical edge type to be compatible with GraphAdapter
"""
return (None, None, None)
def _copy_to_like(arr1, arr2):
return F.copy_to(arr1, F.context(arr2))
......
......@@ -16,7 +16,8 @@ def gen_degree_bucketing_schedule(
recv_nodes,
var_nf,
var_mf,
var_out):
var_out,
ntype=None):
"""Create degree bucketing schedule.
The messages will be divided by their receivers into buckets. Each bucket
......@@ -44,6 +45,9 @@ def gen_degree_bucketing_schedule(
The variable for message frame.
var_out : var.FEAT_DICT
The variable for output feature dicts.
ntype : str, optional
The node type, if running on a heterograph.
If None, assuming it's running on a homogeneous graph.
"""
buckets = _degree_bucketing_schedule(message_ids, dst_nodes, recv_nodes)
# generate schedule
......@@ -53,7 +57,7 @@ def gen_degree_bucketing_schedule(
fd_list = []
for deg, vbkt, mid in zip(degs, buckets, msg_ids):
# create per-bkt rfunc
rfunc = _create_per_bkt_rfunc(reduce_udf, deg, vbkt)
rfunc = _create_per_bkt_rfunc(reduce_udf, deg, vbkt, ntype=ntype)
# vars
vbkt = var.IDX(vbkt)
mid = var.IDX(mid)
......@@ -141,7 +145,7 @@ def _process_node_buckets(buckets):
return v, degs, dsts, msg_ids, zero_deg_nodes
def _create_per_bkt_rfunc(reduce_udf, deg, vbkt):
def _create_per_bkt_rfunc(reduce_udf, deg, vbkt, ntype=None):
"""Internal function to generate the per degree bucket node UDF."""
def _rfunc_wrapper(node_data, mail_data):
def _reshaped_getter(key):
......@@ -149,7 +153,7 @@ def _create_per_bkt_rfunc(reduce_udf, deg, vbkt):
new_shape = (len(vbkt), deg) + F.shape(msg)[1:]
return F.reshape(msg, new_shape)
reshaped_mail_data = utils.LazyDict(_reshaped_getter, mail_data.keys())
nbatch = NodeBatch(vbkt, node_data, reshaped_mail_data)
nbatch = NodeBatch(vbkt, node_data, reshaped_mail_data, ntype=ntype)
return reduce_udf(nbatch)
return _rfunc_wrapper
......@@ -160,7 +164,8 @@ def gen_group_apply_edge_schedule(
var_src_nf,
var_dst_nf,
var_ef,
var_out):
var_out,
canonical_etype=(None, None, None)):
"""Create degree bucketing schedule for group_apply_edge
Edges will be grouped by either its source node or destination node
......@@ -189,6 +194,9 @@ def gen_group_apply_edge_schedule(
The variable for edge frame.
var_out : var.FEAT_DICT
The variable for output feature dicts.
canonical_etype : tuple[str, str, str], optional
Canonical edge type if running on a heterograph.
Default: (None, None, None), if running on a homogeneous graph.
"""
if group_by == "src":
buckets = _degree_bucketing_for_edge_grouping(u, v, eid)
......@@ -204,7 +212,8 @@ def gen_group_apply_edge_schedule(
for deg, u_bkt, v_bkt, eid_bkt in zip(degs, uids, vids, eids):
# create per-bkt efunc
_efunc = var.FUNC(_create_per_bkt_efunc(apply_func, deg,
u_bkt, v_bkt, eid_bkt))
u_bkt, v_bkt, eid_bkt,
canonical_etype=canonical_etype))
# vars
var_u = var.IDX(u_bkt)
var_v = var.IDX(v_bkt)
......@@ -274,7 +283,7 @@ def _process_edge_buckets(buckets):
eids = split(eids)
return degs, uids, vids, eids
def _create_per_bkt_efunc(apply_func, deg, u, v, eid):
def _create_per_bkt_efunc(apply_func, deg, u, v, eid, canonical_etype=(None, None, None)):
"""Internal function to generate the per degree bucket edge UDF."""
batch_size = len(u) // deg
def _efunc_wrapper(src_data, edge_data, dst_data):
......@@ -297,7 +306,8 @@ def _create_per_bkt_efunc(apply_func, deg, u, v, eid):
reshaped_dst_data = utils.LazyDict(_reshape_func(dst_data),
dst_data.keys())
ebatch = EdgeBatch((u, v, eid), reshaped_src_data,
reshaped_edge_data, reshaped_dst_data)
reshaped_edge_data, reshaped_dst_data,
canonical_etype=canonical_etype)
return {k: _reshape_back(v) for k, v in apply_func(ebatch).items()}
return _efunc_wrapper
......
......@@ -104,7 +104,7 @@ def schedule_recv(graph,
# 2) no send has been called
if apply_func is not None:
schedule_apply_nodes(recv_nodes, apply_func, graph.dstframe,
inplace, outframe)
inplace, outframe, ntype=graph.canonical_etype[-1])
else:
var_dst_nf = var.FEAT_DICT(graph.dstframe, 'dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
......@@ -117,7 +117,8 @@ def schedule_recv(graph,
recv_nodes)
# apply
final_feat = _apply_with_accum(var_recv_nodes, var_dst_nf,
reduced_feat, apply_func)
reduced_feat, apply_func,
ntype=graph.canonical_etype[-1])
if inplace:
ir.WRITE_ROW_INPLACE_(var_out_nf, var_recv_nodes, final_feat)
else:
......@@ -182,10 +183,11 @@ def schedule_snr(graph,
var_reduce_nodes=var_recv_nodes,
uv_getter=uv_getter,
adj_creator=adj_creator,
out_map_creator=out_map_creator)
out_map_creator=out_map_creator,
canonical_etype=graph.canonical_etype)
# generate apply schedule
final_feat = _apply_with_accum(var_recv_nodes, var_dst_nf, reduced_feat,
apply_func)
apply_func, ntype=graph.canonical_etype[-1])
if inplace:
ir.WRITE_ROW_INPLACE_(var_out_nf, var_recv_nodes, final_feat)
else:
......@@ -216,7 +218,8 @@ def schedule_update_all(graph,
if apply_func is not None:
nodes = utils.toindex(slice(0, graph.num_dst()))
schedule_apply_nodes(nodes, apply_func, graph.dstframe,
inplace=False, outframe=outframe)
inplace=False, outframe=outframe,
ntype=graph.canonical_etype[-1])
else:
eid = utils.toindex(slice(0, graph.num_edges())) # ALL
recv_nodes = utils.toindex(slice(0, graph.num_dst())) # ALL
......@@ -240,17 +243,20 @@ def schedule_update_all(graph,
var_reduce_nodes=var_recv_nodes,
uv_getter=uv_getter,
adj_creator=adj_creator,
out_map_creator=out_map_creator)
out_map_creator=out_map_creator,
canonical_etype=graph.canonical_etype)
# generate optional apply
final_feat = _apply_with_accum(var_recv_nodes, var_dst_nf,
reduced_feat, apply_func)
reduced_feat, apply_func,
ntype=graph.canonical_etype[-1])
ir.WRITE_DICT_(var_out_nf, final_feat)
def schedule_apply_nodes(v,
apply_func,
node_frame,
inplace,
outframe=None):
outframe=None,
ntype=None):
"""Get apply nodes schedule
Parameters
......@@ -265,6 +271,9 @@ def schedule_apply_nodes(v,
If True, the update will be done in place
outframe : FrameRef, optional
The storage to write output data. If None, use the given node_frame.
ntype : str, optional
The node type, if running on a heterograph.
If None, assuming it's running on a homogeneous graph.
Returns
-------
......@@ -275,7 +284,7 @@ def schedule_apply_nodes(v,
var_out_nf = var_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
v_nf = ir.READ_ROW(var_nf, var_v)
def _afunc_wrapper(node_data):
nbatch = NodeBatch(v, node_data)
nbatch = NodeBatch(v, node_data, ntype=ntype)
return apply_func(nbatch)
afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf)
......@@ -472,7 +481,8 @@ def schedule_pull(graph,
if len(eid) == 0:
# All the nodes are 0deg; downgrades to apply.
if apply_func is not None:
schedule_apply_nodes(pull_nodes, apply_func, graph.dstframe, inplace, outframe)
schedule_apply_nodes(pull_nodes, apply_func, graph.dstframe, inplace,
outframe, ntype=graph.canonical_etype[-1])
else:
pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor()))
pull_nodes = utils.toindex(pull_nodes)
......@@ -492,10 +502,12 @@ def schedule_pull(graph,
graph.dstframe, graph.edgeframe,
message_func, reduce_func, var_eid,
var_pull_nodes, uv_getter, adj_creator,
out_map_creator)
out_map_creator,
canonical_etype=graph.canonical_etype)
# generate optional apply
final_feat = _apply_with_accum(var_pull_nodes, var_dst_nf,
reduced_feat, apply_func)
reduced_feat, apply_func,
ntype=graph.canonical_etype[-1])
if inplace:
ir.WRITE_ROW_INPLACE_(var_out_nf, var_pull_nodes, final_feat)
else:
......@@ -535,7 +547,8 @@ def schedule_group_apply_edge(graph,
var_out_ef = var_ef if outframe is None else var.FEAT_DICT(outframe, name='out_ef')
var_out = var.FEAT_DICT(name='new_ef')
db.gen_group_apply_edge_schedule(apply_func, u, v, eid, group_by,
var_src_nf, var_dst_nf, var_ef, var_out)
var_src_nf, var_dst_nf, var_ef, var_out,
canonical_etype=graph.canonical_etype)
var_eid = var.IDX(eid)
if inplace:
ir.WRITE_ROW_INPLACE_(var_out_ef, var_eid, var_out)
......@@ -700,7 +713,7 @@ def _standardize_func_usage(func, func_name):
' Got: %s' % (func_name, str(func)))
return func
def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func):
def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func, ntype=None):
"""Apply with accumulated features.
Paramters
......@@ -713,6 +726,9 @@ def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func):
The accumulated features.
apply_func : callable, None
The apply function.
ntype : str, optional
The node type, if running on a heterograph.
If None, assuming it's running on a homogeneous graph.
"""
if apply_func:
# To avoid writing reduced features back to node frame and reading
......@@ -722,7 +738,7 @@ def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func):
v_nf = ir.UPDATE_DICT(v_nf, var_accum)
def _afunc_wrapper(node_data):
nbatch = NodeBatch(var_nodes.data, node_data)
nbatch = NodeBatch(var_nodes.data, node_data, ntype=ntype)
return apply_func(nbatch)
afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf)
......@@ -778,7 +794,8 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
else:
# gen degree bucketing schedule for UDF recv
db.gen_degree_bucketing_schedule(rfunc, eid, dst, recv_nodes,
var_dst_nf, var_msg, var_out)
var_dst_nf, var_msg, var_out,
ntype=graph.canonical_etype[-1])
return var_out
def _gen_send_reduce(
......@@ -791,7 +808,8 @@ def _gen_send_reduce(
var_reduce_nodes,
uv_getter,
adj_creator,
out_map_creator):
out_map_creator,
canonical_etype=(None, None, None)):
"""Generate send and reduce schedule.
The function generates symbolic program for computing
......@@ -832,9 +850,12 @@ def _gen_send_reduce(
adj_creator : callable
Function that returns the adjmat, edge order of csr matrix, and
bit-width.
out_map_creator: callable
out_map_creator : callable
A function that returns a mapping from reduce_nodes to relabeled
consecutive ids
canonical_etype : tuple[str, str, str], optional
Canonical edge type if running on a heterograph.
Default: (None, None, None), if running on a homogeneous graph.
Returns
-------
......@@ -917,7 +938,7 @@ def _gen_send_reduce(
else:
# generate UDF send schedule
var_mf = _gen_udf_send(var_src_nf, var_dst_nf, var_ef, var_u,
var_v, var_eid, mfunc)
var_v, var_eid, mfunc, canonical_etype=canonical_etype)
# 6. Generate reduce
if rfunc_is_list:
......@@ -935,17 +956,18 @@ def _gen_send_reduce(
mid = utils.toindex(slice(0, len(var_v.data)))
db.gen_degree_bucketing_schedule(rfunc, mid, var_v.data,
reduce_nodes, var_dst_nf, var_mf,
var_out)
var_out, ntype=canonical_etype[-1])
return var_out
def _gen_udf_send(var_src_nf, var_dst_nf, var_ef, u, v, eid, mfunc):
def _gen_udf_send(var_src_nf, var_dst_nf, var_ef, u, v, eid, mfunc,
canonical_etype=(None, None, None)):
"""Internal function to generate send schedule for UDF message function."""
fdsrc = ir.READ_ROW(var_src_nf, u)
fddst = ir.READ_ROW(var_dst_nf, v)
fdedge = ir.READ_ROW(var_ef, eid)
def _mfunc_wrapper(src_data, edge_data, dst_data):
ebatch = EdgeBatch((u.data, v.data, eid.data),
src_data, edge_data, dst_data)
src_data, edge_data, dst_data, canonical_etype=canonical_etype)
return mfunc(ebatch)
_mfunc_wrapper = var.FUNC(_mfunc_wrapper)
msg = ir.EDGE_UDF(_mfunc_wrapper, fdsrc, fdedge, fddst)
......@@ -982,7 +1004,8 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
else:
# UDF send
var_out = _gen_udf_send(var_src_nf, var_dst_nf, var_ef, var_u,
var_v, var_eid, mfunc)
var_v, var_eid, mfunc,
canonical_etype=graph.canonical_etype)
return var_out
def _build_idx_map(idx, nbits):
......
......@@ -17,12 +17,17 @@ class EdgeBatch(object):
dst_data : dict of tensors
The dst node features, in the form of ``dict``
with ``str`` keys and ``tensor`` values
canonical_etype : tuple of (str, str, str), optional
Canonical edge type of the edge batch, if UDF is
running on a heterograph.
"""
def __init__(self, edges, src_data, edge_data, dst_data):
def __init__(self, edges, src_data, edge_data, dst_data,
canonical_etype=(None, None, None)):
self._edges = edges
self._src_data = src_data
self._edge_data = edge_data
self._dst_data = dst_data
self._canonical_etype = canonical_etype
@property
def src(self):
......@@ -89,6 +94,12 @@ class EdgeBatch(object):
"""
return self.batch_size()
@property
def canonical_etype(self):
"""Return the canonical edge type (i.e. triplet of source, edge, and
destination node type) for this edge batch, if available."""
return self._canonical_etype
class NodeBatch(object):
"""The class that can represent a batch of nodes.
......@@ -102,11 +113,15 @@ class NodeBatch(object):
msgs : dict, optional
The messages, , in the form of ``dict``
with ``str`` keys and ``tensor`` values
ntype : str, optional
The node type of this node batch, if running
on a heterograph.
"""
def __init__(self, nodes, data, msgs=None):
def __init__(self, nodes, data, msgs=None, ntype=None):
self._nodes = nodes
self._data = data
self._msgs = msgs
self._ntype = ntype
@property
def data(self):
......@@ -160,3 +175,8 @@ class NodeBatch(object):
int
"""
return self.batch_size()
@property
def ntype(self):
"""Return the node type of this node batch, if available."""
return self._ntype
......@@ -1412,6 +1412,63 @@ def test_compact():
_check(g2, new_g2, induced_nodes)
def test_types_in_function():
def mfunc1(edges):
assert edges.canonical_etype == ('user', 'follow', 'user')
return {}
def rfunc1(nodes):
assert nodes.ntype == 'user'
return {}
def filter_nodes1(nodes):
assert nodes.ntype == 'user'
return F.zeros((3,))
def filter_edges1(edges):
assert edges.canonical_etype == ('user', 'follow', 'user')
return F.zeros((2,))
def mfunc2(edges):
assert edges.canonical_etype == ('user', 'plays', 'game')
return {}
def rfunc2(nodes):
assert nodes.ntype == 'game'
return {}
def filter_nodes2(nodes):
assert nodes.ntype == 'game'
return F.zeros((3,))
def filter_edges2(edges):
assert edges.canonical_etype == ('user', 'plays', 'game')
return F.zeros((2,))
g = dgl.graph([(0, 1), (1, 2)], 'user', 'follow')
g.apply_nodes(rfunc1)
g.apply_edges(mfunc1)
g.update_all(mfunc1, rfunc1)
g.send_and_recv([0, 1], mfunc1, rfunc1)
g.send([0, 1], mfunc1)
g.recv([1, 2], rfunc1)
g.push([0], mfunc1, rfunc1)
g.pull([1], mfunc1, rfunc1)
g.filter_nodes(filter_nodes1)
g.filter_edges(filter_edges1)
g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
g.apply_nodes(rfunc2, ntype='game')
g.apply_edges(mfunc2)
g.update_all(mfunc2, rfunc2)
g.send_and_recv([0, 1], mfunc2, rfunc2)
g.send([0, 1], mfunc2)
g.recv([1, 2], rfunc2)
g.push([0], mfunc2, rfunc2)
g.pull([1], mfunc2, rfunc2)
g.filter_nodes(filter_nodes2, ntype='game')
g.filter_edges(filter_edges2)
if __name__ == '__main__':
test_create()
test_query()
......@@ -1433,3 +1490,4 @@ if __name__ == '__main__':
test_backward()
test_empty_heterograph()
test_compact()
test_types_in_function()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment