Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b0e02e5b
Commit
b0e02e5b
authored
Oct 05, 2018
by
Da Zheng
Committed by
Minjie Wang
Oct 05, 2018
Browse files
Update the subgraph (#73)
* update subgraph. * update subgraph API. * keep node embedding.
parent
2883eda6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
29 deletions
+29
-29
python/dgl/frame.py
python/dgl/frame.py
+13
-6
python/dgl/graph.py
python/dgl/graph.py
+3
-3
python/dgl/subgraph.py
python/dgl/subgraph.py
+12
-19
tests/pytorch/test_subgraph.py
tests/pytorch/test_subgraph.py
+1
-1
No files found.
python/dgl/frame.py
View file @
b0e02e5b
...
@@ -141,7 +141,7 @@ class FrameRef(MutableMapping):
...
@@ -141,7 +141,7 @@ class FrameRef(MutableMapping):
else
:
else
:
self
.
update_rows
(
key
,
val
)
self
.
update_rows
(
key
,
val
)
def
add_column
(
self
,
name
,
col
):
def
add_column
(
self
,
name
,
col
,
inplace
=
False
):
shp
=
F
.
shape
(
col
)
shp
=
F
.
shape
(
col
)
if
self
.
is_span_whole_column
():
if
self
.
is_span_whole_column
():
if
self
.
num_columns
==
0
:
if
self
.
num_columns
==
0
:
...
@@ -157,18 +157,25 @@ class FrameRef(MutableMapping):
...
@@ -157,18 +157,25 @@ class FrameRef(MutableMapping):
fcol
=
F
.
zeros
((
self
.
_frame
.
num_rows
,)
+
shp
[
1
:])
fcol
=
F
.
zeros
((
self
.
_frame
.
num_rows
,)
+
shp
[
1
:])
fcol
=
F
.
to_context
(
fcol
,
colctx
)
fcol
=
F
.
to_context
(
fcol
,
colctx
)
idx
=
self
.
index
().
tousertensor
(
colctx
)
idx
=
self
.
index
().
tousertensor
(
colctx
)
newfcol
=
F
.
scatter_row
(
fcol
,
idx
,
col
)
if
inplace
:
self
.
_frame
[
name
]
=
newfcol
self
.
_frame
[
name
]
=
fcol
self
.
_frame
[
name
][
idx
]
=
col
else
:
newfcol
=
F
.
scatter_row
(
fcol
,
idx
,
col
)
self
.
_frame
[
name
]
=
newfcol
def
update_rows
(
self
,
query
,
other
):
def
update_rows
(
self
,
query
,
other
,
inplace
=
False
):
rowids
=
self
.
_getrowid
(
query
)
rowids
=
self
.
_getrowid
(
query
)
for
key
,
col
in
other
.
items
():
for
key
,
col
in
other
.
items
():
if
key
not
in
self
:
if
key
not
in
self
:
# add new column
# add new column
tmpref
=
FrameRef
(
self
.
_frame
,
rowids
)
tmpref
=
FrameRef
(
self
.
_frame
,
rowids
)
tmpref
.
add_column
(
key
,
col
)
tmpref
.
add_column
(
key
,
col
,
inplace
)
idx
=
rowids
.
tousertensor
(
F
.
get_context
(
self
.
_frame
[
key
]))
idx
=
rowids
.
tousertensor
(
F
.
get_context
(
self
.
_frame
[
key
]))
self
.
_frame
[
key
]
=
F
.
scatter_row
(
self
.
_frame
[
key
],
idx
,
col
)
if
inplace
:
self
.
_frame
[
key
][
idx
]
=
col
else
:
self
.
_frame
[
key
]
=
F
.
scatter_row
(
self
.
_frame
[
key
],
idx
,
col
)
def
__delitem__
(
self
,
key
):
def
__delitem__
(
self
,
key
):
if
isinstance
(
key
,
str
):
if
isinstance
(
key
,
str
):
...
...
python/dgl/graph.py
View file @
b0e02e5b
...
@@ -486,7 +486,7 @@ class DGLGraph(object):
...
@@ -486,7 +486,7 @@ class DGLGraph(object):
"""
"""
return
self
.
_edge_frame
.
schemes
return
self
.
_edge_frame
.
schemes
def
set_n_repr
(
self
,
hu
,
u
=
ALL
):
def
set_n_repr
(
self
,
hu
,
u
=
ALL
,
inplace
=
False
):
"""Set node(s) representation.
"""Set node(s) representation.
To set multiple node representations at once, pass `u` with a tensor or
To set multiple node representations at once, pass `u` with a tensor or
...
@@ -524,9 +524,9 @@ class DGLGraph(object):
...
@@ -524,9 +524,9 @@ class DGLGraph(object):
self
.
_node_frame
[
__REPR__
]
=
hu
self
.
_node_frame
[
__REPR__
]
=
hu
else
:
else
:
if
utils
.
is_dict_like
(
hu
):
if
utils
.
is_dict_like
(
hu
):
self
.
_node_frame
[
u
]
=
hu
self
.
_node_frame
.
update_rows
(
u
,
hu
,
inplace
=
inplace
)
else
:
else
:
self
.
_node_frame
[
u
]
=
{
__REPR__
:
hu
}
self
.
_node_frame
.
update_rows
(
u
,
{
__REPR__
:
hu
}
,
inplace
=
inplace
)
def
get_n_repr
(
self
,
u
=
ALL
):
def
get_n_repr
(
self
,
u
=
ALL
):
"""Get node(s) representation.
"""Get node(s) representation.
...
...
python/dgl/subgraph.py
View file @
b0e02e5b
...
@@ -15,28 +15,21 @@ class DGLSubGraph(DGLGraph):
...
@@ -15,28 +15,21 @@ class DGLSubGraph(DGLGraph):
nodes
):
nodes
):
super
(
DGLSubGraph
,
self
).
__init__
()
super
(
DGLSubGraph
,
self
).
__init__
()
# relabel nodes
# relabel nodes
self
.
_
node_mapping
=
utils
.
build_relabel_dict
(
nodes
)
self
.
_
parent
=
parent
self
.
_parent_nid
=
utils
.
toindex
(
nodes
)
self
.
_parent_nid
=
utils
.
toindex
(
nodes
)
eids
=
[]
self
.
_graph
,
self
.
_parent_eid
=
parent
.
_graph
.
node_subgraph
(
self
.
_parent_nid
)
# create subgraph
self
.
reset_messages
()
for
eid
,
(
u
,
v
)
in
enumerate
(
parent
.
edge_list
):
if
u
in
self
.
_node_mapping
and
v
in
self
.
_node_mapping
:
self
.
add_edge
(
self
.
_node_mapping
[
u
],
self
.
_node_mapping
[
v
])
eids
.
append
(
eid
)
self
.
_parent_eid
=
utils
.
toindex
(
eids
)
def
copy_from
(
self
,
parent
):
def
copy_to_parent
(
self
,
inplace
=
False
):
self
.
_parent
.
_node_frame
.
update_rows
(
self
.
_parent_nid
,
self
.
_node_frame
,
inplace
=
inplace
)
self
.
_parent
.
_edge_frame
.
update_rows
(
self
.
_parent_eid
,
self
.
_edge_frame
,
inplace
=
inplace
)
def
copy_from_parent
(
self
):
"""Copy node/edge features from the parent graph.
"""Copy node/edge features from the parent graph.
All old features will be removed.
All old features will be removed.
Parameters
----------
parent : DGLGraph
The parent graph to copy from.
"""
"""
if
parent
.
_node_frame
.
num_rows
!=
0
:
if
self
.
_
parent
.
_node_frame
.
num_rows
!=
0
:
self
.
_node_frame
=
FrameRef
(
Frame
(
parent
.
_node_frame
[
self
.
_parent_nid
]))
self
.
_node_frame
=
FrameRef
(
Frame
(
self
.
_
parent
.
_node_frame
[
self
.
_parent_nid
]))
if
parent
.
_edge_frame
.
num_rows
!=
0
:
if
self
.
_
parent
.
_edge_frame
.
num_rows
!=
0
:
self
.
_edge_frame
=
FrameRef
(
Frame
(
parent
.
_edge_frame
[
self
.
_parent_eid
]))
self
.
_edge_frame
=
FrameRef
(
Frame
(
self
.
_
parent
.
_edge_frame
[
self
.
_parent_eid
]))
tests/pytorch/test_subgraph.py
View file @
b0e02e5b
...
@@ -35,7 +35,7 @@ def test_basics():
...
@@ -35,7 +35,7 @@ def test_basics():
assert
len
(
sg
.
get_n_repr
())
==
0
assert
len
(
sg
.
get_n_repr
())
==
0
assert
len
(
sg
.
get_e_repr
())
==
0
assert
len
(
sg
.
get_e_repr
())
==
0
# the data is copied after explict copy from
# the data is copied after explict copy from
sg
.
copy_from
(
g
)
sg
.
copy_from
_parent
()
assert
len
(
sg
.
get_n_repr
())
==
1
assert
len
(
sg
.
get_n_repr
())
==
1
assert
len
(
sg
.
get_e_repr
())
==
1
assert
len
(
sg
.
get_e_repr
())
==
1
sh
=
sg
.
get_n_repr
()[
'h'
]
sh
=
sg
.
get_n_repr
()[
'h'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment