"references/vscode:/vscode.git/clone" did not exist on "095437aad6cfd25b01fa01961e5774904d738347"
Unverified Commit 61139302 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[API Deprecation] Remove candidates in DGLGraph (#4946)

parent e088acac
...@@ -117,7 +117,7 @@ class UTransformer(nn.Module): ...@@ -117,7 +117,7 @@ class UTransformer(nn.Module):
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids) g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
# Send weighted values to target nodes # Send weighted values to target nodes
g.send_and_recv(eids, g.send_and_recv(eids,
[fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')], [fn.u_mul_e('v', 'score', 'v'), fn.copy_e('score', 'score')],
[fn.sum('v', 'wv'), fn.sum('score', 'z')]) [fn.sum('v', 'wv'), fn.sum('score', 'z')])
def update_graph(self, g, eids, pre_pairs, post_pairs): def update_graph(self, g, eids, pre_pairs, post_pairs):
......
...@@ -79,8 +79,8 @@ class Transformer(nn.Module): ...@@ -79,8 +79,8 @@ class Transformer(nn.Module):
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids) g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids) g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
# Send weighted values to target nodes # Send weighted values to target nodes
g.send_and_recv(eids, fn.src_mul_edge('v', 'score', 'v'), fn.sum('v', 'wv')) g.send_and_recv(eids, fn.u_mul_e('v', 'score', 'v'), fn.sum('v', 'wv'))
g.send_and_recv(eids, fn.copy_edge('score', 'score'), fn.sum('score', 'z')) g.send_and_recv(eids, fn.copy_e('score', 'score'), fn.sum('score', 'z'))
def update_graph(self, g, eids, pre_pairs, post_pairs): def update_graph(self, g, eids, pre_pairs, post_pairs):
"Update the node states and edge states of the graph." "Update the node states and edge states of the graph."
......
...@@ -17,7 +17,7 @@ def get_attention_map(g, src_nodes, dst_nodes, h): ...@@ -17,7 +17,7 @@ def get_attention_map(g, src_nodes, dst_nodes, h):
for j, dst in enumerate(dst_nodes.tolist()): for j, dst in enumerate(dst_nodes.tolist()):
if not g.has_edge_between(src, dst): if not g.has_edge_between(src, dst):
continue continue
eid = g.edge_id(src, dst) eid = g.edge_ids(src, dst)
weight[i][j] = g.edata['score'][eid].squeeze(-1).cpu().detach() weight[i][j] = g.edata['score'][eid].squeeze(-1).cpu().detach()
weight = weight.transpose(0, 2) weight = weight.transpose(0, 2)
......
...@@ -131,7 +131,7 @@ def main(args): ...@@ -131,7 +131,7 @@ def main(args):
root_ids = [ root_ids = [
i i
for i in range(g.number_of_nodes()) for i in range(g.number_of_nodes())
if g.out_degree(i) == 0 if g.out_degrees(i) == 0
] ]
root_acc = np.sum( root_acc = np.sum(
batch.label.cpu().data.numpy()[root_ids] batch.label.cpu().data.numpy()[root_ids]
...@@ -170,7 +170,7 @@ def main(args): ...@@ -170,7 +170,7 @@ def main(args):
acc = th.sum(th.eq(batch.label, pred)).item() acc = th.sum(th.eq(batch.label, pred)).item()
accs.append([acc, len(batch.label)]) accs.append([acc, len(batch.label)])
root_ids = [ root_ids = [
i for i in range(g.number_of_nodes()) if g.out_degree(i) == 0 i for i in range(g.number_of_nodes()) if g.out_degrees(i) == 0
] ]
root_acc = np.sum( root_acc = np.sum(
batch.label.cpu().data.numpy()[root_ids] batch.label.cpu().data.numpy()[root_ids]
...@@ -222,7 +222,7 @@ def main(args): ...@@ -222,7 +222,7 @@ def main(args):
acc = th.sum(th.eq(batch.label, pred)).item() acc = th.sum(th.eq(batch.label, pred)).item()
accs.append([acc, len(batch.label)]) accs.append([acc, len(batch.label)])
root_ids = [ root_ids = [
i for i in range(g.number_of_nodes()) if g.out_degree(i) == 0 i for i in range(g.number_of_nodes()) if g.out_degrees(i) == 0
] ]
root_acc = np.sum( root_acc = np.sum(
batch.label.cpu().data.numpy()[root_ids] batch.label.cpu().data.numpy()[root_ids]
......
...@@ -45,7 +45,7 @@ class GCNLayer(layers.Layer): ...@@ -45,7 +45,7 @@ class GCNLayer(layers.Layer):
h = self.dropout(h) h = self.dropout(h)
self.g.ndata['h'] = tf.matmul(h, self.weight) self.g.ndata['h'] = tf.matmul(h, self.weight)
self.g.ndata['norm_h'] = self.g.ndata['h'] * self.g.ndata['norm'] self.g.ndata['norm_h'] = self.g.ndata['h'] * self.g.ndata['norm']
self.g.update_all(fn.copy_src('norm_h', 'm'), self.g.update_all(fn.copy_u('norm_h', 'm'),
fn.sum('m', 'h')) fn.sum('m', 'h'))
h = self.g.ndata['h'] h = self.g.ndata['h']
if self.bias is not None: if self.bias is not None:
......
...@@ -3083,10 +3083,10 @@ class DGLGraph(DGLBaseGraph): ...@@ -3083,10 +3083,10 @@ class DGLGraph(DGLBaseGraph):
>>> g.add_nodes(3) >>> g.add_nodes(3)
>>> g.ndata['x'] = th.tensor([[0.], [1.], [2.]]) >>> g.ndata['x'] = th.tensor([[0.], [1.], [2.]])
Use the built-in message function :func:`~dgl.function.copy_src` for copying Use the built-in message function :func:`~dgl.function.copy_u` for copying
node features as the message. node features as the message.
>>> m_func = dgl.function.copy_src('x', 'm') >>> m_func = dgl.function.copy_u('x', 'm')
>>> g.register_message_func(m_func) >>> g.register_message_func(m_func)
Use the built-int message reducing function :func:`~dgl.function.sum`, which Use the built-int message reducing function :func:`~dgl.function.sum`, which
...@@ -3180,10 +3180,10 @@ class DGLGraph(DGLBaseGraph): ...@@ -3180,10 +3180,10 @@ class DGLGraph(DGLBaseGraph):
>>> g.add_nodes(3) >>> g.add_nodes(3)
>>> g.ndata['x'] = th.tensor([[1.], [2.], [3.]]) >>> g.ndata['x'] = th.tensor([[1.], [2.], [3.]])
Use the built-in message function :func:`~dgl.function.copy_src` for copying Use the built-in message function :func:`~dgl.function.copy_u` for copying
node features as the message. node features as the message.
>>> m_func = dgl.function.copy_src('x', 'm') >>> m_func = dgl.function.copy_u('x', 'm')
>>> g.register_message_func(m_func) >>> g.register_message_func(m_func)
Use the built-int message reducing function :func:`~dgl.function.sum`, which Use the built-int message reducing function :func:`~dgl.function.sum`, which
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from collections.abc import Mapping from collections.abc import Mapping
from . import backend as F from . import backend as F
from .base import ALL, is_all, DGLError, dgl_warning, NID, EID from .base import ALL, is_all, DGLError, NID, EID
from .heterograph_index import disjoint_union, slice_gidx from .heterograph_index import disjoint_union, slice_gidx
from .heterograph import DGLGraph from .heterograph import DGLGraph
from . import convert from . import convert
...@@ -11,8 +11,7 @@ from . import utils ...@@ -11,8 +11,7 @@ from . import utils
__all__ = ['batch', 'unbatch', 'slice_batch'] __all__ = ['batch', 'unbatch', 'slice_batch']
def batch(graphs, ndata=ALL, edata=ALL, *, def batch(graphs, ndata=ALL, edata=ALL):
node_attrs=None, edge_attrs=None):
r"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient r"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient
graph computation. graph computation.
...@@ -151,14 +150,6 @@ def batch(graphs, ndata=ALL, edata=ALL, *, ...@@ -151,14 +150,6 @@ def batch(graphs, ndata=ALL, edata=ALL, *,
""" """
if len(graphs) == 0: if len(graphs) == 0:
raise DGLError('The input list of graphs cannot be empty.') raise DGLError('The input list of graphs cannot be empty.')
if node_attrs is not None:
dgl_warning('Arguments node_attrs has been deprecated. Please use'
' ndata instead.')
ndata = node_attrs
if edge_attrs is not None:
dgl_warning('Arguments edge_attrs has been deprecated. Please use'
' edata instead.')
edata = edge_attrs
if not (is_all(ndata) or isinstance(ndata, list) or ndata is None): if not (is_all(ndata) or isinstance(ndata, list) or ndata is None):
raise DGLError('Invalid argument ndata: must be a string list but got {}.'.format( raise DGLError('Invalid argument ndata: must be a string list but got {}.'.format(
type(ndata))) type(ndata)))
......
...@@ -9,7 +9,7 @@ from .._deprecate.runtime import ir ...@@ -9,7 +9,7 @@ from .._deprecate.runtime import ir
from .._deprecate.runtime.ir import var from .._deprecate.runtime.ir import var
__all__ = ["src_mul_edge", "copy_src", "copy_edge", "copy_u", "copy_e", __all__ = ["copy_u", "copy_e",
"BinaryMessageFunction", "CopyMessageFunction"] "BinaryMessageFunction", "CopyMessageFunction"]
...@@ -34,7 +34,7 @@ class BinaryMessageFunction(MessageFunction): ...@@ -34,7 +34,7 @@ class BinaryMessageFunction(MessageFunction):
See Also See Also
-------- --------
src_mul_edge u_mul_e
""" """
def __init__(self, binary_op, lhs, rhs, lhs_field, rhs_field, out_field): def __init__(self, binary_op, lhs, rhs, lhs_field, rhs_field, out_field):
self.binary_op = binary_op self.binary_op = binary_op
...@@ -73,7 +73,7 @@ class CopyMessageFunction(MessageFunction): ...@@ -73,7 +73,7 @@ class CopyMessageFunction(MessageFunction):
See Also See Also
-------- --------
copy_src copy_u
""" """
def __init__(self, target, in_field, out_field): def __init__(self, target, in_field, out_field):
self.target = target self.target = target
...@@ -218,86 +218,3 @@ def _register_builtin_message_func(): ...@@ -218,86 +218,3 @@ def _register_builtin_message_func():
__all__.append(func.__name__) __all__.append(func.__name__)
_register_builtin_message_func() _register_builtin_message_func()
##############################################################################
# For backward compatibility
def src_mul_edge(src, edge, out):
"""Builtin message function that computes message by performing
binary operation mul between src feature and edge feature.
Notes
-----
This function is deprecated. Please use :func:`~dgl.function.u_mul_e` instead.
Parameters
----------
src : str
The source feature field.
edge : str
The edge feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.src_mul_edge('h', 'e', 'm')
"""
return getattr(sys.modules[__name__], "u_mul_e")(src, edge, out)
def copy_src(src, out):
"""Builtin message function that computes message using source node
feature.
Notes
-----
This function is deprecated. Please use :func:`~dgl.function.copy_u` instead.
Parameters
----------
src : str
The source feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.copy_src('h', 'm')
The above example is equivalent to the following user defined function:
>>> def message_func(edges):
>>> return {'m': edges.src['h']}
"""
return copy_u(src, out)
def copy_edge(edge, out):
"""Builtin message function that computes message using edge feature.
Notes
-----
This function is deprecated. Please use :func:`~dgl.function.copy_e` instead.
Parameters
----------
edge : str
The edge feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.copy_edge('h', 'm')
The above example is equivalent to the following user defined function:
>>> def message_func(edges):
>>> return {'m': edges.data['h']}
"""
return copy_e(edge, out)
This diff is collapsed.
...@@ -261,13 +261,13 @@ class GraphConv(gluon.Block): ...@@ -261,13 +261,13 @@ class GraphConv(gluon.Block):
if weight is not None: if weight is not None:
feat_src = mx.nd.dot(feat_src, weight) feat_src = mx.nd.dot(feat_src, weight)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.dstdata.pop('h') rst = graph.dstdata.pop('h')
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.dstdata.pop('h') rst = graph.dstdata.pop('h')
if weight is not None: if weight is not None:
......
...@@ -114,7 +114,7 @@ class TAGConv(gluon.Block): ...@@ -114,7 +114,7 @@ class TAGConv(gluon.Block):
rst = rst * norm rst = rst * norm
graph.ndata['h'] = rst graph.ndata['h'] = rst
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.ndata['h'] rst = graph.ndata['h']
rst = rst * norm rst = rst * norm
......
...@@ -136,7 +136,7 @@ class GINConv(nn.Module): ...@@ -136,7 +136,7 @@ class GINConv(nn.Module):
""" """
_reducer = getattr(fn, self._aggregator_type) _reducer = getattr(fn, self._aggregator_type)
with graph.local_scope(): with graph.local_scope():
aggregate_fn = fn.copy_src('h', 'm') aggregate_fn = fn.copy_u('h', 'm')
if edge_weight is not None: if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges() assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight graph.edata['_edge_weight'] = edge_weight
......
...@@ -114,13 +114,13 @@ class EdgeWeightNorm(nn.Module): ...@@ -114,13 +114,13 @@ class EdgeWeightNorm(nn.Module):
if self._norm == 'both': if self._norm == 'both':
reversed_g = reverse(graph) reversed_g = reverse(graph)
reversed_g.edata['_edge_w'] = edge_weight reversed_g.edata['_edge_w'] = edge_weight
reversed_g.update_all(fn.copy_edge('_edge_w', 'm'), fn.sum('m', 'out_weight')) reversed_g.update_all(fn.copy_e('_edge_w', 'm'), fn.sum('m', 'out_weight'))
degs = reversed_g.dstdata['out_weight'] + self._eps degs = reversed_g.dstdata['out_weight'] + self._eps
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
graph.srcdata['_src_out_w'] = norm graph.srcdata['_src_out_w'] = norm
if self._norm != 'none': if self._norm != 'none':
graph.update_all(fn.copy_edge('_edge_w', 'm'), fn.sum('m', 'in_weight')) graph.update_all(fn.copy_e('_edge_w', 'm'), fn.sum('m', 'in_weight'))
degs = graph.dstdata['in_weight'] + self._eps degs = graph.dstdata['in_weight'] + self._eps
if self._norm == 'both': if self._norm == 'both':
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
...@@ -389,7 +389,7 @@ class GraphConv(nn.Module): ...@@ -389,7 +389,7 @@ class GraphConv(nn.Module):
'the issue. Setting ``allow_zero_in_degree`` ' 'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will ' 'to be `True` when constructing this module will '
'suppress the check and let the code run.') 'suppress the check and let the code run.')
aggregate_fn = fn.copy_src('h', 'm') aggregate_fn = fn.copy_u('h', 'm')
if edge_weight is not None: if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges() assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight graph.edata['_edge_weight'] = edge_weight
......
...@@ -213,7 +213,7 @@ class SAGEConv(nn.Module): ...@@ -213,7 +213,7 @@ class SAGEConv(nn.Module):
feat_src = feat_dst = self.feat_drop(feat) feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[:graph.number_of_dst_nodes()]
msg_fn = fn.copy_src('h', 'm') msg_fn = fn.copy_u('h', 'm')
if edge_weight is not None: if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges() assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight graph.edata['_edge_weight'] = edge_weight
......
...@@ -334,19 +334,6 @@ class NodeEmbedding: # NodeEmbedding ...@@ -334,19 +334,6 @@ class NodeEmbedding: # NodeEmbedding
""" """
self._trace = [] self._trace = []
@property
def emb_tensor(self):
"""Return the tensor storing the node embeddings
DEPRECATED: renamed weight
Returns
-------
torch.Tensor
The tensor storing the node embeddings
"""
return self._tensor
@property @property
def weight(self): def weight(self):
"""Return the tensor storing the node embeddings """Return the tensor storing the node embeddings
......
...@@ -253,13 +253,13 @@ class GraphConv(layers.Layer): ...@@ -253,13 +253,13 @@ class GraphConv(layers.Layer):
if weight is not None: if weight is not None:
feat_src = tf.matmul(feat_src, weight) feat_src = tf.matmul(feat_src, weight)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.dstdata['h'] rst = graph.dstdata['h']
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.dstdata['h'] rst = graph.dstdata['h']
if weight is not None: if weight is not None:
......
...@@ -166,24 +166,24 @@ class SAGEConv(layers.Layer): ...@@ -166,24 +166,24 @@ class SAGEConv(layers.Layer):
if self._aggre_type == 'mean': if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh'] h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn': elif self._aggre_type == 'gcn':
check_eq_shape(feat) check_eq_shape(feat)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees # divide in_degrees
degs = tf.cast(graph.in_degrees(), tf.float32) degs = tf.cast(graph.in_degrees(), tf.float32)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h'] h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']
) / (tf.expand_dims(degs, -1) + 1) ) / (tf.expand_dims(degs, -1) + 1)
elif self._aggre_type == 'pool': elif self._aggre_type == 'pool':
graph.srcdata['h'] = tf.nn.relu(self.fc_pool(feat_src)) graph.srcdata['h'] = tf.nn.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh'] h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'lstm': elif self._aggre_type == 'lstm':
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer) graph.update_all(fn.copy_u('h', 'm'), self._lstm_reducer)
h_neigh = graph.dstdata['neigh'] h_neigh = graph.dstdata['neigh']
else: else:
raise KeyError( raise KeyError(
......
...@@ -526,7 +526,7 @@ class SparseAdagrad(SparseGradOptimizer): ...@@ -526,7 +526,7 @@ class SparseAdagrad(SparseGradOptimizer):
), "SparseAdagrad only supports dgl.nn.NodeEmbedding" ), "SparseAdagrad only supports dgl.nn.NodeEmbedding"
emb_name = emb.name emb_name = emb.name
if th.device(emb.emb_tensor.device) == th.device("cpu"): if th.device(emb.weight.device) == th.device("cpu"):
# if our embedding is on the CPU, our state also has to be # if our embedding is on the CPU, our state also has to be
if self._rank < 0: if self._rank < 0:
state = th.empty( state = th.empty(
...@@ -550,9 +550,9 @@ class SparseAdagrad(SparseGradOptimizer): ...@@ -550,9 +550,9 @@ class SparseAdagrad(SparseGradOptimizer):
else: else:
# distributed state on on gpu # distributed state on on gpu
state = th.empty( state = th.empty(
emb.emb_tensor.shape, emb.weight.shape,
dtype=th.float32, dtype=th.float32,
device=emb.emb_tensor.device, device=emb.weight.device,
).zero_() ).zero_()
emb.set_optm_state(state) emb.set_optm_state(state)
...@@ -689,7 +689,7 @@ class SparseAdam(SparseGradOptimizer): ...@@ -689,7 +689,7 @@ class SparseAdam(SparseGradOptimizer):
), "SparseAdam only supports dgl.nn.NodeEmbedding" ), "SparseAdam only supports dgl.nn.NodeEmbedding"
emb_name = emb.name emb_name = emb.name
self._is_using_uva[emb_name] = self._use_uva self._is_using_uva[emb_name] = self._use_uva
if th.device(emb.emb_tensor.device) == th.device("cpu"): if th.device(emb.weight.device) == th.device("cpu"):
# if our embedding is on the CPU, our state also has to be # if our embedding is on the CPU, our state also has to be
if self._rank < 0: if self._rank < 0:
state_step = th.empty( state_step = th.empty(
...@@ -743,19 +743,19 @@ class SparseAdam(SparseGradOptimizer): ...@@ -743,19 +743,19 @@ class SparseAdam(SparseGradOptimizer):
# distributed state on on gpu # distributed state on on gpu
state_step = th.empty( state_step = th.empty(
[emb.emb_tensor.shape[0]], [emb.weight.shape[0]],
dtype=th.int32, dtype=th.int32,
device=emb.emb_tensor.device, device=emb.weight.device,
).zero_() ).zero_()
state_mem = th.empty( state_mem = th.empty(
emb.emb_tensor.shape, emb.weight.shape,
dtype=self._dtype, dtype=self._dtype,
device=emb.emb_tensor.device, device=emb.weight.device,
).zero_() ).zero_()
state_power = th.empty( state_power = th.empty(
emb.emb_tensor.shape, emb.weight.shape,
dtype=self._dtype, dtype=self._dtype,
device=emb.emb_tensor.device, device=emb.weight.device,
).zero_() ).zero_()
state = (state_step, state_mem, state_power) state = (state_step, state_mem, state_power)
emb.set_optm_state(state) emb.set_optm_state(state)
......
...@@ -32,10 +32,10 @@ def generate_graph_old(grad=False): ...@@ -32,10 +32,10 @@ def generate_graph_old(grad=False):
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
# 17 edges # 17 edges
for i in range(1, 9): for i in range(1, 9):
g.add_edge(0, i) g.add_edges(0, i)
g.add_edge(i, 9) g.add_edges(i, 9)
# add a back flow from 9 to 0 # add a back flow from 9 to 0
g.add_edge(9, 0) g.add_edges(9, 0)
g = g.to(F.ctx()) g = g.to(F.ctx())
ncol = F.randn((10, D)) ncol = F.randn((10, D))
ecol = F.randn((17, D)) ecol = F.randn((17, D))
...@@ -431,8 +431,8 @@ def test_dynamic_addition(): ...@@ -431,8 +431,8 @@ def test_dynamic_addition():
assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3 assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3
# Test edge addition # Test edge addition
g.add_edge(0, 1) g.add_edges(0, 1)
g.add_edge(1, 0) g.add_edges(1, 0)
g.edata.update({'h1': F.randn((2, D)), g.edata.update({'h1': F.randn((2, D)),
'h2': F.randn((2, D))}) 'h2': F.randn((2, D))})
assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2 assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2
...@@ -441,12 +441,12 @@ def test_dynamic_addition(): ...@@ -441,12 +441,12 @@ def test_dynamic_addition():
g.edata['h1'] = F.randn((4, D)) g.edata['h1'] = F.randn((4, D))
assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4 assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4
g.add_edge(1, 2) g.add_edges(1, 2)
g.edges[4].data['h1'] = F.randn((1, D)) g.edges[4].data['h1'] = F.randn((1, D))
assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5 assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5
# test add edge with part of the features # test add edge with part of the features
g.add_edge(2, 1, {'h1': F.randn((1, D))}) g.add_edges(2, 1, {'h1': F.randn((1, D))})
assert len(g.edata['h1']) == len(g.edata['h2']) assert len(g.edata['h1']) == len(g.edata['h2'])
......
...@@ -15,10 +15,10 @@ def tree1(idtype): ...@@ -15,10 +15,10 @@ def tree1(idtype):
""" """
g = dgl.graph(([], [])).astype(idtype).to(F.ctx()) g = dgl.graph(([], [])).astype(idtype).to(F.ctx())
g.add_nodes(5) g.add_nodes(5)
g.add_edge(3, 1) g.add_edges(3, 1)
g.add_edge(4, 1) g.add_edges(4, 1)
g.add_edge(1, 0) g.add_edges(1, 0)
g.add_edge(2, 0) g.add_edges(2, 0)
g.ndata['h'] = F.tensor([0, 1, 2, 3, 4]) g.ndata['h'] = F.tensor([0, 1, 2, 3, 4])
g.edata['h'] = F.randn((4, 10)) g.edata['h'] = F.randn((4, 10))
return g return g
...@@ -34,10 +34,10 @@ def tree2(idtype): ...@@ -34,10 +34,10 @@ def tree2(idtype):
""" """
g = dgl.graph(([], [])).astype(idtype).to(F.ctx()) g = dgl.graph(([], [])).astype(idtype).to(F.ctx())
g.add_nodes(5) g.add_nodes(5)
g.add_edge(2, 4) g.add_edges(2, 4)
g.add_edge(0, 4) g.add_edges(0, 4)
g.add_edge(4, 1) g.add_edges(4, 1)
g.add_edge(3, 1) g.add_edges(3, 1)
g.ndata['h'] = F.tensor([0, 1, 2, 3, 4]) g.ndata['h'] = F.tensor([0, 1, 2, 3, 4])
g.edata['h'] = F.randn((4, 10)) g.edata['h'] = F.randn((4, 10))
return g return g
...@@ -191,8 +191,8 @@ def test_batched_edge_ordering(idtype): ...@@ -191,8 +191,8 @@ def test_batched_edge_ordering(idtype):
e2 = F.randn((6, 10)) e2 = F.randn((6, 10))
g2.edata['h'] = e2 g2.edata['h'] = e2
g = dgl.batch([g1, g2]) g = dgl.batch([g1, g2])
r1 = g.edata['h'][g.edge_id(4, 5)] r1 = g.edata['h'][g.edge_ids(4, 5)]
r2 = g1.edata['h'][g1.edge_id(4, 5)] r2 = g1.edata['h'][g1.edge_ids(4, 5)]
assert F.array_equal(r1, r2) assert F.array_equal(r1, r2)
@parametrize_idtype @parametrize_idtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment