Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c7c0fd0e
Unverified
Commit
c7c0fd0e
authored
Feb 17, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Feb 17, 2020
Browse files
[Feature] Accessing node/edge types in UDFs. (#1243)
* typed udf * docstrings
parent
b133abb8
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
164 additions
and
36 deletions
+164
-36
python/dgl/graph.py
python/dgl/graph.py
+5
-0
python/dgl/heterograph.py
python/dgl/heterograph.py
+9
-3
python/dgl/nodeflow.py
python/dgl/nodeflow.py
+6
-0
python/dgl/runtime/degree_bucketing.py
python/dgl/runtime/degree_bucketing.py
+18
-8
python/dgl/runtime/scheduler.py
python/dgl/runtime/scheduler.py
+46
-23
python/dgl/udf.py
python/dgl/udf.py
+22
-2
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+58
-0
No files found.
python/dgl/graph.py
View file @
c7c0fd0e
...
@@ -3564,3 +3564,8 @@ class AdaptedDGLGraph(GraphAdapter):
...
@@ -3564,3 +3564,8 @@ class AdaptedDGLGraph(GraphAdapter):
def
bits_needed
(
self
):
def
bits_needed
(
self
):
return
self
.
graph
.
_graph
.
bits_needed
()
return
self
.
graph
.
_graph
.
bits_needed
()
@
property
def
canonical_etype
(
self
):
"""Canonical edge type (None for homogeneous graph)"""
return
(
None
,
None
,
None
)
python/dgl/heterograph.py
View file @
c7c0fd0e
...
@@ -2276,7 +2276,7 @@ class DGLHeteroGraph(object):
...
@@ -2276,7 +2276,7 @@ class DGLHeteroGraph(object):
v_ntype
=
utils
.
toindex
(
v
)
v_ntype
=
utils
.
toindex
(
v
)
with
ir
.
prog
()
as
prog
:
with
ir
.
prog
()
as
prog
:
scheduler
.
schedule_apply_nodes
(
v_ntype
,
func
,
self
.
_node_frames
[
ntid
],
scheduler
.
schedule_apply_nodes
(
v_ntype
,
func
,
self
.
_node_frames
[
ntid
],
inplace
=
inplace
)
inplace
=
inplace
,
ntype
=
self
.
_ntypes
[
ntid
]
)
Runtime
.
run
(
prog
)
Runtime
.
run
(
prog
)
def
apply_edges
(
self
,
func
,
edges
=
ALL
,
etype
=
None
,
inplace
=
False
):
def
apply_edges
(
self
,
func
,
edges
=
ALL
,
etype
=
None
,
inplace
=
False
):
...
@@ -3496,7 +3496,7 @@ class DGLHeteroGraph(object):
...
@@ -3496,7 +3496,7 @@ class DGLHeteroGraph(object):
v
=
utils
.
toindex
(
nodes
)
v
=
utils
.
toindex
(
nodes
)
n_repr
=
self
.
_get_n_repr
(
ntid
,
v
)
n_repr
=
self
.
_get_n_repr
(
ntid
,
v
)
nbatch
=
NodeBatch
(
v
,
n_repr
)
nbatch
=
NodeBatch
(
v
,
n_repr
,
ntype
=
self
.
ntypes
[
ntid
]
)
n_mask
=
F
.
copy_to
(
predicate
(
nbatch
),
F
.
cpu
())
n_mask
=
F
.
copy_to
(
predicate
(
nbatch
),
F
.
cpu
())
if
is_all
(
nodes
):
if
is_all
(
nodes
):
...
@@ -3557,7 +3557,8 @@ class DGLHeteroGraph(object):
...
@@ -3557,7 +3557,8 @@ class DGLHeteroGraph(object):
src_data
=
self
.
_get_n_repr
(
stid
,
u
)
src_data
=
self
.
_get_n_repr
(
stid
,
u
)
edge_data
=
self
.
_get_e_repr
(
etid
,
eid
)
edge_data
=
self
.
_get_e_repr
(
etid
,
eid
)
dst_data
=
self
.
_get_n_repr
(
dtid
,
v
)
dst_data
=
self
.
_get_n_repr
(
dtid
,
v
)
ebatch
=
EdgeBatch
((
u
,
v
,
eid
),
src_data
,
edge_data
,
dst_data
)
ebatch
=
EdgeBatch
((
u
,
v
,
eid
),
src_data
,
edge_data
,
dst_data
,
canonical_etype
=
self
.
canonical_etypes
[
etid
])
e_mask
=
F
.
copy_to
(
predicate
(
ebatch
),
F
.
cpu
())
e_mask
=
F
.
copy_to
(
predicate
(
ebatch
),
F
.
cpu
())
if
is_all
(
edges
):
if
is_all
(
edges
):
...
@@ -3988,3 +3989,8 @@ class AdaptedHeteroGraph(GraphAdapter):
...
@@ -3988,3 +3989,8 @@ class AdaptedHeteroGraph(GraphAdapter):
def
bits_needed
(
self
):
def
bits_needed
(
self
):
return
self
.
graph
.
_graph
.
bits_needed
(
self
.
etid
)
return
self
.
graph
.
_graph
.
bits_needed
(
self
.
etid
)
@
property
def
canonical_etype
(
self
):
"""Canonical edge type."""
return
self
.
graph
.
canonical_etypes
[
self
.
etid
]
python/dgl/nodeflow.py
View file @
c7c0fd0e
...
@@ -1017,6 +1017,12 @@ class NodeFlow(DGLBaseGraph):
...
@@ -1017,6 +1017,12 @@ class NodeFlow(DGLBaseGraph):
self
.
block_compute
(
i
,
message_func
,
reduce_func
,
apply_node_func
,
self
.
block_compute
(
i
,
message_func
,
reduce_func
,
apply_node_func
,
inplace
=
inplace
)
inplace
=
inplace
)
@
property
def
canonical_etype
(
self
):
"""Return canonical edge type to be compatible with GraphAdapter
"""
return
(
None
,
None
,
None
)
def
_copy_to_like
(
arr1
,
arr2
):
def
_copy_to_like
(
arr1
,
arr2
):
return
F
.
copy_to
(
arr1
,
F
.
context
(
arr2
))
return
F
.
copy_to
(
arr1
,
F
.
context
(
arr2
))
...
...
python/dgl/runtime/degree_bucketing.py
View file @
c7c0fd0e
...
@@ -16,7 +16,8 @@ def gen_degree_bucketing_schedule(
...
@@ -16,7 +16,8 @@ def gen_degree_bucketing_schedule(
recv_nodes
,
recv_nodes
,
var_nf
,
var_nf
,
var_mf
,
var_mf
,
var_out
):
var_out
,
ntype
=
None
):
"""Create degree bucketing schedule.
"""Create degree bucketing schedule.
The messages will be divided by their receivers into buckets. Each bucket
The messages will be divided by their receivers into buckets. Each bucket
...
@@ -44,6 +45,9 @@ def gen_degree_bucketing_schedule(
...
@@ -44,6 +45,9 @@ def gen_degree_bucketing_schedule(
The variable for message frame.
The variable for message frame.
var_out : var.FEAT_DICT
var_out : var.FEAT_DICT
The variable for output feature dicts.
The variable for output feature dicts.
ntype : str, optional
The node type, if running on a heterograph.
If None, assuming it's running on a homogeneous graph.
"""
"""
buckets
=
_degree_bucketing_schedule
(
message_ids
,
dst_nodes
,
recv_nodes
)
buckets
=
_degree_bucketing_schedule
(
message_ids
,
dst_nodes
,
recv_nodes
)
# generate schedule
# generate schedule
...
@@ -53,7 +57,7 @@ def gen_degree_bucketing_schedule(
...
@@ -53,7 +57,7 @@ def gen_degree_bucketing_schedule(
fd_list
=
[]
fd_list
=
[]
for
deg
,
vbkt
,
mid
in
zip
(
degs
,
buckets
,
msg_ids
):
for
deg
,
vbkt
,
mid
in
zip
(
degs
,
buckets
,
msg_ids
):
# create per-bkt rfunc
# create per-bkt rfunc
rfunc
=
_create_per_bkt_rfunc
(
reduce_udf
,
deg
,
vbkt
)
rfunc
=
_create_per_bkt_rfunc
(
reduce_udf
,
deg
,
vbkt
,
ntype
=
ntype
)
# vars
# vars
vbkt
=
var
.
IDX
(
vbkt
)
vbkt
=
var
.
IDX
(
vbkt
)
mid
=
var
.
IDX
(
mid
)
mid
=
var
.
IDX
(
mid
)
...
@@ -141,7 +145,7 @@ def _process_node_buckets(buckets):
...
@@ -141,7 +145,7 @@ def _process_node_buckets(buckets):
return
v
,
degs
,
dsts
,
msg_ids
,
zero_deg_nodes
return
v
,
degs
,
dsts
,
msg_ids
,
zero_deg_nodes
def
_create_per_bkt_rfunc
(
reduce_udf
,
deg
,
vbkt
):
def
_create_per_bkt_rfunc
(
reduce_udf
,
deg
,
vbkt
,
ntype
=
None
):
"""Internal function to generate the per degree bucket node UDF."""
"""Internal function to generate the per degree bucket node UDF."""
def
_rfunc_wrapper
(
node_data
,
mail_data
):
def
_rfunc_wrapper
(
node_data
,
mail_data
):
def
_reshaped_getter
(
key
):
def
_reshaped_getter
(
key
):
...
@@ -149,7 +153,7 @@ def _create_per_bkt_rfunc(reduce_udf, deg, vbkt):
...
@@ -149,7 +153,7 @@ def _create_per_bkt_rfunc(reduce_udf, deg, vbkt):
new_shape
=
(
len
(
vbkt
),
deg
)
+
F
.
shape
(
msg
)[
1
:]
new_shape
=
(
len
(
vbkt
),
deg
)
+
F
.
shape
(
msg
)[
1
:]
return
F
.
reshape
(
msg
,
new_shape
)
return
F
.
reshape
(
msg
,
new_shape
)
reshaped_mail_data
=
utils
.
LazyDict
(
_reshaped_getter
,
mail_data
.
keys
())
reshaped_mail_data
=
utils
.
LazyDict
(
_reshaped_getter
,
mail_data
.
keys
())
nbatch
=
NodeBatch
(
vbkt
,
node_data
,
reshaped_mail_data
)
nbatch
=
NodeBatch
(
vbkt
,
node_data
,
reshaped_mail_data
,
ntype
=
ntype
)
return
reduce_udf
(
nbatch
)
return
reduce_udf
(
nbatch
)
return
_rfunc_wrapper
return
_rfunc_wrapper
...
@@ -160,7 +164,8 @@ def gen_group_apply_edge_schedule(
...
@@ -160,7 +164,8 @@ def gen_group_apply_edge_schedule(
var_src_nf
,
var_src_nf
,
var_dst_nf
,
var_dst_nf
,
var_ef
,
var_ef
,
var_out
):
var_out
,
canonical_etype
=
(
None
,
None
,
None
)):
"""Create degree bucketing schedule for group_apply_edge
"""Create degree bucketing schedule for group_apply_edge
Edges will be grouped by either its source node or destination node
Edges will be grouped by either its source node or destination node
...
@@ -189,6 +194,9 @@ def gen_group_apply_edge_schedule(
...
@@ -189,6 +194,9 @@ def gen_group_apply_edge_schedule(
The variable for edge frame.
The variable for edge frame.
var_out : var.FEAT_DICT
var_out : var.FEAT_DICT
The variable for output feature dicts.
The variable for output feature dicts.
canonical_etype : tuple[str, str, str], optional
Canonical edge type if running on a heterograph.
Default: (None, None, None), if running on a homogeneous graph.
"""
"""
if
group_by
==
"src"
:
if
group_by
==
"src"
:
buckets
=
_degree_bucketing_for_edge_grouping
(
u
,
v
,
eid
)
buckets
=
_degree_bucketing_for_edge_grouping
(
u
,
v
,
eid
)
...
@@ -204,7 +212,8 @@ def gen_group_apply_edge_schedule(
...
@@ -204,7 +212,8 @@ def gen_group_apply_edge_schedule(
for
deg
,
u_bkt
,
v_bkt
,
eid_bkt
in
zip
(
degs
,
uids
,
vids
,
eids
):
for
deg
,
u_bkt
,
v_bkt
,
eid_bkt
in
zip
(
degs
,
uids
,
vids
,
eids
):
# create per-bkt efunc
# create per-bkt efunc
_efunc
=
var
.
FUNC
(
_create_per_bkt_efunc
(
apply_func
,
deg
,
_efunc
=
var
.
FUNC
(
_create_per_bkt_efunc
(
apply_func
,
deg
,
u_bkt
,
v_bkt
,
eid_bkt
))
u_bkt
,
v_bkt
,
eid_bkt
,
canonical_etype
=
canonical_etype
))
# vars
# vars
var_u
=
var
.
IDX
(
u_bkt
)
var_u
=
var
.
IDX
(
u_bkt
)
var_v
=
var
.
IDX
(
v_bkt
)
var_v
=
var
.
IDX
(
v_bkt
)
...
@@ -274,7 +283,7 @@ def _process_edge_buckets(buckets):
...
@@ -274,7 +283,7 @@ def _process_edge_buckets(buckets):
eids
=
split
(
eids
)
eids
=
split
(
eids
)
return
degs
,
uids
,
vids
,
eids
return
degs
,
uids
,
vids
,
eids
def
_create_per_bkt_efunc
(
apply_func
,
deg
,
u
,
v
,
eid
):
def
_create_per_bkt_efunc
(
apply_func
,
deg
,
u
,
v
,
eid
,
canonical_etype
=
(
None
,
None
,
None
)
):
"""Internal function to generate the per degree bucket edge UDF."""
"""Internal function to generate the per degree bucket edge UDF."""
batch_size
=
len
(
u
)
//
deg
batch_size
=
len
(
u
)
//
deg
def
_efunc_wrapper
(
src_data
,
edge_data
,
dst_data
):
def
_efunc_wrapper
(
src_data
,
edge_data
,
dst_data
):
...
@@ -297,7 +306,8 @@ def _create_per_bkt_efunc(apply_func, deg, u, v, eid):
...
@@ -297,7 +306,8 @@ def _create_per_bkt_efunc(apply_func, deg, u, v, eid):
reshaped_dst_data
=
utils
.
LazyDict
(
_reshape_func
(
dst_data
),
reshaped_dst_data
=
utils
.
LazyDict
(
_reshape_func
(
dst_data
),
dst_data
.
keys
())
dst_data
.
keys
())
ebatch
=
EdgeBatch
((
u
,
v
,
eid
),
reshaped_src_data
,
ebatch
=
EdgeBatch
((
u
,
v
,
eid
),
reshaped_src_data
,
reshaped_edge_data
,
reshaped_dst_data
)
reshaped_edge_data
,
reshaped_dst_data
,
canonical_etype
=
canonical_etype
)
return
{
k
:
_reshape_back
(
v
)
for
k
,
v
in
apply_func
(
ebatch
).
items
()}
return
{
k
:
_reshape_back
(
v
)
for
k
,
v
in
apply_func
(
ebatch
).
items
()}
return
_efunc_wrapper
return
_efunc_wrapper
...
...
python/dgl/runtime/scheduler.py
View file @
c7c0fd0e
...
@@ -104,7 +104,7 @@ def schedule_recv(graph,
...
@@ -104,7 +104,7 @@ def schedule_recv(graph,
# 2) no send has been called
# 2) no send has been called
if
apply_func
is
not
None
:
if
apply_func
is
not
None
:
schedule_apply_nodes
(
recv_nodes
,
apply_func
,
graph
.
dstframe
,
schedule_apply_nodes
(
recv_nodes
,
apply_func
,
graph
.
dstframe
,
inplace
,
outframe
)
inplace
,
outframe
,
ntype
=
graph
.
canonical_etype
[
-
1
]
)
else
:
else
:
var_dst_nf
=
var
.
FEAT_DICT
(
graph
.
dstframe
,
'dst_nf'
)
var_dst_nf
=
var
.
FEAT_DICT
(
graph
.
dstframe
,
'dst_nf'
)
var_out_nf
=
var_dst_nf
if
outframe
is
None
else
var
.
FEAT_DICT
(
outframe
,
name
=
'out_nf'
)
var_out_nf
=
var_dst_nf
if
outframe
is
None
else
var
.
FEAT_DICT
(
outframe
,
name
=
'out_nf'
)
...
@@ -117,7 +117,8 @@ def schedule_recv(graph,
...
@@ -117,7 +117,8 @@ def schedule_recv(graph,
recv_nodes
)
recv_nodes
)
# apply
# apply
final_feat
=
_apply_with_accum
(
var_recv_nodes
,
var_dst_nf
,
final_feat
=
_apply_with_accum
(
var_recv_nodes
,
var_dst_nf
,
reduced_feat
,
apply_func
)
reduced_feat
,
apply_func
,
ntype
=
graph
.
canonical_etype
[
-
1
])
if
inplace
:
if
inplace
:
ir
.
WRITE_ROW_INPLACE_
(
var_out_nf
,
var_recv_nodes
,
final_feat
)
ir
.
WRITE_ROW_INPLACE_
(
var_out_nf
,
var_recv_nodes
,
final_feat
)
else
:
else
:
...
@@ -182,10 +183,11 @@ def schedule_snr(graph,
...
@@ -182,10 +183,11 @@ def schedule_snr(graph,
var_reduce_nodes
=
var_recv_nodes
,
var_reduce_nodes
=
var_recv_nodes
,
uv_getter
=
uv_getter
,
uv_getter
=
uv_getter
,
adj_creator
=
adj_creator
,
adj_creator
=
adj_creator
,
out_map_creator
=
out_map_creator
)
out_map_creator
=
out_map_creator
,
canonical_etype
=
graph
.
canonical_etype
)
# generate apply schedule
# generate apply schedule
final_feat
=
_apply_with_accum
(
var_recv_nodes
,
var_dst_nf
,
reduced_feat
,
final_feat
=
_apply_with_accum
(
var_recv_nodes
,
var_dst_nf
,
reduced_feat
,
apply_func
)
apply_func
,
ntype
=
graph
.
canonical_etype
[
-
1
]
)
if
inplace
:
if
inplace
:
ir
.
WRITE_ROW_INPLACE_
(
var_out_nf
,
var_recv_nodes
,
final_feat
)
ir
.
WRITE_ROW_INPLACE_
(
var_out_nf
,
var_recv_nodes
,
final_feat
)
else
:
else
:
...
@@ -216,7 +218,8 @@ def schedule_update_all(graph,
...
@@ -216,7 +218,8 @@ def schedule_update_all(graph,
if
apply_func
is
not
None
:
if
apply_func
is
not
None
:
nodes
=
utils
.
toindex
(
slice
(
0
,
graph
.
num_dst
()))
nodes
=
utils
.
toindex
(
slice
(
0
,
graph
.
num_dst
()))
schedule_apply_nodes
(
nodes
,
apply_func
,
graph
.
dstframe
,
schedule_apply_nodes
(
nodes
,
apply_func
,
graph
.
dstframe
,
inplace
=
False
,
outframe
=
outframe
)
inplace
=
False
,
outframe
=
outframe
,
ntype
=
graph
.
canonical_etype
[
-
1
])
else
:
else
:
eid
=
utils
.
toindex
(
slice
(
0
,
graph
.
num_edges
()))
# ALL
eid
=
utils
.
toindex
(
slice
(
0
,
graph
.
num_edges
()))
# ALL
recv_nodes
=
utils
.
toindex
(
slice
(
0
,
graph
.
num_dst
()))
# ALL
recv_nodes
=
utils
.
toindex
(
slice
(
0
,
graph
.
num_dst
()))
# ALL
...
@@ -240,17 +243,20 @@ def schedule_update_all(graph,
...
@@ -240,17 +243,20 @@ def schedule_update_all(graph,
var_reduce_nodes
=
var_recv_nodes
,
var_reduce_nodes
=
var_recv_nodes
,
uv_getter
=
uv_getter
,
uv_getter
=
uv_getter
,
adj_creator
=
adj_creator
,
adj_creator
=
adj_creator
,
out_map_creator
=
out_map_creator
)
out_map_creator
=
out_map_creator
,
canonical_etype
=
graph
.
canonical_etype
)
# generate optional apply
# generate optional apply
final_feat
=
_apply_with_accum
(
var_recv_nodes
,
var_dst_nf
,
final_feat
=
_apply_with_accum
(
var_recv_nodes
,
var_dst_nf
,
reduced_feat
,
apply_func
)
reduced_feat
,
apply_func
,
ntype
=
graph
.
canonical_etype
[
-
1
])
ir
.
WRITE_DICT_
(
var_out_nf
,
final_feat
)
ir
.
WRITE_DICT_
(
var_out_nf
,
final_feat
)
def
schedule_apply_nodes
(
v
,
def
schedule_apply_nodes
(
v
,
apply_func
,
apply_func
,
node_frame
,
node_frame
,
inplace
,
inplace
,
outframe
=
None
):
outframe
=
None
,
ntype
=
None
):
"""Get apply nodes schedule
"""Get apply nodes schedule
Parameters
Parameters
...
@@ -265,6 +271,9 @@ def schedule_apply_nodes(v,
...
@@ -265,6 +271,9 @@ def schedule_apply_nodes(v,
If True, the update will be done in place
If True, the update will be done in place
outframe : FrameRef, optional
outframe : FrameRef, optional
The storage to write output data. If None, use the given node_frame.
The storage to write output data. If None, use the given node_frame.
ntype : str, optional
The node type, if running on a heterograph.
If None, assuming it's running on a homogeneous graph.
Returns
Returns
-------
-------
...
@@ -275,7 +284,7 @@ def schedule_apply_nodes(v,
...
@@ -275,7 +284,7 @@ def schedule_apply_nodes(v,
var_out_nf
=
var_nf
if
outframe
is
None
else
var
.
FEAT_DICT
(
outframe
,
name
=
'out_nf'
)
var_out_nf
=
var_nf
if
outframe
is
None
else
var
.
FEAT_DICT
(
outframe
,
name
=
'out_nf'
)
v_nf
=
ir
.
READ_ROW
(
var_nf
,
var_v
)
v_nf
=
ir
.
READ_ROW
(
var_nf
,
var_v
)
def
_afunc_wrapper
(
node_data
):
def
_afunc_wrapper
(
node_data
):
nbatch
=
NodeBatch
(
v
,
node_data
)
nbatch
=
NodeBatch
(
v
,
node_data
,
ntype
=
ntype
)
return
apply_func
(
nbatch
)
return
apply_func
(
nbatch
)
afunc
=
var
.
FUNC
(
_afunc_wrapper
)
afunc
=
var
.
FUNC
(
_afunc_wrapper
)
applied_feat
=
ir
.
NODE_UDF
(
afunc
,
v_nf
)
applied_feat
=
ir
.
NODE_UDF
(
afunc
,
v_nf
)
...
@@ -472,7 +481,8 @@ def schedule_pull(graph,
...
@@ -472,7 +481,8 @@ def schedule_pull(graph,
if
len
(
eid
)
==
0
:
if
len
(
eid
)
==
0
:
# All the nodes are 0deg; downgrades to apply.
# All the nodes are 0deg; downgrades to apply.
if
apply_func
is
not
None
:
if
apply_func
is
not
None
:
schedule_apply_nodes
(
pull_nodes
,
apply_func
,
graph
.
dstframe
,
inplace
,
outframe
)
schedule_apply_nodes
(
pull_nodes
,
apply_func
,
graph
.
dstframe
,
inplace
,
outframe
,
ntype
=
graph
.
canonical_etype
[
-
1
])
else
:
else
:
pull_nodes
,
_
=
F
.
sort_1d
(
F
.
unique
(
pull_nodes
.
tousertensor
()))
pull_nodes
,
_
=
F
.
sort_1d
(
F
.
unique
(
pull_nodes
.
tousertensor
()))
pull_nodes
=
utils
.
toindex
(
pull_nodes
)
pull_nodes
=
utils
.
toindex
(
pull_nodes
)
...
@@ -492,10 +502,12 @@ def schedule_pull(graph,
...
@@ -492,10 +502,12 @@ def schedule_pull(graph,
graph
.
dstframe
,
graph
.
edgeframe
,
graph
.
dstframe
,
graph
.
edgeframe
,
message_func
,
reduce_func
,
var_eid
,
message_func
,
reduce_func
,
var_eid
,
var_pull_nodes
,
uv_getter
,
adj_creator
,
var_pull_nodes
,
uv_getter
,
adj_creator
,
out_map_creator
)
out_map_creator
,
canonical_etype
=
graph
.
canonical_etype
)
# generate optional apply
# generate optional apply
final_feat
=
_apply_with_accum
(
var_pull_nodes
,
var_dst_nf
,
final_feat
=
_apply_with_accum
(
var_pull_nodes
,
var_dst_nf
,
reduced_feat
,
apply_func
)
reduced_feat
,
apply_func
,
ntype
=
graph
.
canonical_etype
[
-
1
])
if
inplace
:
if
inplace
:
ir
.
WRITE_ROW_INPLACE_
(
var_out_nf
,
var_pull_nodes
,
final_feat
)
ir
.
WRITE_ROW_INPLACE_
(
var_out_nf
,
var_pull_nodes
,
final_feat
)
else
:
else
:
...
@@ -535,7 +547,8 @@ def schedule_group_apply_edge(graph,
...
@@ -535,7 +547,8 @@ def schedule_group_apply_edge(graph,
var_out_ef
=
var_ef
if
outframe
is
None
else
var
.
FEAT_DICT
(
outframe
,
name
=
'out_ef'
)
var_out_ef
=
var_ef
if
outframe
is
None
else
var
.
FEAT_DICT
(
outframe
,
name
=
'out_ef'
)
var_out
=
var
.
FEAT_DICT
(
name
=
'new_ef'
)
var_out
=
var
.
FEAT_DICT
(
name
=
'new_ef'
)
db
.
gen_group_apply_edge_schedule
(
apply_func
,
u
,
v
,
eid
,
group_by
,
db
.
gen_group_apply_edge_schedule
(
apply_func
,
u
,
v
,
eid
,
group_by
,
var_src_nf
,
var_dst_nf
,
var_ef
,
var_out
)
var_src_nf
,
var_dst_nf
,
var_ef
,
var_out
,
canonical_etype
=
graph
.
canonical_etype
)
var_eid
=
var
.
IDX
(
eid
)
var_eid
=
var
.
IDX
(
eid
)
if
inplace
:
if
inplace
:
ir
.
WRITE_ROW_INPLACE_
(
var_out_ef
,
var_eid
,
var_out
)
ir
.
WRITE_ROW_INPLACE_
(
var_out_ef
,
var_eid
,
var_out
)
...
@@ -700,7 +713,7 @@ def _standardize_func_usage(func, func_name):
...
@@ -700,7 +713,7 @@ def _standardize_func_usage(func, func_name):
' Got: %s'
%
(
func_name
,
str
(
func
)))
' Got: %s'
%
(
func_name
,
str
(
func
)))
return
func
return
func
def
_apply_with_accum
(
var_nodes
,
var_nf
,
var_accum
,
apply_func
):
def
_apply_with_accum
(
var_nodes
,
var_nf
,
var_accum
,
apply_func
,
ntype
=
None
):
"""Apply with accumulated features.
"""Apply with accumulated features.
Paramters
Paramters
...
@@ -713,6 +726,9 @@ def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func):
...
@@ -713,6 +726,9 @@ def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func):
The accumulated features.
The accumulated features.
apply_func : callable, None
apply_func : callable, None
The apply function.
The apply function.
ntype : str, optional
The node type, if running on a heterograph.
If None, assuming it's running on a homogeneous graph.
"""
"""
if
apply_func
:
if
apply_func
:
# To avoid writing reduced features back to node frame and reading
# To avoid writing reduced features back to node frame and reading
...
@@ -722,7 +738,7 @@ def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func):
...
@@ -722,7 +738,7 @@ def _apply_with_accum(var_nodes, var_nf, var_accum, apply_func):
v_nf
=
ir
.
UPDATE_DICT
(
v_nf
,
var_accum
)
v_nf
=
ir
.
UPDATE_DICT
(
v_nf
,
var_accum
)
def
_afunc_wrapper
(
node_data
):
def
_afunc_wrapper
(
node_data
):
nbatch
=
NodeBatch
(
var_nodes
.
data
,
node_data
)
nbatch
=
NodeBatch
(
var_nodes
.
data
,
node_data
,
ntype
=
ntype
)
return
apply_func
(
nbatch
)
return
apply_func
(
nbatch
)
afunc
=
var
.
FUNC
(
_afunc_wrapper
)
afunc
=
var
.
FUNC
(
_afunc_wrapper
)
applied_feat
=
ir
.
NODE_UDF
(
afunc
,
v_nf
)
applied_feat
=
ir
.
NODE_UDF
(
afunc
,
v_nf
)
...
@@ -778,7 +794,8 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
...
@@ -778,7 +794,8 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
else
:
else
:
# gen degree bucketing schedule for UDF recv
# gen degree bucketing schedule for UDF recv
db
.
gen_degree_bucketing_schedule
(
rfunc
,
eid
,
dst
,
recv_nodes
,
db
.
gen_degree_bucketing_schedule
(
rfunc
,
eid
,
dst
,
recv_nodes
,
var_dst_nf
,
var_msg
,
var_out
)
var_dst_nf
,
var_msg
,
var_out
,
ntype
=
graph
.
canonical_etype
[
-
1
])
return
var_out
return
var_out
def
_gen_send_reduce
(
def
_gen_send_reduce
(
...
@@ -791,7 +808,8 @@ def _gen_send_reduce(
...
@@ -791,7 +808,8 @@ def _gen_send_reduce(
var_reduce_nodes
,
var_reduce_nodes
,
uv_getter
,
uv_getter
,
adj_creator
,
adj_creator
,
out_map_creator
):
out_map_creator
,
canonical_etype
=
(
None
,
None
,
None
)):
"""Generate send and reduce schedule.
"""Generate send and reduce schedule.
The function generates symbolic program for computing
The function generates symbolic program for computing
...
@@ -832,9 +850,12 @@ def _gen_send_reduce(
...
@@ -832,9 +850,12 @@ def _gen_send_reduce(
adj_creator : callable
adj_creator : callable
Function that returns the adjmat, edge order of csr matrix, and
Function that returns the adjmat, edge order of csr matrix, and
bit-width.
bit-width.
out_map_creator: callable
out_map_creator
: callable
A function that returns a mapping from reduce_nodes to relabeled
A function that returns a mapping from reduce_nodes to relabeled
consecutive ids
consecutive ids
canonical_etype : tuple[str, str, str], optional
Canonical edge type if running on a heterograph.
Default: (None, None, None), if running on a homogeneous graph.
Returns
Returns
-------
-------
...
@@ -917,7 +938,7 @@ def _gen_send_reduce(
...
@@ -917,7 +938,7 @@ def _gen_send_reduce(
else
:
else
:
# generate UDF send schedule
# generate UDF send schedule
var_mf
=
_gen_udf_send
(
var_src_nf
,
var_dst_nf
,
var_ef
,
var_u
,
var_mf
=
_gen_udf_send
(
var_src_nf
,
var_dst_nf
,
var_ef
,
var_u
,
var_v
,
var_eid
,
mfunc
)
var_v
,
var_eid
,
mfunc
,
canonical_etype
=
canonical_etype
)
# 6. Generate reduce
# 6. Generate reduce
if
rfunc_is_list
:
if
rfunc_is_list
:
...
@@ -935,17 +956,18 @@ def _gen_send_reduce(
...
@@ -935,17 +956,18 @@ def _gen_send_reduce(
mid
=
utils
.
toindex
(
slice
(
0
,
len
(
var_v
.
data
)))
mid
=
utils
.
toindex
(
slice
(
0
,
len
(
var_v
.
data
)))
db
.
gen_degree_bucketing_schedule
(
rfunc
,
mid
,
var_v
.
data
,
db
.
gen_degree_bucketing_schedule
(
rfunc
,
mid
,
var_v
.
data
,
reduce_nodes
,
var_dst_nf
,
var_mf
,
reduce_nodes
,
var_dst_nf
,
var_mf
,
var_out
)
var_out
,
ntype
=
canonical_etype
[
-
1
]
)
return
var_out
return
var_out
def
_gen_udf_send
(
var_src_nf
,
var_dst_nf
,
var_ef
,
u
,
v
,
eid
,
mfunc
):
def
_gen_udf_send
(
var_src_nf
,
var_dst_nf
,
var_ef
,
u
,
v
,
eid
,
mfunc
,
canonical_etype
=
(
None
,
None
,
None
)):
"""Internal function to generate send schedule for UDF message function."""
"""Internal function to generate send schedule for UDF message function."""
fdsrc
=
ir
.
READ_ROW
(
var_src_nf
,
u
)
fdsrc
=
ir
.
READ_ROW
(
var_src_nf
,
u
)
fddst
=
ir
.
READ_ROW
(
var_dst_nf
,
v
)
fddst
=
ir
.
READ_ROW
(
var_dst_nf
,
v
)
fdedge
=
ir
.
READ_ROW
(
var_ef
,
eid
)
fdedge
=
ir
.
READ_ROW
(
var_ef
,
eid
)
def
_mfunc_wrapper
(
src_data
,
edge_data
,
dst_data
):
def
_mfunc_wrapper
(
src_data
,
edge_data
,
dst_data
):
ebatch
=
EdgeBatch
((
u
.
data
,
v
.
data
,
eid
.
data
),
ebatch
=
EdgeBatch
((
u
.
data
,
v
.
data
,
eid
.
data
),
src_data
,
edge_data
,
dst_data
)
src_data
,
edge_data
,
dst_data
,
canonical_etype
=
canonical_etype
)
return
mfunc
(
ebatch
)
return
mfunc
(
ebatch
)
_mfunc_wrapper
=
var
.
FUNC
(
_mfunc_wrapper
)
_mfunc_wrapper
=
var
.
FUNC
(
_mfunc_wrapper
)
msg
=
ir
.
EDGE_UDF
(
_mfunc_wrapper
,
fdsrc
,
fdedge
,
fddst
)
msg
=
ir
.
EDGE_UDF
(
_mfunc_wrapper
,
fdsrc
,
fdedge
,
fddst
)
...
@@ -982,7 +1004,8 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
...
@@ -982,7 +1004,8 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
else
:
else
:
# UDF send
# UDF send
var_out
=
_gen_udf_send
(
var_src_nf
,
var_dst_nf
,
var_ef
,
var_u
,
var_out
=
_gen_udf_send
(
var_src_nf
,
var_dst_nf
,
var_ef
,
var_u
,
var_v
,
var_eid
,
mfunc
)
var_v
,
var_eid
,
mfunc
,
canonical_etype
=
graph
.
canonical_etype
)
return
var_out
return
var_out
def
_build_idx_map
(
idx
,
nbits
):
def
_build_idx_map
(
idx
,
nbits
):
...
...
python/dgl/udf.py
View file @
c7c0fd0e
...
@@ -17,12 +17,17 @@ class EdgeBatch(object):
...
@@ -17,12 +17,17 @@ class EdgeBatch(object):
dst_data : dict of tensors
dst_data : dict of tensors
The dst node features, in the form of ``dict``
The dst node features, in the form of ``dict``
with ``str`` keys and ``tensor`` values
with ``str`` keys and ``tensor`` values
canonical_etype : tuple of (str, str, str), optional
Canonical edge type of the edge batch, if UDF is
running on a heterograph.
"""
"""
def
__init__
(
self
,
edges
,
src_data
,
edge_data
,
dst_data
):
def
__init__
(
self
,
edges
,
src_data
,
edge_data
,
dst_data
,
canonical_etype
=
(
None
,
None
,
None
)):
self
.
_edges
=
edges
self
.
_edges
=
edges
self
.
_src_data
=
src_data
self
.
_src_data
=
src_data
self
.
_edge_data
=
edge_data
self
.
_edge_data
=
edge_data
self
.
_dst_data
=
dst_data
self
.
_dst_data
=
dst_data
self
.
_canonical_etype
=
canonical_etype
@
property
@
property
def
src
(
self
):
def
src
(
self
):
...
@@ -89,6 +94,12 @@ class EdgeBatch(object):
...
@@ -89,6 +94,12 @@ class EdgeBatch(object):
"""
"""
return
self
.
batch_size
()
return
self
.
batch_size
()
@
property
def
canonical_etype
(
self
):
"""Return the canonical edge type (i.e. triplet of source, edge, and
destination node type) for this edge batch, if available."""
return
self
.
_canonical_etype
class
NodeBatch
(
object
):
class
NodeBatch
(
object
):
"""The class that can represent a batch of nodes.
"""The class that can represent a batch of nodes.
...
@@ -102,11 +113,15 @@ class NodeBatch(object):
...
@@ -102,11 +113,15 @@ class NodeBatch(object):
msgs : dict, optional
msgs : dict, optional
The messages, , in the form of ``dict``
The messages, , in the form of ``dict``
with ``str`` keys and ``tensor`` values
with ``str`` keys and ``tensor`` values
ntype : str, optional
The node type of this node batch, if running
on a heterograph.
"""
"""
def
__init__
(
self
,
nodes
,
data
,
msgs
=
None
):
def
__init__
(
self
,
nodes
,
data
,
msgs
=
None
,
ntype
=
None
):
self
.
_nodes
=
nodes
self
.
_nodes
=
nodes
self
.
_data
=
data
self
.
_data
=
data
self
.
_msgs
=
msgs
self
.
_msgs
=
msgs
self
.
_ntype
=
ntype
@
property
@
property
def
data
(
self
):
def
data
(
self
):
...
@@ -160,3 +175,8 @@ class NodeBatch(object):
...
@@ -160,3 +175,8 @@ class NodeBatch(object):
int
int
"""
"""
return
self
.
batch_size
()
return
self
.
batch_size
()
@
property
def
ntype
(
self
):
"""Return the node type of this node batch, if available."""
return
self
.
_ntype
tests/compute/test_heterograph.py
View file @
c7c0fd0e
...
@@ -1412,6 +1412,63 @@ def test_compact():
...
@@ -1412,6 +1412,63 @@ def test_compact():
_check
(
g2
,
new_g2
,
induced_nodes
)
_check
(
g2
,
new_g2
,
induced_nodes
)
def
test_types_in_function
():
def
mfunc1
(
edges
):
assert
edges
.
canonical_etype
==
(
'user'
,
'follow'
,
'user'
)
return
{}
def
rfunc1
(
nodes
):
assert
nodes
.
ntype
==
'user'
return
{}
def
filter_nodes1
(
nodes
):
assert
nodes
.
ntype
==
'user'
return
F
.
zeros
((
3
,))
def
filter_edges1
(
edges
):
assert
edges
.
canonical_etype
==
(
'user'
,
'follow'
,
'user'
)
return
F
.
zeros
((
2
,))
def
mfunc2
(
edges
):
assert
edges
.
canonical_etype
==
(
'user'
,
'plays'
,
'game'
)
return
{}
def
rfunc2
(
nodes
):
assert
nodes
.
ntype
==
'game'
return
{}
def
filter_nodes2
(
nodes
):
assert
nodes
.
ntype
==
'game'
return
F
.
zeros
((
3
,))
def
filter_edges2
(
edges
):
assert
edges
.
canonical_etype
==
(
'user'
,
'plays'
,
'game'
)
return
F
.
zeros
((
2
,))
g
=
dgl
.
graph
([(
0
,
1
),
(
1
,
2
)],
'user'
,
'follow'
)
g
.
apply_nodes
(
rfunc1
)
g
.
apply_edges
(
mfunc1
)
g
.
update_all
(
mfunc1
,
rfunc1
)
g
.
send_and_recv
([
0
,
1
],
mfunc1
,
rfunc1
)
g
.
send
([
0
,
1
],
mfunc1
)
g
.
recv
([
1
,
2
],
rfunc1
)
g
.
push
([
0
],
mfunc1
,
rfunc1
)
g
.
pull
([
1
],
mfunc1
,
rfunc1
)
g
.
filter_nodes
(
filter_nodes1
)
g
.
filter_edges
(
filter_edges1
)
g
=
dgl
.
bipartite
([(
0
,
1
),
(
1
,
2
)],
'user'
,
'plays'
,
'game'
)
g
.
apply_nodes
(
rfunc2
,
ntype
=
'game'
)
g
.
apply_edges
(
mfunc2
)
g
.
update_all
(
mfunc2
,
rfunc2
)
g
.
send_and_recv
([
0
,
1
],
mfunc2
,
rfunc2
)
g
.
send
([
0
,
1
],
mfunc2
)
g
.
recv
([
1
,
2
],
rfunc2
)
g
.
push
([
0
],
mfunc2
,
rfunc2
)
g
.
pull
([
1
],
mfunc2
,
rfunc2
)
g
.
filter_nodes
(
filter_nodes2
,
ntype
=
'game'
)
g
.
filter_edges
(
filter_edges2
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_create
()
test_create
()
test_query
()
test_query
()
...
@@ -1433,3 +1490,4 @@ if __name__ == '__main__':
...
@@ -1433,3 +1490,4 @@ if __name__ == '__main__':
test_backward
()
test_backward
()
test_empty_heterograph
()
test_empty_heterograph
()
test_compact
()
test_compact
()
test_types_in_function
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment