"docs/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "acefb9a9f0af08a72d0e25d632e3a0be493e9178"
Unverified Commit 4b1fb681 authored by Infinity_X's avatar Infinity_X Committed by GitHub
Browse files

[Model] Heterogeneous graph support for GNNExplainer (#4401)



* [Model] Heterogeneous graph support for GNNExplainer (#1)

* add HeteroGNNExplainer

* GNNExplainer for heterogeenous graph

* fix typo

* variable name cleanup

* added HeteroGNNExplainer test

* added doc indexing for HeteroGNNExplainer

* Update python/dgl/nn/pytorch/explain/gnnexplainer.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* Update python/dgl/nn/pytorch/explain/gnnexplainer.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* Update python/dgl/nn/pytorch/explain/gnnexplainer.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* Update python/dgl/nn/pytorch/explain/gnnexplainer.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* Update python/dgl/nn/pytorch/explain/gnnexplainer.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* Update python/dgl/nn/pytorch/explain/gnnexplainer.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* Update python/dgl/nn/pytorch/explain/gnnexplainer.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* Update python/dgl/nn/pytorch/explain/gnnexplainer.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* Update python/dgl/nn/pytorch/explain/gnnexplainer.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* Update python/dgl/nn/pytorch/explain/gnnexplainer.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* Update python/dgl/nn/pytorch/explain/gnnexplainer.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* Update python/dgl/nn/pytorch/explain/gnnexplainer.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* Update python/dgl/nn/pytorch/explain/gnnexplainer.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* Update gnnexplainer.py

Change DGLHeteroGraph to DGLGraph, and specified parameter inputs

* Added ntype parameter to the explainer_node call

* responding to @mufeili's comment regarding restoring empty lines at appriopiate places to be consistent with existing practices

* responding to @mufeili's comment regarding restoring empty lines at appriopiate places that were missed in the last commit

* docstring comments added based on @mufeili suggestions

* indorporated @mufeili requested changes related to docstring model declaration.

* example model and test_nn.py added for explain_graphs

* explain_nodes fixed and fixed the way hetero num nodes and edges are handled

* white spaces removed

* lint issues fixed

* explain_graph model updated

* explain nodes model updated

* minor fixes related to gpu compatability

* cuda support added

* simplify WIP

* _init_masks for ennexplainer updated to match heterographs

* Update

* model simplified and docstring comments updated

* nits: docstring udpated

* lint check issues updated

* lint check updated

* soem formatting updated

* disabling int32 testing for GNNExplainer

* Update
Co-authored-by: default avatarKangkook Jee <kangkook.jee@gmail.com>
Co-authored-by: default avatarahadjawaid <94938815+ahadjawaid@users.noreply.github.com>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
Co-authored-by: default avatarkxm180046 <kxm180046@utdallas.edu>
Co-authored-by: default avatarKunal Mukherjee <kunmukh@gmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-9-26.ap-northeast-1.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
parent 98adca68
...@@ -112,6 +112,7 @@ Utility Modules ...@@ -112,6 +112,7 @@ Utility Modules
~dgl.nn.pytorch.utils.JumpingKnowledge ~dgl.nn.pytorch.utils.JumpingKnowledge
~dgl.nn.pytorch.sparse_emb.NodeEmbedding ~dgl.nn.pytorch.sparse_emb.NodeEmbedding
~dgl.nn.pytorch.explain.GNNExplainer ~dgl.nn.pytorch.explain.GNNExplainer
~dgl.nn.pytorch.explain.HeteroGNNExplainer
~dgl.nn.pytorch.utils.LabelPropagation ~dgl.nn.pytorch.utils.LabelPropagation
Network Embedding Modules Network Embedding Modules
......
"""Torch modules for explanation models.""" """Torch modules for explanation models."""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
from .gnnexplainer import GNNExplainer from .gnnexplainer import GNNExplainer, HeteroGNNExplainer
__all__ = ["GNNExplainer"]
...@@ -1329,6 +1329,71 @@ def test_gnnexplainer(g, idtype, out_dim): ...@@ -1329,6 +1329,71 @@ def test_gnnexplainer(g, idtype, out_dim):
explainer = nn.GNNExplainer(model, num_hops=1) explainer = nn.GNNExplainer(model, num_hops=1)
feat_mask, edge_mask = explainer.explain_graph(g, feat) feat_mask, edge_mask = explainer.explain_graph(g, feat)
@pytest.mark.parametrize('g', get_cases(['hetero'], exclude=['zero-degree']))
@pytest.mark.parametrize('idtype', [F.int64])
@pytest.mark.parametrize('input_dim', [5])
@pytest.mark.parametrize('output_dim', [1, 2])
def test_heterognnexplainer(g, idtype, input_dim, output_dim):
g = g.astype(idtype).to(F.ctx())
device = g.device
# add self-loop and reverse edges
transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
g = transform1(g)
transform2 = dgl.transforms.AddReverse(copy_edata=True)
g = transform2(g)
feat = {ntype: th.zeros((g.num_nodes(ntype), input_dim), device=device)
for ntype in g.ntypes}
class Model(th.nn.Module):
def __init__(self, in_dim, num_classes, canonical_etypes, graph=False):
super(Model, self).__init__()
self.graph=graph
self.etype_weights = th.nn.ModuleDict({
'_'.join(c_etype): th.nn.Linear(in_dim, num_classes)
for c_etype in canonical_etypes
})
def forward(self, graph, feat, eweight=None):
with graph.local_scope():
c_etype_func_dict = {}
for c_etype in graph.canonical_etypes:
src_type, etype, dst_type = c_etype
wh = self.etype_weights['_'.join(c_etype)](feat[src_type])
graph.nodes[src_type].data[f'h_{c_etype}'] = wh
if eweight is None:
c_etype_func_dict[c_etype] = (fn.copy_u(f'h_{c_etype}', 'm'),
fn.mean('m', 'h'))
else:
graph.edges[c_etype].data['w'] = eweight[c_etype]
c_etype_func_dict[c_etype] = (
fn.u_mul_e(f'h_{c_etype}', 'w', 'm'), fn.mean('m', 'h'))
graph.multi_update_all(c_etype_func_dict, 'sum')
if self.graph:
hg = 0
for ntype in graph.ntypes:
if graph.num_nodes(ntype):
hg = hg + dgl.mean_nodes(graph, 'h', ntype=ntype)
return hg
else:
return graph.ndata['h']
# Explain node prediction
model = Model(input_dim, output_dim, g.canonical_etypes)
model = model.to(F.ctx())
ntype = g.ntypes[0]
explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)
new_center, sg, feat_mask, edge_mask = explainer.explain_node(ntype, 0, g, feat)
# Explain graph prediction
model = Model(input_dim, output_dim, g.canonical_etypes, graph=True)
model = model.to(F.ctx())
explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)
feat_mask, edge_mask = explainer.explain_graph(g, feat)
def test_jumping_knowledge(): def test_jumping_knowledge():
ctx = F.ctx() ctx = F.ctx()
num_layers = 2 num_layers = 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment