Unverified Commit 02029dce authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Sync initializers (#772)

parent 165d4538
...@@ -684,7 +684,7 @@ class SharedMemoryDGLGraph(BaseGraphStore): ...@@ -684,7 +684,7 @@ class SharedMemoryDGLGraph(BaseGraphStore):
raise Exception("graph store only supports CPU context for node data") raise Exception("graph store only supports CPU context for node data")
init = self._node_frame.get_initializer(ndata_name) init = self._node_frame.get_initializer(ndata_name)
if init is None: if init is None:
self._node_frame._frame._warn_and_set_initializer() self._node_frame._frame._set_zero_default_initializer()
init = self._node_frame.get_initializer(ndata_name) init = self._node_frame.get_initializer(ndata_name)
init = self._init_manager.serialize(init) init = self._init_manager.serialize(init)
self.proxy.init_ndata(init, ndata_name, tuple(shape), dtype) self.proxy.init_ndata(init, ndata_name, tuple(shape), dtype)
...@@ -712,7 +712,7 @@ class SharedMemoryDGLGraph(BaseGraphStore): ...@@ -712,7 +712,7 @@ class SharedMemoryDGLGraph(BaseGraphStore):
raise Exception("graph store only supports CPU context for edge data") raise Exception("graph store only supports CPU context for edge data")
init = self._edge_frame.get_initializer(edata_name) init = self._edge_frame.get_initializer(edata_name)
if init is None: if init is None:
self._edge_frame._frame._warn_and_set_initializer() self._edge_frame._frame._set_zero_default_initializer()
init = self._edge_frame.get_initializer(edata_name) init = self._edge_frame.get_initializer(edata_name)
init = self._init_manager.serialize(init) init = self._init_manager.serialize(init)
self.proxy.init_edata(init, edata_name, tuple(shape), dtype) self.proxy.init_edata(init, edata_name, tuple(shape), dtype)
......
...@@ -215,10 +215,8 @@ class Frame(MutableMapping): ...@@ -215,10 +215,8 @@ class Frame(MutableMapping):
self._remote_init_builder = None self._remote_init_builder = None
self._default_initializer = None self._default_initializer = None
def _warn_and_set_initializer(self): def _set_zero_default_initializer(self):
dgl_warning('Initializer is not set. Use zero initializer instead.' """Set the default initializer to be zero initializer."""
' To suppress this warning, use `set_initializer` to'
' explicitly specify which initializer to use.')
self._default_initializer = zero_initializer self._default_initializer = zero_initializer
def get_initializer(self, column=None): def get_initializer(self, column=None):
...@@ -279,7 +277,7 @@ class Frame(MutableMapping): ...@@ -279,7 +277,7 @@ class Frame(MutableMapping):
return None return None
if self.get_initializer(name) is None: if self.get_initializer(name) is None:
self._warn_and_set_initializer() self._set_zero_default_initializer()
initializer = self.get_initializer(name) initializer = self.get_initializer(name)
return self._remote_init_builder(initializer, name) return self._remote_init_builder(initializer, name)
...@@ -364,7 +362,7 @@ class Frame(MutableMapping): ...@@ -364,7 +362,7 @@ class Frame(MutableMapping):
init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype, ctx) init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype, ctx)
else: else:
if self.get_initializer(name) is None: if self.get_initializer(name) is None:
self._warn_and_set_initializer() self._set_zero_default_initializer()
initializer = self.get_initializer(name) initializer = self.get_initializer(name)
init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype, init_data = initializer((self.num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(0, self.num_rows)) ctx, slice(0, self.num_rows))
...@@ -386,7 +384,7 @@ class Frame(MutableMapping): ...@@ -386,7 +384,7 @@ class Frame(MutableMapping):
scheme = col.scheme scheme = col.scheme
ctx = F.context(col.data) ctx = F.context(col.data)
if self.get_initializer(key) is None: if self.get_initializer(key) is None:
self._warn_and_set_initializer() self._set_zero_default_initializer()
initializer = self.get_initializer(key) initializer = self.get_initializer(key)
new_data = initializer((num_rows,) + scheme.shape, scheme.dtype, new_data = initializer((num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(self._num_rows, self._num_rows + num_rows)) ctx, slice(self._num_rows, self._num_rows + num_rows))
...@@ -433,7 +431,7 @@ class Frame(MutableMapping): ...@@ -433,7 +431,7 @@ class Frame(MutableMapping):
scheme = col.scheme scheme = col.scheme
ctx = F.context(col.data) ctx = F.context(col.data)
if self.get_initializer(key) is None: if self.get_initializer(key) is None:
self._warn_and_set_initializer() self._set_zero_default_initializer()
initializer = self.get_initializer(key) initializer = self.get_initializer(key)
new_data = initializer((other.num_rows,) + scheme.shape, new_data = initializer((other.num_rows,) + scheme.shape,
scheme.dtype, ctx, scheme.dtype, ctx,
...@@ -902,10 +900,23 @@ def frame_like(other, num_rows): ...@@ -902,10 +900,23 @@ def frame_like(other, num_rows):
newf = Frame(num_rows=num_rows) newf = Frame(num_rows=num_rows)
# set global initializr # set global initializr
if other.get_initializer() is None: if other.get_initializer() is None:
other._warn_and_set_initializer() other._set_zero_default_initializer()
newf._default_initializer = other._default_initializer sync_frame_initializer(newf, other)
return newf
def sync_frame_initializer(new_frame, reference_frame):
"""Set the initializers of the new_frame to be the same as the reference_frame,
for both the default initializer and per-column initializers.
Parameters
----------
new_frame : Frame
The frame to set initializers
reference_frame : Frame
The frame to copy initializers
"""
new_frame._default_initializer = reference_frame._default_initializer
# set per-col initializer # set per-col initializer
# TODO(minjie): hack; cannot rely on keys as the _initializers # TODO(minjie): hack; cannot rely on keys as the _initializers
# now supports non-exist columns. # now supports non-exist columns.
newf._initializers = other._initializers new_frame._initializers = reference_frame._initializers
return newf
...@@ -9,7 +9,7 @@ import dgl ...@@ -9,7 +9,7 @@ import dgl
from .base import ALL, is_all, DGLError from .base import ALL, is_all, DGLError
from . import backend as F from . import backend as F
from . import init from . import init
from .frame import FrameRef, Frame, Scheme from .frame import FrameRef, Frame, Scheme, sync_frame_initializer
from . import graph_index from . import graph_index
from .runtime import ir, scheduler, Runtime from .runtime import ir, scheduler, Runtime
from . import utils from . import utils
...@@ -3353,6 +3353,9 @@ class DGLGraph(DGLBaseGraph): ...@@ -3353,6 +3353,9 @@ class DGLGraph(DGLBaseGraph):
However, any out-place mutation to the feature data will not reflect to this graph, However, any out-place mutation to the feature data will not reflect to this graph,
thus making it easier to use in a function scope. thus making it easier to use in a function scope.
If set, the local graph object will use same initializers for node features and
edge features.
Examples Examples
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
...@@ -3401,9 +3404,16 @@ class DGLGraph(DGLBaseGraph): ...@@ -3401,9 +3404,16 @@ class DGLGraph(DGLBaseGraph):
DGLGraph DGLGraph
The graph object that can be used as a local variable. The graph object that can be used as a local variable.
""" """
local_node_frame = FrameRef(Frame(self._node_frame._frame))
local_edge_frame = FrameRef(Frame(self._edge_frame._frame))
# Use same per-column initializers and default initializer.
# If registered, a column (based on key) initializer will be used first,
# otherwise the default initializer will be used.
sync_frame_initializer(local_node_frame._frame, self._node_frame._frame)
sync_frame_initializer(local_edge_frame._frame, self._edge_frame._frame)
return DGLGraph(self._graph, return DGLGraph(self._graph,
FrameRef(Frame(self._node_frame._frame)), local_node_frame,
FrameRef(Frame(self._edge_frame._frame))) local_edge_frame)
@contextmanager @contextmanager
def local_scope(self): def local_scope(self):
...@@ -3412,6 +3422,9 @@ class DGLGraph(DGLBaseGraph): ...@@ -3412,6 +3422,9 @@ class DGLGraph(DGLBaseGraph):
By entering a local scope, any out-place mutation to the feature data will By entering a local scope, any out-place mutation to the feature data will
not reflect to the original graph, thus making it easier to use in a function scope. not reflect to the original graph, thus making it easier to use in a function scope.
If set, the local scope will use same initializers for node features and
edge features.
Examples Examples
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
...@@ -3451,6 +3464,11 @@ class DGLGraph(DGLBaseGraph): ...@@ -3451,6 +3464,11 @@ class DGLGraph(DGLBaseGraph):
old_eframe = self._edge_frame old_eframe = self._edge_frame
self._node_frame = FrameRef(Frame(self._node_frame._frame)) self._node_frame = FrameRef(Frame(self._node_frame._frame))
self._edge_frame = FrameRef(Frame(self._edge_frame._frame)) self._edge_frame = FrameRef(Frame(self._edge_frame._frame))
# Use same per-column initializers and default initializer.
# If registered, a column (based on key) initializer will be used first,
# otherwise the default initializer will be used.
sync_frame_initializer(self._node_frame._frame, old_nframe._frame)
sync_frame_initializer(self._edge_frame._frame, old_eframe._frame)
yield yield
self._node_frame = old_nframe self._node_frame = old_nframe
self._edge_frame = old_eframe self._edge_frame = old_eframe
...@@ -691,6 +691,28 @@ def test_local_var(): ...@@ -691,6 +691,28 @@ def test_local_var():
assert 'hh' not in g.ndata assert 'hh' not in g.ndata
assert 'ww' not in g.edata assert 'ww' not in g.edata
# test initializer1
g = DGLGraph()
g.add_nodes(2)
g.add_edges([0, 1], [1, 1])
g.set_n_initializer(dgl.init.zero_initializer)
def foo(g):
g = g.local_var()
g.nodes[0].data['h'] = F.ones((1, 1))
assert F.allclose(g.ndata['h'], F.tensor([[1.], [0.]]))
foo(g)
# test initializer2
def foo_e_initializer(shape, dtype, ctx, id_range):
return F.ones(shape)
g.set_e_initializer(foo_e_initializer, field='h')
def foo(g):
g = g.local_var()
g.edges[0, 1].data['h'] = F.ones((1, 1))
assert F.allclose(g.edata['h'], F.ones((2, 1)))
g.edges[0, 1].data['w'] = F.ones((1, 1))
assert F.allclose(g.edata['w'], F.tensor([[1.], [0.]]))
foo(g)
def test_local_scope(): def test_local_scope():
g = DGLGraph(nx.path_graph(5)) g = DGLGraph(nx.path_graph(5))
g.ndata['h'] = F.zeros((g.number_of_nodes(), 3)) g.ndata['h'] = F.zeros((g.number_of_nodes(), 3))
...@@ -742,6 +764,28 @@ def test_local_scope(): ...@@ -742,6 +764,28 @@ def test_local_scope():
assert 'hh' not in g.ndata assert 'hh' not in g.ndata
assert 'ww' not in g.edata assert 'ww' not in g.edata
# test initializer1
g = DGLGraph()
g.add_nodes(2)
g.add_edges([0, 1], [1, 1])
g.set_n_initializer(dgl.init.zero_initializer)
def foo(g):
with g.local_scope():
g.nodes[0].data['h'] = F.ones((1, 1))
assert F.allclose(g.ndata['h'], F.tensor([[1.], [0.]]))
foo(g)
# test initializer2
def foo_e_initializer(shape, dtype, ctx, id_range):
return F.ones(shape)
g.set_e_initializer(foo_e_initializer, field='h')
def foo(g):
with g.local_scope():
g.edges[0, 1].data['h'] = F.ones((1, 1))
assert F.allclose(g.edata['h'], F.ones((2, 1)))
g.edges[0, 1].data['w'] = F.ones((1, 1))
assert F.allclose(g.edata['w'], F.tensor([[1.], [0.]]))
foo(g)
if __name__ == '__main__': if __name__ == '__main__':
test_nx_conversion() test_nx_conversion()
test_batch_setter_getter() test_batch_setter_getter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment