Unverified Commit 61139302 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[API Deprecation] Remove candidates in DGLGraph (#4946)

parent e088acac
...@@ -31,7 +31,7 @@ def track_time(feat_size, num_relations, multi_reduce_type): ...@@ -31,7 +31,7 @@ def track_time(feat_size, num_relations, multi_reduce_type):
update_dict = {} update_dict = {}
for i in range(num_relations): for i in range(num_relations):
update_dict['e_{}'.format(i)] = ( update_dict['e_{}'.format(i)] = (
fn.copy_src('h', 'm'), fn.sum('m', 'h')) fn.copy_u('h', 'm'), fn.sum('m', 'h'))
graph.multi_update_all( graph.multi_update_all(
update_dict, update_dict,
multi_reduce_type) multi_reduce_type)
......
...@@ -117,10 +117,6 @@ Here is a cheatsheet of all the DGL built-in functions. ...@@ -117,10 +117,6 @@ Here is a cheatsheet of all the DGL built-in functions.
| Unary message function | ``copy_u`` | | | Unary message function | ``copy_u`` | |
| +-----------------------------------------------------------------+-----------------------+ | +-----------------------------------------------------------------+-----------------------+
| | ``copy_e`` | | | | ``copy_e`` | |
| +-----------------------------------------------------------------+-----------------------+
| | ``copy_src`` | alias of ``copy_u`` |
| +-----------------------------------------------------------------+-----------------------+
| | ``copy_edge`` | alias of ``copy_e`` |
+-------------------------+-----------------------------------------------------------------+-----------------------+ +-------------------------+-----------------------------------------------------------------+-----------------------+
| Binary message function | ``u_add_v``, ``u_sub_v``, ``u_mul_v``, ``u_div_v``, ``u_dot_v`` | | | Binary message function | ``u_add_v``, ``u_sub_v``, ``u_mul_v``, ``u_div_v``, ``u_dot_v`` | |
| +-----------------------------------------------------------------+-----------------------+ | +-----------------------------------------------------------------+-----------------------+
...@@ -133,8 +129,6 @@ Here is a cheatsheet of all the DGL built-in functions. ...@@ -133,8 +129,6 @@ Here is a cheatsheet of all the DGL built-in functions.
| | ``e_add_u``, ``e_sub_u``, ``e_mul_u``, ``e_div_u``, ``e_dot_u`` | | | | ``e_add_u``, ``e_sub_u``, ``e_mul_u``, ``e_div_u``, ``e_dot_u`` | |
| +-----------------------------------------------------------------+-----------------------+ | +-----------------------------------------------------------------+-----------------------+
| | ``e_add_v``, ``e_sub_v``, ``e_mul_v``, ``e_div_v``, ``e_dot_v`` | | | | ``e_add_v``, ``e_sub_v``, ``e_mul_v``, ``e_div_v``, ``e_dot_v`` | |
| +-----------------------------------------------------------------+-----------------------+
| | ``src_mul_edge`` | alias of ``u_mul_e`` |
+-------------------------+-----------------------------------------------------------------+-----------------------+ +-------------------------+-----------------------------------------------------------------+-----------------------+
| Reduce function | ``max`` | | | Reduce function | ``max`` | |
| +-----------------------------------------------------------------+-----------------------+ | +-----------------------------------------------------------------+-----------------------+
...@@ -151,9 +145,6 @@ Message functions ...@@ -151,9 +145,6 @@ Message functions
.. autosummary:: .. autosummary::
:toctree: ../../generated/ :toctree: ../../generated/
copy_src
copy_edge
src_mul_edge
copy_u copy_u
copy_e copy_e
u_add_v u_add_v
......
...@@ -28,7 +28,7 @@ class GCNLayer(gluon.Block): ...@@ -28,7 +28,7 @@ class GCNLayer(gluon.Block):
def forward(self, h): def forward(self, h):
self.g.ndata['h'] = h * self.g.ndata['out_norm'] self.g.ndata['h'] = h * self.g.ndata['out_norm']
self.g.update_all(fn.copy_src(src='h', out='m'), self.g.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='accum')) fn.sum(msg='m', out='accum'))
accum = self.g.ndata.pop('accum') accum = self.g.ndata.pop('accum')
accum = self.dense(accum * self.g.ndata['in_norm']) accum = self.dense(accum * self.g.ndata['in_norm'])
......
...@@ -118,7 +118,7 @@ def main(args): ...@@ -118,7 +118,7 @@ def main(args):
pseudo = [] pseudo = []
for i in range(g.number_of_edges()): for i in range(g.number_of_edges()):
pseudo.append( pseudo.append(
[1 / np.sqrt(g.in_degree(us[i])), 1 / np.sqrt(g.in_degree(vs[i]))] [1 / np.sqrt(g.in_degrees(us[i])), 1 / np.sqrt(g.in_degrees(vs[i]))]
) )
pseudo = nd.array(pseudo, ctx=ctx) pseudo = nd.array(pseudo, ctx=ctx)
......
...@@ -171,7 +171,7 @@ def main(args): ...@@ -171,7 +171,7 @@ def main(args):
root_ids = [ root_ids = [
i i
for i in range(batch.graph.number_of_nodes()) for i in range(batch.graph.number_of_nodes())
if batch.graph.out_degree(i) == 0 if batch.graph.out_degrees(i) == 0
] ]
root_acc = np.sum( root_acc = np.sum(
batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids] batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]
...@@ -208,7 +208,7 @@ def main(args): ...@@ -208,7 +208,7 @@ def main(args):
root_ids = [ root_ids = [
i i
for i in range(batch.graph.number_of_nodes()) for i in range(batch.graph.number_of_nodes())
if batch.graph.out_degree(i) == 0 if batch.graph.out_degrees(i) == 0
] ]
root_acc = np.sum( root_acc = np.sum(
batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids] batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]
...@@ -261,7 +261,7 @@ def main(args): ...@@ -261,7 +261,7 @@ def main(args):
root_ids = [ root_ids = [
i i
for i in range(batch.graph.number_of_nodes()) for i in range(batch.graph.number_of_nodes())
if batch.graph.out_degree(i) == 0 if batch.graph.out_degrees(i) == 0
] ]
root_acc = np.sum( root_acc = np.sum(
batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids] batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]
......
...@@ -204,7 +204,7 @@ class ChooseDestAndUpdate(nn.Module): ...@@ -204,7 +204,7 @@ class ChooseDestAndUpdate(nn.Module):
if not self.training: if not self.training:
dest = Categorical(dests_probs).sample().item() dest = Categorical(dests_probs).sample().item()
if not g.has_edge_between(src, dest): if not g.has_edges_between(src, dest):
# For undirected graphs, we add edges for both directions # For undirected graphs, we add edges for both directions
# so that we can perform graph propagation. # so that we can perform graph propagation.
src_list = [src, dest] src_list = [src, dest]
......
...@@ -40,7 +40,7 @@ class GraphSageLayer(nn.Module): ...@@ -40,7 +40,7 @@ class GraphSageLayer(nn.Module):
if self.use_bn and not hasattr(self, 'bn'): if self.use_bn and not hasattr(self, 'bn'):
device = h.device device = h.device
self.bn = nn.BatchNorm1d(h.size()[1]).to(device) self.bn = nn.BatchNorm1d(h.size()[1]).to(device)
g.update_all(fn.copy_src(src='h', out='m'), self.aggregator, g.update_all(fn.copy_u(u='h', out='m'), self.aggregator,
self.bundler) self.bundler)
if self.use_bn: if self.use_bn:
h = self.bn(h) h = self.bn(h)
......
...@@ -89,7 +89,7 @@ class GCMCGraphConv(nn.Module): ...@@ -89,7 +89,7 @@ class GCMCGraphConv(nn.Module):
feat = feat * self.dropout(cj) feat = feat * self.dropout(cj)
graph.srcdata['h'] = feat graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.dstdata['h'] rst = graph.dstdata['h']
rst = rst * ci rst = rst * ci
......
...@@ -117,10 +117,10 @@ def _ns_dataloader( ...@@ -117,10 +117,10 @@ def _ns_dataloader(
edge_types = [] edge_types = []
for s, e, t in edges: for s, e, t in edges:
g.add_edge(nid2idx[s], nid2idx[t]) g.add_edges(nid2idx[s], nid2idx[t])
edge_types.append(e) edge_types.append(e)
if e in reverse_edge: if e in reverse_edge:
g.add_edge(nid2idx[t], nid2idx[s]) g.add_edges(nid2idx[t], nid2idx[s])
edge_types.append(reverse_edge[e]) edge_types.append(reverse_edge[e])
g.edata["type"] = torch.tensor(edge_types, dtype=torch.long) g.edata["type"] = torch.tensor(edge_types, dtype=torch.long)
annotation = torch.zeros(len(node_ids), dtype=torch.long) annotation = torch.zeros(len(node_ids), dtype=torch.long)
...@@ -234,10 +234,10 @@ def _gc_dataloader( ...@@ -234,10 +234,10 @@ def _gc_dataloader(
edge_types = [] edge_types = []
for s, e, t in edges: for s, e, t in edges:
g.add_edge(nid2idx[s], nid2idx[t]) g.add_edges(nid2idx[s], nid2idx[t])
edge_types.append(e) edge_types.append(e)
if e in reverse_edge: if e in reverse_edge:
g.add_edge(nid2idx[t], nid2idx[s]) g.add_edges(nid2idx[t], nid2idx[s])
edge_types.append(reverse_edge[e]) edge_types.append(reverse_edge[e])
g.edata["type"] = torch.tensor(edge_types, dtype=torch.long) g.edata["type"] = torch.tensor(edge_types, dtype=torch.long)
annotation = torch.zeros([len(node_ids), 2], dtype=torch.long) annotation = torch.zeros([len(node_ids), 2], dtype=torch.long)
...@@ -361,10 +361,10 @@ def _path_finding_dataloader( ...@@ -361,10 +361,10 @@ def _path_finding_dataloader(
edge_types = [] edge_types = []
for s, e, t in edges: for s, e, t in edges:
g.add_edge(nid2idx[s], nid2idx[t]) g.add_edges(nid2idx[s], nid2idx[t])
edge_types.append(e) edge_types.append(e)
if e in reverse_edge: if e in reverse_edge:
g.add_edge(nid2idx[t], nid2idx[s]) g.add_edges(nid2idx[t], nid2idx[s])
edge_types.append(reverse_edge[e]) edge_types.append(reverse_edge[e])
g.edata["type"] = torch.tensor(edge_types, dtype=torch.long) g.edata["type"] = torch.tensor(edge_types, dtype=torch.long)
annotation = torch.zeros([len(node_ids), 2], dtype=torch.long) annotation = torch.zeros([len(node_ids), 2], dtype=torch.long)
......
...@@ -179,7 +179,7 @@ def get_edges_to_match(G, node_id, matched_nodes): ...@@ -179,7 +179,7 @@ def get_edges_to_match(G, node_id, matched_nodes):
index = np.array([], dtype=int) index = np.array([], dtype=int)
direction = np.array([], dtype=int) direction = np.array([], dtype=int)
if G.has_edge_between(node_id, node_id): if G.has_edge_between(node_id, node_id):
self_edge_ids = G.edge_id(node_id, node_id, return_array=True).numpy() self_edge_ids = G.edge_ids(node_id, node_id, return_array=True).numpy()
incident_edges = np.concatenate((incident_edges, self_edge_ids)) incident_edges = np.concatenate((incident_edges, self_edge_ids))
index = np.concatenate((index, [-1] * len(self_edge_ids))) index = np.concatenate((index, [-1] * len(self_edge_ids)))
direction = np.concatenate((direction, [0] * len(self_edge_ids))) direction = np.concatenate((direction, [0] * len(self_edge_ids)))
...@@ -647,7 +647,7 @@ def contextual_cost_matrix_construction( ...@@ -647,7 +647,7 @@ def contextual_cost_matrix_construction(
for i in range(num_G1_nodes): for i in range(num_G1_nodes):
if G1.has_edge_between(i, i): if G1.has_edge_between(i, i):
self_edge_list_G1[i] = sorted( self_edge_list_G1[i] = sorted(
G1.edge_id(i, i, return_array=True).numpy() G1.edge_ids(i, i, return_array=True).numpy()
) )
incoming_edges_G1[i] = G1.in_edges([i], "eid").numpy() incoming_edges_G1[i] = G1.in_edges([i], "eid").numpy()
incoming_edges_G1[i] = np.setdiff1d( incoming_edges_G1[i] = np.setdiff1d(
...@@ -660,7 +660,7 @@ def contextual_cost_matrix_construction( ...@@ -660,7 +660,7 @@ def contextual_cost_matrix_construction(
for i in range(num_G2_nodes): for i in range(num_G2_nodes):
if G2.has_edge_between(i, i): if G2.has_edge_between(i, i):
self_edge_list_G2[i] = sorted( self_edge_list_G2[i] = sorted(
G2.edge_id(i, i, return_array=True).numpy() G2.edge_ids(i, i, return_array=True).numpy()
) )
incoming_edges_G2[i] = G2.in_edges([i], "eid").numpy() incoming_edges_G2[i] = G2.in_edges([i], "eid").numpy()
incoming_edges_G2[i] = np.setdiff1d( incoming_edges_G2[i] = np.setdiff1d(
...@@ -790,7 +790,7 @@ def hausdorff_matching( ...@@ -790,7 +790,7 @@ def hausdorff_matching(
for i in range(num_G1_nodes): for i in range(num_G1_nodes):
if G1.has_edge_between(i, i): if G1.has_edge_between(i, i):
self_edge_list_G1[i] = sorted( self_edge_list_G1[i] = sorted(
G1.edge_id(i, i, return_array=True).numpy() G1.edge_ids(i, i, return_array=True).numpy()
) )
incoming_edges_G1[i] = G1.in_edges([i], "eid").numpy() incoming_edges_G1[i] = G1.in_edges([i], "eid").numpy()
incoming_edges_G1[i] = np.setdiff1d( incoming_edges_G1[i] = np.setdiff1d(
...@@ -803,7 +803,7 @@ def hausdorff_matching( ...@@ -803,7 +803,7 @@ def hausdorff_matching(
for i in range(num_G2_nodes): for i in range(num_G2_nodes):
if G2.has_edge_between(i, i): if G2.has_edge_between(i, i):
self_edge_list_G2[i] = sorted( self_edge_list_G2[i] = sorted(
G2.edge_id(i, i, return_array=True).numpy() G2.edge_ids(i, i, return_array=True).numpy()
) )
incoming_edges_G2[i] = G2.in_edges([i], "eid").numpy() incoming_edges_G2[i] = G2.in_edges([i], "eid").numpy()
incoming_edges_G2[i] = np.setdiff1d( incoming_edges_G2[i] = np.setdiff1d(
......
...@@ -251,7 +251,7 @@ class GraphCrossNet(torch.nn.Module): ...@@ -251,7 +251,7 @@ class GraphCrossNet(torch.nn.Module):
edge_feat = self.e2l_lin(edge_feat) edge_feat = self.e2l_lin(edge_feat)
with graph.local_scope(): with graph.local_scope():
graph.edata["he"] = edge_feat graph.edata["he"] = edge_feat
graph.update_all(fn.copy_edge("he", "m"), fn.sum("m", "hn")) graph.update_all(fn.copy_e("he", "m"), fn.sum("m", "hn"))
edge2node_feat = graph.ndata.pop("hn") edge2node_feat = graph.ndata.pop("hn")
node_feat = torch.cat((node_feat, edge2node_feat), dim=1) node_feat = torch.cat((node_feat, edge2node_feat), dim=1)
......
...@@ -44,7 +44,7 @@ class WeightedGraphConv(GraphConv): ...@@ -44,7 +44,7 @@ class WeightedGraphConv(GraphConv):
n_feat = n_feat * src_norm n_feat = n_feat * src_norm
graph.ndata["h"] = n_feat graph.ndata["h"] = n_feat
graph.edata["e"] = e_feat graph.edata["e"] = e_feat
graph.update_all(fn.src_mul_edge("h", "e", "m"), graph.update_all(fn.u_mul_e("h", "e", "m"),
fn.sum("m", "h")) fn.sum("m", "h"))
n_feat = graph.ndata.pop("h") n_feat = graph.ndata.pop("h")
n_feat = n_feat * dst_norm n_feat = n_feat * dst_norm
...@@ -100,7 +100,7 @@ class NodeInfoScoreLayer(nn.Module): ...@@ -100,7 +100,7 @@ class NodeInfoScoreLayer(nn.Module):
graph.ndata["h"] = src_feat graph.ndata["h"] = src_feat
graph.edata["e"] = e_feat graph.edata["e"] = e_feat
graph = dgl.remove_self_loop(graph) graph = dgl.remove_self_loop(graph)
graph.update_all(fn.src_mul_edge("h", "e", "m"), fn.sum("m", "h")) graph.update_all(fn.u_mul_e("h", "e", "m"), fn.sum("m", "h"))
dst_feat = graph.ndata.pop("h") * dst_norm dst_feat = graph.ndata.pop("h") * dst_norm
feat = feat - dst_feat feat = feat - dst_feat
...@@ -111,7 +111,7 @@ class NodeInfoScoreLayer(nn.Module): ...@@ -111,7 +111,7 @@ class NodeInfoScoreLayer(nn.Module):
graph.ndata["h"] = feat graph.ndata["h"] = feat
graph.edata["e"] = e_feat graph.edata["e"] = e_feat
graph = dgl.remove_self_loop(graph) graph = dgl.remove_self_loop(graph)
graph.update_all(fn.src_mul_edge("h", "e", "m"), fn.sum("m", "h")) graph.update_all(fn.u_mul_e("h", "e", "m"), fn.sum("m", "h"))
feat = feat - dst_norm * graph.ndata.pop("h") feat = feat - dst_norm * graph.ndata.pop("h")
......
...@@ -114,11 +114,11 @@ class LoopyBPUpdate(nn.Module): ...@@ -114,11 +114,11 @@ class LoopyBPUpdate(nn.Module):
if PAPER: if PAPER:
mpn_gather_msg = [ mpn_gather_msg = [
DGLF.copy_edge(edge='msg', out='msg'), DGLF.copy_e(edge='msg', out='msg'),
DGLF.copy_edge(edge='alpha', out='alpha') DGLF.copy_e(edge='alpha', out='alpha')
] ]
else: else:
mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg') mpn_gather_msg = DGLF.copy_e(edge='msg', out='msg')
if PAPER: if PAPER:
......
...@@ -21,7 +21,7 @@ def dfs_order(forest, roots): ...@@ -21,7 +21,7 @@ def dfs_order(forest, roots):
# using find_edges(). # using find_edges().
yield e ^ l, l yield e ^ l, l
dec_tree_node_msg = DGLF.copy_edge(edge='m', out='m') dec_tree_node_msg = DGLF.copy_e(edge='m', out='m')
dec_tree_node_reduce = DGLF.sum(msg='m', out='h') dec_tree_node_reduce = DGLF.sum(msg='m', out='h')
...@@ -353,7 +353,7 @@ class DGLJTNNDecoder(nn.Module): ...@@ -353,7 +353,7 @@ class DGLJTNNDecoder(nn.Module):
break # At root, terminate break # At root, terminate
pu, _ = stack[-2] pu, _ = stack[-2]
u_pu = mol_tree_graph.edge_id(u, pu) u_pu = mol_tree_graph.edge_ids(u, pu)
mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's')) mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm')) mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm'))
......
...@@ -31,11 +31,11 @@ class GNNModule(nn.Module): ...@@ -31,11 +31,11 @@ class GNNModule(nn.Module):
def aggregate(self, g, z): def aggregate(self, g, z):
z_list = [] z_list = []
g.ndata['z'] = z g.ndata['z'] = z
g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z')) g.update_all(fn.copy_u(u='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.ndata['z']) z_list.append(g.ndata['z'])
for i in range(self.radius - 1): for i in range(self.radius - 1):
for j in range(2 ** i): for j in range(2 ** i):
g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z')) g.update_all(fn.copy_u(u='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.ndata['z']) z_list.append(g.ndata['z'])
return z_list return z_list
...@@ -45,7 +45,7 @@ class GNNModule(nn.Module): ...@@ -45,7 +45,7 @@ class GNNModule(nn.Module):
sum_x = sum(theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))) sum_x = sum(theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x)))
g.edata['y'] = y g.edata['y'] = y
g.update_all(fn.copy_edge(edge='y', out='m'), fn.sum('m', 'pmpd_y')) g.update_all(fn.copy_e(e='y', out='m'), fn.sum('m', 'pmpd_y'))
pmpd_y = g.ndata.pop('pmpd_y') pmpd_y = g.ndata.pop('pmpd_y')
x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum_x + self.theta_y(pmpd_y) x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum_x + self.theta_y(pmpd_y)
......
...@@ -270,8 +270,8 @@ if __name__ == "__main__": ...@@ -270,8 +270,8 @@ if __name__ == "__main__":
mask[test_idx] = True mask[test_idx] = True
graph.ndata["test_mask"] = mask graph.ndata["test_mask"] = mask
graph.in_degree(0) graph.in_degrees(0)
graph.out_degree(0) graph.out_degrees(0)
graph.find_edges(0) graph.find_edges(0)
cluster_iter_data = ClusterIter( cluster_iter_data = ClusterIter(
......
...@@ -125,7 +125,6 @@ def net2graph(net_sm): ...@@ -125,7 +125,6 @@ def net2graph(net_sm):
def make_undirected(G): def make_undirected(G):
# G.readonly(False)
G.add_edges(G.edges()[1], G.edges()[0]) G.add_edges(G.edges()[1], G.edges()[0])
return G return G
......
...@@ -124,7 +124,6 @@ def net2graph(net_sm): ...@@ -124,7 +124,6 @@ def net2graph(net_sm):
def make_undirected(G): def make_undirected(G):
# G.readonly(False)
G.add_edges(G.edges()[1], G.edges()[0]) G.add_edges(G.edges()[1], G.edges()[0])
return G return G
......
...@@ -63,7 +63,7 @@ class MWEConv(nn.Module): ...@@ -63,7 +63,7 @@ class MWEConv(nn.Module):
else: else:
g.ndata["feat_" + str(c)] = node_state_c g.ndata["feat_" + str(c)] = node_state_c
g.update_all( g.update_all(
fn.src_mul_edge("feat_" + str(c), "feat_" + str(c), "m"), fn.sum("m", "feat_" + str(c) + "_new") fn.u_mul_e("feat_" + str(c), "feat_" + str(c), "m"), fn.sum("m", "feat_" + str(c) + "_new")
) )
node_state_c = g.ndata.pop("feat_" + str(c) + "_new") node_state_c = g.ndata.pop("feat_" + str(c) + "_new")
if self._out_feats >= self._in_feats: if self._out_feats >= self._in_feats:
......
...@@ -15,7 +15,7 @@ def compute_pagerank(g): ...@@ -15,7 +15,7 @@ def compute_pagerank(g):
degrees = g.out_degrees(g.nodes()).type(torch.float32) degrees = g.out_degrees(g.nodes()).type(torch.float32)
for k in range(K): for k in range(K):
g.ndata['pv'] = g.ndata['pv'] / degrees g.ndata['pv'] = g.ndata['pv'] / degrees
g.update_all(message_func=fn.copy_src(src='pv', out='m'), g.update_all(message_func=fn.copy_u(u='pv', out='m'),
reduce_func=fn.sum(msg='m', out='pv')) reduce_func=fn.sum(msg='m', out='pv'))
g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['pv'] g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['pv']
return g.ndata['pv'] return g.ndata['pv']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment