Unverified Commit 61139302 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[API Deprecation] Remove candidates in DGLGraph (#4946)

parent e088acac
...@@ -117,7 +117,7 @@ class UTransformer(nn.Module): ...@@ -117,7 +117,7 @@ class UTransformer(nn.Module):
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids) g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
# Send weighted values to target nodes # Send weighted values to target nodes
g.send_and_recv(eids, g.send_and_recv(eids,
[fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')], [fn.u_mul_e('v', 'score', 'v'), fn.copy_e('score', 'score')],
[fn.sum('v', 'wv'), fn.sum('score', 'z')]) [fn.sum('v', 'wv'), fn.sum('score', 'z')])
def update_graph(self, g, eids, pre_pairs, post_pairs): def update_graph(self, g, eids, pre_pairs, post_pairs):
......
...@@ -79,8 +79,8 @@ class Transformer(nn.Module): ...@@ -79,8 +79,8 @@ class Transformer(nn.Module):
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids) g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids) g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
# Send weighted values to target nodes # Send weighted values to target nodes
g.send_and_recv(eids, fn.src_mul_edge('v', 'score', 'v'), fn.sum('v', 'wv')) g.send_and_recv(eids, fn.u_mul_e('v', 'score', 'v'), fn.sum('v', 'wv'))
g.send_and_recv(eids, fn.copy_edge('score', 'score'), fn.sum('score', 'z')) g.send_and_recv(eids, fn.copy_e('score', 'score'), fn.sum('score', 'z'))
def update_graph(self, g, eids, pre_pairs, post_pairs): def update_graph(self, g, eids, pre_pairs, post_pairs):
"Update the node states and edge states of the graph." "Update the node states and edge states of the graph."
......
...@@ -17,7 +17,7 @@ def get_attention_map(g, src_nodes, dst_nodes, h): ...@@ -17,7 +17,7 @@ def get_attention_map(g, src_nodes, dst_nodes, h):
for j, dst in enumerate(dst_nodes.tolist()): for j, dst in enumerate(dst_nodes.tolist()):
if not g.has_edge_between(src, dst): if not g.has_edge_between(src, dst):
continue continue
eid = g.edge_id(src, dst) eid = g.edge_ids(src, dst)
weight[i][j] = g.edata['score'][eid].squeeze(-1).cpu().detach() weight[i][j] = g.edata['score'][eid].squeeze(-1).cpu().detach()
weight = weight.transpose(0, 2) weight = weight.transpose(0, 2)
......
...@@ -131,7 +131,7 @@ def main(args): ...@@ -131,7 +131,7 @@ def main(args):
root_ids = [ root_ids = [
i i
for i in range(g.number_of_nodes()) for i in range(g.number_of_nodes())
if g.out_degree(i) == 0 if g.out_degrees(i) == 0
] ]
root_acc = np.sum( root_acc = np.sum(
batch.label.cpu().data.numpy()[root_ids] batch.label.cpu().data.numpy()[root_ids]
...@@ -170,7 +170,7 @@ def main(args): ...@@ -170,7 +170,7 @@ def main(args):
acc = th.sum(th.eq(batch.label, pred)).item() acc = th.sum(th.eq(batch.label, pred)).item()
accs.append([acc, len(batch.label)]) accs.append([acc, len(batch.label)])
root_ids = [ root_ids = [
i for i in range(g.number_of_nodes()) if g.out_degree(i) == 0 i for i in range(g.number_of_nodes()) if g.out_degrees(i) == 0
] ]
root_acc = np.sum( root_acc = np.sum(
batch.label.cpu().data.numpy()[root_ids] batch.label.cpu().data.numpy()[root_ids]
...@@ -222,7 +222,7 @@ def main(args): ...@@ -222,7 +222,7 @@ def main(args):
acc = th.sum(th.eq(batch.label, pred)).item() acc = th.sum(th.eq(batch.label, pred)).item()
accs.append([acc, len(batch.label)]) accs.append([acc, len(batch.label)])
root_ids = [ root_ids = [
i for i in range(g.number_of_nodes()) if g.out_degree(i) == 0 i for i in range(g.number_of_nodes()) if g.out_degrees(i) == 0
] ]
root_acc = np.sum( root_acc = np.sum(
batch.label.cpu().data.numpy()[root_ids] batch.label.cpu().data.numpy()[root_ids]
......
...@@ -45,7 +45,7 @@ class GCNLayer(layers.Layer): ...@@ -45,7 +45,7 @@ class GCNLayer(layers.Layer):
h = self.dropout(h) h = self.dropout(h)
self.g.ndata['h'] = tf.matmul(h, self.weight) self.g.ndata['h'] = tf.matmul(h, self.weight)
self.g.ndata['norm_h'] = self.g.ndata['h'] * self.g.ndata['norm'] self.g.ndata['norm_h'] = self.g.ndata['h'] * self.g.ndata['norm']
self.g.update_all(fn.copy_src('norm_h', 'm'), self.g.update_all(fn.copy_u('norm_h', 'm'),
fn.sum('m', 'h')) fn.sum('m', 'h'))
h = self.g.ndata['h'] h = self.g.ndata['h']
if self.bias is not None: if self.bias is not None:
......
...@@ -3083,10 +3083,10 @@ class DGLGraph(DGLBaseGraph): ...@@ -3083,10 +3083,10 @@ class DGLGraph(DGLBaseGraph):
>>> g.add_nodes(3) >>> g.add_nodes(3)
>>> g.ndata['x'] = th.tensor([[0.], [1.], [2.]]) >>> g.ndata['x'] = th.tensor([[0.], [1.], [2.]])
Use the built-in message function :func:`~dgl.function.copy_src` for copying Use the built-in message function :func:`~dgl.function.copy_u` for copying
node features as the message. node features as the message.
>>> m_func = dgl.function.copy_src('x', 'm') >>> m_func = dgl.function.copy_u('x', 'm')
>>> g.register_message_func(m_func) >>> g.register_message_func(m_func)
Use the built-int message reducing function :func:`~dgl.function.sum`, which Use the built-int message reducing function :func:`~dgl.function.sum`, which
...@@ -3180,10 +3180,10 @@ class DGLGraph(DGLBaseGraph): ...@@ -3180,10 +3180,10 @@ class DGLGraph(DGLBaseGraph):
>>> g.add_nodes(3) >>> g.add_nodes(3)
>>> g.ndata['x'] = th.tensor([[1.], [2.], [3.]]) >>> g.ndata['x'] = th.tensor([[1.], [2.], [3.]])
Use the built-in message function :func:`~dgl.function.copy_src` for copying Use the built-in message function :func:`~dgl.function.copy_u` for copying
node features as the message. node features as the message.
>>> m_func = dgl.function.copy_src('x', 'm') >>> m_func = dgl.function.copy_u('x', 'm')
>>> g.register_message_func(m_func) >>> g.register_message_func(m_func)
Use the built-int message reducing function :func:`~dgl.function.sum`, which Use the built-int message reducing function :func:`~dgl.function.sum`, which
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from collections.abc import Mapping from collections.abc import Mapping
from . import backend as F from . import backend as F
from .base import ALL, is_all, DGLError, dgl_warning, NID, EID from .base import ALL, is_all, DGLError, NID, EID
from .heterograph_index import disjoint_union, slice_gidx from .heterograph_index import disjoint_union, slice_gidx
from .heterograph import DGLGraph from .heterograph import DGLGraph
from . import convert from . import convert
...@@ -11,8 +11,7 @@ from . import utils ...@@ -11,8 +11,7 @@ from . import utils
__all__ = ['batch', 'unbatch', 'slice_batch'] __all__ = ['batch', 'unbatch', 'slice_batch']
def batch(graphs, ndata=ALL, edata=ALL, *, def batch(graphs, ndata=ALL, edata=ALL):
node_attrs=None, edge_attrs=None):
r"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient r"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient
graph computation. graph computation.
...@@ -151,14 +150,6 @@ def batch(graphs, ndata=ALL, edata=ALL, *, ...@@ -151,14 +150,6 @@ def batch(graphs, ndata=ALL, edata=ALL, *,
""" """
if len(graphs) == 0: if len(graphs) == 0:
raise DGLError('The input list of graphs cannot be empty.') raise DGLError('The input list of graphs cannot be empty.')
if node_attrs is not None:
dgl_warning('Arguments node_attrs has been deprecated. Please use'
' ndata instead.')
ndata = node_attrs
if edge_attrs is not None:
dgl_warning('Arguments edge_attrs has been deprecated. Please use'
' edata instead.')
edata = edge_attrs
if not (is_all(ndata) or isinstance(ndata, list) or ndata is None): if not (is_all(ndata) or isinstance(ndata, list) or ndata is None):
raise DGLError('Invalid argument ndata: must be a string list but got {}.'.format( raise DGLError('Invalid argument ndata: must be a string list but got {}.'.format(
type(ndata))) type(ndata)))
......
...@@ -9,7 +9,7 @@ from .._deprecate.runtime import ir ...@@ -9,7 +9,7 @@ from .._deprecate.runtime import ir
from .._deprecate.runtime.ir import var from .._deprecate.runtime.ir import var
__all__ = ["src_mul_edge", "copy_src", "copy_edge", "copy_u", "copy_e", __all__ = ["copy_u", "copy_e",
"BinaryMessageFunction", "CopyMessageFunction"] "BinaryMessageFunction", "CopyMessageFunction"]
...@@ -34,7 +34,7 @@ class BinaryMessageFunction(MessageFunction): ...@@ -34,7 +34,7 @@ class BinaryMessageFunction(MessageFunction):
See Also See Also
-------- --------
src_mul_edge u_mul_e
""" """
def __init__(self, binary_op, lhs, rhs, lhs_field, rhs_field, out_field): def __init__(self, binary_op, lhs, rhs, lhs_field, rhs_field, out_field):
self.binary_op = binary_op self.binary_op = binary_op
...@@ -73,7 +73,7 @@ class CopyMessageFunction(MessageFunction): ...@@ -73,7 +73,7 @@ class CopyMessageFunction(MessageFunction):
See Also See Also
-------- --------
copy_src copy_u
""" """
def __init__(self, target, in_field, out_field): def __init__(self, target, in_field, out_field):
self.target = target self.target = target
...@@ -218,86 +218,3 @@ def _register_builtin_message_func(): ...@@ -218,86 +218,3 @@ def _register_builtin_message_func():
__all__.append(func.__name__) __all__.append(func.__name__)
_register_builtin_message_func() _register_builtin_message_func()
##############################################################################
# For backward compatibility
def src_mul_edge(src, edge, out):
"""Builtin message function that computes message by performing
binary operation mul between src feature and edge feature.
Notes
-----
This function is deprecated. Please use :func:`~dgl.function.u_mul_e` instead.
Parameters
----------
src : str
The source feature field.
edge : str
The edge feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.src_mul_edge('h', 'e', 'm')
"""
return getattr(sys.modules[__name__], "u_mul_e")(src, edge, out)
def copy_src(src, out):
"""Builtin message function that computes message using source node
feature.
Notes
-----
This function is deprecated. Please use :func:`~dgl.function.copy_u` instead.
Parameters
----------
src : str
The source feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.copy_src('h', 'm')
The above example is equivalent to the following user defined function:
>>> def message_func(edges):
>>> return {'m': edges.src['h']}
"""
return copy_u(src, out)
def copy_edge(edge, out):
"""Builtin message function that computes message using edge feature.
Notes
-----
This function is deprecated. Please use :func:`~dgl.function.copy_e` instead.
Parameters
----------
edge : str
The edge feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.copy_edge('h', 'm')
The above example is equivalent to the following user defined function:
>>> def message_func(edges):
>>> return {'m': edges.data['h']}
"""
return copy_e(edge, out)
...@@ -346,14 +346,6 @@ class DGLGraph(object): ...@@ -346,14 +346,6 @@ class DGLGraph(object):
self._node_frames[ntid].append(data) self._node_frames[ntid].append(data)
self._reset_cached_info() self._reset_cached_info()
def add_edge(self, u, v, data=None, etype=None):
"""Add one edge to the graph.
DEPRECATED: please use ``add_edges``.
"""
dgl_warning("DGLGraph.add_edge is deprecated. Please use DGLGraph.add_edges")
self.add_edges(u, v, data, etype)
def add_edges(self, u, v, data=None, etype=None): def add_edges(self, u, v, data=None, etype=None):
r"""Add multiple new edges for the specified edge type r"""Add multiple new edges for the specified edge type
...@@ -2623,20 +2615,6 @@ class DGLGraph(object): ...@@ -2623,20 +2615,6 @@ class DGLGraph(object):
""" """
return len(self.ntypes) == 1 and len(self.etypes) == 1 return len(self.ntypes) == 1 and len(self.etypes) == 1
@property
def is_readonly(self):
"""**DEPRECATED**: DGLGraph will always be mutable.
Returns
-------
bool
True if the graph is readonly, False otherwise.
"""
dgl_warning('DGLGraph.is_readonly is deprecated in v0.5.\n'
'DGLGraph now always supports mutable operations like add_nodes'
' and add_edges.')
return False
@property @property
def idtype(self): def idtype(self):
"""The data type for storing the structure-related graph information """The data type for storing the structure-related graph information
...@@ -2682,12 +2660,6 @@ class DGLGraph(object): ...@@ -2682,12 +2660,6 @@ class DGLGraph(object):
""" """
return self._graph.dtype return self._graph.dtype
def __contains__(self, vid):
"""**DEPRECATED**: please directly call :func:`has_nodes`."""
dgl_warning('DGLGraph.__contains__ is deprecated.'
' Please directly call has_nodes.')
return self.has_nodes(vid)
def has_nodes(self, vid, ntype=None): def has_nodes(self, vid, ntype=None):
"""Return whether the graph contains the given nodes. """Return whether the graph contains the given nodes.
...@@ -2745,14 +2717,6 @@ class DGLGraph(object): ...@@ -2745,14 +2717,6 @@ class DGLGraph(object):
else: else:
return F.astype(ret, F.bool) return F.astype(ret, F.bool)
def has_node(self, vid, ntype=None):
"""Whether the graph has a particular node of a given type.
**DEPRECATED**: see :func:`~DGLGraph.has_nodes`
"""
dgl_warning("DGLGraph.has_node is deprecated. Please use DGLGraph.has_nodes")
return self.has_nodes(vid, ntype)
def has_edges_between(self, u, v, etype=None): def has_edges_between(self, u, v, etype=None):
"""Return whether the graph contains the given edges. """Return whether the graph contains the given edges.
...@@ -2843,15 +2807,6 @@ class DGLGraph(object): ...@@ -2843,15 +2807,6 @@ class DGLGraph(object):
else: else:
return F.astype(ret, F.bool) return F.astype(ret, F.bool)
def has_edge_between(self, u, v, etype=None):
"""Whether the graph has edges of type ``etype``.
**DEPRECATED**: please use :func:`~DGLGraph.has_edge_between`.
"""
dgl_warning("DGLGraph.has_edge_between is deprecated. "
"Please use DGLGraph.has_edges_between")
return self.has_edges_between(u, v, etype)
def predecessors(self, v, etype=None): def predecessors(self, v, etype=None):
"""Return the predecessor(s) of a particular node with the specified edge type. """Return the predecessor(s) of a particular node with the specified edge type.
...@@ -2969,17 +2924,7 @@ class DGLGraph(object): ...@@ -2969,17 +2924,7 @@ class DGLGraph(object):
raise DGLError('Non-existing node ID {}'.format(v)) raise DGLError('Non-existing node ID {}'.format(v))
return self._graph.successors(self.get_etype_id(etype), v) return self._graph.successors(self.get_etype_id(etype), v)
def edge_id(self, u, v, force_multi=None, return_uv=False, etype=None): def edge_ids(self, u, v, return_uv=False, etype=None):
"""Return the edge ID, or an array of edge IDs, between source node
`u` and destination node `v`, with the specified edge type
**DEPRECATED**: See edge_ids
"""
dgl_warning("DGLGraph.edge_id is deprecated. Please use DGLGraph.edge_ids.")
return self.edge_ids(u, v, force_multi=force_multi,
return_uv=return_uv, etype=etype)
def edge_ids(self, u, v, force_multi=None, return_uv=False, etype=None):
"""Return the edge ID(s) given the two endpoints of the edge(s). """Return the edge ID(s) given the two endpoints of the edge(s).
Parameters Parameters
...@@ -2999,9 +2944,6 @@ class DGLGraph(object): ...@@ -2999,9 +2944,6 @@ class DGLGraph(object):
* Int Tensor: Each element is a node ID. The tensor must have the same device type * Int Tensor: Each element is a node ID. The tensor must have the same device type
and ID data type as the graph's. and ID data type as the graph's.
* iterable[int]: Each element is a node ID. * iterable[int]: Each element is a node ID.
force_multi : bool, optional
**DEPRECATED**, use :attr:`return_uv` instead. Whether to allow the graph to be a
multigraph, i.e. there can be multiple edges from one node to another.
return_uv : bool, optional return_uv : bool, optional
Whether to return the source and destination node IDs along with the edges. If Whether to return the source and destination node IDs along with the edges. If
False (default), it assumes that the graph is a simple graph and there is only False (default), it assumes that the graph is a simple graph and there is only
...@@ -3084,10 +3026,6 @@ class DGLGraph(object): ...@@ -3084,10 +3026,6 @@ class DGLGraph(object):
v = utils.prepare_tensor(self, v, 'v') v = utils.prepare_tensor(self, v, 'v')
if F.as_scalar(F.sum(self.has_nodes(v, ntype=dsttype), dim=0)) != len(v): if F.as_scalar(F.sum(self.has_nodes(v, ntype=dsttype), dim=0)) != len(v):
raise DGLError('v contains invalid node IDs') raise DGLError('v contains invalid node IDs')
if force_multi is not None:
dgl_warning("force_multi will be deprecated, " \
"Please use return_uv instead")
return_uv = force_multi
if return_uv: if return_uv:
return self._graph.edge_ids_all(self.get_etype_id(etype), u, v) return self._graph.edge_ids_all(self.get_etype_id(etype), u, v)
...@@ -3424,14 +3362,6 @@ class DGLGraph(object): ...@@ -3424,14 +3362,6 @@ class DGLGraph(object):
else: else:
raise DGLError('Invalid form: {}. Must be "all", "uv" or "eid".'.format(form)) raise DGLError('Invalid form: {}. Must be "all", "uv" or "eid".'.format(form))
def in_degree(self, v, etype=None):
"""Return the in-degree of node ``v`` with edges of type ``etype``.
**DEPRECATED**: Please use in_degrees
"""
dgl_warning("DGLGraph.in_degree is deprecated. Please use DGLGraph.in_degrees")
return self.in_degrees(v, etype)
def in_degrees(self, v=ALL, etype=None): def in_degrees(self, v=ALL, etype=None):
"""Return the in-degree(s) of the given nodes. """Return the in-degree(s) of the given nodes.
...@@ -3508,14 +3438,6 @@ class DGLGraph(object): ...@@ -3508,14 +3438,6 @@ class DGLGraph(object):
else: else:
return deg return deg
def out_degree(self, u, etype=None):
"""Return the out-degree of node `u` with edges of type ``etype``.
DEPRECATED: please use DGL.out_degrees
"""
dgl_warning("DGLGraph.out_degree is deprecated. Please use DGLGraph.out_degrees")
return self.out_degrees(u, etype)
def out_degrees(self, u=ALL, etype=None): def out_degrees(self, u=ALL, etype=None):
"""Return the out-degree(s) of the given nodes. """Return the out-degree(s) of the given nodes.
...@@ -3713,15 +3635,6 @@ class DGLGraph(object): ...@@ -3713,15 +3635,6 @@ class DGLGraph(object):
else: else:
return self._graph.adjacency_matrix_tensors(etid, False, fmt)[2:] return self._graph.adjacency_matrix_tensors(etid, False, fmt)[2:]
def adjacency_matrix_scipy(self, transpose=False, fmt='csr', return_edge_ids=None):
"""DEPRECATED: please use ``dgl.adjacency_matrix(transpose, scipy_fmt=fmt)``.
"""
dgl_warning('DGLGraph.adjacency_matrix_scipy is deprecated. '
'Please replace it with:\n\n\t'
'DGLGraph.adjacency_matrix(transpose, scipy_fmt="{}").\n'.format(fmt))
return self.adjacency_matrix(transpose=transpose, scipy_fmt=fmt)
def inc(self, typestr, ctx=F.cpu(), etype=None): def inc(self, typestr, ctx=F.cpu(), etype=None):
"""Return the incidence matrix representation of edges with the given """Return the incidence matrix representation of edges with the given
edge type. edge type.
...@@ -4283,7 +4196,7 @@ class DGLGraph(object): ...@@ -4283,7 +4196,7 @@ class DGLGraph(object):
# Message passing # Message passing
################################################################# #################################################################
def apply_nodes(self, func, v=ALL, ntype=None, inplace=False): def apply_nodes(self, func, v=ALL, ntype=None):
"""Update the features of the specified nodes by the provided function. """Update the features of the specified nodes by the provided function.
Parameters Parameters
...@@ -4303,8 +4216,6 @@ class DGLGraph(object): ...@@ -4303,8 +4216,6 @@ class DGLGraph(object):
ntype : str, optional ntype : str, optional
The node type name. Can be omitted if there is The node type name. Can be omitted if there is
only one type of nodes in the graph. only one type of nodes in the graph.
inplace : bool, optional
**DEPRECATED**.
Examples Examples
-------- --------
...@@ -4340,8 +4251,6 @@ class DGLGraph(object): ...@@ -4340,8 +4251,6 @@ class DGLGraph(object):
-------- --------
apply_edges apply_edges
""" """
if inplace:
raise DGLError('The `inplace` option is removed in v0.5.')
ntid = self.get_ntype_id(ntype) ntid = self.get_ntype_id(ntype)
ntype = self.ntypes[ntid] ntype = self.ntypes[ntid]
if is_all(v): if is_all(v):
...@@ -4351,7 +4260,7 @@ class DGLGraph(object): ...@@ -4351,7 +4260,7 @@ class DGLGraph(object):
ndata = core.invoke_node_udf(self, v_id, ntype, func, orig_nid=v_id) ndata = core.invoke_node_udf(self, v_id, ntype, func, orig_nid=v_id)
self._set_n_repr(ntid, v, ndata) self._set_n_repr(ntid, v, ndata)
def apply_edges(self, func, edges=ALL, etype=None, inplace=False): def apply_edges(self, func, edges=ALL, etype=None):
"""Update the features of the specified edges by the provided function. """Update the features of the specified edges by the provided function.
Parameters Parameters
...@@ -4382,9 +4291,6 @@ class DGLGraph(object): ...@@ -4382,9 +4291,6 @@ class DGLGraph(object):
Can be omitted if the graph has only one type of edges. Can be omitted if the graph has only one type of edges.
inplace: bool, optional
**DEPRECATED**.
Notes Notes
----- -----
DGL recommends using DGL's bulit-in function for the :attr:`func` argument, DGL recommends using DGL's bulit-in function for the :attr:`func` argument,
...@@ -4435,8 +4341,6 @@ class DGLGraph(object): ...@@ -4435,8 +4341,6 @@ class DGLGraph(object):
-------- --------
apply_nodes apply_nodes
""" """
if inplace:
raise DGLError('The `inplace` option is removed in v0.5.')
# Graph with one relation type # Graph with one relation type
if self._graph.number_of_etypes() == 1 or etype is not None: if self._graph.number_of_etypes() == 1 or etype is not None:
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
...@@ -4476,8 +4380,7 @@ class DGLGraph(object): ...@@ -4476,8 +4380,7 @@ class DGLGraph(object):
message_func, message_func,
reduce_func, reduce_func,
apply_node_func=None, apply_node_func=None,
etype=None, etype=None):
inplace=False):
"""Send messages along the specified edges and reduce them on """Send messages along the specified edges and reduce them on
the destination nodes to update their features. the destination nodes to update their features.
...@@ -4513,9 +4416,6 @@ class DGLGraph(object): ...@@ -4513,9 +4416,6 @@ class DGLGraph(object):
Can be omitted if the graph has only one type of edges. Can be omitted if the graph has only one type of edges.
inplace: bool, optional
**DEPRECATED**.
Notes Notes
----- -----
DGL recommends using DGL's bulit-in function for the :attr:`message_func` DGL recommends using DGL's bulit-in function for the :attr:`message_func`
...@@ -4558,7 +4458,7 @@ class DGLGraph(object): ...@@ -4558,7 +4458,7 @@ class DGLGraph(object):
... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]) ... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])
... }) ... })
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]]) >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g.send_and_recv(g['follows'].edges(), fn.copy_src('h', 'm'), >>> g.send_and_recv(g['follows'].edges(), fn.copy_u('h', 'm'),
... fn.sum('m', 'h'), etype='follows') ... fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h'] >>> g.nodes['user'].data['h']
tensor([[0.], tensor([[0.],
...@@ -4588,8 +4488,6 @@ class DGLGraph(object): ...@@ -4588,8 +4488,6 @@ class DGLGraph(object):
Note that the feature of node 0 remains the same as it has no incoming edges. Note that the feature of node 0 remains the same as it has no incoming edges.
""" """
if inplace:
raise DGLError('The `inplace` option is removed in v0.5.')
# edge type # edge type
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
_, dtid = self._graph.metagraph.find_edge(etid) _, dtid = self._graph.metagraph.find_edge(etid)
...@@ -4612,8 +4510,7 @@ class DGLGraph(object): ...@@ -4612,8 +4510,7 @@ class DGLGraph(object):
message_func, message_func,
reduce_func, reduce_func,
apply_node_func=None, apply_node_func=None,
etype=None, etype=None):
inplace=False):
"""Pull messages from the specified node(s)' predecessors along the """Pull messages from the specified node(s)' predecessors along the
specified edge type, aggregate them to update the node features. specified edge type, aggregate them to update the node features.
...@@ -4645,9 +4542,6 @@ class DGLGraph(object): ...@@ -4645,9 +4542,6 @@ class DGLGraph(object):
Can be omitted if the graph has only one type of edges. Can be omitted if the graph has only one type of edges.
inplace: bool, optional
**DEPRECATED**.
Notes Notes
----- -----
* If some of the given nodes :attr:`v` has no in-edges, DGL does not invoke * If some of the given nodes :attr:`v` has no in-edges, DGL does not invoke
...@@ -4688,14 +4582,12 @@ class DGLGraph(object): ...@@ -4688,14 +4582,12 @@ class DGLGraph(object):
Pull. Pull.
>>> g['follows'].pull(2, fn.copy_src('h', 'm'), fn.sum('m', 'h'), etype='follows') >>> g['follows'].pull(2, fn.copy_u('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h'] >>> g.nodes['user'].data['h']
tensor([[0.], tensor([[0.],
[1.], [1.],
[1.]]) [1.]])
""" """
if inplace:
raise DGLError('The `inplace` option is removed in v0.5.')
v = utils.prepare_tensor(self, v, 'v') v = utils.prepare_tensor(self, v, 'v')
if len(v) == 0: if len(v) == 0:
# no computation # no computation
...@@ -4716,8 +4608,7 @@ class DGLGraph(object): ...@@ -4716,8 +4608,7 @@ class DGLGraph(object):
message_func, message_func,
reduce_func, reduce_func,
apply_node_func=None, apply_node_func=None,
etype=None, etype=None):
inplace=False):
"""Send message from the specified node(s) to their successors """Send message from the specified node(s) to their successors
along the specified edge type and update their node features. along the specified edge type and update their node features.
...@@ -4749,9 +4640,6 @@ class DGLGraph(object): ...@@ -4749,9 +4640,6 @@ class DGLGraph(object):
Can be omitted if the graph has only one type of edges. Can be omitted if the graph has only one type of edges.
inplace: bool, optional
**DEPRECATED**.
Notes Notes
----- -----
DGL recommends using DGL's bulit-in function for the :attr:`message_func` DGL recommends using DGL's bulit-in function for the :attr:`message_func`
...@@ -4785,14 +4673,12 @@ class DGLGraph(object): ...@@ -4785,14 +4673,12 @@ class DGLGraph(object):
Push. Push.
>>> g['follows'].push(0, fn.copy_src('h', 'm'), fn.sum('m', 'h'), etype='follows') >>> g['follows'].push(0, fn.copy_u('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h'] >>> g.nodes['user'].data['h']
tensor([[0.], tensor([[0.],
[0.], [0.],
[0.]]) [0.]])
""" """
if inplace:
raise DGLError('The `inplace` option is removed in v0.5.')
edges = self.out_edges(u, form='eid', etype=etype) edges = self.out_edges(u, form='eid', etype=etype)
self.send_and_recv(edges, message_func, reduce_func, apply_node_func, etype=etype) self.send_and_recv(edges, message_func, reduce_func, apply_node_func, etype=etype)
...@@ -4864,7 +4750,7 @@ class DGLGraph(object): ...@@ -4864,7 +4750,7 @@ class DGLGraph(object):
Update all. Update all.
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]]) >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g['follows'].update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'), etype='follows') >>> g['follows'].update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h'] >>> g.nodes['user'].data['h']
tensor([[0.], tensor([[0.],
[0.], [0.],
...@@ -4881,7 +4767,7 @@ class DGLGraph(object): ...@@ -4881,7 +4767,7 @@ class DGLGraph(object):
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]]) >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]]) >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
>>> g.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h')) >>> g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>> g.nodes['user'].data['h'] >>> g.nodes['user'].data['h']
tensor([[0.], tensor([[0.],
[4.]]) [4.]])
...@@ -4989,8 +4875,8 @@ class DGLGraph(object): ...@@ -4989,8 +4875,8 @@ class DGLGraph(object):
Update all. Update all.
>>> g.multi_update_all( >>> g.multi_update_all(
... {'follows': (fn.copy_src('h', 'm'), fn.sum('m', 'h')), ... {'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')),
... 'attracts': (fn.copy_src('h', 'm'), fn.sum('m', 'h'))}, ... 'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))},
... "sum") ... "sum")
>>> g.nodes['user'].data['h'] >>> g.nodes['user'].data['h']
tensor([[0.], tensor([[0.],
...@@ -5004,8 +4890,8 @@ class DGLGraph(object): ...@@ -5004,8 +4890,8 @@ class DGLGraph(object):
Use the user-defined cross reducer. Use the user-defined cross reducer.
>>> g.multi_update_all( >>> g.multi_update_all(
... {'follows': (fn.copy_src('h', 'm'), fn.sum('m', 'h')), ... {'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')),
... 'attracts': (fn.copy_src('h', 'm'), fn.sum('m', 'h'))}, ... 'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))},
... cross_sum) ... cross_sum)
""" """
all_out = defaultdict(list) all_out = defaultdict(list)
...@@ -5088,7 +4974,7 @@ class DGLGraph(object): ...@@ -5088,7 +4974,7 @@ class DGLGraph(object):
>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2, 3], [2, 3, 4, 4])}) >>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2, 3], [2, 3, 4, 4])})
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]]) >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])
>>> g['follows'].prop_nodes([[2, 3], [4]], fn.copy_src('h', 'm'), >>> g['follows'].prop_nodes([[2, 3], [4]], fn.copy_u('h', 'm'),
... fn.sum('m', 'h'), etype='follows') ... fn.sum('m', 'h'), etype='follows')
tensor([[1.], tensor([[1.],
[2.], [2.],
...@@ -5151,7 +5037,7 @@ class DGLGraph(object): ...@@ -5151,7 +5037,7 @@ class DGLGraph(object):
>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2, 3], [2, 3, 4, 4])}) >>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2, 3], [2, 3, 4, 4])})
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]]) >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])
>>> g['follows'].prop_edges([[0, 1], [2, 3]], fn.copy_src('h', 'm'), >>> g['follows'].prop_edges([[0, 1], [2, 3]], fn.copy_u('h', 'm'),
... fn.sum('m', 'h'), etype='follows') ... fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h'] >>> g.nodes['user'].data['h']
tensor([[1.], tensor([[1.],
...@@ -6062,113 +5948,6 @@ class DGLGraph(object): ...@@ -6062,113 +5948,6 @@ class DGLGraph(object):
""" """
return self.astype(F.int32) return self.astype(F.int32)
#################################################################
# DEPRECATED: from the old DGLGraph
#################################################################
def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
"""DEPRECATED: please use
``dgl.from_networkx(nx_graph, node_attrs, edge_attrs)``
which will return a new graph created from the networkx graph.
"""
raise DGLError('DGLGraph.from_networkx is deprecated. Please call the following\n\n'
'\t dgl.from_networkx(nx_graph, node_attrs, edge_attrs)\n\n'
', which creates a new DGLGraph from the networkx graph.')
def from_scipy_sparse_matrix(self, spmat, multigraph=None):
"""DEPRECATED: please use
``dgl.from_scipy(spmat)``
which will return a new graph created from the scipy matrix.
"""
raise DGLError('DGLGraph.from_scipy_sparse_matrix is deprecated. '
'Please call the following\n\n'
'\t dgl.from_scipy(spmat)\n\n'
', which creates a new DGLGraph from the scipy matrix.')
def register_apply_node_func(self, func):
"""Deprecated: please directly call :func:`apply_nodes` with ``func``
as argument.
"""
raise DGLError('DGLGraph.register_apply_node_func is deprecated.'
' Please directly call apply_nodes with func as the argument.')
def register_apply_edge_func(self, func):
"""Deprecated: please directly call :func:`apply_edges` with ``func``
as argument.
"""
raise DGLError('DGLGraph.register_apply_edge_func is deprecated.'
' Please directly call apply_edges with func as the argument.')
def register_message_func(self, func):
"""Deprecated: please directly call :func:`update_all` with ``func``
as argument.
"""
raise DGLError('DGLGraph.register_message_func is deprecated.'
' Please directly call update_all with func as the argument.')
def register_reduce_func(self, func):
"""Deprecated: please directly call :func:`update_all` with ``func``
as argument.
"""
raise DGLError('DGLGraph.register_reduce_func is deprecated.'
' Please directly call update_all with func as the argument.')
def group_apply_edges(self, group_by, func, edges=ALL, etype=None, inplace=False):
"""**DEPRECATED**: The API is removed in 0.5."""
raise DGLError('DGLGraph.group_apply_edges is removed in 0.5.')
def send(self, edges, message_func, etype=None):
"""Send messages along the given edges with the same edge type.
DEPRECATE: please use send_and_recv, update_all.
"""
raise DGLError('DGLGraph.send is deprecated. As a replacement, use DGLGraph.apply_edges\n'
' API to compute messages as edge data. Then use DGLGraph.send_and_recv\n'
' and set the message function as dgl.function.copy_e to conduct message\n'
' aggregation.')
def recv(self, v, reduce_func, apply_node_func=None, etype=None, inplace=False):
r"""Receive and reduce incoming messages and update the features of node(s) :math:`v`.
DEPRECATE: please use send_and_recv, update_all.
"""
raise DGLError('DGLGraph.recv is deprecated. As a replacement, use DGLGraph.apply_edges\n'
' API to compute messages as edge data. Then use DGLGraph.send_and_recv\n'
' and set the message function as dgl.function.copy_e to conduct message\n'
' aggregation.')
def multi_recv(self, v, reducer_dict, cross_reducer, apply_node_func=None, inplace=False):
r"""Receive messages from multiple edge types and perform aggregation.
DEPRECATE: please use multi_send_and_recv, multi_update_all.
"""
raise DGLError('DGLGraph.multi_recv is deprecated. As a replacement,\n'
' use DGLGraph.apply_edges API to compute messages as edge data.\n'
' Then use DGLGraph.multi_send_and_recv and set the message function\n'
' as dgl.function.copy_e to conduct message aggregation.')
def multi_send_and_recv(self, etype_dict, cross_reducer, apply_node_func=None, inplace=False):
r"""**DEPRECATED**: The API is removed in v0.5."""
raise DGLError('DGLGraph.multi_pull is removed in v0.5. As a replacement,\n'
' use DGLGraph.edge_subgraph to extract the subgraph first \n'
' and then call DGLGraph.multi_update_all.')
def multi_pull(self, v, etype_dict, cross_reducer, apply_node_func=None, inplace=False):
r"""**DEPRECATED**: The API is removed in v0.5."""
raise DGLError('DGLGraph.multi_pull is removed in v0.5. As a replacement,\n'
' use DGLGraph.edge_subgraph to extract the subgraph first \n'
' and then call DGLGraph.multi_update_all.')
def readonly(self, readonly_state=True):
"""Deprecated: DGLGraph will always be mutable."""
dgl_warning('DGLGraph.readonly is deprecated in v0.5.\n'
'DGLGraph now always supports mutable operations like add_nodes'
' and add_edges.')
############################################################ ############################################################
# Internal APIs # Internal APIs
############################################################ ############################################################
......
...@@ -261,13 +261,13 @@ class GraphConv(gluon.Block): ...@@ -261,13 +261,13 @@ class GraphConv(gluon.Block):
if weight is not None: if weight is not None:
feat_src = mx.nd.dot(feat_src, weight) feat_src = mx.nd.dot(feat_src, weight)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.dstdata.pop('h') rst = graph.dstdata.pop('h')
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.dstdata.pop('h') rst = graph.dstdata.pop('h')
if weight is not None: if weight is not None:
......
...@@ -114,7 +114,7 @@ class TAGConv(gluon.Block): ...@@ -114,7 +114,7 @@ class TAGConv(gluon.Block):
rst = rst * norm rst = rst * norm
graph.ndata['h'] = rst graph.ndata['h'] = rst
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.ndata['h'] rst = graph.ndata['h']
rst = rst * norm rst = rst * norm
......
...@@ -136,7 +136,7 @@ class GINConv(nn.Module): ...@@ -136,7 +136,7 @@ class GINConv(nn.Module):
""" """
_reducer = getattr(fn, self._aggregator_type) _reducer = getattr(fn, self._aggregator_type)
with graph.local_scope(): with graph.local_scope():
aggregate_fn = fn.copy_src('h', 'm') aggregate_fn = fn.copy_u('h', 'm')
if edge_weight is not None: if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges() assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight graph.edata['_edge_weight'] = edge_weight
......
...@@ -114,13 +114,13 @@ class EdgeWeightNorm(nn.Module): ...@@ -114,13 +114,13 @@ class EdgeWeightNorm(nn.Module):
if self._norm == 'both': if self._norm == 'both':
reversed_g = reverse(graph) reversed_g = reverse(graph)
reversed_g.edata['_edge_w'] = edge_weight reversed_g.edata['_edge_w'] = edge_weight
reversed_g.update_all(fn.copy_edge('_edge_w', 'm'), fn.sum('m', 'out_weight')) reversed_g.update_all(fn.copy_e('_edge_w', 'm'), fn.sum('m', 'out_weight'))
degs = reversed_g.dstdata['out_weight'] + self._eps degs = reversed_g.dstdata['out_weight'] + self._eps
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
graph.srcdata['_src_out_w'] = norm graph.srcdata['_src_out_w'] = norm
if self._norm != 'none': if self._norm != 'none':
graph.update_all(fn.copy_edge('_edge_w', 'm'), fn.sum('m', 'in_weight')) graph.update_all(fn.copy_e('_edge_w', 'm'), fn.sum('m', 'in_weight'))
degs = graph.dstdata['in_weight'] + self._eps degs = graph.dstdata['in_weight'] + self._eps
if self._norm == 'both': if self._norm == 'both':
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
...@@ -389,7 +389,7 @@ class GraphConv(nn.Module): ...@@ -389,7 +389,7 @@ class GraphConv(nn.Module):
'the issue. Setting ``allow_zero_in_degree`` ' 'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will ' 'to be `True` when constructing this module will '
'suppress the check and let the code run.') 'suppress the check and let the code run.')
aggregate_fn = fn.copy_src('h', 'm') aggregate_fn = fn.copy_u('h', 'm')
if edge_weight is not None: if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges() assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight graph.edata['_edge_weight'] = edge_weight
......
...@@ -213,7 +213,7 @@ class SAGEConv(nn.Module): ...@@ -213,7 +213,7 @@ class SAGEConv(nn.Module):
feat_src = feat_dst = self.feat_drop(feat) feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[:graph.number_of_dst_nodes()]
msg_fn = fn.copy_src('h', 'm') msg_fn = fn.copy_u('h', 'm')
if edge_weight is not None: if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges() assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight graph.edata['_edge_weight'] = edge_weight
......
...@@ -334,19 +334,6 @@ class NodeEmbedding: # NodeEmbedding ...@@ -334,19 +334,6 @@ class NodeEmbedding: # NodeEmbedding
""" """
self._trace = [] self._trace = []
@property
def emb_tensor(self):
"""Return the tensor storing the node embeddings
DEPRECATED: renamed weight
Returns
-------
torch.Tensor
The tensor storing the node embeddings
"""
return self._tensor
@property @property
def weight(self): def weight(self):
"""Return the tensor storing the node embeddings """Return the tensor storing the node embeddings
......
...@@ -253,13 +253,13 @@ class GraphConv(layers.Layer): ...@@ -253,13 +253,13 @@ class GraphConv(layers.Layer):
if weight is not None: if weight is not None:
feat_src = tf.matmul(feat_src, weight) feat_src = tf.matmul(feat_src, weight)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.dstdata['h'] rst = graph.dstdata['h']
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.dstdata['h'] rst = graph.dstdata['h']
if weight is not None: if weight is not None:
......
...@@ -166,24 +166,24 @@ class SAGEConv(layers.Layer): ...@@ -166,24 +166,24 @@ class SAGEConv(layers.Layer):
if self._aggre_type == 'mean': if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh'] h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn': elif self._aggre_type == 'gcn':
check_eq_shape(feat) check_eq_shape(feat)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees # divide in_degrees
degs = tf.cast(graph.in_degrees(), tf.float32) degs = tf.cast(graph.in_degrees(), tf.float32)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h'] h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']
) / (tf.expand_dims(degs, -1) + 1) ) / (tf.expand_dims(degs, -1) + 1)
elif self._aggre_type == 'pool': elif self._aggre_type == 'pool':
graph.srcdata['h'] = tf.nn.relu(self.fc_pool(feat_src)) graph.srcdata['h'] = tf.nn.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh'] h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'lstm': elif self._aggre_type == 'lstm':
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer) graph.update_all(fn.copy_u('h', 'm'), self._lstm_reducer)
h_neigh = graph.dstdata['neigh'] h_neigh = graph.dstdata['neigh']
else: else:
raise KeyError( raise KeyError(
......
...@@ -526,7 +526,7 @@ class SparseAdagrad(SparseGradOptimizer): ...@@ -526,7 +526,7 @@ class SparseAdagrad(SparseGradOptimizer):
), "SparseAdagrad only supports dgl.nn.NodeEmbedding" ), "SparseAdagrad only supports dgl.nn.NodeEmbedding"
emb_name = emb.name emb_name = emb.name
if th.device(emb.emb_tensor.device) == th.device("cpu"): if th.device(emb.weight.device) == th.device("cpu"):
# if our embedding is on the CPU, our state also has to be # if our embedding is on the CPU, our state also has to be
if self._rank < 0: if self._rank < 0:
state = th.empty( state = th.empty(
...@@ -550,9 +550,9 @@ class SparseAdagrad(SparseGradOptimizer): ...@@ -550,9 +550,9 @@ class SparseAdagrad(SparseGradOptimizer):
else: else:
# distributed state on on gpu # distributed state on on gpu
state = th.empty( state = th.empty(
emb.emb_tensor.shape, emb.weight.shape,
dtype=th.float32, dtype=th.float32,
device=emb.emb_tensor.device, device=emb.weight.device,
).zero_() ).zero_()
emb.set_optm_state(state) emb.set_optm_state(state)
...@@ -689,7 +689,7 @@ class SparseAdam(SparseGradOptimizer): ...@@ -689,7 +689,7 @@ class SparseAdam(SparseGradOptimizer):
), "SparseAdam only supports dgl.nn.NodeEmbedding" ), "SparseAdam only supports dgl.nn.NodeEmbedding"
emb_name = emb.name emb_name = emb.name
self._is_using_uva[emb_name] = self._use_uva self._is_using_uva[emb_name] = self._use_uva
if th.device(emb.emb_tensor.device) == th.device("cpu"): if th.device(emb.weight.device) == th.device("cpu"):
# if our embedding is on the CPU, our state also has to be # if our embedding is on the CPU, our state also has to be
if self._rank < 0: if self._rank < 0:
state_step = th.empty( state_step = th.empty(
...@@ -743,19 +743,19 @@ class SparseAdam(SparseGradOptimizer): ...@@ -743,19 +743,19 @@ class SparseAdam(SparseGradOptimizer):
# distributed state on on gpu # distributed state on on gpu
state_step = th.empty( state_step = th.empty(
[emb.emb_tensor.shape[0]], [emb.weight.shape[0]],
dtype=th.int32, dtype=th.int32,
device=emb.emb_tensor.device, device=emb.weight.device,
).zero_() ).zero_()
state_mem = th.empty( state_mem = th.empty(
emb.emb_tensor.shape, emb.weight.shape,
dtype=self._dtype, dtype=self._dtype,
device=emb.emb_tensor.device, device=emb.weight.device,
).zero_() ).zero_()
state_power = th.empty( state_power = th.empty(
emb.emb_tensor.shape, emb.weight.shape,
dtype=self._dtype, dtype=self._dtype,
device=emb.emb_tensor.device, device=emb.weight.device,
).zero_() ).zero_()
state = (state_step, state_mem, state_power) state = (state_step, state_mem, state_power)
emb.set_optm_state(state) emb.set_optm_state(state)
......
...@@ -32,10 +32,10 @@ def generate_graph_old(grad=False): ...@@ -32,10 +32,10 @@ def generate_graph_old(grad=False):
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
# 17 edges # 17 edges
for i in range(1, 9): for i in range(1, 9):
g.add_edge(0, i) g.add_edges(0, i)
g.add_edge(i, 9) g.add_edges(i, 9)
# add a back flow from 9 to 0 # add a back flow from 9 to 0
g.add_edge(9, 0) g.add_edges(9, 0)
g = g.to(F.ctx()) g = g.to(F.ctx())
ncol = F.randn((10, D)) ncol = F.randn((10, D))
ecol = F.randn((17, D)) ecol = F.randn((17, D))
...@@ -431,8 +431,8 @@ def test_dynamic_addition(): ...@@ -431,8 +431,8 @@ def test_dynamic_addition():
assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3 assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3
# Test edge addition # Test edge addition
g.add_edge(0, 1) g.add_edges(0, 1)
g.add_edge(1, 0) g.add_edges(1, 0)
g.edata.update({'h1': F.randn((2, D)), g.edata.update({'h1': F.randn((2, D)),
'h2': F.randn((2, D))}) 'h2': F.randn((2, D))})
assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2 assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2
...@@ -441,12 +441,12 @@ def test_dynamic_addition(): ...@@ -441,12 +441,12 @@ def test_dynamic_addition():
g.edata['h1'] = F.randn((4, D)) g.edata['h1'] = F.randn((4, D))
assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4 assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4
g.add_edge(1, 2) g.add_edges(1, 2)
g.edges[4].data['h1'] = F.randn((1, D)) g.edges[4].data['h1'] = F.randn((1, D))
assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5 assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5
# test add edge with part of the features # test add edge with part of the features
g.add_edge(2, 1, {'h1': F.randn((1, D))}) g.add_edges(2, 1, {'h1': F.randn((1, D))})
assert len(g.edata['h1']) == len(g.edata['h2']) assert len(g.edata['h1']) == len(g.edata['h2'])
......
...@@ -15,10 +15,10 @@ def tree1(idtype): ...@@ -15,10 +15,10 @@ def tree1(idtype):
""" """
g = dgl.graph(([], [])).astype(idtype).to(F.ctx()) g = dgl.graph(([], [])).astype(idtype).to(F.ctx())
g.add_nodes(5) g.add_nodes(5)
g.add_edge(3, 1) g.add_edges(3, 1)
g.add_edge(4, 1) g.add_edges(4, 1)
g.add_edge(1, 0) g.add_edges(1, 0)
g.add_edge(2, 0) g.add_edges(2, 0)
g.ndata['h'] = F.tensor([0, 1, 2, 3, 4]) g.ndata['h'] = F.tensor([0, 1, 2, 3, 4])
g.edata['h'] = F.randn((4, 10)) g.edata['h'] = F.randn((4, 10))
return g return g
...@@ -34,10 +34,10 @@ def tree2(idtype): ...@@ -34,10 +34,10 @@ def tree2(idtype):
""" """
g = dgl.graph(([], [])).astype(idtype).to(F.ctx()) g = dgl.graph(([], [])).astype(idtype).to(F.ctx())
g.add_nodes(5) g.add_nodes(5)
g.add_edge(2, 4) g.add_edges(2, 4)
g.add_edge(0, 4) g.add_edges(0, 4)
g.add_edge(4, 1) g.add_edges(4, 1)
g.add_edge(3, 1) g.add_edges(3, 1)
g.ndata['h'] = F.tensor([0, 1, 2, 3, 4]) g.ndata['h'] = F.tensor([0, 1, 2, 3, 4])
g.edata['h'] = F.randn((4, 10)) g.edata['h'] = F.randn((4, 10))
return g return g
...@@ -191,8 +191,8 @@ def test_batched_edge_ordering(idtype): ...@@ -191,8 +191,8 @@ def test_batched_edge_ordering(idtype):
e2 = F.randn((6, 10)) e2 = F.randn((6, 10))
g2.edata['h'] = e2 g2.edata['h'] = e2
g = dgl.batch([g1, g2]) g = dgl.batch([g1, g2])
r1 = g.edata['h'][g.edge_id(4, 5)] r1 = g.edata['h'][g.edge_ids(4, 5)]
r2 = g1.edata['h'][g1.edge_id(4, 5)] r2 = g1.edata['h'][g1.edge_ids(4, 5)]
assert F.array_equal(r1, r2) assert F.array_equal(r1, r2)
@parametrize_idtype @parametrize_idtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment