"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "c4b86b0707b804154fb47dd2425b62b4092284a1"
Unverified Commit fc9d30fa authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Graph] add local scope function (#735)

* add local scope function

* fix lint

* fix docstring

* change local_scope to local_var; add context manager

* address comments
parent e9e587b6
...@@ -95,6 +95,8 @@ Using Node/edge features ...@@ -95,6 +95,8 @@ Using Node/edge features
DGLGraph.edge_attr_schemes DGLGraph.edge_attr_schemes
DGLGraph.set_n_initializer DGLGraph.set_n_initializer
DGLGraph.set_e_initializer DGLGraph.set_e_initializer
DGLGraph.local_var
DGLGraph.local_scope
Computing with DGLGraph Computing with DGLGraph
----------------------- -----------------------
......
...@@ -197,7 +197,9 @@ class Frame(MutableMapping): ...@@ -197,7 +197,9 @@ class Frame(MutableMapping):
# Note that we always create a new column for the given data. # Note that we always create a new column for the given data.
# This avoids two frames accidentally sharing the same column. # This avoids two frames accidentally sharing the same column.
self._columns = {k : Column.create(v) for k, v in data.items()} self._columns = {k : Column.create(v) for k, v in data.items()}
if len(self._columns) != 0: if isinstance(data, (Frame, FrameRef)):
self._num_rows = data.num_rows
elif len(self._columns) != 0:
self._num_rows = len(next(iter(self._columns.values()))) self._num_rows = len(next(iter(self._columns.values())))
else: else:
self._num_rows = 0 self._num_rows = 0
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager
import networkx as nx import networkx as nx
import dgl import dgl
...@@ -3344,3 +3345,112 @@ class DGLGraph(DGLBaseGraph): ...@@ -3344,3 +3345,112 @@ class DGLGraph(DGLBaseGraph):
for k in self.edata.keys(): for k in self.edata.keys():
self.edata[k] = F.copy_to(self.edata[k], ctx) self.edata[k] = F.copy_to(self.edata[k], ctx)
# pylint: enable=invalid-name # pylint: enable=invalid-name
def local_var(self):
"""Return a graph object that can be used in a local function scope.
The returned graph object shares the feature data and graph structure of this graph.
However, any out-place mutation to the feature data will not reflect to this graph,
thus making it easier to use in a function scope.
Examples
--------
The following example uses PyTorch backend.
Avoid accidentally overriding existing feature data. This is quite common when
implementing a NN module:
>>> def foo(g):
>>> g = g.local_var()
>>> g.ndata['h'] = torch.ones((g.number_of_nodes(), 3))
>>> return g.ndata['h']
>>>
>>> g = ... # some graph
>>> g.ndata['h'] = torch.zeros((g.number_of_nodes(), 3))
>>> newh = foo(g) # get tensor of all ones
>>> print(g.ndata['h']) # still get tensor of all zeros
Automatically garbage collect locally-defined tensors without the need to manually
``pop`` the tensors.
>>> def foo(g):
>>> g = g.local_var()
>>> # This 'xxx' feature will stay local and be GCed when the function exits
>>> g.ndata['xxx'] = torch.ones((g.number_of_nodes(), 3))
>>> return g.ndata['xxx']
>>>
>>> g = ... # some graph
>>> xxx = foo(g)
>>> print('xxx' in g.ndata)
False
Notes
-----
Internally, the returned graph shares the same feature tensors, but construct a new
dictionary structure (aka. Frame) so adding/removing feature tensors from the returned
graph will not reflect to the original graph. However, inplace operations do change
the shared tensor values, so will be reflected to the original graph. This function
also has little overhead when the number of feature tensors in this graph is small.
See Also
--------
local_var
Returns
-------
DGLGraph
The graph object that can be used as a local variable.
"""
return DGLGraph(self._graph,
FrameRef(Frame(self._node_frame._frame)),
FrameRef(Frame(self._edge_frame._frame)))
@contextmanager
def local_scope(self):
"""Enter a local scope context for this graph.
By entering a local scope, any out-place mutation to the feature data will
not reflect to the original graph, thus making it easier to use in a function scope.
Examples
--------
The following example uses PyTorch backend.
Avoid accidentally overriding existing feature data. This is quite common when
implementing a NN module:
>>> def foo(g):
>>> with g.local_scope():
>>> g.ndata['h'] = torch.ones((g.number_of_nodes(), 3))
>>> return g.ndata['h']
>>>
>>> g = ... # some graph
>>> g.ndata['h'] = torch.zeros((g.number_of_nodes(), 3))
>>> newh = foo(g) # get tensor of all ones
>>> print(g.ndata['h']) # still get tensor of all zeros
Automatically garbage collect locally-defined tensors without the need to manually
``pop`` the tensors.
>>> def foo(g):
>>> with g.local_scope():
>>> # This 'xxx' feature will stay local and be GCed when the function exits
>>> g.ndata['xxx'] = torch.ones((g.number_of_nodes(), 3))
>>> return g.ndata['xxx']
>>>
>>> g = ... # some graph
>>> xxx = foo(g)
>>> print('xxx' in g.ndata)
False
See Also
--------
local_var
"""
old_nframe = self._node_frame
old_eframe = self._edge_frame
self._node_frame = FrameRef(Frame(self._node_frame._frame))
self._edge_frame = FrameRef(Frame(self._edge_frame._frame))
yield
self._node_frame = old_nframe
self._edge_frame = old_eframe
...@@ -4,7 +4,6 @@ import mxnet as mx ...@@ -4,7 +4,6 @@ import mxnet as mx
from mxnet import gluon from mxnet import gluon
from ... import function as fn from ... import function as fn
from ...utils import get_ndata_name
__all__ = ['GraphConv'] __all__ = ['GraphConv']
...@@ -68,8 +67,6 @@ class GraphConv(gluon.Block): ...@@ -68,8 +67,6 @@ class GraphConv(gluon.Block):
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self._norm = norm self._norm = norm
self._feat_name = "_gconv_feat"
self._msg_name = "_gconv_msg"
with self.name_scope(): with self.name_scope():
self.weight = self.params.get('weight', shape=(in_feats, out_feats), self.weight = self.params.get('weight', shape=(in_feats, out_feats),
...@@ -104,8 +101,7 @@ class GraphConv(gluon.Block): ...@@ -104,8 +101,7 @@ class GraphConv(gluon.Block):
mxnet.NDArray mxnet.NDArray
The output feature The output feature
""" """
self._feat_name = get_ndata_name(graph, self._feat_name) graph = graph.local_var()
if self._norm: if self._norm:
degs = graph.in_degrees().astype('float32') degs = graph.in_degrees().astype('float32')
norm = mx.nd.power(degs, -0.5) norm = mx.nd.power(degs, -0.5)
...@@ -116,16 +112,16 @@ class GraphConv(gluon.Block): ...@@ -116,16 +112,16 @@ class GraphConv(gluon.Block):
if self._in_feats > self._out_feats: if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation. # mult W first to reduce the feature size for aggregation.
feat = mx.nd.dot(feat, self.weight.data(feat.context)) feat = mx.nd.dot(feat, self.weight.data(feat.context))
graph.ndata[self._feat_name] = feat graph.ndata['h'] = feat
graph.update_all(fn.copy_src(src=self._feat_name, out=self._msg_name), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg=self._msg_name, out=self._feat_name)) fn.sum(msg='m', out='h'))
rst = graph.ndata.pop(self._feat_name) rst = graph.ndata.pop('h')
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.ndata[self._feat_name] = feat graph.ndata['h'] = feat
graph.update_all(fn.copy_src(src=self._feat_name, out=self._msg_name), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg=self._msg_name, out=self._feat_name)) fn.sum(msg='m', out='h'))
rst = graph.ndata.pop(self._feat_name) rst = graph.ndata.pop('h')
rst = mx.nd.dot(rst, self.weight.data(feat.context)) rst = mx.nd.dot(rst, self.weight.data(feat.context))
if self._norm: if self._norm:
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# pylint: disable= no-member, arguments-differ # pylint: disable= no-member, arguments-differ
import mxnet as mx import mxnet as mx
from ... import utils
from ... import function as fn from ... import function as fn
__all__ = ['edge_softmax'] __all__ = ['edge_softmax']
...@@ -30,7 +29,9 @@ class EdgeSoftmax(mx.autograd.Function): ...@@ -30,7 +29,9 @@ class EdgeSoftmax(mx.autograd.Function):
self.g = g self.g = g
def forward(self, score): def forward(self, score):
""" """Forward function.
Pseudo-code:
score = dgl.EData(g, score) score = dgl.EData(g, score)
score_max = score.dst_max() # of type dgl.NData score_max = score.dst_max() # of type dgl.NData
score = score - score_max # edge_sub_dst, ret dgl.EData score = score - score_max # edge_sub_dst, ret dgl.EData
...@@ -38,48 +39,39 @@ class EdgeSoftmax(mx.autograd.Function): ...@@ -38,48 +39,39 @@ class EdgeSoftmax(mx.autograd.Function):
out = score / score_sum # edge_div_dst, ret dgl.EData out = score / score_sum # edge_div_dst, ret dgl.EData
return out.data return out.data
""" """
g = self.g g = self.g.local_var()
score_name = utils.get_edata_name(g, 'score') g.edata['s'] = score
tmp_name = utils.get_ndata_name(g, 'tmp') g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax'))
out_name = utils.get_edata_name(g, 'out') g.apply_edges(fn.e_sub_v('s', 'smax', 'out'))
g.edata[score_name] = score g.edata['out'] = g.edata['out'].exp()
g.update_all(fn.copy_e(score_name, 'm'), fn.max('m', tmp_name)) g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum'))
g.apply_edges(fn.e_sub_v(score_name, tmp_name, out_name)) g.apply_edges(fn.e_div_v('out', 'out_sum', 'out'))
g.edata[out_name] = g.edata[out_name].exp() out = g.edata['out']
g.update_all(fn.copy_e(out_name, 'm'), fn.sum('m', tmp_name))
g.apply_edges(fn.e_div_v(out_name, tmp_name, out_name))
g.edata.pop(score_name)
g.ndata.pop(tmp_name)
out = g.edata.pop(out_name)
self.save_for_backward(out) self.save_for_backward(out)
return out return out
def backward(self, grad_out): def backward(self, grad_out):
""" """Backward function.
Pseudo-code:
g, out = ctx.backward_cache g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out) grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out) out = dgl.EData(g, out)
sds = out * grad_out # type dgl.EData sds = out * grad_out # type dgl.EData
sds_sum = sds.dst_sum() # type dgl.NData sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions grad_score = sds - sds * sds_sum # multiple expressions
return grad_score.data
""" """
g = self.g g = self.g.local_var()
out, = self.saved_tensors # pylint: disable=access-member-before-definition, unpacking-non-sequence out, = self.saved_tensors # pylint: disable=access-member-before-definition, unpacking-non-sequence
# clear saved tensors explicitly # clear saved tensors explicitly
self.saved_tensors = None self.saved_tensors = None
out_name = utils.get_edata_name(g, 'out') g.edata['out'] = out
accum_name = utils.get_ndata_name(g, 'accum') g.edata['grad_score'] = out * grad_out
grad_score_name = utils.get_edata_name(g, 'grad_score') g.update_all(fn.copy_e('grad_score', 'm'), fn.sum('m', 'accum'))
g.edata[out_name] = out g.apply_edges(fn.e_mul_v('out', 'accum', 'out'))
g.edata[grad_score_name] = out * grad_out grad_score = g.edata['grad_score'] - g.edata['out']
g.update_all(fn.copy_e(grad_score_name, 'm'), fn.sum('m', accum_name))
g.apply_edges(fn.e_mul_v(out_name, accum_name, out_name))
g.ndata.pop(accum_name)
grad_score = g.edata.pop(grad_score_name) - g.edata.pop(out_name)
return grad_score return grad_score
def edge_softmax(graph, logits): def edge_softmax(graph, logits):
r"""Compute edge softmax. r"""Compute edge softmax.
......
...@@ -5,7 +5,6 @@ from torch import nn ...@@ -5,7 +5,6 @@ from torch import nn
from torch.nn import init from torch.nn import init
from ... import function as fn from ... import function as fn
from ...utils import get_ndata_name
__all__ = ['GraphConv'] __all__ = ['GraphConv']
...@@ -69,8 +68,6 @@ class GraphConv(nn.Module): ...@@ -69,8 +68,6 @@ class GraphConv(nn.Module):
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self._norm = norm self._norm = norm
self._feat_name = "_gconv_feat"
self._msg_name = "_gconv_msg"
self.weight = nn.Parameter(th.Tensor(in_feats, out_feats)) self.weight = nn.Parameter(th.Tensor(in_feats, out_feats))
if bias: if bias:
...@@ -109,8 +106,7 @@ class GraphConv(nn.Module): ...@@ -109,8 +106,7 @@ class GraphConv(nn.Module):
torch.Tensor torch.Tensor
The output feature The output feature
""" """
self._feat_name = get_ndata_name(graph, self._feat_name) graph = graph.local_var()
if self._norm: if self._norm:
norm = th.pow(graph.in_degrees().float(), -0.5) norm = th.pow(graph.in_degrees().float(), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1) shp = norm.shape + (1,) * (feat.dim() - 1)
...@@ -120,16 +116,16 @@ class GraphConv(nn.Module): ...@@ -120,16 +116,16 @@ class GraphConv(nn.Module):
if self._in_feats > self._out_feats: if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation. # mult W first to reduce the feature size for aggregation.
feat = th.matmul(feat, self.weight) feat = th.matmul(feat, self.weight)
graph.ndata[self._feat_name] = feat graph.ndata['h'] = feat
graph.update_all(fn.copy_src(src=self._feat_name, out=self._msg_name), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg=self._msg_name, out=self._feat_name)) fn.sum(msg='m', out='h'))
rst = graph.ndata.pop(self._feat_name) rst = graph.ndata['h']
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.ndata[self._feat_name] = feat graph.ndata['h'] = feat
graph.update_all(fn.copy_src(src=self._feat_name, out=self._msg_name), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg=self._msg_name, out=self._feat_name)) fn.sum(msg='m', out='h'))
rst = graph.ndata.pop(self._feat_name) rst = graph.ndata['h']
rst = th.matmul(rst, self.weight) rst = th.matmul(rst, self.weight)
if self._norm: if self._norm:
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# pylint: disable= no-member, arguments-differ # pylint: disable= no-member, arguments-differ
import torch as th import torch as th
from ... import utils
from ... import function as fn from ... import function as fn
__all__ = ['edge_softmax'] __all__ = ['edge_softmax']
...@@ -27,7 +26,9 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -27,7 +26,9 @@ class EdgeSoftmax(th.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, g, score): def forward(ctx, g, score):
""" """Forward function.
Pseudo-code:
score = dgl.EData(g, score) score = dgl.EData(g, score)
score_max = score.dst_max() # of type dgl.NData score_max = score.dst_max() # of type dgl.NData
score = score - score_max # edge_sub_dst, ret dgl.EData score = score - score_max # edge_sub_dst, ret dgl.EData
...@@ -35,46 +36,43 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -35,46 +36,43 @@ class EdgeSoftmax(th.autograd.Function):
out = score / score_sum # edge_div_dst, ret dgl.EData out = score / score_sum # edge_div_dst, ret dgl.EData
return out.data return out.data
""" """
score_name = utils.get_edata_name(g, 'score') # remember to save the graph to backward cache before making it
tmp_name = utils.get_ndata_name(g, 'tmp') # a local variable
out_name = utils.get_edata_name(g, 'out')
g.edata[score_name] = score
g.update_all(fn.copy_e(score_name, 'm'), fn.max('m', tmp_name))
g.apply_edges(fn.e_sub_v(score_name, tmp_name, out_name))
g.edata[out_name] = th.exp(g.edata[out_name])
g.update_all(fn.copy_e(out_name, 'm'), fn.sum('m', tmp_name))
g.apply_edges(fn.e_div_v(out_name, tmp_name, out_name))
g.edata.pop(score_name)
g.ndata.pop(tmp_name)
out = g.edata.pop(out_name)
ctx.save_for_backward(out)
ctx.backward_cache = g ctx.backward_cache = g
g = g.local_var()
g.edata['s'] = score
g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax'))
g.apply_edges(fn.e_sub_v('s', 'smax', 'out'))
g.edata['out'] = th.exp(g.edata['out'])
g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum'))
g.apply_edges(fn.e_div_v('out', 'out_sum', 'out'))
out = g.edata['out']
ctx.save_for_backward(out)
return out return out
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
""" """Backward function.
Pseudo-code:
g, out = ctx.backward_cache g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out) grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out) out = dgl.EData(g, out)
sds = out * grad_out # type dgl.EData sds = out * grad_out # type dgl.EData
sds_sum = sds.dst_sum() # type dgl.NData sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - out * sds_sum # multiple expressions grad_score = sds - sds * sds_sum # multiple expressions
return grad_score.data return grad_score.data
""" """
g = ctx.backward_cache g = ctx.backward_cache
g = g.local_var()
out, = ctx.saved_tensors out, = ctx.saved_tensors
# clear backward cache explicitly # clear backward cache explicitly
ctx.backward_cache = None ctx.backward_cache = None
out_name = utils.get_edata_name(g, 'out') g.edata['out'] = out
accum_name = utils.get_ndata_name(g, 'accum') g.edata['grad_s'] = out * grad_out
grad_score_name = utils.get_edata_name(g, 'grad_score') g.update_all(fn.copy_e('grad_s', 'm'), fn.sum('m', 'accum'))
g.edata[out_name] = out g.apply_edges(fn.e_mul_v('out', 'accum', 'out'))
g.edata[grad_score_name] = out * grad_out grad_score = g.edata['grad_s'] - g.edata['out']
g.update_all(fn.copy_e(grad_score_name, 'm'), fn.sum('m', accum_name))
g.apply_edges(fn.e_mul_v(out_name, accum_name, out_name))
g.ndata.pop(accum_name)
grad_score = g.edata.pop(grad_score_name) - g.edata.pop(out_name)
return None, grad_score return None, grad_score
......
...@@ -490,48 +490,6 @@ def is_iterable(obj): ...@@ -490,48 +490,6 @@ def is_iterable(obj):
"""Return true if the object is an iterable.""" """Return true if the object is an iterable."""
return isinstance(obj, Iterable) return isinstance(obj, Iterable)
def get_ndata_name(g, name):
"""Return a node data name that does not exist in the given graph.
The given name is directly returned if it does not exist in the given graph.
Parameters
----------
g : DGLGraph
The graph.
name : str
The proposed name.
Returns
-------
str
The node data name that does not exist.
"""
while name in g.ndata:
name += '_'
return name
def get_edata_name(g, name):
"""Return an edge data name that does not exist in the given graph.
The given name is directly returned if it does not exist in the given graph.
Parameters
----------
g : DGLGraph
The graph.
name : str
The proposed name.
Returns
-------
str
The node data name that does not exist.
"""
while name in g.edata:
name += '_'
return name
def to_dgl_context(ctx): def to_dgl_context(ctx):
"""Convert a backend context to DGLContext""" """Convert a backend context to DGLContext"""
device_type = nd.DGLContext.STR2MASK[F.device_type(ctx)] device_type = nd.DGLContext.STR2MASK[F.device_type(ctx)]
......
...@@ -654,6 +654,93 @@ def test_group_apply_edges(): ...@@ -654,6 +654,93 @@ def test_group_apply_edges():
# test group by destination nodes # test group by destination nodes
_test('dst') _test('dst')
def test_local_var():
g = DGLGraph(nx.path_graph(5))
g.ndata['h'] = F.zeros((g.number_of_nodes(), 3))
g.edata['w'] = F.zeros((g.number_of_edges(), 4))
# test override
def foo(g):
g = g.local_var()
g.ndata['h'] = F.ones((g.number_of_nodes(), 3))
g.edata['w'] = F.ones((g.number_of_edges(), 4))
foo(g)
assert F.allclose(g.ndata['h'], F.zeros((g.number_of_nodes(), 3)))
assert F.allclose(g.edata['w'], F.zeros((g.number_of_edges(), 4)))
# test out-place update
def foo(g):
g = g.local_var()
g.nodes[[2, 3]].data['h'] = F.ones((2, 3))
g.edges[[2, 3]].data['w'] = F.ones((2, 4))
foo(g)
assert F.allclose(g.ndata['h'], F.zeros((g.number_of_nodes(), 3)))
assert F.allclose(g.edata['w'], F.zeros((g.number_of_edges(), 4)))
# test out-place update 2
def foo(g):
g = g.local_var()
g.apply_nodes(lambda nodes: {'h' : nodes.data['h'] + 10}, [2, 3])
g.apply_edges(lambda edges: {'w' : edges.data['w'] + 10}, [2, 3])
foo(g)
assert F.allclose(g.ndata['h'], F.zeros((g.number_of_nodes(), 3)))
assert F.allclose(g.edata['w'], F.zeros((g.number_of_edges(), 4)))
# test auto-pop
def foo(g):
g = g.local_var()
g.ndata['hh'] = F.ones((g.number_of_nodes(), 3))
g.edata['ww'] = F.ones((g.number_of_edges(), 4))
foo(g)
assert 'hh' not in g.ndata
assert 'ww' not in g.edata
def test_local_scope():
g = DGLGraph(nx.path_graph(5))
g.ndata['h'] = F.zeros((g.number_of_nodes(), 3))
g.edata['w'] = F.zeros((g.number_of_edges(), 4))
# test override
def foo(g):
with g.local_scope():
g.ndata['h'] = F.ones((g.number_of_nodes(), 3))
g.edata['w'] = F.ones((g.number_of_edges(), 4))
foo(g)
assert F.allclose(g.ndata['h'], F.zeros((g.number_of_nodes(), 3)))
assert F.allclose(g.edata['w'], F.zeros((g.number_of_edges(), 4)))
# test out-place update
def foo(g):
with g.local_scope():
g.nodes[[2, 3]].data['h'] = F.ones((2, 3))
g.edges[[2, 3]].data['w'] = F.ones((2, 4))
foo(g)
assert F.allclose(g.ndata['h'], F.zeros((g.number_of_nodes(), 3)))
assert F.allclose(g.edata['w'], F.zeros((g.number_of_edges(), 4)))
# test out-place update 2
def foo(g):
with g.local_scope():
g.apply_nodes(lambda nodes: {'h' : nodes.data['h'] + 10}, [2, 3])
g.apply_edges(lambda edges: {'w' : edges.data['w'] + 10}, [2, 3])
foo(g)
assert F.allclose(g.ndata['h'], F.zeros((g.number_of_nodes(), 3)))
assert F.allclose(g.edata['w'], F.zeros((g.number_of_edges(), 4)))
# test auto-pop
def foo(g):
with g.local_scope():
g.ndata['hh'] = F.ones((g.number_of_nodes(), 3))
g.edata['ww'] = F.ones((g.number_of_edges(), 4))
foo(g)
assert 'hh' not in g.ndata
assert 'ww' not in g.edata
# test nested scope
def foo(g):
with g.local_scope():
g.ndata['hh'] = F.ones((g.number_of_nodes(), 3))
g.edata['ww'] = F.ones((g.number_of_edges(), 4))
with g.local_scope():
g.ndata['hhh'] = F.ones((g.number_of_nodes(), 3))
g.edata['www'] = F.ones((g.number_of_edges(), 4))
assert 'hhh' not in g.ndata
assert 'www' not in g.edata
foo(g)
assert 'hh' not in g.ndata
assert 'ww' not in g.edata
if __name__ == '__main__': if __name__ == '__main__':
test_nx_conversion() test_nx_conversion()
...@@ -672,3 +759,5 @@ if __name__ == '__main__': ...@@ -672,3 +759,5 @@ if __name__ == '__main__':
test_dynamic_addition() test_dynamic_addition()
test_repr() test_repr()
test_group_apply_edges() test_group_apply_edges()
test_local_var()
test_local_scope()
...@@ -24,10 +24,14 @@ def test_graph_conv(): ...@@ -24,10 +24,14 @@ def test_graph_conv():
# test#1: basic # test#1: basic
h0 = mx.nd.ones((3, 5)) h0 = mx.nd.ones((3, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias)) check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias))
# test#2: more-dim # test#2: more-dim
h0 = mx.nd.ones((3, 5, 5)) h0 = mx.nd.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias)) check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias))
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, 2)
...@@ -36,9 +40,13 @@ def test_graph_conv(): ...@@ -36,9 +40,13 @@ def test_graph_conv():
# test#3: basic # test#3: basic
h0 = mx.nd.ones((3, 5)) h0 = mx.nd.ones((3, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
# test#4: basic # test#4: basic
h0 = mx.nd.ones((3, 5, 5)) h0 = mx.nd.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, 2)
conv.initialize(ctx=ctx) conv.initialize(ctx=ctx)
...@@ -47,14 +55,21 @@ def test_graph_conv(): ...@@ -47,14 +55,21 @@ def test_graph_conv():
# test#3: basic # test#3: basic
h0 = mx.nd.ones((3, 5)) h0 = mx.nd.ones((3, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
# test#4: basic # test#4: basic
h0 = mx.nd.ones((3, 5, 5)) h0 = mx.nd.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
# test repeated features # test not override features
g.ndata["_gconv_feat"] = 2 * mx.nd.ones((3, 1)) g.ndata["h"] = 2 * mx.nd.ones((3, 1))
h1 = conv(h0, g) h1 = conv(h0, g)
assert "_gconv_feat" in g.ndata assert len(g.ndata) == 1
assert len(g.edata) == 0
assert "h" in g.ndata
check_eq(g.ndata['h'], 2 * mx.nd.ones((3, 1)))
def uniform_attention(g, shape): def uniform_attention(g, shape):
a = mx.nd.ones(shape) a = mx.nd.ones(shape)
...@@ -66,12 +81,16 @@ def test_edge_softmax(): ...@@ -66,12 +81,16 @@ def test_edge_softmax():
g = dgl.DGLGraph(nx.path_graph(3)) g = dgl.DGLGraph(nx.path_graph(3))
edata = mx.nd.ones((g.number_of_edges(), 1)) edata = mx.nd.ones((g.number_of_edges(), 1))
a = nn.edge_softmax(g, edata) a = nn.edge_softmax(g, edata)
assert len(g.ndata) == 0
assert len(g.edata) == 0
assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(), assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
1e-4, 1e-4) 1e-4, 1e-4)
# Test higher dimension case # Test higher dimension case
edata = mx.nd.ones((g.number_of_edges(), 3, 1)) edata = mx.nd.ones((g.number_of_edges(), 3, 1))
a = nn.edge_softmax(g, edata) a = nn.edge_softmax(g, edata)
assert len(g.ndata) == 0
assert len(g.edata) == 0
assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(), assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
1e-4, 1e-4) 1e-4, 1e-4)
......
...@@ -21,27 +21,39 @@ def test_graph_conv(): ...@@ -21,27 +21,39 @@ def test_graph_conv():
# test#1: basic # test#1: basic
h0 = th.ones((3, 5)) h0 = th.ones((3, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
assert th.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias)) assert th.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
# test#2: more-dim # test#2: more-dim
h0 = th.ones((3, 5, 5)) h0 = th.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
assert th.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias)) assert th.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, 2)
# test#3: basic # test#3: basic
h0 = th.ones((3, 5)) h0 = th.ones((3, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
# test#4: basic # test#4: basic
h0 = th.ones((3, 5, 5)) h0 = th.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, 2)
# test#3: basic # test#3: basic
h0 = th.ones((3, 5)) h0 = th.ones((3, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
# test#4: basic # test#4: basic
h0 = th.ones((3, 5, 5)) h0 = th.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
# test rest_parameters # test rest_parameters
old_weight = deepcopy(conv.weight.data) old_weight = deepcopy(conv.weight.data)
...@@ -59,11 +71,15 @@ def test_edge_softmax(): ...@@ -59,11 +71,15 @@ def test_edge_softmax():
g = dgl.DGLGraph(nx.path_graph(3)) g = dgl.DGLGraph(nx.path_graph(3))
edata = th.ones(g.number_of_edges(), 1) edata = th.ones(g.number_of_edges(), 1)
a = nn.edge_softmax(g, edata) a = nn.edge_softmax(g, edata)
assert len(g.ndata) == 0
assert len(g.edata) == 0
assert th.allclose(a, uniform_attention(g, a.shape)) assert th.allclose(a, uniform_attention(g, a.shape))
# Test higher dimension case # Test higher dimension case
edata = th.ones(g.number_of_edges(), 3, 1) edata = th.ones(g.number_of_edges(), 3, 1)
a = nn.edge_softmax(g, edata) a = nn.edge_softmax(g, edata)
assert len(g.ndata) == 0
assert len(g.edata) == 0
assert th.allclose(a, uniform_attention(g, a.shape)) assert th.allclose(a, uniform_attention(g, a.shape))
# Test both forward and backward with PyTorch built-in softmax. # Test both forward and backward with PyTorch built-in softmax.
...@@ -82,6 +98,8 @@ def test_edge_softmax(): ...@@ -82,6 +98,8 @@ def test_edge_softmax():
grad_score = score.grad grad_score = score.grad
score.grad.zero_() score.grad.zero_()
y_dgl = nn.edge_softmax(g, score) y_dgl = nn.edge_softmax(g, score)
assert len(g.ndata) == 0
assert len(g.edata) == 0
# check forward # check forward
assert th.allclose(y_dgl, y) assert th.allclose(y_dgl, y)
y_dgl.backward(grad) y_dgl.backward(grad)
...@@ -104,6 +122,8 @@ def test_edge_softmax(): ...@@ -104,6 +122,8 @@ def test_edge_softmax():
builtin_sm = nn.edge_softmax(g, a2) builtin_sm = nn.edge_softmax(g, a2)
builtin_sm.sum().backward() builtin_sm.sum().backward()
print(a1.grad - a2.grad) print(a1.grad - a2.grad)
assert len(g.ndata) == 0
assert len(g.edata) == 2
assert th.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend assert th.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment