Unverified Commit 02029dce authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Sync initializers (#772)

parent 165d4538
......@@ -684,7 +684,7 @@ class SharedMemoryDGLGraph(BaseGraphStore):
raise Exception("graph store only supports CPU context for node data")
init = self._node_frame.get_initializer(ndata_name)
if init is None:
self._node_frame._frame._warn_and_set_initializer()
self._node_frame._frame._set_zero_default_initializer()
init = self._node_frame.get_initializer(ndata_name)
init = self._init_manager.serialize(init)
self.proxy.init_ndata(init, ndata_name, tuple(shape), dtype)
......@@ -712,7 +712,7 @@ class SharedMemoryDGLGraph(BaseGraphStore):
raise Exception("graph store only supports CPU context for edge data")
init = self._edge_frame.get_initializer(edata_name)
if init is None:
self._edge_frame._frame._warn_and_set_initializer()
self._edge_frame._frame._set_zero_default_initializer()
init = self._edge_frame.get_initializer(edata_name)
init = self._init_manager.serialize(init)
self.proxy.init_edata(init, edata_name, tuple(shape), dtype)
......
......@@ -215,10 +215,8 @@ class Frame(MutableMapping):
self._remote_init_builder = None
self._default_initializer = None
def _warn_and_set_initializer(self):
dgl_warning('Initializer is not set. Use zero initializer instead.'
' To suppress this warning, use `set_initializer` to'
' explicitly specify which initializer to use.')
def _set_zero_default_initializer(self):
"""Set the default initializer to be zero initializer."""
self._default_initializer = zero_initializer
def get_initializer(self, column=None):
......@@ -279,7 +277,7 @@ class Frame(MutableMapping):
return None
if self.get_initializer(name) is None:
self._warn_and_set_initializer()
self._set_zero_default_initializer()
initializer = self.get_initializer(name)
return self._remote_init_builder(initializer, name)
......@@ -364,7 +362,7 @@ class Frame(MutableMapping):
init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype, ctx)
else:
if self.get_initializer(name) is None:
self._warn_and_set_initializer()
self._set_zero_default_initializer()
initializer = self.get_initializer(name)
init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(0, self.num_rows))
......@@ -386,7 +384,7 @@ class Frame(MutableMapping):
scheme = col.scheme
ctx = F.context(col.data)
if self.get_initializer(key) is None:
self._warn_and_set_initializer()
self._set_zero_default_initializer()
initializer = self.get_initializer(key)
new_data = initializer((num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(self._num_rows, self._num_rows + num_rows))
......@@ -433,7 +431,7 @@ class Frame(MutableMapping):
scheme = col.scheme
ctx = F.context(col.data)
if self.get_initializer(key) is None:
self._warn_and_set_initializer()
self._set_zero_default_initializer()
initializer = self.get_initializer(key)
new_data = initializer((other.num_rows,) + scheme.shape,
scheme.dtype, ctx,
......@@ -902,10 +900,23 @@ def frame_like(other, num_rows):
newf = Frame(num_rows=num_rows)
# set global initializr
if other.get_initializer() is None:
other._warn_and_set_initializer()
newf._default_initializer = other._default_initializer
other._set_zero_default_initializer()
sync_frame_initializer(newf, other)
return newf
def sync_frame_initializer(new_frame, reference_frame):
"""Set the initializers of the new_frame to be the same as the reference_frame,
for both the default initializer and per-column initializers.
Parameters
----------
new_frame : Frame
The frame to set initializers
reference_frame : Frame
The frame to copy initializers
"""
new_frame._default_initializer = reference_frame._default_initializer
# set per-col initializer
# TODO(minjie): hack; cannot rely on keys as the _initializers
# now supports non-exist columns.
newf._initializers = other._initializers
return newf
new_frame._initializers = reference_frame._initializers
......@@ -9,7 +9,7 @@ import dgl
from .base import ALL, is_all, DGLError
from . import backend as F
from . import init
from .frame import FrameRef, Frame, Scheme
from .frame import FrameRef, Frame, Scheme, sync_frame_initializer
from . import graph_index
from .runtime import ir, scheduler, Runtime
from . import utils
......@@ -3353,6 +3353,9 @@ class DGLGraph(DGLBaseGraph):
However, any out-place mutation to the feature data will not reflect to this graph,
thus making it easier to use in a function scope.
If set, the local graph object will use same initializers for node features and
edge features.
Examples
--------
The following example uses PyTorch backend.
......@@ -3401,9 +3404,16 @@ class DGLGraph(DGLBaseGraph):
DGLGraph
The graph object that can be used as a local variable.
"""
local_node_frame = FrameRef(Frame(self._node_frame._frame))
local_edge_frame = FrameRef(Frame(self._edge_frame._frame))
# Use same per-column initializers and default initializer.
# If registered, a column (based on key) initializer will be used first,
# otherwise the default initializer will be used.
sync_frame_initializer(local_node_frame._frame, self._node_frame._frame)
sync_frame_initializer(local_edge_frame._frame, self._edge_frame._frame)
return DGLGraph(self._graph,
FrameRef(Frame(self._node_frame._frame)),
FrameRef(Frame(self._edge_frame._frame)))
local_node_frame,
local_edge_frame)
@contextmanager
def local_scope(self):
......@@ -3412,6 +3422,9 @@ class DGLGraph(DGLBaseGraph):
By entering a local scope, any out-place mutation to the feature data will
not reflect to the original graph, thus making it easier to use in a function scope.
If set, the local scope will use same initializers for node features and
edge features.
Examples
--------
The following example uses PyTorch backend.
......@@ -3451,6 +3464,11 @@ class DGLGraph(DGLBaseGraph):
old_eframe = self._edge_frame
self._node_frame = FrameRef(Frame(self._node_frame._frame))
self._edge_frame = FrameRef(Frame(self._edge_frame._frame))
# Use same per-column initializers and default initializer.
# If registered, a column (based on key) initializer will be used first,
# otherwise the default initializer will be used.
sync_frame_initializer(self._node_frame._frame, old_nframe._frame)
sync_frame_initializer(self._edge_frame._frame, old_eframe._frame)
yield
self._node_frame = old_nframe
self._edge_frame = old_eframe
......@@ -691,6 +691,28 @@ def test_local_var():
assert 'hh' not in g.ndata
assert 'ww' not in g.edata
# test initializer1
g = DGLGraph()
g.add_nodes(2)
g.add_edges([0, 1], [1, 1])
g.set_n_initializer(dgl.init.zero_initializer)
def foo(g):
g = g.local_var()
g.nodes[0].data['h'] = F.ones((1, 1))
assert F.allclose(g.ndata['h'], F.tensor([[1.], [0.]]))
foo(g)
# test initializer2
def foo_e_initializer(shape, dtype, ctx, id_range):
return F.ones(shape)
g.set_e_initializer(foo_e_initializer, field='h')
def foo(g):
g = g.local_var()
g.edges[0, 1].data['h'] = F.ones((1, 1))
assert F.allclose(g.edata['h'], F.ones((2, 1)))
g.edges[0, 1].data['w'] = F.ones((1, 1))
assert F.allclose(g.edata['w'], F.tensor([[1.], [0.]]))
foo(g)
def test_local_scope():
g = DGLGraph(nx.path_graph(5))
g.ndata['h'] = F.zeros((g.number_of_nodes(), 3))
......@@ -742,6 +764,28 @@ def test_local_scope():
assert 'hh' not in g.ndata
assert 'ww' not in g.edata
# test initializer1
g = DGLGraph()
g.add_nodes(2)
g.add_edges([0, 1], [1, 1])
g.set_n_initializer(dgl.init.zero_initializer)
def foo(g):
with g.local_scope():
g.nodes[0].data['h'] = F.ones((1, 1))
assert F.allclose(g.ndata['h'], F.tensor([[1.], [0.]]))
foo(g)
# test initializer2
def foo_e_initializer(shape, dtype, ctx, id_range):
return F.ones(shape)
g.set_e_initializer(foo_e_initializer, field='h')
def foo(g):
with g.local_scope():
g.edges[0, 1].data['h'] = F.ones((1, 1))
assert F.allclose(g.edata['h'], F.ones((2, 1)))
g.edges[0, 1].data['w'] = F.ones((1, 1))
assert F.allclose(g.edata['w'], F.tensor([[1.], [0.]]))
foo(g)
if __name__ == '__main__':
test_nx_conversion()
test_batch_setter_getter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment