Unverified Commit f25bc176 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Hetero] Improve speed of several Hetero APIs (#1486)

* add clone function to frame

* add utest

* replace all local_var with local_scope

* fix utest

* avoid creating canonical types in __getitem__

* lint

* try another utest  appraoch for mx

* utest
parent 3c4506e9
...@@ -982,6 +982,21 @@ def logical_not(input): ...@@ -982,6 +982,21 @@ def logical_not(input):
""" """
pass pass
def clone(input):
"""Return a clone of the input tensor.
Parameters
----------
input : Tensor
Input tensor.
Returns
-------
Tensor
A clone tensor.
"""
pass
############################################################################### ###############################################################################
# Tensor functions used *only* on index tensor # Tensor functions used *only* on index tensor
# ---------------- # ----------------
......
...@@ -313,6 +313,9 @@ def equal(x, y): ...@@ -313,6 +313,9 @@ def equal(x, y):
def logical_not(input): def logical_not(input):
return nd.logical_not(input) return nd.logical_not(input)
def clone(input):
return input.copy()
def unique(input): def unique(input):
# TODO: fallback to numpy is unfortunate # TODO: fallback to numpy is unfortunate
tmp = input.asnumpy() tmp = input.asnumpy()
......
...@@ -247,6 +247,9 @@ def equal(x, y): ...@@ -247,6 +247,9 @@ def equal(x, y):
def logical_not(input): def logical_not(input):
return ~input return ~input
def clone(input):
return input.clone()
def unique(input): def unique(input):
return th.unique(input) return th.unique(input)
......
...@@ -348,6 +348,9 @@ def equal(x, y): ...@@ -348,6 +348,9 @@ def equal(x, y):
def logical_not(input): def logical_not(input):
return ~input return ~input
def clone(input):
# TF tensor is always immutable so returning the input is safe.
return input
def unique(input): def unique(input):
return tf.unique(input).y return tf.unique(input).y
......
...@@ -20,8 +20,6 @@ def is_all(arg): ...@@ -20,8 +20,6 @@ def is_all(arg):
"""Return true if the argument is a special symbol for all nodes or edges.""" """Return true if the argument is a special symbol for all nodes or edges."""
return isinstance(arg, str) and arg == ALL return isinstance(arg, str) and arg == ALL
def dgl_warning(msg, warn_type=UserWarning): dgl_warning = warnings.warn # pylint: disable=invalid-name
"""Print out warning messages."""
warnings.warn(msg, warn_type)
_init_internal_api() _init_internal_api()
...@@ -123,7 +123,8 @@ def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True ...@@ -123,7 +123,8 @@ def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True
edata_schemes={}) edata_schemes={})
""" """
if card is not None: if card is not None:
dgl_warning("card will be deprecated, please use num_nodes='{}' instead.") dgl_warning("Argument 'card' will be deprecated. "
"Please use num_nodes={} instead.".format(card))
num_nodes = card num_nodes = card
if num_nodes is not None: if num_nodes is not None:
...@@ -271,7 +272,8 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=Non ...@@ -271,7 +272,8 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=Non
if utype == vtype: if utype == vtype:
raise DGLError('utype should not be equal to vtype. Use ``dgl.graph`` instead.') raise DGLError('utype should not be equal to vtype. Use ``dgl.graph`` instead.')
if card is not None: if card is not None:
dgl_warning("card will be deprecated, please use num_nodes='{}' instead.") dgl_warning("Argument 'card' will be deprecated. "
"Please use num_nodes={} instead.".format(card))
num_nodes = card num_nodes = card
if num_nodes is not None: if num_nodes is not None:
urange, vrange = num_nodes urange, vrange = num_nodes
......
...@@ -62,6 +62,13 @@ class Column(object): ...@@ -62,6 +62,13 @@ class Column(object):
The initial data of the column. The initial data of the column.
scheme : Scheme, optional scheme : Scheme, optional
The scheme of the column. Will be inferred if not provided. The scheme of the column. Will be inferred if not provided.
Attributes
----------
data : Tensor
The data of the column.
scheme : Scheme
The scheme of the column.
""" """
def __init__(self, data, scheme=None): def __init__(self, data, scheme=None):
self.data = data self.data = data
...@@ -164,6 +171,10 @@ class Column(object): ...@@ -164,6 +171,10 @@ class Column(object):
feats = F.copy_to(feats, F.context(self.data)) feats = F.copy_to(feats, F.context(self.data))
self.data = F.cat([self.data, feats], dim=0) self.data = F.cat([self.data, feats], dim=0)
def clone(self):
"""Return a deepcopy of this column."""
return Column(F.clone(self.data), self.scheme)
@staticmethod @staticmethod
def create(data): def create(data):
"""Create a new column using the given data.""" """Create a new column using the given data."""
...@@ -255,6 +266,8 @@ class Frame(MutableMapping): ...@@ -255,6 +266,8 @@ class Frame(MutableMapping):
def set_remote_init_builder(self, builder): def set_remote_init_builder(self, builder):
"""Set an initializer builder to create a remote initializer for a new column to a frame. """Set an initializer builder to create a remote initializer for a new column to a frame.
NOTE(minjie): This is a temporary solution. Will be replaced by KVStore in the future.
The builder is a callable that returns an initializer. The returned initializer The builder is a callable that returns an initializer. The returned initializer
is also a callable that returns a tensor given a local tensor and tensor name. is also a callable that returns a tensor given a local tensor and tensor name.
...@@ -268,6 +281,8 @@ class Frame(MutableMapping): ...@@ -268,6 +281,8 @@ class Frame(MutableMapping):
def get_remote_initializer(self, name): def get_remote_initializer(self, name):
"""Get a remote initializer. """Get a remote initializer.
NOTE(minjie): This is a temporary solution. Will be replaced by KVStore in the future.
Parameters Parameters
---------- ----------
name : string name : string
...@@ -478,6 +493,46 @@ class Frame(MutableMapping): ...@@ -478,6 +493,46 @@ class Frame(MutableMapping):
"""Return the keys.""" """Return the keys."""
return self._columns.keys() return self._columns.keys()
def clone(self):
"""Return a clone of this frame.
The clone frame does not share the underlying storage with this frame,
i.e., adding or removing columns will not be visible to each other. However,
they still share the tensor contents so any mutable operation on the column
tensor are visible to each other. Hence, the function does not allocate extra
tensor memory. Use :func:`~dgl.Frame.deepclone` for cloning
a frame that does not share any data.
Returns
-------
Frame
A cloned frame.
"""
newframe = Frame(self._columns, self._num_rows)
newframe._initializers = self._initializers
newframe._remote_init_builder = self._remote_init_builder
newframe._default_initializer = self._default_initializer
return newframe
def deepclone(self):
"""Return a deep clone of this frame.
The clone frame has an copy of this frame and any modification to the clone frame
is not visible to this frame. The function allocate new tensors and copy the contents
from this frame. Use :func:`~dgl.Frame.clone` for cloning a frame that does not
allocate extra tensor memory.
Returns
-------
Frame
A deep-cloned frame.
"""
newframe = Frame({k : col.clone() for k, col in self._columns.items()}, self._num_rows)
newframe._initializers = self._initializers
newframe._remote_init_builder = self._remote_init_builder
newframe._default_initializer = self._default_initializer
return newframe
class FrameRef(MutableMapping): class FrameRef(MutableMapping):
"""Reference object to a frame on a subset of rows. """Reference object to a frame on a subset of rows.
...@@ -538,6 +593,8 @@ class FrameRef(MutableMapping): ...@@ -538,6 +593,8 @@ class FrameRef(MutableMapping):
def set_remote_init_builder(self, builder): def set_remote_init_builder(self, builder):
"""Set an initializer builder to create a remote initializer for a new column to a frame. """Set an initializer builder to create a remote initializer for a new column to a frame.
NOTE(minjie): This is a temporary solution. Will be replaced by KVStore in the future.
The builder is a callable that returns an initializer. The returned initializer The builder is a callable that returns an initializer. The returned initializer
is also a callable that returns a tensor given a local tensor and tensor name. is also a callable that returns a tensor given a local tensor and tensor name.
...@@ -865,6 +922,34 @@ class FrameRef(MutableMapping): ...@@ -865,6 +922,34 @@ class FrameRef(MutableMapping):
"""Return whether this refers to all the rows.""" """Return whether this refers to all the rows."""
return self.is_contiguous() and self.num_rows == self._frame.num_rows return self.is_contiguous() and self.num_rows == self._frame.num_rows
def clone(self):
"""Return a new reference to a clone of the underlying frame.
Returns
-------
FrameRef
A cloned frame reference.
See Also
--------
dgl.Frame.clone
"""
return FrameRef(self._frame.clone(), self._index)
def deepclone(self):
"""Return a new reference to a deep clone of the underlying frame.
Returns
-------
FrameRef
A deep-cloned frame reference.
See Also
--------
dgl.Frame.deepclone
"""
return FrameRef(self._frame.deepclone(), self._index)
def _getrows(self, query): def _getrows(self, query):
"""Internal function to convert from the local row ids to the row ids of the frame. """Internal function to convert from the local row ids to the row ids of the frame.
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#pylint: disable= too-many-lines #pylint: disable= too-many-lines
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
import copy
import networkx as nx import networkx as nx
import numpy as np import numpy as np
...@@ -11,7 +12,7 @@ from . import utils ...@@ -11,7 +12,7 @@ from . import utils
from . import backend as F from . import backend as F
from . import init from . import init
from .runtime import ir, scheduler, Runtime, GraphAdapter from .runtime import ir, scheduler, Runtime, GraphAdapter
from .frame import Frame, FrameRef, frame_like, sync_frame_initializer from .frame import Frame, FrameRef, frame_like
from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView
from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning
from .udf import NodeBatch, EdgeBatch from .udf import NodeBatch, EdgeBatch
...@@ -200,6 +201,7 @@ class DGLHeteroGraph(object): ...@@ -200,6 +201,7 @@ class DGLHeteroGraph(object):
def _init(self, gidx, ntypes, etypes, node_frames, edge_frames): def _init(self, gidx, ntypes, etypes, node_frames, edge_frames):
"""Init internal states.""" """Init internal states."""
self._graph = gidx self._graph = gidx
self._canonical_etypes = None
# Handle node types # Handle node types
if isinstance(ntypes, tuple): if isinstance(ntypes, tuple):
...@@ -214,6 +216,8 @@ class DGLHeteroGraph(object): ...@@ -214,6 +216,8 @@ class DGLHeteroGraph(object):
self._srctypes_invmap = {t : i for i, t in enumerate(ntypes[0])} self._srctypes_invmap = {t : i for i, t in enumerate(ntypes[0])}
self._dsttypes_invmap = {t : i + len(ntypes[0]) for i, t in enumerate(ntypes[1])} self._dsttypes_invmap = {t : i + len(ntypes[0]) for i, t in enumerate(ntypes[1])}
self._is_unibipartite = True self._is_unibipartite = True
if len(ntypes[0]) == 1 and len(ntypes[1]) == 1 and len(etypes) == 1:
self._canonical_etypes = [(ntypes[0][0], etypes[0], ntypes[1][0])]
else: else:
self._ntypes = ntypes self._ntypes = ntypes
src_dst_map = find_src_dst_ntypes(self._ntypes, self._graph.metagraph) src_dst_map = find_src_dst_ntypes(self._ntypes, self._graph.metagraph)
...@@ -226,6 +230,7 @@ class DGLHeteroGraph(object): ...@@ -226,6 +230,7 @@ class DGLHeteroGraph(object):
# Handle edge types # Handle edge types
self._etypes = etypes self._etypes = etypes
if self._canonical_etypes is None:
self._canonical_etypes = make_canonical_etypes( self._canonical_etypes = make_canonical_etypes(
self._etypes, self._ntypes, self._graph.metagraph) self._etypes, self._ntypes, self._graph.metagraph)
...@@ -912,7 +917,7 @@ class DGLHeteroGraph(object): ...@@ -912,7 +917,7 @@ class DGLHeteroGraph(object):
new_ntypes = [srctype] new_ntypes = [srctype]
new_nframes = [self._node_frames[stid]] new_nframes = [self._node_frames[stid]]
else: else:
new_ntypes = [srctype, dsttype] new_ntypes = ([srctype], [dsttype])
new_nframes = [self._node_frames[stid], self._node_frames[dtid]] new_nframes = [self._node_frames[stid], self._node_frames[dtid]]
new_etypes = [etype] new_etypes = [etype]
new_eframes = [self._edge_frames[etid]] new_eframes = [self._edge_frames[etid]]
...@@ -4033,18 +4038,12 @@ class DGLHeteroGraph(object): ...@@ -4033,18 +4038,12 @@ class DGLHeteroGraph(object):
-------- --------
local_var local_var
""" """
local_node_frames = [FrameRef(Frame(fr._frame)) for fr in self._node_frames] local_node_frames = [fr.clone() for fr in self._node_frames]
local_edge_frames = [FrameRef(Frame(fr._frame)) for fr in self._edge_frames] local_edge_frames = [fr.clone() for fr in self._edge_frames]
# Use same per-column initializers and default initializer. ret = copy.copy(self)
# If registered, a column (based on key) initializer will be used first, ret._node_frames = local_node_frames
# otherwise the default initializer will be used. ret._edge_frames = local_edge_frames
for fr1, fr2 in zip(local_node_frames, self._node_frames): return ret
sync_frame_initializer(fr1._frame, fr2._frame)
for fr1, fr2 in zip(local_edge_frames, self._edge_frames):
sync_frame_initializer(fr1._frame, fr2._frame)
return DGLHeteroGraph(self._graph, self.ntypes, self.etypes,
local_node_frames,
local_edge_frames)
@contextmanager @contextmanager
def local_scope(self): def local_scope(self):
...@@ -4093,15 +4092,8 @@ class DGLHeteroGraph(object): ...@@ -4093,15 +4092,8 @@ class DGLHeteroGraph(object):
""" """
old_nframes = self._node_frames old_nframes = self._node_frames
old_eframes = self._edge_frames old_eframes = self._edge_frames
self._node_frames = [FrameRef(Frame(fr._frame)) for fr in self._node_frames] self._node_frames = [fr.clone() for fr in self._node_frames]
self._edge_frames = [FrameRef(Frame(fr._frame)) for fr in self._edge_frames] self._edge_frames = [fr.clone() for fr in self._edge_frames]
# Use same per-column initializers and default initializer.
# If registered, a column (based on key) initializer will be used first,
# otherwise the default initializer will be used.
for fr1, fr2 in zip(self._node_frames, old_nframes):
sync_frame_initializer(fr1._frame, fr2._frame)
for fr1, fr2 in zip(self._edge_frames, old_eframes):
sync_frame_initializer(fr1._frame, fr2._frame)
yield yield
self._node_frames = old_nframes self._node_frames = old_nframes
self._edge_frames = old_eframes self._edge_frames = old_eframes
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .ndarray import null from . import ndarray as nd
# pylint: disable=invalid-name # pylint: disable=invalid-name
def infer_binary_feature_shape(op, lhs, rhs): def infer_binary_feature_shape(op, lhs, rhs):
...@@ -136,11 +136,11 @@ def binary_op_reduce(reducer, op, G, A_target, B_target, A, B, out, ...@@ -136,11 +136,11 @@ def binary_op_reduce(reducer, op, G, A_target, B_target, A, B, out,
The rows to write to output tensor. The rows to write to output tensor.
""" """
if A_rows is None: if A_rows is None:
A_rows = null() A_rows = nd.NULL
if B_rows is None: if B_rows is None:
B_rows = null() B_rows = nd.NULL
if out_rows is None: if out_rows is None:
out_rows = null() out_rows = nd.NULL
_CAPI_DGLKernelBinaryOpReduce( _CAPI_DGLKernelBinaryOpReduce(
reducer, op, G, reducer, op, G,
int(A_target), int(B_target), int(A_target), int(B_target),
...@@ -200,11 +200,11 @@ def backward_lhs_binary_op_reduce( ...@@ -200,11 +200,11 @@ def backward_lhs_binary_op_reduce(
The rows written to output tensor. The rows written to output tensor.
""" """
if A_rows is None: if A_rows is None:
A_rows = null() A_rows = nd.NULL
if B_rows is None: if B_rows is None:
B_rows = null() B_rows = nd.NULL
if out_rows is None: if out_rows is None:
out_rows = null() out_rows = nd.NULL
_CAPI_DGLKernelBackwardLhsBinaryOpReduce( _CAPI_DGLKernelBackwardLhsBinaryOpReduce(
reducer, op, G, reducer, op, G,
int(A_target), int(B_target), int(A_target), int(B_target),
...@@ -265,11 +265,11 @@ def backward_rhs_binary_op_reduce( ...@@ -265,11 +265,11 @@ def backward_rhs_binary_op_reduce(
The rows written to output tensor. The rows written to output tensor.
""" """
if A_rows is None: if A_rows is None:
A_rows = null() A_rows = nd.NULL
if B_rows is None: if B_rows is None:
B_rows = null() B_rows = nd.NULL
if out_rows is None: if out_rows is None:
out_rows = null() out_rows = nd.NULL
_CAPI_DGLKernelBackwardRhsBinaryOpReduce( _CAPI_DGLKernelBackwardRhsBinaryOpReduce(
reducer, op, G, reducer, op, G,
int(A_target), int(B_target), int(A_target), int(B_target),
...@@ -364,9 +364,9 @@ def copy_reduce(reducer, G, target, ...@@ -364,9 +364,9 @@ def copy_reduce(reducer, G, target,
The rows to write to output tensor. The rows to write to output tensor.
""" """
if X_rows is None: if X_rows is None:
X_rows = null() X_rows = nd.NULL
if out_rows is None: if out_rows is None:
out_rows = null() out_rows = nd.NULL
_CAPI_DGLKernelCopyReduce( _CAPI_DGLKernelCopyReduce(
reducer, G, int(target), reducer, G, int(target),
X, out, X_rows, out_rows) X, out, X_rows, out_rows)
...@@ -406,9 +406,9 @@ def backward_copy_reduce(reducer, G, target, ...@@ -406,9 +406,9 @@ def backward_copy_reduce(reducer, G, target,
The rows written to output tensor. The rows written to output tensor.
""" """
if X_rows is None: if X_rows is None:
X_rows = null() X_rows = nd.NULL
if out_rows is None: if out_rows is None:
out_rows = null() out_rows = nd.NULL
_CAPI_DGLKernelBackwardCopyReduce( _CAPI_DGLKernelBackwardCopyReduce(
reducer, G, int(target), reducer, G, int(target),
X, out, grad_out, grad_X, X, out, grad_out, grad_X,
......
...@@ -90,17 +90,6 @@ def zerocopy_from_numpy(np_data): ...@@ -90,17 +90,6 @@ def zerocopy_from_numpy(np_data):
handle = ctypes.pointer(arr) handle = ctypes.pointer(arr)
return NDArray(handle, is_view=True) return NDArray(handle, is_view=True)
def null():
"""Return a ndarray representing null value. It can be safely converted
to other backend tensors.
Returns
-------
NDArray
A null array
"""
return array(_np.array([], dtype=_np.int64))
class SparseFormat: class SparseFormat:
"""Format code""" """Format code"""
ANY = 0 ANY = 0
...@@ -185,3 +174,7 @@ class SparseMatrix(ObjectBase): ...@@ -185,3 +174,7 @@ class SparseMatrix(ObjectBase):
_set_class_ndarray(NDArray) _set_class_ndarray(NDArray)
_init_api("dgl.ndarray") _init_api("dgl.ndarray")
# An array representing null (no value) that can be safely converted to
# other backend tensors.
NULL = array(_np.array([], dtype=_np.int64))
...@@ -59,8 +59,7 @@ class AGNNConv(nn.Block): ...@@ -59,8 +59,7 @@ class AGNNConv(nn.Block):
The output feature of shape :math:`(N, *)` where :math:`*` The output feature of shape :math:`(N, *)` where :math:`*`
should be the same as input shape. should be the same as input shape.
""" """
graph = graph.local_var() with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat) feat_src, feat_dst = expand_as_pair(feat)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.srcdata['norm_h'] = normalize(feat_src, p=2, axis=-1) graph.srcdata['norm_h'] = normalize(feat_src, p=2, axis=-1)
......
...@@ -54,7 +54,7 @@ class APPNPConv(nn.Block): ...@@ -54,7 +54,7 @@ class APPNPConv(nn.Block):
The output feature of shape :math:`(N, *)` where :math:`*` The output feature of shape :math:`(N, *)` where :math:`*`
should be the same as input shape. should be the same as input shape.
""" """
graph = graph.local_var() with graph.local_scope():
norm = mx.nd.power(mx.nd.clip( norm = mx.nd.power(mx.nd.clip(
graph.in_degrees().astype(feat.dtype), a_min=1, a_max=float("inf")), -0.5) graph.in_degrees().astype(feat.dtype), a_min=1, a_max=float("inf")), -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1) shp = norm.shape + (1,) * (feat.ndim - 1)
......
...@@ -117,7 +117,7 @@ class GATConv(nn.Block): ...@@ -117,7 +117,7 @@ class GATConv(nn.Block):
The output feature of shape :math:`(N, H, D_{out})` where :math:`H` The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature. is the number of heads, and :math:`D_{out}` is size of output feature.
""" """
graph = graph.local_var() with graph.local_scope():
if isinstance(feat, tuple): if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0]) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1]) h_dst = self.feat_drop(feat[1])
......
...@@ -75,10 +75,11 @@ class GatedGraphConv(nn.Block): ...@@ -75,10 +75,11 @@ class GatedGraphConv(nn.Block):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size. is the output feature size.
""" """
with graph.local_scope():
assert graph.is_homograph(), \ assert graph.is_homograph(), \
"not a homograph; convert it with to_homo and pass in the edge type as argument" "not a homograph; convert it with to_homo and pass in the edge type as argument"
graph = graph.local_var() zero_pad = nd.zeros((feat.shape[0], self._out_feats - feat.shape[1]),
zero_pad = nd.zeros((feat.shape[0], self._out_feats - feat.shape[1]), ctx=feat.context) ctx=feat.context)
feat = nd.concat(feat, zero_pad, dim=-1) feat = nd.concat(feat, zero_pad, dim=-1)
for _ in range(self._n_steps): for _ in range(self._n_steps):
......
...@@ -74,7 +74,7 @@ class GINConv(nn.Block): ...@@ -74,7 +74,7 @@ class GINConv(nn.Block):
If ``apply_func`` is None, :math:`D_{out}` should be the same If ``apply_func`` is None, :math:`D_{out}` should be the same
as input dimensionality. as input dimensionality.
""" """
graph = graph.local_var() with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat) feat_src, feat_dst = expand_as_pair(feat)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
......
...@@ -119,8 +119,7 @@ class GraphConv(gluon.Block): ...@@ -119,8 +119,7 @@ class GraphConv(gluon.Block):
mxnet.NDArray mxnet.NDArray
The output feature The output feature
""" """
graph = graph.local_var() with graph.local_scope():
if self._norm == 'both': if self._norm == 'both':
degs = graph.out_degrees().as_in_context(feat.context).astype('float32') degs = graph.out_degrees().as_in_context(feat.context).astype('float32')
degs = mx.nd.clip(degs, a_min=1, a_max=float("inf")) degs = mx.nd.clip(degs, a_min=1, a_max=float("inf"))
......
...@@ -97,8 +97,7 @@ class SAGEConv(nn.Block): ...@@ -97,8 +97,7 @@ class SAGEConv(nn.Block):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature. is size of output feature.
""" """
graph = graph.local_var() with graph.local_scope():
if isinstance(feat, tuple): if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0]) feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1]) feat_dst = self.feat_drop(feat[1])
......
...@@ -74,7 +74,7 @@ class SGConv(nn.Block): ...@@ -74,7 +74,7 @@ class SGConv(nn.Block):
If ``cache`` is se to True, ``feat`` and ``graph`` should not change during If ``cache`` is se to True, ``feat`` and ``graph`` should not change during
training, or you will get wrong results. training, or you will get wrong results.
""" """
graph = graph.local_var() with graph.local_scope():
if self._cached_h is not None: if self._cached_h is not None:
feat = self._cached_h feat = self._cached_h
else: else:
......
...@@ -76,8 +76,8 @@ class TAGConv(gluon.Block): ...@@ -76,8 +76,8 @@ class TAGConv(gluon.Block):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature. is size of output feature.
""" """
with graph.local_scope():
assert graph.is_homograph(), 'Graph is not homogeneous' assert graph.is_homograph(), 'Graph is not homogeneous'
graph = graph.local_var()
degs = graph.in_degrees().astype('float32') degs = graph.in_degrees().astype('float32')
norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5) norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5)
......
"""Gluon layer for graph related softmax.""" """Gluon layer for graph related softmax."""
# pylint: disable= no-member, arguments-differ # pylint: disable= no-member, arguments-differ, access-member-before-definition, unpacking-non-sequence
import mxnet as mx import mxnet as mx
from ... import function as fn from ... import function as fn
...@@ -45,7 +45,8 @@ class EdgeSoftmax(mx.autograd.Function): ...@@ -45,7 +45,8 @@ class EdgeSoftmax(mx.autograd.Function):
out = score / score_sum # edge_div_dst, ret dgl.EData out = score / score_sum # edge_div_dst, ret dgl.EData
return out.data return out.data
""" """
g = self.g.local_var() g = self.g
with g.local_scope():
g.edata['s'] = score g.edata['s'] = score
g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax')) g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax'))
g.apply_edges(fn.e_sub_v('s', 'smax', 'out')) g.apply_edges(fn.e_sub_v('s', 'smax', 'out'))
...@@ -70,8 +71,9 @@ class EdgeSoftmax(mx.autograd.Function): ...@@ -70,8 +71,9 @@ class EdgeSoftmax(mx.autograd.Function):
sds_sum = sds.dst_sum() # type dgl.NData sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions grad_score = sds - sds * sds_sum # multiple expressions
""" """
g = self.g.local_var() g = self.g
out, = self.saved_tensors # pylint: disable=access-member-before-definition, unpacking-non-sequence with g.local_scope():
out, = self.saved_tensors
# clear saved tensors explicitly # clear saved tensors explicitly
self.saved_tensors = None self.saved_tensors = None
g.edata['out'] = out g.edata['out'] = out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment