Unverified Commit f25bc176 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Hetero] Improve speed of several Hetero APIs (#1486)

* add clone function to frame

* add utest

* replace all local_var with local_scope

* fix utest

* avoid creating canonical types in __getitem__

* lint

* try another utest  appraoch for mx

* utest
parent 3c4506e9
...@@ -982,6 +982,21 @@ def logical_not(input): ...@@ -982,6 +982,21 @@ def logical_not(input):
""" """
pass pass
def clone(input):
"""Return a clone of the input tensor.
Parameters
----------
input : Tensor
Input tensor.
Returns
-------
Tensor
A clone tensor.
"""
pass
############################################################################### ###############################################################################
# Tensor functions used *only* on index tensor # Tensor functions used *only* on index tensor
# ---------------- # ----------------
......
...@@ -313,6 +313,9 @@ def equal(x, y): ...@@ -313,6 +313,9 @@ def equal(x, y):
def logical_not(input): def logical_not(input):
return nd.logical_not(input) return nd.logical_not(input)
def clone(input):
return input.copy()
def unique(input): def unique(input):
# TODO: fallback to numpy is unfortunate # TODO: fallback to numpy is unfortunate
tmp = input.asnumpy() tmp = input.asnumpy()
......
...@@ -247,6 +247,9 @@ def equal(x, y): ...@@ -247,6 +247,9 @@ def equal(x, y):
def logical_not(input): def logical_not(input):
return ~input return ~input
def clone(input):
return input.clone()
def unique(input): def unique(input):
return th.unique(input) return th.unique(input)
......
...@@ -348,6 +348,9 @@ def equal(x, y): ...@@ -348,6 +348,9 @@ def equal(x, y):
def logical_not(input): def logical_not(input):
return ~input return ~input
def clone(input):
# TF tensor is always immutable so returning the input is safe.
return input
def unique(input): def unique(input):
return tf.unique(input).y return tf.unique(input).y
......
...@@ -20,8 +20,6 @@ def is_all(arg): ...@@ -20,8 +20,6 @@ def is_all(arg):
"""Return true if the argument is a special symbol for all nodes or edges.""" """Return true if the argument is a special symbol for all nodes or edges."""
return isinstance(arg, str) and arg == ALL return isinstance(arg, str) and arg == ALL
def dgl_warning(msg, warn_type=UserWarning): dgl_warning = warnings.warn # pylint: disable=invalid-name
"""Print out warning messages."""
warnings.warn(msg, warn_type)
_init_internal_api() _init_internal_api()
...@@ -123,7 +123,8 @@ def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True ...@@ -123,7 +123,8 @@ def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True
edata_schemes={}) edata_schemes={})
""" """
if card is not None: if card is not None:
dgl_warning("card will be deprecated, please use num_nodes='{}' instead.") dgl_warning("Argument 'card' will be deprecated. "
"Please use num_nodes={} instead.".format(card))
num_nodes = card num_nodes = card
if num_nodes is not None: if num_nodes is not None:
...@@ -271,7 +272,8 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=Non ...@@ -271,7 +272,8 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=Non
if utype == vtype: if utype == vtype:
raise DGLError('utype should not be equal to vtype. Use ``dgl.graph`` instead.') raise DGLError('utype should not be equal to vtype. Use ``dgl.graph`` instead.')
if card is not None: if card is not None:
dgl_warning("card will be deprecated, please use num_nodes='{}' instead.") dgl_warning("Argument 'card' will be deprecated. "
"Please use num_nodes={} instead.".format(card))
num_nodes = card num_nodes = card
if num_nodes is not None: if num_nodes is not None:
urange, vrange = num_nodes urange, vrange = num_nodes
......
...@@ -62,6 +62,13 @@ class Column(object): ...@@ -62,6 +62,13 @@ class Column(object):
The initial data of the column. The initial data of the column.
scheme : Scheme, optional scheme : Scheme, optional
The scheme of the column. Will be inferred if not provided. The scheme of the column. Will be inferred if not provided.
Attributes
----------
data : Tensor
The data of the column.
scheme : Scheme
The scheme of the column.
""" """
def __init__(self, data, scheme=None): def __init__(self, data, scheme=None):
self.data = data self.data = data
...@@ -164,6 +171,10 @@ class Column(object): ...@@ -164,6 +171,10 @@ class Column(object):
feats = F.copy_to(feats, F.context(self.data)) feats = F.copy_to(feats, F.context(self.data))
self.data = F.cat([self.data, feats], dim=0) self.data = F.cat([self.data, feats], dim=0)
def clone(self):
"""Return a deepcopy of this column."""
return Column(F.clone(self.data), self.scheme)
@staticmethod @staticmethod
def create(data): def create(data):
"""Create a new column using the given data.""" """Create a new column using the given data."""
...@@ -255,6 +266,8 @@ class Frame(MutableMapping): ...@@ -255,6 +266,8 @@ class Frame(MutableMapping):
def set_remote_init_builder(self, builder): def set_remote_init_builder(self, builder):
"""Set an initializer builder to create a remote initializer for a new column to a frame. """Set an initializer builder to create a remote initializer for a new column to a frame.
NOTE(minjie): This is a temporary solution. Will be replaced by KVStore in the future.
The builder is a callable that returns an initializer. The returned initializer The builder is a callable that returns an initializer. The returned initializer
is also a callable that returns a tensor given a local tensor and tensor name. is also a callable that returns a tensor given a local tensor and tensor name.
...@@ -268,6 +281,8 @@ class Frame(MutableMapping): ...@@ -268,6 +281,8 @@ class Frame(MutableMapping):
def get_remote_initializer(self, name): def get_remote_initializer(self, name):
"""Get a remote initializer. """Get a remote initializer.
NOTE(minjie): This is a temporary solution. Will be replaced by KVStore in the future.
Parameters Parameters
---------- ----------
name : string name : string
...@@ -478,6 +493,46 @@ class Frame(MutableMapping): ...@@ -478,6 +493,46 @@ class Frame(MutableMapping):
"""Return the keys.""" """Return the keys."""
return self._columns.keys() return self._columns.keys()
def clone(self):
"""Return a clone of this frame.
The clone frame does not share the underlying storage with this frame,
i.e., adding or removing columns will not be visible to each other. However,
they still share the tensor contents so any mutable operation on the column
tensor are visible to each other. Hence, the function does not allocate extra
tensor memory. Use :func:`~dgl.Frame.deepclone` for cloning
a frame that does not share any data.
Returns
-------
Frame
A cloned frame.
"""
newframe = Frame(self._columns, self._num_rows)
newframe._initializers = self._initializers
newframe._remote_init_builder = self._remote_init_builder
newframe._default_initializer = self._default_initializer
return newframe
def deepclone(self):
"""Return a deep clone of this frame.
The clone frame has an copy of this frame and any modification to the clone frame
is not visible to this frame. The function allocate new tensors and copy the contents
from this frame. Use :func:`~dgl.Frame.clone` for cloning a frame that does not
allocate extra tensor memory.
Returns
-------
Frame
A deep-cloned frame.
"""
newframe = Frame({k : col.clone() for k, col in self._columns.items()}, self._num_rows)
newframe._initializers = self._initializers
newframe._remote_init_builder = self._remote_init_builder
newframe._default_initializer = self._default_initializer
return newframe
class FrameRef(MutableMapping): class FrameRef(MutableMapping):
"""Reference object to a frame on a subset of rows. """Reference object to a frame on a subset of rows.
...@@ -538,6 +593,8 @@ class FrameRef(MutableMapping): ...@@ -538,6 +593,8 @@ class FrameRef(MutableMapping):
def set_remote_init_builder(self, builder): def set_remote_init_builder(self, builder):
"""Set an initializer builder to create a remote initializer for a new column to a frame. """Set an initializer builder to create a remote initializer for a new column to a frame.
NOTE(minjie): This is a temporary solution. Will be replaced by KVStore in the future.
The builder is a callable that returns an initializer. The returned initializer The builder is a callable that returns an initializer. The returned initializer
is also a callable that returns a tensor given a local tensor and tensor name. is also a callable that returns a tensor given a local tensor and tensor name.
...@@ -865,6 +922,34 @@ class FrameRef(MutableMapping): ...@@ -865,6 +922,34 @@ class FrameRef(MutableMapping):
"""Return whether this refers to all the rows.""" """Return whether this refers to all the rows."""
return self.is_contiguous() and self.num_rows == self._frame.num_rows return self.is_contiguous() and self.num_rows == self._frame.num_rows
def clone(self):
"""Return a new reference to a clone of the underlying frame.
Returns
-------
FrameRef
A cloned frame reference.
See Also
--------
dgl.Frame.clone
"""
return FrameRef(self._frame.clone(), self._index)
def deepclone(self):
"""Return a new reference to a deep clone of the underlying frame.
Returns
-------
FrameRef
A deep-cloned frame reference.
See Also
--------
dgl.Frame.deepclone
"""
return FrameRef(self._frame.deepclone(), self._index)
def _getrows(self, query): def _getrows(self, query):
"""Internal function to convert from the local row ids to the row ids of the frame. """Internal function to convert from the local row ids to the row ids of the frame.
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#pylint: disable= too-many-lines #pylint: disable= too-many-lines
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
import copy
import networkx as nx import networkx as nx
import numpy as np import numpy as np
...@@ -11,7 +12,7 @@ from . import utils ...@@ -11,7 +12,7 @@ from . import utils
from . import backend as F from . import backend as F
from . import init from . import init
from .runtime import ir, scheduler, Runtime, GraphAdapter from .runtime import ir, scheduler, Runtime, GraphAdapter
from .frame import Frame, FrameRef, frame_like, sync_frame_initializer from .frame import Frame, FrameRef, frame_like
from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView
from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning
from .udf import NodeBatch, EdgeBatch from .udf import NodeBatch, EdgeBatch
...@@ -200,6 +201,7 @@ class DGLHeteroGraph(object): ...@@ -200,6 +201,7 @@ class DGLHeteroGraph(object):
def _init(self, gidx, ntypes, etypes, node_frames, edge_frames): def _init(self, gidx, ntypes, etypes, node_frames, edge_frames):
"""Init internal states.""" """Init internal states."""
self._graph = gidx self._graph = gidx
self._canonical_etypes = None
# Handle node types # Handle node types
if isinstance(ntypes, tuple): if isinstance(ntypes, tuple):
...@@ -214,6 +216,8 @@ class DGLHeteroGraph(object): ...@@ -214,6 +216,8 @@ class DGLHeteroGraph(object):
self._srctypes_invmap = {t : i for i, t in enumerate(ntypes[0])} self._srctypes_invmap = {t : i for i, t in enumerate(ntypes[0])}
self._dsttypes_invmap = {t : i + len(ntypes[0]) for i, t in enumerate(ntypes[1])} self._dsttypes_invmap = {t : i + len(ntypes[0]) for i, t in enumerate(ntypes[1])}
self._is_unibipartite = True self._is_unibipartite = True
if len(ntypes[0]) == 1 and len(ntypes[1]) == 1 and len(etypes) == 1:
self._canonical_etypes = [(ntypes[0][0], etypes[0], ntypes[1][0])]
else: else:
self._ntypes = ntypes self._ntypes = ntypes
src_dst_map = find_src_dst_ntypes(self._ntypes, self._graph.metagraph) src_dst_map = find_src_dst_ntypes(self._ntypes, self._graph.metagraph)
...@@ -226,8 +230,9 @@ class DGLHeteroGraph(object): ...@@ -226,8 +230,9 @@ class DGLHeteroGraph(object):
# Handle edge types # Handle edge types
self._etypes = etypes self._etypes = etypes
self._canonical_etypes = make_canonical_etypes( if self._canonical_etypes is None:
self._etypes, self._ntypes, self._graph.metagraph) self._canonical_etypes = make_canonical_etypes(
self._etypes, self._ntypes, self._graph.metagraph)
# An internal map from etype to canonical etype tuple. # An internal map from etype to canonical etype tuple.
# If two etypes have the same name, an empty tuple is stored instead to indicate # If two etypes have the same name, an empty tuple is stored instead to indicate
...@@ -912,7 +917,7 @@ class DGLHeteroGraph(object): ...@@ -912,7 +917,7 @@ class DGLHeteroGraph(object):
new_ntypes = [srctype] new_ntypes = [srctype]
new_nframes = [self._node_frames[stid]] new_nframes = [self._node_frames[stid]]
else: else:
new_ntypes = [srctype, dsttype] new_ntypes = ([srctype], [dsttype])
new_nframes = [self._node_frames[stid], self._node_frames[dtid]] new_nframes = [self._node_frames[stid], self._node_frames[dtid]]
new_etypes = [etype] new_etypes = [etype]
new_eframes = [self._edge_frames[etid]] new_eframes = [self._edge_frames[etid]]
...@@ -4033,18 +4038,12 @@ class DGLHeteroGraph(object): ...@@ -4033,18 +4038,12 @@ class DGLHeteroGraph(object):
-------- --------
local_var local_var
""" """
local_node_frames = [FrameRef(Frame(fr._frame)) for fr in self._node_frames] local_node_frames = [fr.clone() for fr in self._node_frames]
local_edge_frames = [FrameRef(Frame(fr._frame)) for fr in self._edge_frames] local_edge_frames = [fr.clone() for fr in self._edge_frames]
# Use same per-column initializers and default initializer. ret = copy.copy(self)
# If registered, a column (based on key) initializer will be used first, ret._node_frames = local_node_frames
# otherwise the default initializer will be used. ret._edge_frames = local_edge_frames
for fr1, fr2 in zip(local_node_frames, self._node_frames): return ret
sync_frame_initializer(fr1._frame, fr2._frame)
for fr1, fr2 in zip(local_edge_frames, self._edge_frames):
sync_frame_initializer(fr1._frame, fr2._frame)
return DGLHeteroGraph(self._graph, self.ntypes, self.etypes,
local_node_frames,
local_edge_frames)
@contextmanager @contextmanager
def local_scope(self): def local_scope(self):
...@@ -4093,15 +4092,8 @@ class DGLHeteroGraph(object): ...@@ -4093,15 +4092,8 @@ class DGLHeteroGraph(object):
""" """
old_nframes = self._node_frames old_nframes = self._node_frames
old_eframes = self._edge_frames old_eframes = self._edge_frames
self._node_frames = [FrameRef(Frame(fr._frame)) for fr in self._node_frames] self._node_frames = [fr.clone() for fr in self._node_frames]
self._edge_frames = [FrameRef(Frame(fr._frame)) for fr in self._edge_frames] self._edge_frames = [fr.clone() for fr in self._edge_frames]
# Use same per-column initializers and default initializer.
# If registered, a column (based on key) initializer will be used first,
# otherwise the default initializer will be used.
for fr1, fr2 in zip(self._node_frames, old_nframes):
sync_frame_initializer(fr1._frame, fr2._frame)
for fr1, fr2 in zip(self._edge_frames, old_eframes):
sync_frame_initializer(fr1._frame, fr2._frame)
yield yield
self._node_frames = old_nframes self._node_frames = old_nframes
self._edge_frames = old_eframes self._edge_frames = old_eframes
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .ndarray import null from . import ndarray as nd
# pylint: disable=invalid-name # pylint: disable=invalid-name
def infer_binary_feature_shape(op, lhs, rhs): def infer_binary_feature_shape(op, lhs, rhs):
...@@ -136,11 +136,11 @@ def binary_op_reduce(reducer, op, G, A_target, B_target, A, B, out, ...@@ -136,11 +136,11 @@ def binary_op_reduce(reducer, op, G, A_target, B_target, A, B, out,
The rows to write to output tensor. The rows to write to output tensor.
""" """
if A_rows is None: if A_rows is None:
A_rows = null() A_rows = nd.NULL
if B_rows is None: if B_rows is None:
B_rows = null() B_rows = nd.NULL
if out_rows is None: if out_rows is None:
out_rows = null() out_rows = nd.NULL
_CAPI_DGLKernelBinaryOpReduce( _CAPI_DGLKernelBinaryOpReduce(
reducer, op, G, reducer, op, G,
int(A_target), int(B_target), int(A_target), int(B_target),
...@@ -200,11 +200,11 @@ def backward_lhs_binary_op_reduce( ...@@ -200,11 +200,11 @@ def backward_lhs_binary_op_reduce(
The rows written to output tensor. The rows written to output tensor.
""" """
if A_rows is None: if A_rows is None:
A_rows = null() A_rows = nd.NULL
if B_rows is None: if B_rows is None:
B_rows = null() B_rows = nd.NULL
if out_rows is None: if out_rows is None:
out_rows = null() out_rows = nd.NULL
_CAPI_DGLKernelBackwardLhsBinaryOpReduce( _CAPI_DGLKernelBackwardLhsBinaryOpReduce(
reducer, op, G, reducer, op, G,
int(A_target), int(B_target), int(A_target), int(B_target),
...@@ -265,11 +265,11 @@ def backward_rhs_binary_op_reduce( ...@@ -265,11 +265,11 @@ def backward_rhs_binary_op_reduce(
The rows written to output tensor. The rows written to output tensor.
""" """
if A_rows is None: if A_rows is None:
A_rows = null() A_rows = nd.NULL
if B_rows is None: if B_rows is None:
B_rows = null() B_rows = nd.NULL
if out_rows is None: if out_rows is None:
out_rows = null() out_rows = nd.NULL
_CAPI_DGLKernelBackwardRhsBinaryOpReduce( _CAPI_DGLKernelBackwardRhsBinaryOpReduce(
reducer, op, G, reducer, op, G,
int(A_target), int(B_target), int(A_target), int(B_target),
...@@ -364,9 +364,9 @@ def copy_reduce(reducer, G, target, ...@@ -364,9 +364,9 @@ def copy_reduce(reducer, G, target,
The rows to write to output tensor. The rows to write to output tensor.
""" """
if X_rows is None: if X_rows is None:
X_rows = null() X_rows = nd.NULL
if out_rows is None: if out_rows is None:
out_rows = null() out_rows = nd.NULL
_CAPI_DGLKernelCopyReduce( _CAPI_DGLKernelCopyReduce(
reducer, G, int(target), reducer, G, int(target),
X, out, X_rows, out_rows) X, out, X_rows, out_rows)
...@@ -406,9 +406,9 @@ def backward_copy_reduce(reducer, G, target, ...@@ -406,9 +406,9 @@ def backward_copy_reduce(reducer, G, target,
The rows written to output tensor. The rows written to output tensor.
""" """
if X_rows is None: if X_rows is None:
X_rows = null() X_rows = nd.NULL
if out_rows is None: if out_rows is None:
out_rows = null() out_rows = nd.NULL
_CAPI_DGLKernelBackwardCopyReduce( _CAPI_DGLKernelBackwardCopyReduce(
reducer, G, int(target), reducer, G, int(target),
X, out, grad_out, grad_X, X, out, grad_out, grad_X,
......
...@@ -90,17 +90,6 @@ def zerocopy_from_numpy(np_data): ...@@ -90,17 +90,6 @@ def zerocopy_from_numpy(np_data):
handle = ctypes.pointer(arr) handle = ctypes.pointer(arr)
return NDArray(handle, is_view=True) return NDArray(handle, is_view=True)
def null():
"""Return a ndarray representing null value. It can be safely converted
to other backend tensors.
Returns
-------
NDArray
A null array
"""
return array(_np.array([], dtype=_np.int64))
class SparseFormat: class SparseFormat:
"""Format code""" """Format code"""
ANY = 0 ANY = 0
...@@ -185,3 +174,7 @@ class SparseMatrix(ObjectBase): ...@@ -185,3 +174,7 @@ class SparseMatrix(ObjectBase):
_set_class_ndarray(NDArray) _set_class_ndarray(NDArray)
_init_api("dgl.ndarray") _init_api("dgl.ndarray")
# An array representing null (no value) that can be safely converted to
# other backend tensors.
NULL = array(_np.array([], dtype=_np.int64))
...@@ -59,17 +59,16 @@ class AGNNConv(nn.Block): ...@@ -59,17 +59,16 @@ class AGNNConv(nn.Block):
The output feature of shape :math:`(N, *)` where :math:`*` The output feature of shape :math:`(N, *)` where :math:`*`
should be the same as input shape. should be the same as input shape.
""" """
graph = graph.local_var() with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat)
feat_src, feat_dst = expand_as_pair(feat) graph.srcdata['h'] = feat_src
graph.srcdata['h'] = feat_src graph.srcdata['norm_h'] = normalize(feat_src, p=2, axis=-1)
graph.srcdata['norm_h'] = normalize(feat_src, p=2, axis=-1) if isinstance(feat, tuple):
if isinstance(feat, tuple): graph.dstdata['norm_h'] = normalize(feat_dst, p=2, axis=-1)
graph.dstdata['norm_h'] = normalize(feat_dst, p=2, axis=-1) # compute cosine distance
# compute cosine distance graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos'))
graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos')) cos = graph.edata.pop('cos')
cos = graph.edata.pop('cos') e = self.beta.data(feat_src.context) * cos
e = self.beta.data(feat_src.context) * cos graph.edata['p'] = edge_softmax(graph, e)
graph.edata['p'] = edge_softmax(graph, e) graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h'))
graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h')) return graph.dstdata.pop('h')
return graph.dstdata.pop('h')
...@@ -54,22 +54,22 @@ class APPNPConv(nn.Block): ...@@ -54,22 +54,22 @@ class APPNPConv(nn.Block):
The output feature of shape :math:`(N, *)` where :math:`*` The output feature of shape :math:`(N, *)` where :math:`*`
should be the same as input shape. should be the same as input shape.
""" """
graph = graph.local_var() with graph.local_scope():
norm = mx.nd.power(mx.nd.clip( norm = mx.nd.power(mx.nd.clip(
graph.in_degrees().astype(feat.dtype), a_min=1, a_max=float("inf")), -0.5) graph.in_degrees().astype(feat.dtype), a_min=1, a_max=float("inf")), -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1) shp = norm.shape + (1,) * (feat.ndim - 1)
norm = norm.reshape(shp).as_in_context(feat.context) norm = norm.reshape(shp).as_in_context(feat.context)
feat_0 = feat feat_0 = feat
for _ in range(self._k): for _ in range(self._k):
# normalization by src node # normalization by src node
feat = feat * norm feat = feat * norm
graph.ndata['h'] = feat graph.ndata['h'] = feat
graph.edata['w'] = self.edge_drop( graph.edata['w'] = self.edge_drop(
nd.ones((graph.number_of_edges(), 1), ctx=feat.context)) nd.ones((graph.number_of_edges(), 1), ctx=feat.context))
graph.update_all(fn.u_mul_e('h', 'w', 'm'), graph.update_all(fn.u_mul_e('h', 'w', 'm'),
fn.sum('m', 'h')) fn.sum('m', 'h'))
feat = graph.ndata.pop('h') feat = graph.ndata.pop('h')
# normalization by dst node # normalization by dst node
feat = feat * norm feat = feat * norm
feat = (1 - self._alpha) * feat + self._alpha * feat_0 feat = (1 - self._alpha) * feat + self._alpha * feat_0
return feat return feat
...@@ -117,45 +117,45 @@ class GATConv(nn.Block): ...@@ -117,45 +117,45 @@ class GATConv(nn.Block):
The output feature of shape :math:`(N, H, D_{out})` where :math:`H` The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature. is the number of heads, and :math:`D_{out}` is size of output feature.
""" """
graph = graph.local_var() with graph.local_scope():
if isinstance(feat, tuple): if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0]) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1]) h_dst = self.feat_drop(feat[1])
feat_src = self.fc_src(h_src).reshape( feat_src = self.fc_src(h_src).reshape(
-1, self._num_heads, self._out_feats) -1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).reshape( feat_dst = self.fc_dst(h_dst).reshape(
-1, self._num_heads, self._out_feats) -1, self._num_heads, self._out_feats)
else: else:
h_src = h_dst = self.feat_drop(feat) h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).reshape( feat_src = feat_dst = self.fc(h_src).reshape(
-1, self._num_heads, self._out_feats) -1, self._num_heads, self._out_feats)
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then # to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent: # addition", the two approaches are mathematically equivalent:
# We decompose the weight vector a mentioned in the paper into # We decompose the weight vector a mentioned in the paper into
# [a_l || a_r], then # [a_l || a_r], then
# a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
# Our implementation is much efficient because we do not need to # Our implementation is much efficient because we do not need to
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v, # addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint. # which further speeds up computation and saves memory footprint.
el = (feat_src * self.attn_l.data(feat_src.context)).sum(axis=-1).expand_dims(-1) el = (feat_src * self.attn_l.data(feat_src.context)).sum(axis=-1).expand_dims(-1)
er = (feat_dst * self.attn_r.data(feat_src.context)).sum(axis=-1).expand_dims(-1) er = (feat_dst * self.attn_r.data(feat_src.context)).sum(axis=-1).expand_dims(-1)
graph.srcdata.update({'ft': feat_src, 'el': el}) graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er}) graph.dstdata.update({'er': er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e')) graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e')) e = self.leaky_relu(graph.edata.pop('e'))
# compute softmax # compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft')) fn.sum('m', 'ft'))
rst = graph.dstdata['ft'] rst = graph.dstdata['ft']
# residual # residual
if self.res_fc is not None: if self.res_fc is not None:
resval = self.res_fc(h_dst).reshape(h_dst.shape[0], -1, self._out_feats) resval = self.res_fc(h_dst).reshape(h_dst.shape[0], -1, self._out_feats)
rst = rst + resval rst = rst + resval
# activation # activation
if self.activation: if self.activation:
rst = self.activation(rst) rst = self.activation(rst)
return rst return rst
...@@ -75,23 +75,24 @@ class GatedGraphConv(nn.Block): ...@@ -75,23 +75,24 @@ class GatedGraphConv(nn.Block):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size. is the output feature size.
""" """
assert graph.is_homograph(), \ with graph.local_scope():
"not a homograph; convert it with to_homo and pass in the edge type as argument" assert graph.is_homograph(), \
graph = graph.local_var() "not a homograph; convert it with to_homo and pass in the edge type as argument"
zero_pad = nd.zeros((feat.shape[0], self._out_feats - feat.shape[1]), ctx=feat.context) zero_pad = nd.zeros((feat.shape[0], self._out_feats - feat.shape[1]),
feat = nd.concat(feat, zero_pad, dim=-1) ctx=feat.context)
feat = nd.concat(feat, zero_pad, dim=-1)
for _ in range(self._n_steps): for _ in range(self._n_steps):
graph.ndata['h'] = feat graph.ndata['h'] = feat
for i in range(self._n_etypes): for i in range(self._n_etypes):
eids = (etypes.asnumpy() == i).nonzero()[0] eids = (etypes.asnumpy() == i).nonzero()[0]
eids = nd.from_numpy(eids, zero_copy=True) eids = nd.from_numpy(eids, zero_copy=True)
if len(eids) > 0: if len(eids) > 0:
graph.apply_edges( graph.apply_edges(
lambda edges: {'W_e*h': self.linears[i](edges.src['h'])}, lambda edges: {'W_e*h': self.linears[i](edges.src['h'])},
eids eids
) )
graph.update_all(fn.copy_e('W_e*h', 'm'), fn.sum('m', 'a')) graph.update_all(fn.copy_e('W_e*h', 'm'), fn.sum('m', 'a'))
a = graph.ndata.pop('a') a = graph.ndata.pop('a')
feat = self.gru(a, [feat])[0] feat = self.gru(a, [feat])[0]
return feat return feat
...@@ -74,11 +74,11 @@ class GINConv(nn.Block): ...@@ -74,11 +74,11 @@ class GINConv(nn.Block):
If ``apply_func`` is None, :math:`D_{out}` should be the same If ``apply_func`` is None, :math:`D_{out}` should be the same
as input dimensionality. as input dimensionality.
""" """
graph = graph.local_var() with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat) feat_src, feat_dst = expand_as_pair(feat)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
rst = (1 + self.eps.data(feat_dst.context)) * feat_dst + graph.dstdata['neigh'] rst = (1 + self.eps.data(feat_dst.context)) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None: if self.apply_func is not None:
rst = self.apply_func(rst) rst = self.apply_func(rst)
return rst return rst
...@@ -119,59 +119,58 @@ class GraphConv(gluon.Block): ...@@ -119,59 +119,58 @@ class GraphConv(gluon.Block):
mxnet.NDArray mxnet.NDArray
The output feature The output feature
""" """
graph = graph.local_var() with graph.local_scope():
if self._norm == 'both':
degs = graph.out_degrees().as_in_context(feat.context).astype('float32')
degs = mx.nd.clip(degs, a_min=1, a_max=float("inf"))
norm = mx.nd.power(degs, -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1)
norm = norm.reshape(shp)
feat = feat * norm
if weight is not None:
if self.weight is not None:
raise DGLError('External weight is provided while at the same time the'
' module has defined its own weight parameter. Please'
' create the module with flag weight=False.')
else:
weight = self.weight.data(feat.context)
if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation.
if weight is not None:
feat = mx.nd.dot(feat, weight)
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata.pop('h')
else:
# aggregate first then mult W
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata.pop('h')
if weight is not None:
rst = mx.nd.dot(rst, weight)
if self._norm != 'none':
degs = graph.in_degrees().as_in_context(feat.context).astype('float32')
degs = mx.nd.clip(degs, a_min=1, a_max=float("inf"))
if self._norm == 'both': if self._norm == 'both':
degs = graph.out_degrees().as_in_context(feat.context).astype('float32')
degs = mx.nd.clip(degs, a_min=1, a_max=float("inf"))
norm = mx.nd.power(degs, -0.5) norm = mx.nd.power(degs, -0.5)
else: shp = norm.shape + (1,) * (feat.ndim - 1)
norm = 1.0 / degs norm = norm.reshape(shp)
shp = norm.shape + (1,) * (feat.ndim - 1) feat = feat * norm
norm = norm.reshape(shp)
rst = rst * norm
if self.bias is not None:
rst = rst + self.bias.data(rst.context)
if self._activation is not None: if weight is not None:
rst = self._activation(rst) if self.weight is not None:
raise DGLError('External weight is provided while at the same time the'
return rst ' module has defined its own weight parameter. Please'
' create the module with flag weight=False.')
else:
weight = self.weight.data(feat.context)
if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation.
if weight is not None:
feat = mx.nd.dot(feat, weight)
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata.pop('h')
else:
# aggregate first then mult W
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata.pop('h')
if weight is not None:
rst = mx.nd.dot(rst, weight)
if self._norm != 'none':
degs = graph.in_degrees().as_in_context(feat.context).astype('float32')
degs = mx.nd.clip(degs, a_min=1, a_max=float("inf"))
if self._norm == 'both':
norm = mx.nd.power(degs, -0.5)
else:
norm = 1.0 / degs
shp = norm.shape + (1,) * (feat.ndim - 1)
norm = norm.reshape(shp)
rst = rst * norm
if self.bias is not None:
rst = rst + self.bias.data(rst.context)
if self._activation is not None:
rst = self._activation(rst)
return rst
def __repr__(self): def __repr__(self):
summary = 'GraphConv(' summary = 'GraphConv('
......
...@@ -97,46 +97,45 @@ class SAGEConv(nn.Block): ...@@ -97,46 +97,45 @@ class SAGEConv(nn.Block):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature. is size of output feature.
""" """
graph = graph.local_var() with graph.local_scope():
if isinstance(feat, tuple):
if isinstance(feat, tuple): feat_src = self.feat_drop(feat[0])
feat_src = self.feat_drop(feat[0]) feat_dst = self.feat_drop(feat[1])
feat_dst = self.feat_drop(feat[1]) else:
else: feat_src = feat_dst = self.feat_drop(feat)
feat_src = feat_dst = self.feat_drop(feat)
h_self = feat_dst
h_self = feat_dst
if self._aggre_type == 'mean':
if self._aggre_type == 'mean': graph.srcdata['h'] = feat_src
graph.srcdata['h'] = feat_src graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh')) h_neigh = graph.dstdata['neigh']
h_neigh = graph.dstdata['neigh'] elif self._aggre_type == 'gcn':
elif self._aggre_type == 'gcn': check_eq_shape(feat)
check_eq_shape(feat) graph.srcdata['h'] = feat_src
graph.srcdata['h'] = feat_src graph.dstdata['h'] = feat_dst # saame as above if homogeneous
graph.dstdata['h'] = feat_dst # saame as above if homogeneous graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh')) # divide in degrees
# divide in degrees degs = graph.in_degrees().astype(feat_dst.dtype)
degs = graph.in_degrees().astype(feat_dst.dtype) degs = degs.as_in_context(feat_dst.context)
degs = degs.as_in_context(feat_dst.context) h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.expand_dims(-1) + 1)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.expand_dims(-1) + 1) elif self._aggre_type == 'pool':
elif self._aggre_type == 'pool': graph.srcdata['h'] = nd.relu(self.fc_pool(feat_src))
graph.srcdata['h'] = nd.relu(self.fc_pool(feat_src)) graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh')) h_neigh = graph.dstdata['neigh']
h_neigh = graph.dstdata['neigh'] elif self._aggre_type == 'lstm':
elif self._aggre_type == 'lstm': raise NotImplementedError
raise NotImplementedError else:
else: raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
if self._aggre_type == 'gcn':
if self._aggre_type == 'gcn': rst = self.fc_neigh(h_neigh)
rst = self.fc_neigh(h_neigh) else:
else: rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) # activation
# activation if self.activation is not None:
if self.activation is not None: rst = self.activation(rst)
rst = self.activation(rst) # normalization
# normalization if self.norm is not None:
if self.norm is not None: rst = self.norm(rst)
rst = self.norm(rst) return rst
return rst
...@@ -74,27 +74,27 @@ class SGConv(nn.Block): ...@@ -74,27 +74,27 @@ class SGConv(nn.Block):
If ``cache`` is se to True, ``feat`` and ``graph`` should not change during If ``cache`` is se to True, ``feat`` and ``graph`` should not change during
training, or you will get wrong results. training, or you will get wrong results.
""" """
graph = graph.local_var() with graph.local_scope():
if self._cached_h is not None: if self._cached_h is not None:
feat = self._cached_h feat = self._cached_h
else: else:
# compute normalization # compute normalization
degs = nd.clip(graph.in_degrees().astype(feat.dtype), 1, float('inf')) degs = nd.clip(graph.in_degrees().astype(feat.dtype), 1, float('inf'))
norm = nd.power(degs, -0.5).expand_dims(1) norm = nd.power(degs, -0.5).expand_dims(1)
norm = norm.as_in_context(feat.context) norm = norm.as_in_context(feat.context)
# compute (D^-1 A D)^k X # compute (D^-1 A D)^k X
for _ in range(self._k): for _ in range(self._k):
feat = feat * norm feat = feat * norm
graph.ndata['h'] = feat graph.ndata['h'] = feat
graph.update_all(fn.copy_u('h', 'm'), graph.update_all(fn.copy_u('h', 'm'),
fn.sum('m', 'h')) fn.sum('m', 'h'))
feat = graph.ndata.pop('h') feat = graph.ndata.pop('h')
feat = feat * norm feat = feat * norm
if self.norm is not None: if self.norm is not None:
feat = self.norm(feat) feat = self.norm(feat)
# cache feature # cache feature
if self._cached: if self._cached:
self._cached_h = feat self._cached_h = feat
return self.fc(feat) return self.fc(feat)
...@@ -76,30 +76,30 @@ class TAGConv(gluon.Block): ...@@ -76,30 +76,30 @@ class TAGConv(gluon.Block):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature. is size of output feature.
""" """
assert graph.is_homograph(), 'Graph is not homogeneous' with graph.local_scope():
graph = graph.local_var() assert graph.is_homograph(), 'Graph is not homogeneous'
degs = graph.in_degrees().astype('float32') degs = graph.in_degrees().astype('float32')
norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5) norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1) shp = norm.shape + (1,) * (feat.ndim - 1)
norm = norm.reshape(shp).as_in_context(feat.context) norm = norm.reshape(shp).as_in_context(feat.context)
rst = feat rst = feat
for _ in range(self.k): for _ in range(self.k):
rst = rst * norm rst = rst * norm
graph.ndata['h'] = rst graph.ndata['h'] = rst
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.ndata['h'] rst = graph.ndata['h']
rst = rst * norm rst = rst * norm
feat = mx.nd.concat(feat, rst, dim=-1) feat = mx.nd.concat(feat, rst, dim=-1)
rst = mx.nd.dot(feat, self.lin.data(feat.context)) rst = mx.nd.dot(feat, self.lin.data(feat.context))
if self.bias is not None: if self.bias is not None:
rst = rst + self.h_bias.data(rst.context) rst = rst + self.h_bias.data(rst.context)
if self.activation is not None: if self.activation is not None:
rst = self.activation(rst) rst = self.activation(rst)
return rst return rst
"""Gluon layer for graph related softmax.""" """Gluon layer for graph related softmax."""
# pylint: disable= no-member, arguments-differ # pylint: disable= no-member, arguments-differ, access-member-before-definition, unpacking-non-sequence
import mxnet as mx import mxnet as mx
from ... import function as fn from ... import function as fn
...@@ -45,16 +45,17 @@ class EdgeSoftmax(mx.autograd.Function): ...@@ -45,16 +45,17 @@ class EdgeSoftmax(mx.autograd.Function):
out = score / score_sum # edge_div_dst, ret dgl.EData out = score / score_sum # edge_div_dst, ret dgl.EData
return out.data return out.data
""" """
g = self.g.local_var() g = self.g
g.edata['s'] = score with g.local_scope():
g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax')) g.edata['s'] = score
g.apply_edges(fn.e_sub_v('s', 'smax', 'out')) g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax'))
g.edata['out'] = g.edata['out'].exp() g.apply_edges(fn.e_sub_v('s', 'smax', 'out'))
g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum')) g.edata['out'] = g.edata['out'].exp()
g.apply_edges(fn.e_div_v('out', 'out_sum', 'out')) g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum'))
out = g.edata['out'] g.apply_edges(fn.e_div_v('out', 'out_sum', 'out'))
self.save_for_backward(out) out = g.edata['out']
return out self.save_for_backward(out)
return out
def backward(self, grad_out): def backward(self, grad_out):
"""Backward function. """Backward function.
...@@ -70,16 +71,17 @@ class EdgeSoftmax(mx.autograd.Function): ...@@ -70,16 +71,17 @@ class EdgeSoftmax(mx.autograd.Function):
sds_sum = sds.dst_sum() # type dgl.NData sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions grad_score = sds - sds * sds_sum # multiple expressions
""" """
g = self.g.local_var() g = self.g
out, = self.saved_tensors # pylint: disable=access-member-before-definition, unpacking-non-sequence with g.local_scope():
# clear saved tensors explicitly out, = self.saved_tensors
self.saved_tensors = None # clear saved tensors explicitly
g.edata['out'] = out self.saved_tensors = None
g.edata['grad_score'] = out * grad_out g.edata['out'] = out
g.update_all(fn.copy_e('grad_score', 'm'), fn.sum('m', 'accum')) g.edata['grad_score'] = out * grad_out
g.apply_edges(fn.e_mul_v('out', 'accum', 'out')) g.update_all(fn.copy_e('grad_score', 'm'), fn.sum('m', 'accum'))
grad_score = g.edata['grad_score'] - g.edata['out'] g.apply_edges(fn.e_mul_v('out', 'accum', 'out'))
return grad_score grad_score = g.edata['grad_score'] - g.edata['out']
return grad_score
def edge_softmax(graph, logits, eids=ALL): def edge_softmax(graph, logits, eids=ALL):
r"""Compute edge softmax. r"""Compute edge softmax.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment