Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
02029dce
Unverified
Commit
02029dce
authored
Aug 17, 2019
by
Mufei Li
Committed by
GitHub
Aug 17, 2019
Browse files
Sync initializers (#772)
parent
165d4538
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
90 additions
and
17 deletions
+90
-17
python/dgl/contrib/graph_store.py
python/dgl/contrib/graph_store.py
+2
-2
python/dgl/frame.py
python/dgl/frame.py
+23
-12
python/dgl/graph.py
python/dgl/graph.py
+21
-3
tests/compute/test_basics.py
tests/compute/test_basics.py
+44
-0
No files found.
python/dgl/contrib/graph_store.py
View file @
02029dce
...
@@ -684,7 +684,7 @@ class SharedMemoryDGLGraph(BaseGraphStore):
...
@@ -684,7 +684,7 @@ class SharedMemoryDGLGraph(BaseGraphStore):
raise
Exception
(
"graph store only supports CPU context for node data"
)
raise
Exception
(
"graph store only supports CPU context for node data"
)
init
=
self
.
_node_frame
.
get_initializer
(
ndata_name
)
init
=
self
.
_node_frame
.
get_initializer
(
ndata_name
)
if
init
is
None
:
if
init
is
None
:
self
.
_node_frame
.
_frame
.
_
warn_and_se
t_initializer
()
self
.
_node_frame
.
_frame
.
_
set_zero_defaul
t_initializer
()
init
=
self
.
_node_frame
.
get_initializer
(
ndata_name
)
init
=
self
.
_node_frame
.
get_initializer
(
ndata_name
)
init
=
self
.
_init_manager
.
serialize
(
init
)
init
=
self
.
_init_manager
.
serialize
(
init
)
self
.
proxy
.
init_ndata
(
init
,
ndata_name
,
tuple
(
shape
),
dtype
)
self
.
proxy
.
init_ndata
(
init
,
ndata_name
,
tuple
(
shape
),
dtype
)
...
@@ -712,7 +712,7 @@ class SharedMemoryDGLGraph(BaseGraphStore):
...
@@ -712,7 +712,7 @@ class SharedMemoryDGLGraph(BaseGraphStore):
raise
Exception
(
"graph store only supports CPU context for edge data"
)
raise
Exception
(
"graph store only supports CPU context for edge data"
)
init
=
self
.
_edge_frame
.
get_initializer
(
edata_name
)
init
=
self
.
_edge_frame
.
get_initializer
(
edata_name
)
if
init
is
None
:
if
init
is
None
:
self
.
_edge_frame
.
_frame
.
_
warn_and_se
t_initializer
()
self
.
_edge_frame
.
_frame
.
_
set_zero_defaul
t_initializer
()
init
=
self
.
_edge_frame
.
get_initializer
(
edata_name
)
init
=
self
.
_edge_frame
.
get_initializer
(
edata_name
)
init
=
self
.
_init_manager
.
serialize
(
init
)
init
=
self
.
_init_manager
.
serialize
(
init
)
self
.
proxy
.
init_edata
(
init
,
edata_name
,
tuple
(
shape
),
dtype
)
self
.
proxy
.
init_edata
(
init
,
edata_name
,
tuple
(
shape
),
dtype
)
...
...
python/dgl/frame.py
View file @
02029dce
...
@@ -215,10 +215,8 @@ class Frame(MutableMapping):
...
@@ -215,10 +215,8 @@ class Frame(MutableMapping):
self
.
_remote_init_builder
=
None
self
.
_remote_init_builder
=
None
self
.
_default_initializer
=
None
self
.
_default_initializer
=
None
def
_warn_and_set_initializer
(
self
):
def
_set_zero_default_initializer
(
self
):
dgl_warning
(
'Initializer is not set. Use zero initializer instead.'
"""Set the default initializer to be zero initializer."""
' To suppress this warning, use `set_initializer` to'
' explicitly specify which initializer to use.'
)
self
.
_default_initializer
=
zero_initializer
self
.
_default_initializer
=
zero_initializer
def
get_initializer
(
self
,
column
=
None
):
def
get_initializer
(
self
,
column
=
None
):
...
@@ -279,7 +277,7 @@ class Frame(MutableMapping):
...
@@ -279,7 +277,7 @@ class Frame(MutableMapping):
return
None
return
None
if
self
.
get_initializer
(
name
)
is
None
:
if
self
.
get_initializer
(
name
)
is
None
:
self
.
_
warn_and_se
t_initializer
()
self
.
_
set_zero_defaul
t_initializer
()
initializer
=
self
.
get_initializer
(
name
)
initializer
=
self
.
get_initializer
(
name
)
return
self
.
_remote_init_builder
(
initializer
,
name
)
return
self
.
_remote_init_builder
(
initializer
,
name
)
...
@@ -364,7 +362,7 @@ class Frame(MutableMapping):
...
@@ -364,7 +362,7 @@ class Frame(MutableMapping):
init_data
=
initializer
((
self
.
num_rows
,)
+
scheme
.
shape
,
scheme
.
dtype
,
ctx
)
init_data
=
initializer
((
self
.
num_rows
,)
+
scheme
.
shape
,
scheme
.
dtype
,
ctx
)
else
:
else
:
if
self
.
get_initializer
(
name
)
is
None
:
if
self
.
get_initializer
(
name
)
is
None
:
self
.
_
warn_and_se
t_initializer
()
self
.
_
set_zero_defaul
t_initializer
()
initializer
=
self
.
get_initializer
(
name
)
initializer
=
self
.
get_initializer
(
name
)
init_data
=
initializer
((
self
.
num_rows
,)
+
scheme
.
shape
,
scheme
.
dtype
,
init_data
=
initializer
((
self
.
num_rows
,)
+
scheme
.
shape
,
scheme
.
dtype
,
ctx
,
slice
(
0
,
self
.
num_rows
))
ctx
,
slice
(
0
,
self
.
num_rows
))
...
@@ -386,7 +384,7 @@ class Frame(MutableMapping):
...
@@ -386,7 +384,7 @@ class Frame(MutableMapping):
scheme
=
col
.
scheme
scheme
=
col
.
scheme
ctx
=
F
.
context
(
col
.
data
)
ctx
=
F
.
context
(
col
.
data
)
if
self
.
get_initializer
(
key
)
is
None
:
if
self
.
get_initializer
(
key
)
is
None
:
self
.
_
warn_and_se
t_initializer
()
self
.
_
set_zero_defaul
t_initializer
()
initializer
=
self
.
get_initializer
(
key
)
initializer
=
self
.
get_initializer
(
key
)
new_data
=
initializer
((
num_rows
,)
+
scheme
.
shape
,
scheme
.
dtype
,
new_data
=
initializer
((
num_rows
,)
+
scheme
.
shape
,
scheme
.
dtype
,
ctx
,
slice
(
self
.
_num_rows
,
self
.
_num_rows
+
num_rows
))
ctx
,
slice
(
self
.
_num_rows
,
self
.
_num_rows
+
num_rows
))
...
@@ -433,7 +431,7 @@ class Frame(MutableMapping):
...
@@ -433,7 +431,7 @@ class Frame(MutableMapping):
scheme
=
col
.
scheme
scheme
=
col
.
scheme
ctx
=
F
.
context
(
col
.
data
)
ctx
=
F
.
context
(
col
.
data
)
if
self
.
get_initializer
(
key
)
is
None
:
if
self
.
get_initializer
(
key
)
is
None
:
self
.
_
warn_and_se
t_initializer
()
self
.
_
set_zero_defaul
t_initializer
()
initializer
=
self
.
get_initializer
(
key
)
initializer
=
self
.
get_initializer
(
key
)
new_data
=
initializer
((
other
.
num_rows
,)
+
scheme
.
shape
,
new_data
=
initializer
((
other
.
num_rows
,)
+
scheme
.
shape
,
scheme
.
dtype
,
ctx
,
scheme
.
dtype
,
ctx
,
...
@@ -902,10 +900,23 @@ def frame_like(other, num_rows):
...
@@ -902,10 +900,23 @@ def frame_like(other, num_rows):
newf
=
Frame
(
num_rows
=
num_rows
)
newf
=
Frame
(
num_rows
=
num_rows
)
# set global initializr
# set global initializr
if
other
.
get_initializer
()
is
None
:
if
other
.
get_initializer
()
is
None
:
other
.
_warn_and_set_initializer
()
other
.
_set_zero_default_initializer
()
newf
.
_default_initializer
=
other
.
_default_initializer
sync_frame_initializer
(
newf
,
other
)
return
newf
def
sync_frame_initializer
(
new_frame
,
reference_frame
):
"""Set the initializers of the new_frame to be the same as the reference_frame,
for both the default initializer and per-column initializers.
Parameters
----------
new_frame : Frame
The frame to set initializers
reference_frame : Frame
The frame to copy initializers
"""
new_frame
.
_default_initializer
=
reference_frame
.
_default_initializer
# set per-col initializer
# set per-col initializer
# TODO(minjie): hack; cannot rely on keys as the _initializers
# TODO(minjie): hack; cannot rely on keys as the _initializers
# now supports non-exist columns.
# now supports non-exist columns.
newf
.
_initializers
=
other
.
_initializers
new_frame
.
_initializers
=
reference_frame
.
_initializers
return
newf
python/dgl/graph.py
View file @
02029dce
...
@@ -9,7 +9,7 @@ import dgl
...
@@ -9,7 +9,7 @@ import dgl
from
.base
import
ALL
,
is_all
,
DGLError
from
.base
import
ALL
,
is_all
,
DGLError
from
.
import
backend
as
F
from
.
import
backend
as
F
from
.
import
init
from
.
import
init
from
.frame
import
FrameRef
,
Frame
,
Scheme
from
.frame
import
FrameRef
,
Frame
,
Scheme
,
sync_frame_initializer
from
.
import
graph_index
from
.
import
graph_index
from
.runtime
import
ir
,
scheduler
,
Runtime
from
.runtime
import
ir
,
scheduler
,
Runtime
from
.
import
utils
from
.
import
utils
...
@@ -3353,6 +3353,9 @@ class DGLGraph(DGLBaseGraph):
...
@@ -3353,6 +3353,9 @@ class DGLGraph(DGLBaseGraph):
However, any out-place mutation to the feature data will not reflect to this graph,
However, any out-place mutation to the feature data will not reflect to this graph,
thus making it easier to use in a function scope.
thus making it easier to use in a function scope.
If set, the local graph object will use same initializers for node features and
edge features.
Examples
Examples
--------
--------
The following example uses PyTorch backend.
The following example uses PyTorch backend.
...
@@ -3401,9 +3404,16 @@ class DGLGraph(DGLBaseGraph):
...
@@ -3401,9 +3404,16 @@ class DGLGraph(DGLBaseGraph):
DGLGraph
DGLGraph
The graph object that can be used as a local variable.
The graph object that can be used as a local variable.
"""
"""
local_node_frame
=
FrameRef
(
Frame
(
self
.
_node_frame
.
_frame
))
local_edge_frame
=
FrameRef
(
Frame
(
self
.
_edge_frame
.
_frame
))
# Use same per-column initializers and default initializer.
# If registered, a column (based on key) initializer will be used first,
# otherwise the default initializer will be used.
sync_frame_initializer
(
local_node_frame
.
_frame
,
self
.
_node_frame
.
_frame
)
sync_frame_initializer
(
local_edge_frame
.
_frame
,
self
.
_edge_frame
.
_frame
)
return
DGLGraph
(
self
.
_graph
,
return
DGLGraph
(
self
.
_graph
,
FrameRef
(
Frame
(
self
.
_node_frame
.
_frame
))
,
local
_node_frame
,
FrameRef
(
Frame
(
self
.
_edge_frame
.
_frame
))
)
local
_edge_frame
)
@
contextmanager
@
contextmanager
def
local_scope
(
self
):
def
local_scope
(
self
):
...
@@ -3412,6 +3422,9 @@ class DGLGraph(DGLBaseGraph):
...
@@ -3412,6 +3422,9 @@ class DGLGraph(DGLBaseGraph):
By entering a local scope, any out-place mutation to the feature data will
By entering a local scope, any out-place mutation to the feature data will
not reflect to the original graph, thus making it easier to use in a function scope.
not reflect to the original graph, thus making it easier to use in a function scope.
If set, the local scope will use same initializers for node features and
edge features.
Examples
Examples
--------
--------
The following example uses PyTorch backend.
The following example uses PyTorch backend.
...
@@ -3451,6 +3464,11 @@ class DGLGraph(DGLBaseGraph):
...
@@ -3451,6 +3464,11 @@ class DGLGraph(DGLBaseGraph):
old_eframe
=
self
.
_edge_frame
old_eframe
=
self
.
_edge_frame
self
.
_node_frame
=
FrameRef
(
Frame
(
self
.
_node_frame
.
_frame
))
self
.
_node_frame
=
FrameRef
(
Frame
(
self
.
_node_frame
.
_frame
))
self
.
_edge_frame
=
FrameRef
(
Frame
(
self
.
_edge_frame
.
_frame
))
self
.
_edge_frame
=
FrameRef
(
Frame
(
self
.
_edge_frame
.
_frame
))
# Use same per-column initializers and default initializer.
# If registered, a column (based on key) initializer will be used first,
# otherwise the default initializer will be used.
sync_frame_initializer
(
self
.
_node_frame
.
_frame
,
old_nframe
.
_frame
)
sync_frame_initializer
(
self
.
_edge_frame
.
_frame
,
old_eframe
.
_frame
)
yield
yield
self
.
_node_frame
=
old_nframe
self
.
_node_frame
=
old_nframe
self
.
_edge_frame
=
old_eframe
self
.
_edge_frame
=
old_eframe
tests/compute/test_basics.py
View file @
02029dce
...
@@ -691,6 +691,28 @@ def test_local_var():
...
@@ -691,6 +691,28 @@ def test_local_var():
assert
'hh'
not
in
g
.
ndata
assert
'hh'
not
in
g
.
ndata
assert
'ww'
not
in
g
.
edata
assert
'ww'
not
in
g
.
edata
# test initializer1
g
=
DGLGraph
()
g
.
add_nodes
(
2
)
g
.
add_edges
([
0
,
1
],
[
1
,
1
])
g
.
set_n_initializer
(
dgl
.
init
.
zero_initializer
)
def
foo
(
g
):
g
=
g
.
local_var
()
g
.
nodes
[
0
].
data
[
'h'
]
=
F
.
ones
((
1
,
1
))
assert
F
.
allclose
(
g
.
ndata
[
'h'
],
F
.
tensor
([[
1.
],
[
0.
]]))
foo
(
g
)
# test initializer2
def
foo_e_initializer
(
shape
,
dtype
,
ctx
,
id_range
):
return
F
.
ones
(
shape
)
g
.
set_e_initializer
(
foo_e_initializer
,
field
=
'h'
)
def
foo
(
g
):
g
=
g
.
local_var
()
g
.
edges
[
0
,
1
].
data
[
'h'
]
=
F
.
ones
((
1
,
1
))
assert
F
.
allclose
(
g
.
edata
[
'h'
],
F
.
ones
((
2
,
1
)))
g
.
edges
[
0
,
1
].
data
[
'w'
]
=
F
.
ones
((
1
,
1
))
assert
F
.
allclose
(
g
.
edata
[
'w'
],
F
.
tensor
([[
1.
],
[
0.
]]))
foo
(
g
)
def
test_local_scope
():
def
test_local_scope
():
g
=
DGLGraph
(
nx
.
path_graph
(
5
))
g
=
DGLGraph
(
nx
.
path_graph
(
5
))
g
.
ndata
[
'h'
]
=
F
.
zeros
((
g
.
number_of_nodes
(),
3
))
g
.
ndata
[
'h'
]
=
F
.
zeros
((
g
.
number_of_nodes
(),
3
))
...
@@ -742,6 +764,28 @@ def test_local_scope():
...
@@ -742,6 +764,28 @@ def test_local_scope():
assert
'hh'
not
in
g
.
ndata
assert
'hh'
not
in
g
.
ndata
assert
'ww'
not
in
g
.
edata
assert
'ww'
not
in
g
.
edata
# test initializer1
g
=
DGLGraph
()
g
.
add_nodes
(
2
)
g
.
add_edges
([
0
,
1
],
[
1
,
1
])
g
.
set_n_initializer
(
dgl
.
init
.
zero_initializer
)
def
foo
(
g
):
with
g
.
local_scope
():
g
.
nodes
[
0
].
data
[
'h'
]
=
F
.
ones
((
1
,
1
))
assert
F
.
allclose
(
g
.
ndata
[
'h'
],
F
.
tensor
([[
1.
],
[
0.
]]))
foo
(
g
)
# test initializer2
def
foo_e_initializer
(
shape
,
dtype
,
ctx
,
id_range
):
return
F
.
ones
(
shape
)
g
.
set_e_initializer
(
foo_e_initializer
,
field
=
'h'
)
def
foo
(
g
):
with
g
.
local_scope
():
g
.
edges
[
0
,
1
].
data
[
'h'
]
=
F
.
ones
((
1
,
1
))
assert
F
.
allclose
(
g
.
edata
[
'h'
],
F
.
ones
((
2
,
1
)))
g
.
edges
[
0
,
1
].
data
[
'w'
]
=
F
.
ones
((
1
,
1
))
assert
F
.
allclose
(
g
.
edata
[
'w'
],
F
.
tensor
([[
1.
],
[
0.
]]))
foo
(
g
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_nx_conversion
()
test_nx_conversion
()
test_batch_setter_getter
()
test_batch_setter_getter
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment