Unverified Commit 303b150f authored by rudongyu's avatar rudongyu Committed by GitHub
Browse files

[Transform] Double Radius Node Labeling (#4513)

* add double radius node labeling

* add doc

* add test

* fix lint issue

* update

* update

* update

* update

* fix lint
parent 98325b10
...@@ -116,6 +116,7 @@ Operators for generating positional encodings of each node. ...@@ -116,6 +116,7 @@ Operators for generating positional encodings of each node.
random_walk_pe random_walk_pe
laplacian_pe laplacian_pe
double_radius_node_labeling
.. _api-partition: .. _api-partition:
......
...@@ -84,7 +84,8 @@ __all__ = [ ...@@ -84,7 +84,8 @@ __all__ = [
'laplacian_pe', 'laplacian_pe',
'to_half', 'to_half',
'to_float', 'to_float',
'to_double' 'to_double',
'double_radius_node_labeling',
] ]
...@@ -3790,4 +3791,69 @@ def to_double(g): ...@@ -3790,4 +3791,69 @@ def to_double(g):
ret._node_frames = [frame.double() for frame in ret._node_frames] ret._node_frames = [frame.double() for frame in ret._node_frames]
return ret return ret
def double_radius_node_labeling(g, src, dst):
r"""Double Radius Node Labeling, as introduced in `Link Prediction
Based on Graph Neural Networks <https://arxiv.org/abs/1802.09691>`__.
This function computes the double radius node labeling for each node to mark
nodes' different roles in an enclosing subgraph, given a target link.
The node labels of source :math:`s` and destination :math:`t` are set to 1 and
those of unreachable nodes from source or destination are set to 0. The labels
of other nodes :math:`l` are defined according to the following hash function:
:math:`l = 1 + min(d_s, d_t) + (d//2)[(d//2) + (d%2) - 1]`
where :math:`d_s` and :math:`d_t` denote the shortest distance to the source and
the target, respectively. :math:`d = d_s + d_t`.
Parameters
----------
g : DGLGraph
The input graph.
src : int
The source node ID of the target link.
dst : int
The destination node ID of the target link.
Returns
-------
Tensor
Labels of all nodes. The tensor is of shape :math:`(N,)`, where
:math:`N` is the number of nodes in the input graph.
Example
-------
>>> import dgl
>>> g = dgl.graph(([0,0,0,0,1,1,2,4], [1,2,3,6,3,4,4,5]))
>>> dgl.double_radius_node_labeling(g, 0, 1)
tensor([1, 1, 3, 2, 3, 7, 0])
"""
adj = g.adj(scipy_fmt='csr')
src, dst = (dst, src) if src > dst else (src, dst)
idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
adj_wo_src = adj[idx, :][:, idx]
idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
adj_wo_dst = adj[idx, :][:, idx]
# distance to the source node
ds = sparse.csgraph.shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src)
ds = np.insert(ds, dst, 0, axis=0)
# distance to the destination node
dt = sparse.csgraph.shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst-1)
dt = np.insert(dt, src, 0, axis=0)
d = ds + dt
# suppress invalid value (nan) warnings
with np.errstate(invalid='ignore'):
z = 1 + np.stack([ds, dt]).min(axis=0) + d//2 * (d//2 + d%2 - 1)
z[src] = 1
z[dst] = 1
z[np.isnan(z)] = 0 # unreachable nodes
return F.tensor(z, F.int64)
_init_api("dgl.transform", __name__) _init_api("dgl.transform", __name__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment