Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c7c0fd0e
Unverified
Commit
c7c0fd0e
authored
Feb 17, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Feb 17, 2020
Browse files
[Feature] Accessing node/edge types in UDFs. (#1243)
* typed udf * docstrings
parent
b133abb8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
164 additions
and
36 deletions
+164
-36
python/dgl/graph.py
python/dgl/graph.py
+5
-0
python/dgl/heterograph.py
python/dgl/heterograph.py
+9
-3
python/dgl/nodeflow.py
python/dgl/nodeflow.py
+6
-0
python/dgl/runtime/degree_bucketing.py
python/dgl/runtime/degree_bucketing.py
+18
-8
python/dgl/runtime/scheduler.py
python/dgl/runtime/scheduler.py
+46
-23
python/dgl/udf.py
python/dgl/udf.py
+22
-2
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+58
-0
No files found.
python/dgl/graph.py
View file @
c7c0fd0e
...
...
@@ -3564,3 +3564,8 @@ class AdaptedDGLGraph(GraphAdapter):
def
bits_needed
(
self
):
return
self
.
graph
.
_graph
.
bits_needed
()
@
property
def
canonical_etype
(
self
):
"""Canonical edge type (None for homogeneous graph)"""
return
(
None
,
None
,
None
)
python/dgl/heterograph.py
View file @
c7c0fd0e
...
...
@@ -2276,7 +2276,7 @@ class DGLHeteroGraph(object):
v_ntype
=
utils
.
toindex
(
v
)
with
ir
.
prog
()
as
prog
:
scheduler
.
schedule_apply_nodes
(
v_ntype
,
func
,
self
.
_node_frames
[
ntid
],
inplace
=
inplace
)
inplace
=
inplace
,
ntype
=
self
.
_ntypes
[
ntid
]
)
Runtime
.
run
(
prog
)
def
apply_edges
(
self
,
func
,
edges
=
ALL
,
etype
=
None
,
inplace
=
False
):
...
...
@@ -3496,7 +3496,7 @@ class DGLHeteroGraph(object):
v
=
utils
.
toindex
(
nodes
)
n_repr
=
self
.
_get_n_repr
(
ntid
,
v
)
nbatch
=
NodeBatch
(
v
,
n_repr
)
nbatch
=
NodeBatch
(
v
,
n_repr
,
ntype
=
self
.
ntypes
[
ntid
]
)
n_mask
=
F
.
copy_to
(
predicate
(
nbatch
),
F
.
cpu
())
if
is_all
(
nodes
):
...
...
@@ -3557,7 +3557,8 @@ class DGLHeteroGraph(object):
src_data
=
self
.
_get_n_repr
(
stid
,
u
)
edge_data
=
self
.
_get_e_repr
(
etid
,
eid
)
dst_data
=
self
.
_get_n_repr
(
dtid
,
v
)
ebatch
=
EdgeBatch
((
u
,
v
,
eid
),
src_data
,
edge_data
,
dst_data
)
ebatch
=
EdgeBatch
((
u
,
v
,
eid
),
src_data
,
edge_data
,
dst_data
,
canonical_etype
=
self
.
canonical_etypes
[
etid
])
e_mask
=
F
.
copy_to
(
predicate
(
ebatch
),
F
.
cpu
())
if
is_all
(
edges
):
...
...
@@ -3988,3 +3989,8 @@ class AdaptedHeteroGraph(GraphAdapter):
def
bits_needed
(
self
):
return
self
.
graph
.
_graph
.
bits_needed
(
self
.
etid
)
@
property
def
canonical_etype
(
self
):
"""Canonical edge type."""
return
self
.
graph
.
canonical_etypes
[
self
.
etid
]
python/dgl/nodeflow.py
View file @
c7c0fd0e
...
...
@@ -1017,6 +1017,12 @@ class NodeFlow(DGLBaseGraph):
self
.
block_compute
(
i
,
message_func
,
reduce_func
,
apply_node_func
,
inplace
=
inplace
)
@
property
def
canonical_etype
(
self
):
"""Return canonical edge type to be compatible with GraphAdapter
"""
return
(
None
,
None
,
None
)
def
_copy_to_like
(
arr1
,
arr2
):
return
F
.
copy_to
(
arr1
,
F
.
context
(
arr2
))
...
...
python/dgl/runtime/degree_bucketing.py
View file @
c7c0fd0e
...
...
@@ -16,7 +16,8 @@ def gen_degree_bucketing_schedule(
recv_nodes
,
var_nf
,
var_mf
,
var_out
):
var_out
,
ntype
=
None
):
"""Create degree bucketing schedule.
The messages will be divided by their receivers into buckets. Each bucket
...
...
@@ -44,6 +45,9 @@ def gen_degree_bucketing_schedule(
The variable for message frame.
var_out : var.FEAT_DICT
The variable for output feature dicts.
ntype : str, optional
The node type, if running on a heterograph.
If None, assuming it's running on a homogeneous graph.
"""
buckets
=
_degree_bucketing_schedule
(
message_ids
,
dst_nodes
,
recv_nodes
)
# generate schedule
...
...
@@ -53,7 +57,7 @@ def gen_degree_bucketing_schedule(
fd_list
=
[]
for
deg
,
vbkt
,
mid
in
zip
(
degs
,
buckets
,
msg_ids
):
# create per-bkt rfunc
rfunc
=
_create_per_bkt_rfunc
(
reduce_udf
,
deg
,
vbkt
)
rfunc
=
_create_per_bkt_rfunc
(
reduce_udf
,
deg
,
vbkt
,
ntype
=
ntype
)
# vars
vbkt
=
var
.
IDX
(
vbkt
)
mid
=
var
.
IDX
(
mid
)
...
...
@@ -141,7 +145,7 @@ def _process_node_buckets(buckets):
return
v
,
degs
,
dsts
,
msg_ids
,
zero_deg_nodes
def
_create_per_bkt_rfunc
(
reduce_udf
,
deg
,
vbkt
):
def
_create_per_bkt_rfunc
(
reduce_udf
,
deg
,
vbkt
,
ntype
=
None
):
"""Internal function to generate the per degree bucket node UDF."""
def
_rfunc_wrapper
(
node_data
,
mail_data
):
def
_reshaped_getter
(
key
):
...
...
@@ -149,7 +153,7 @@ def _create_per_bkt_rfunc(reduce_udf, deg, vbkt):
new_shape
=
(
len
(
vbkt
),
deg
)
+
F
.
shape
(
msg
)[
1
:]
return
F
.
reshape
(
msg
,
new_shape
)
reshaped_mail_data
=
utils
.
LazyDict
(
_reshaped_getter
,
mail_data
.
keys
())
nbatch
=
NodeBatch
(
vbkt
,
node_data
,
reshaped_mail_data
)
nbatch
=
NodeBatch
(
vbkt
,
node_data
,
reshaped_mail_data
,
ntype
=
ntype
)
return
reduce_udf
(
nbatch
)
return
_rfunc_wrapper
...
...
@@ -160,7 +164,8 @@ def gen_group_apply_edge_schedule(
var_src_nf
,
var_dst_nf
,
var_ef
,
var_out
):
var_out
,
canonical_etype
=
(
None
,
None
,
None
)):
"""Create degree bucketing schedule for group_apply_edge
Edges will be grouped by either its source node or destination node
...
...
@@ -189,6 +194,9 @@ def gen_group_apply_edge_schedule(
The variable for edge frame.
var_out : var.FEAT_DICT
The variable for output feature dicts.
canonical_etype : tuple[str, str, str], optional
Canonical edge type if running on a heterograph.
Default: (None, None, None), if running on a homogeneous graph.
"""
if
group_by
==
"src"
:
buckets
=
_degree_bucketing_for_edge_grouping
(
u
,
v
,
eid
)
...
...
@@ -204,7 +212,8 @@ def gen_group_apply_edge_schedule(
for
deg
,
u_bkt
,
v_bkt
,
eid_bkt
in
zip
(
degs
,
uids
,
vids
,
eids
):
# create per-bkt efunc
_efunc
=
var
.
FUNC
(
_create_per_bkt_efunc
(
apply_func
,
deg
,
u_bkt
,
v_bkt
,
eid_bkt
))
u_bkt
,
v_bkt
,
eid_bkt
,
canonical_etype
=
canonical_etype
))
# vars
var_u
=
var
.
IDX
(
u_bkt
)
var_v
=
var
.
IDX
(
v_bkt
)
...
...
@@ -274,7 +283,7 @@ def _process_edge_buckets(buckets):
eids
=
split
(
eids
)
return
degs
,
uids
,
vids
,
eids
def
_create_per_bkt_efunc
(
apply_func
,
deg
,
u
,
v
,
eid
):
def
_create_per_bkt_efunc
(
apply_func
,
deg
,
u
,
v
,
eid
,
canonical_etype
=
(
None
,
None
,
None
)
):
"""Internal function to generate the per degree bucket edge UDF."""
batch_size
=
len
(
u
)
//
deg
def
_efunc_wrapper
(
src_data
,
edge_data
,
dst_data
):
...
...
@@ -297,7 +306,8 @@ def _create_per_bkt_efunc(apply_func, deg, u, v, eid):
reshaped_dst_data
=
utils
.
LazyDict
(
_reshape_func
(
dst_data
),
dst_data
.
keys
())
ebatch
=
EdgeBatch
((
u
,
v
,
eid
),
reshaped_src_data
,
reshaped_edge_data
,
reshaped_dst_data
)
reshaped_edge_data
,
reshaped_dst_data
,
canonical_etype
=
canonical_etype
)
return
{
k
:
_reshape_back
(
v
)
for
k
,
v
in
apply_func
(
ebatch
).
items
()}
return
_efunc_wrapper
...
...
python/dgl/runtime/scheduler.py
View file @
c7c0fd0e
...
...
@@ -104,7 +104,7 @@ def schedule_recv(graph,
# 2) no send has been called
if
apply_func
is
not
None
:
schedule_apply_nodes
(
recv_nodes
,
apply_func
,
graph
.
dstframe
,
inplace
,
outframe
)
inplace
,
outframe
,
ntype
=
graph
.
canonical_etype
[
-
1
]
)
else
:
var_dst_nf
=
var
.
FEAT_DICT
(
graph
.
dstframe
,
'dst_nf'
)
var_out_nf
=
var_dst_nf
if
outframe
is
None
else
var
.
FEAT_DICT
(
outframe
,
name
=
'out_nf'
)
...
...
@@ -117,7 +117,8 @@ def schedule_recv(graph,
recv_nodes
)
# apply
final_feat
=
_apply_with_accum
(
var_recv_nodes
,
var_dst_nf
,
reduced_feat
,
apply_func
)
reduced_feat
,
apply_func
,
ntype
=
graph
.
canonical_etype
[
-
1
])
if
inplace
:
ir
.
WRITE_ROW_INPLACE_
(
var_out_nf
,
var_recv_nodes
,
final_feat
)
else
:
...
...
@@ -182,10 +183,11 @@ def schedule_snr(graph,
var_reduce_nodes
=
var_recv_nodes
,
uv_getter
=
uv_getter
,
adj_creator
=
adj_creator
,
out_map_creator
=
out_map_creator
)
out_map_creator
=
out_map_creator
,
canonical_etype
=
graph
.
canonical_etype
)
# generate apply schedule
final_feat
=
_apply_with_accum
(
var_recv_nodes
,
var_dst_nf
,
reduced_feat
,
apply_func
)
apply_func
,
ntype
=
graph
.
canonical_etype
[
-
1
]
)
if
inplace
:
ir
.
WRITE_ROW_INPLACE_
(
var_out_nf
,
var_recv_nodes
,
final_feat
)
else
:
...
...
@@ -216,7 +218,8 @@ def schedule_update_all(graph,
if
apply_func
is
not
None
:
nodes
=
utils
.
toindex
(
slice
(
0
,
graph
.
num_dst
()))
schedule_apply_nodes
(
nodes
,
apply_func
,
graph
.
dstframe
,
inplace
=
False
,
outframe
=
outframe
)
inplace
=
False
,
outframe
=
outframe
,
ntype
=
graph
.
canonical_etype
[
-
1
])
else
:
eid
=
utils
.
toindex
(
slice
(
0
,
graph
.
num_edges
()))
# ALL
recv_nodes
=
utils
.
toindex
(
slice
(
0
,
graph
.
num_dst
()))
# ALL
...
...
@@ -240,17 +243,20 @@ def schedule_update_all(graph,
var_reduce_nodes
=
var_recv_nodes
,
uv_getter
=
uv_getter
,
adj_creator
=
adj_creator
,
out_map_creator
=
out_map_creator
)
out_map_creator
=
out_map_creator
,
canonical_etype
=
graph
.
canonical_etype
)
# generate optional apply
final_feat
=
_apply_with_accum
(
var_recv_nodes
,
var_dst_nf
,
reduced_feat
,
apply_func
)
reduced_feat
,
apply_func
,
ntype
=
graph
.
canonical_etype
[
-
1
])
ir
.
WRITE_DICT_
(
var_out_nf
,
final_feat
)
def
schedule_apply_nodes
(
v
,
apply_func
,
node_frame
,
inplace
,
outframe
=
None
):
outframe
=
None
,
ntype
=
None
):
"""Get apply nodes schedule
Parameters
...
...
@@ -265,6 +271,9 @@ def schedule_apply_nodes(v,
If True, the update will be done in place
outframe : FrameRef, optional
The storage to write output data. If None, use the given node_frame.
ntype : str, optional
The node type, if running on a heterograph.
If None, assuming it's running on a homogeneous graph.
Returns
-------
...
...
@@ -275,7 +284,7 @@ def schedule_apply_nodes(v,
var_out_nf
=
var_nf
if
outframe
is
None
else
var
.
FEAT_DICT
(
outframe
,
name
=
'out_nf'
)
v_nf
=
ir
.
READ_ROW
(
var_nf
,
var_v
)
def
_afunc_wrapper
(
node_data
):
nbatch
=
NodeBatch
(
v
,
node_data
)
nbatch
=
NodeBatch
(
v
,
node_data
,
ntype
=
ntype
)
return
apply_func
(
nbatch
)
afunc
=
var
.
FUNC
(
_afunc_wrapper
)
applied_feat
=
ir
.
NODE_UDF
(
afunc
,
v_nf
)
...
...
@@ -472,7 +481,8 @@ def schedule_pull(graph,
if
len
(
eid
)
==
0
:
# All the nodes are 0deg; downgrades to apply.
if
apply_func
is
not
None
:
schedule_apply_nodes
(
pull_nodes
,
apply_func
,
graph
.
dstframe
,
inplace
,
outframe
)
schedule_apply_nodes
(
pull_nodes
,
apply_func
,
graph
.
dstframe
,
inplace
,
outframe
,
ntype
=
graph
.
canonical_etype
[
-
1
])
else
:
pull_nodes
,
_
=
F
.
sort_1d
(
F
.
unique
(
pull_nodes
.
tousertensor
()))
pull_nodes
=
utils
.
toindex
(
pull_nodes
)
...
...
@@ -492,10 +502,12 @@ def schedule_pull(graph,
graph
.
dstframe
,
graph
.
edgeframe
,
message_func
,
reduce_func
,
var_eid
,
var_pull_nodes
,
uv_getter
,
adj_creator
,
out_map_creator
)
out_map_creator
,
canonical_etype
=
graph
.
canonical_etype
)
# generate optional apply
final_feat
=
_apply_with_accum
(
var_pull_nodes
,
var_dst_nf
,
reduced_feat
,
apply_func
)
reduced_feat
,
apply_func
,
ntype
=
graph
.
canonical_etype
[
-
1
])
if
inplace
:
ir
.
WRITE_ROW_INPLACE_
(
var_out_nf
,
var_pull_nodes
,
final_feat
)
else
:
...
...
@@ -535,7 +547,8 @@ def schedule_group_apply_edge(graph,
var_out_ef
=
var_ef
if
outframe
is
None
else
var
.
FEAT_DICT
(
outframe
,
name
=
'out_ef'
)
var_out
=
var
.
FEAT_DICT
(
name
=
'new_ef'
)
db
.
gen_group_apply_edge_schedule
(
apply_func
,
u
,
v
,
eid
,
group_by
,
var_src_nf
,
var_dst_nf
,
var_ef
,
var_out
)
var_src_nf
,
var_dst_nf
,
var_ef
,
var_out
,
canonical_etype
=
graph
.
canonical_etype
)
var_eid
=
var
.
IDX
(
eid
)
if
inplace
:
ir
.
WRITE_ROW_INPLACE_
(
var_out_ef
,
var_eid
,
var_out
)
...
...
@@ -700,7 +713,7 @@ def _standardize_func_usage(func, func_name):
' Got: %s'
%
(
func_name
,
str
(
func
)))
return
func
def
_apply_with_accum
(
var_nodes
,
var_nf
,
var_accum
,
apply_func
):
def
_apply_with_accum
(
var_nodes
,
var_nf
,
var_accum
,
apply_func
,
ntype
=
None
):
"""Apply with accumulated features.
Paramters
...
...
@@ -713,6 +726,9 @@ def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func):
The accumulated features.
apply_func : callable, None
The apply function.
ntype : str, optional
The node type, if running on a heterograph.
If None, assuming it's running on a homogeneous graph.
"""
if
apply_func
:
# To avoid writing reduced features back to node frame and reading
...
...
@@ -722,7 +738,7 @@ def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func):
v_nf
=
ir
.
UPDATE_DICT
(
v_nf
,
var_accum
)
def
_afunc_wrapper
(
node_data
):
nbatch
=
NodeBatch
(
var_nodes
.
data
,
node_data
)
nbatch
=
NodeBatch
(
var_nodes
.
data
,
node_data
,
ntype
=
ntype
)
return
apply_func
(
nbatch
)
afunc
=
var
.
FUNC
(
_afunc_wrapper
)
applied_feat
=
ir
.
NODE_UDF
(
afunc
,
v_nf
)
...
...
@@ -778,7 +794,8 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
else
:
# gen degree bucketing schedule for UDF recv
db
.
gen_degree_bucketing_schedule
(
rfunc
,
eid
,
dst
,
recv_nodes
,
var_dst_nf
,
var_msg
,
var_out
)
var_dst_nf
,
var_msg
,
var_out
,
ntype
=
graph
.
canonical_etype
[
-
1
])
return
var_out
def
_gen_send_reduce
(
...
...
@@ -791,7 +808,8 @@ def _gen_send_reduce(
var_reduce_nodes
,
uv_getter
,
adj_creator
,
out_map_creator
):
out_map_creator
,
canonical_etype
=
(
None
,
None
,
None
)):
"""Generate send and reduce schedule.
The function generates symbolic program for computing
...
...
@@ -832,9 +850,12 @@ def _gen_send_reduce(
adj_creator : callable
Function that returns the adjmat, edge order of csr matrix, and
bit-width.
out_map_creator: callable
out_map_creator
: callable
A function that returns a mapping from reduce_nodes to relabeled
consecutive ids
canonical_etype : tuple[str, str, str], optional
Canonical edge type if running on a heterograph.
Default: (None, None, None), if running on a homogeneous graph.
Returns
-------
...
...
@@ -917,7 +938,7 @@ def _gen_send_reduce(
else
:
# generate UDF send schedule
var_mf
=
_gen_udf_send
(
var_src_nf
,
var_dst_nf
,
var_ef
,
var_u
,
var_v
,
var_eid
,
mfunc
)
var_v
,
var_eid
,
mfunc
,
canonical_etype
=
canonical_etype
)
# 6. Generate reduce
if
rfunc_is_list
:
...
...
@@ -935,17 +956,18 @@ def _gen_send_reduce(
mid
=
utils
.
toindex
(
slice
(
0
,
len
(
var_v
.
data
)))
db
.
gen_degree_bucketing_schedule
(
rfunc
,
mid
,
var_v
.
data
,
reduce_nodes
,
var_dst_nf
,
var_mf
,
var_out
)
var_out
,
ntype
=
canonical_etype
[
-
1
]
)
return
var_out
def
_gen_udf_send
(
var_src_nf
,
var_dst_nf
,
var_ef
,
u
,
v
,
eid
,
mfunc
):
def
_gen_udf_send
(
var_src_nf
,
var_dst_nf
,
var_ef
,
u
,
v
,
eid
,
mfunc
,
canonical_etype
=
(
None
,
None
,
None
)):
"""Internal function to generate send schedule for UDF message function."""
fdsrc
=
ir
.
READ_ROW
(
var_src_nf
,
u
)
fddst
=
ir
.
READ_ROW
(
var_dst_nf
,
v
)
fdedge
=
ir
.
READ_ROW
(
var_ef
,
eid
)
def
_mfunc_wrapper
(
src_data
,
edge_data
,
dst_data
):
ebatch
=
EdgeBatch
((
u
.
data
,
v
.
data
,
eid
.
data
),
src_data
,
edge_data
,
dst_data
)
src_data
,
edge_data
,
dst_data
,
canonical_etype
=
canonical_etype
)
return
mfunc
(
ebatch
)
_mfunc_wrapper
=
var
.
FUNC
(
_mfunc_wrapper
)
msg
=
ir
.
EDGE_UDF
(
_mfunc_wrapper
,
fdsrc
,
fdedge
,
fddst
)
...
...
@@ -982,7 +1004,8 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
else
:
# UDF send
var_out
=
_gen_udf_send
(
var_src_nf
,
var_dst_nf
,
var_ef
,
var_u
,
var_v
,
var_eid
,
mfunc
)
var_v
,
var_eid
,
mfunc
,
canonical_etype
=
graph
.
canonical_etype
)
return
var_out
def
_build_idx_map
(
idx
,
nbits
):
...
...
python/dgl/udf.py
View file @
c7c0fd0e
...
...
@@ -17,12 +17,17 @@ class EdgeBatch(object):
dst_data : dict of tensors
The dst node features, in the form of ``dict``
with ``str`` keys and ``tensor`` values
canonical_etype : tuple of (str, str, str), optional
Canonical edge type of the edge batch, if UDF is
running on a heterograph.
"""
def
__init__
(
self
,
edges
,
src_data
,
edge_data
,
dst_data
):
def
__init__
(
self
,
edges
,
src_data
,
edge_data
,
dst_data
,
canonical_etype
=
(
None
,
None
,
None
)):
self
.
_edges
=
edges
self
.
_src_data
=
src_data
self
.
_edge_data
=
edge_data
self
.
_dst_data
=
dst_data
self
.
_canonical_etype
=
canonical_etype
@
property
def
src
(
self
):
...
...
@@ -89,6 +94,12 @@ class EdgeBatch(object):
"""
return
self
.
batch_size
()
@
property
def
canonical_etype
(
self
):
"""Return the canonical edge type (i.e. triplet of source, edge, and
destination node type) for this edge batch, if available."""
return
self
.
_canonical_etype
class
NodeBatch
(
object
):
"""The class that can represent a batch of nodes.
...
...
@@ -102,11 +113,15 @@ class NodeBatch(object):
msgs : dict, optional
The messages, , in the form of ``dict``
with ``str`` keys and ``tensor`` values
ntype : str, optional
The node type of this node batch, if running
on a heterograph.
"""
def
__init__
(
self
,
nodes
,
data
,
msgs
=
None
):
def
__init__
(
self
,
nodes
,
data
,
msgs
=
None
,
ntype
=
None
):
self
.
_nodes
=
nodes
self
.
_data
=
data
self
.
_msgs
=
msgs
self
.
_ntype
=
ntype
@
property
def
data
(
self
):
...
...
@@ -160,3 +175,8 @@ class NodeBatch(object):
int
"""
return
self
.
batch_size
()
@
property
def
ntype
(
self
):
"""Return the node type of this node batch, if available."""
return
self
.
_ntype
tests/compute/test_heterograph.py
View file @
c7c0fd0e
...
...
@@ -1412,6 +1412,63 @@ def test_compact():
_check
(
g2
,
new_g2
,
induced_nodes
)
def
test_types_in_function
():
def
mfunc1
(
edges
):
assert
edges
.
canonical_etype
==
(
'user'
,
'follow'
,
'user'
)
return
{}
def
rfunc1
(
nodes
):
assert
nodes
.
ntype
==
'user'
return
{}
def
filter_nodes1
(
nodes
):
assert
nodes
.
ntype
==
'user'
return
F
.
zeros
((
3
,))
def
filter_edges1
(
edges
):
assert
edges
.
canonical_etype
==
(
'user'
,
'follow'
,
'user'
)
return
F
.
zeros
((
2
,))
def
mfunc2
(
edges
):
assert
edges
.
canonical_etype
==
(
'user'
,
'plays'
,
'game'
)
return
{}
def
rfunc2
(
nodes
):
assert
nodes
.
ntype
==
'game'
return
{}
def
filter_nodes2
(
nodes
):
assert
nodes
.
ntype
==
'game'
return
F
.
zeros
((
3
,))
def
filter_edges2
(
edges
):
assert
edges
.
canonical_etype
==
(
'user'
,
'plays'
,
'game'
)
return
F
.
zeros
((
2
,))
g
=
dgl
.
graph
([(
0
,
1
),
(
1
,
2
)],
'user'
,
'follow'
)
g
.
apply_nodes
(
rfunc1
)
g
.
apply_edges
(
mfunc1
)
g
.
update_all
(
mfunc1
,
rfunc1
)
g
.
send_and_recv
([
0
,
1
],
mfunc1
,
rfunc1
)
g
.
send
([
0
,
1
],
mfunc1
)
g
.
recv
([
1
,
2
],
rfunc1
)
g
.
push
([
0
],
mfunc1
,
rfunc1
)
g
.
pull
([
1
],
mfunc1
,
rfunc1
)
g
.
filter_nodes
(
filter_nodes1
)
g
.
filter_edges
(
filter_edges1
)
g
=
dgl
.
bipartite
([(
0
,
1
),
(
1
,
2
)],
'user'
,
'plays'
,
'game'
)
g
.
apply_nodes
(
rfunc2
,
ntype
=
'game'
)
g
.
apply_edges
(
mfunc2
)
g
.
update_all
(
mfunc2
,
rfunc2
)
g
.
send_and_recv
([
0
,
1
],
mfunc2
,
rfunc2
)
g
.
send
([
0
,
1
],
mfunc2
)
g
.
recv
([
1
,
2
],
rfunc2
)
g
.
push
([
0
],
mfunc2
,
rfunc2
)
g
.
pull
([
1
],
mfunc2
,
rfunc2
)
g
.
filter_nodes
(
filter_nodes2
,
ntype
=
'game'
)
g
.
filter_edges
(
filter_edges2
)
if
__name__
==
'__main__'
:
test_create
()
test_query
()
...
...
@@ -1433,3 +1490,4 @@ if __name__ == '__main__':
test_backward
()
test_empty_heterograph
()
test_compact
()
test_types_in_function
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment