Unverified Commit c235d17d authored by Zhiteng Li's avatar Zhiteng Li Committed by GitHub
Browse files

[Transform] Add ToLevi transform (#4884)



* add ToLevi transform

* fix parameter typo

* fix lint issues

* fix device issue

* refine according to dongyu's comments
Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent ca144886
......@@ -33,3 +33,4 @@ dgl.transforms
FeatMask
RowFeatNormalizer
SIGNDiffusion
ToLevi
......@@ -23,6 +23,7 @@ from .. import backend as F
from .. import function as fn
from ..base import DGLError
from . import functional
from .. import utils
try:
import torch
......@@ -52,7 +53,8 @@ __all__ = [
'DropNode',
'DropEdge',
'AddEdge',
'SIGNDiffusion'
'SIGNDiffusion',
'ToLevi'
]
def update_graph_structure(g, data_dict, copy_edata=True):
......@@ -1718,3 +1720,71 @@ class SIGNDiffusion(BaseTransform):
self.alpha * in_feat
feat_list.append(g.ndata[self.in_feat_name])
return feat_list
class ToLevi(BaseTransform):
r"""This function transforms the original graph to its heterogeneous Levi graph,
by converting edges to intermediate nodes, only support homogeneous directed graph.
Example
-------
>>> import dgl
>>> import torch as th
>>> from dgl import ToLevi
>>> transform = ToLevi()
>>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 0]))
>>> g.ndata['h'] = th.randn((g.num_nodes(), 2))
>>> g.edata['w'] = th.randn((g.num_edges(), 2))
>>> lg = transform(g)
>>> lg
Grpah(num_nodes={'edge': 4, 'node': 4},
num_edges={('edge', 'e2n', 'node'): 4,
('node', 'n2e', 'edge'): 4},
metagraph=[('edge', 'node', 'e2n'),
('node', 'edge', 'n2e')])
>>> lg.nodes('node')
tensor([0, 1, 2, 3])
>>> lg.nodes('edge')
tensor([0, 1, 2, 3])
>>> lg.nodes['node'].data['h'].shape
torch.Size([4, 2])
>>> lg.nodes['edge'].data['w'].shape
torch.Size([4, 2])
"""
def __init__(self):
pass
def __call__(self, g):
r"""
Parameters
----------
g : DGLGraph
The input graph, should be a homogeneous directed graph.
Returns
-------
DGLGraph
The Levi graph of input, will be a heterogeneous graph, where nodes of
ntypes ``'node'`` and ``'edge'`` have corresponding IDs of nodes and edges
in the original graph. Edge features of the input graph are copied to
corresponding new nodes of ntype ``'edge'``.
"""
device = g.device
idtype = g.idtype
edge_list = g.edges()
n2e = edge_list[0], F.arange(0, g.num_edges(), idtype, device)
e2n = F.arange(0, g.num_edges(), idtype, device), edge_list[1]
graph_data = {('node', 'n2e', 'edge'): n2e,
('edge', 'e2n', 'node'): e2n}
levi_g = convert.heterograph(graph_data, idtype=idtype, device=device)
# Copy ndata and edata
# Since the node types in dgl.heterograph are in alphabetical order
# ('edge' < 'node'), edge_frames should be in front of node_frames.
node_frames = utils.extract_node_subframes(g, nodes_or_device=device)
edge_frames = utils.extract_edge_subframes(g, edges_or_device=device)
utils.set_new_frames(levi_g, node_frames=edge_frames+node_frames)
return levi_g
......@@ -955,7 +955,7 @@ def extract_edge_subframes(graph, edges_or_device, store_ids=True):
subf[EID] = ind_edges
edge_frames.append(subf)
else: # device object
edge_frames = [nf.to(device) for nf in graph._edge_frames]
edge_frames = [nf.to(edges_or_device) for nf in graph._edge_frames]
return edge_frames
......
......@@ -2672,6 +2672,34 @@ def test_shortest_dist(idtype):
assert F.array_equal(dist, tgt_dist)
assert F.array_equal(paths, tgt_paths)
@parametrize_idtype
def test_module_to_levi(idtype):
transform = dgl.ToLevi()
g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 0]), idtype=idtype, device=F.ctx())
g.ndata['h'] = F.randn((g.num_nodes(), 2))
g.edata['w'] = F.randn((g.num_edges(), 2))
lg = transform(g)
assert lg.device == g.device
assert lg.idtype == g.idtype
assert lg.ntypes == ['edge', 'node']
assert lg.canonical_etypes == [('edge', 'e2n', 'node'),
('node', 'n2e', 'edge')]
assert lg.num_nodes('node') == g.num_nodes()
assert lg.num_nodes('edge') == g.num_edges()
assert lg.num_edges('n2e') == g.num_edges()
assert lg.num_edges('e2n') == g.num_edges()
src, dst = lg.edges(etype='n2e')
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 0), (1, 1), (2, 2), (3, 3)}
src, dst = lg.edges(etype='e2n')
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 1), (1, 2), (2, 3), (3, 0)}
assert F.allclose(lg.nodes['node'].data['h'], g.ndata['h'])
assert F.allclose(lg.nodes['edge'].data['w'], g.edata['w'])
if __name__ == '__main__':
test_partition_with_halo()
test_module_heat_kernel(F.int32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment