Unverified Commit eb40ed55 authored by Nick Baker's avatar Nick Baker Committed by GitHub
Browse files

[Model] Add Node explanation for Heterogenous PGExplainer Impl. (#6050)


Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent a848aa3e
...@@ -1898,8 +1898,11 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes): ...@@ -1898,8 +1898,11 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes):
g = transform2(g) g = transform2(g)
class Model(th.nn.Module): class Model(th.nn.Module):
def __init__(self, in_feats, embed_dim, out_feats, canonical_etypes): def __init__(
self, in_feats, embed_dim, out_feats, canonical_etypes, graph=True
):
super(Model, self).__init__() super(Model, self).__init__()
self.graph = graph
self.conv = nn.HeteroGraphConv( self.conv = nn.HeteroGraphConv(
{ {
c_etype: nn.GraphConv(in_feats, embed_dim) c_etype: nn.GraphConv(in_feats, embed_dim)
...@@ -1918,7 +1921,7 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes): ...@@ -1918,7 +1921,7 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes):
else: else:
h = self.conv(g, h) h = self.conv(g, h)
if embed: if not self.graph or embed:
return h return h
with g.local_scope(): with g.local_scope():
...@@ -1931,13 +1934,33 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes): ...@@ -1931,13 +1934,33 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes):
embed_dim = input_dim embed_dim = input_dim
# graph explainer # graph explainer
model = Model(input_dim, embed_dim, n_classes, g.canonical_etypes) model = Model(
input_dim, embed_dim, n_classes, g.canonical_etypes, graph=True
)
model = model.to(ctx) model = model.to(ctx)
explainer = nn.HeteroPGExplainer(model, embed_dim) explainer = nn.HeteroPGExplainer(model, embed_dim)
explainer.train_step(g, feat, 5.0) explainer.train_step(g, feat, 5.0)
probs, edge_weight = explainer.explain_graph(g, feat) probs, edge_weight = explainer.explain_graph(g, feat)
# node explainer
model = Model(
input_dim, embed_dim, n_classes, g.canonical_etypes, graph=False
)
model = model.to(ctx)
explainer = nn.HeteroPGExplainer(
model, embed_dim, num_hops=1, explain_graph=False
)
explainer.train_step_node({g.ntypes[0]: [0]}, g, feat, 5.0)
explainer.train_step_node({g.ntypes[0]: th.tensor([0, 1])}, g, feat, 5.0)
probs, edge_weight, bg, inverse_indices = explainer.explain_node(
{g.ntypes[0]: [0]}, g, feat
)
probs, edge_weight, bg, inverse_indices = explainer.explain_node(
{g.ntypes[0]: th.tensor([0, 1])}, g, feat
)
def test_jumping_knowledge(): def test_jumping_knowledge():
ctx = F.ctx() ctx = F.ctx()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment