Unverified Commit eb40ed55 authored by Nick Baker's avatar Nick Baker Committed by GitHub
Browse files

[Model] Add Node explanation for Heterogenous PGExplainer Impl. (#6050)


Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent a848aa3e
......@@ -14,10 +14,11 @@ class PGExplainer(nn.Module):
r"""PGExplainer from `Parameterized Explainer for Graph Neural Network
<https://arxiv.org/pdf/2011.04573>`
PGExplainer adopts a deep neural network (explanation network) to parameterize the generation
process of explanations, which enables it to explain multiple instances
collectively. PGExplainer models the underlying structure as edge
distributions, from which the explanatory graph is sampled.
PGExplainer adopts a deep neural network (explanation network) to
parameterize the generation process of explanations, which enables it to
explain multiple instances collectively. PGExplainer models the underlying
structure as edge distributions, from which the explanatory graph is
sampled.
Parameters
----------
......@@ -58,6 +59,7 @@ class PGExplainer(nn.Module):
self.model = model
self.graph_explanation = explain_graph
# Node explanation requires additional self-embedding data.
self.num_features = num_features * (2 if self.graph_explanation else 3)
self.num_hops = num_hops
......@@ -206,8 +208,8 @@ class PGExplainer(nn.Module):
return gate_inputs
def train_step(self, graph, feat, tmp, **kwargs):
r"""Compute the loss of the explanation network
def train_step(self, graph, feat, temperature, **kwargs):
r"""Compute the loss of the explanation network for graph classification
Parameters
----------
......@@ -216,7 +218,7 @@ class PGExplainer(nn.Module):
feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size.
tmp : float
temperature : float
The temperature parameter fed to the sampling procedure.
kwargs : dict
Additional arguments passed to the GNN model.
......@@ -228,7 +230,7 @@ class PGExplainer(nn.Module):
"""
assert (
self.graph_explanation
), '"explain_graph" must be True in initializing the module.'
), '"explain_graph" must be True when initializing the module.'
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
......@@ -237,26 +239,26 @@ class PGExplainer(nn.Module):
pred = pred.argmax(-1).data
prob, _ = self.explain_graph(
graph, feat, tmp=tmp, training=True, **kwargs
graph, feat, temperature, training=True, **kwargs
)
loss = self.loss(prob, pred)
return loss
def train_step_node(self, nodes, graph, feat, tmp, **kwargs):
r"""Compute the loss of the explanation network
def train_step_node(self, nodes, graph, feat, temperature, **kwargs):
r"""Compute the loss of the explanation network for node classification
Parameters
----------
nodes : int, iterable[int], tensor
The nodes from the graph used to train the explanation network, which cannot
have any duplicate value.
The nodes from the graph used to train the explanation network,
which cannot have any duplicate value.
graph : DGLGraph
Input homogeneous graph.
feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size.
tmp : float
temperature : float
The temperature parameter fed to the sampling procedure.
kwargs : dict
Additional arguments passed to the GNN model.
......@@ -268,7 +270,7 @@ class PGExplainer(nn.Module):
"""
assert (
not self.graph_explanation
), '"explain_graph" must be False in initializing the module.'
), '"explain_graph" must be False when initializing the module.'
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
......@@ -279,7 +281,7 @@ class PGExplainer(nn.Module):
nodes = [nodes]
prob, _, batched_graph, inverse_indices = self.explain_node(
nodes, graph, feat, tmp=tmp, training=True, **kwargs
nodes, graph, feat, temperature, training=True, **kwargs
)
pred = self.model(
......@@ -290,7 +292,9 @@ class PGExplainer(nn.Module):
loss = self.loss(prob[inverse_indices], pred[inverse_indices])
return loss
def explain_graph(self, graph, feat, tmp=1.0, training=False, **kwargs):
def explain_graph(
self, graph, feat, temperature=1.0, training=False, **kwargs
):
r"""Learn and return an edge mask that plays a crucial role to
explain the prediction made by the GNN for a graph. Also, return
the prediction made with the edges chosen based on the edge mask.
......@@ -302,7 +306,7 @@ class PGExplainer(nn.Module):
feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size.
tmp : float
temperature : float
The temperature parameter fed to the sampling procedure.
training : bool
Training the explanation network.
......@@ -312,13 +316,13 @@ class PGExplainer(nn.Module):
Returns
-------
Tensor
Classification probabilities given the masked graph. It is a tensor of
shape :math:`(B, L)`, where :math:`L` is the different types of label
in the dataset, and :math:`B` is the batch size.
Classification probabilities given the masked graph. It is a tensor
of shape :math:`(B, L)`, where :math:`L` is the different types of
label in the dataset, and :math:`B` is the batch size.
Tensor
Edge weights which is a tensor of shape :math:`(E)`, where :math:`E`
is the number of edges in the graph. A higher weight suggests a larger
contribution of the edge.
is the number of edges in the graph. A higher weight suggests a
larger contribution of the edge.
Examples
--------
......@@ -388,7 +392,7 @@ class PGExplainer(nn.Module):
"""
assert (
self.graph_explanation
), '"explain_graph" must be True in initializing the module.'
), '"explain_graph" must be True when initializing the module.'
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
......@@ -403,7 +407,9 @@ class PGExplainer(nn.Module):
emb = self.elayers(emb)
values = emb.reshape(-1)
values = self.concrete_sample(values, beta=tmp, training=training)
values = self.concrete_sample(
values, beta=temperature, training=training
)
self.sparse_mask_values = values
reverse_eids = graph.edge_ids(row, col).long()
......@@ -423,12 +429,11 @@ class PGExplainer(nn.Module):
return (probs, edge_mask)
def explain_node(
self, nodes, graph, feat, tmp=1.0, training=False, **kwargs
self, nodes, graph, feat, temperature=1.0, training=False, **kwargs
):
r"""Learn and return an edge mask that plays a crucial role to
explain the prediction made by the GNN for node :attr:`node_id`.
Also, return the prediction made with the edges chosen based on
the edge mask.
explain the prediction made by the GNN for provided set of node IDs.
Also, return the prediction made with the graph and edge mask.
Parameters
----------
......@@ -439,7 +444,7 @@ class PGExplainer(nn.Module):
feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size.
tmp : float
temperature : float
The temperature parameter fed to the sampling procedure.
training : bool
Training the explanation network.
......@@ -449,13 +454,14 @@ class PGExplainer(nn.Module):
Returns
-------
Tensor
Classification probabilities given the masked graph. It is a tensor of
shape :math:`(B, L)`, where :math:`L` is the different types of label
in the dataset, and :math:`B` is the batch size.
Classification probabilities given the masked graph. It is a tensor
of shape :math:`(N, L)`, where :math:`L` is the different types
of node labels in the dataset, and :math:`N` is the number of nodes
in the graph.
Tensor
Edge weights which is a tensor of shape :math:`(E)`, where :math:`E`
is the number of edges in the graph. A higher weight suggests a larger
contribution of the edge.
is the number of edges in the graph. A higher weight suggests a
larger contribution of the edge.
DGLGraph
The batched set of subgraphs induced on the k-hop in-neighborhood
of the input center nodes.
......@@ -523,10 +529,10 @@ class PGExplainer(nn.Module):
"""
assert (
not self.graph_explanation
), '"explain_graph" must be False in initializing the module.'
), '"explain_graph" must be False when initializing the module.'
assert (
self.num_hops is not None
), '"num_hops" must be provided in initializing the module.'
), '"num_hops" must be provided when initializing the module.'
if isinstance(nodes, torch.Tensor):
nodes = nodes.tolist()
......@@ -537,17 +543,17 @@ class PGExplainer(nn.Module):
self.elayers = self.elayers.to(graph.device)
batched_graph = []
batched_feats = []
batched_embed = []
batched_inverse_indices = []
node_idx = 0
for node_id in nodes:
sg, inverse_indices = khop_in_subgraph(
graph, node_id, self.num_hops
)
sg_feat = feat[sg.ndata[NID].long()]
sg.ndata["feat"] = feat[sg.ndata[NID].long()]
sg.ndata["train"] = torch.tensor(
[nid in inverse_indices for nid in sg.nodes()], device=sg.device
)
embed = self.model(sg, sg_feat, embed=True, **kwargs)
embed = self.model(sg, sg.ndata["feat"], embed=True, **kwargs)
embed = embed.data
col, row = sg.edges()
......@@ -557,20 +563,16 @@ class PGExplainer(nn.Module):
emb = torch.cat([col_emb, row_emb, self_emb], dim=-1)
batched_embed.append(emb)
batched_graph.append(sg)
batched_feats.append(sg_feat)
# node id's of subgraph mapped to batch:
# https://docs.dgl.ai/en/latest/generated/dgl.batch.html#dgl.batch
batched_inverse_indices.append(inverse_indices[0].item() + node_idx)
node_idx += sg.num_nodes()
batched_graph = batch(batched_graph)
batched_feats = torch.cat(batched_feats)
batched_embed = torch.cat(batched_embed)
batched_embed = torch.cat(batched_embed)
batched_embed = self.elayers(batched_embed)
values = batched_embed.reshape(-1)
values = self.concrete_sample(values, beta=tmp, training=training)
values = self.concrete_sample(
values, beta=temperature, training=training
)
self.sparse_mask_values = values
col, row = batched_graph.edges()
......@@ -579,12 +581,17 @@ class PGExplainer(nn.Module):
self.set_masks(batched_graph, edge_mask)
batched_feats = batched_graph.ndata["feat"]
# the model prediction with the updated edge mask
logits = self.model(
batched_graph, batched_feats, edge_weight=self.edge_mask, **kwargs
)
probs = F.softmax(logits, dim=-1)
batched_inverse_indices = (
batched_graph.ndata["train"].nonzero().squeeze(1)
)
if training:
self.batched_feats = batched_feats
probs = probs.data
......@@ -592,7 +599,7 @@ class PGExplainer(nn.Module):
self.clear_masks()
return (
probs.data,
probs,
edge_mask,
batched_graph,
batched_inverse_indices,
......@@ -603,10 +610,11 @@ class HeteroPGExplainer(PGExplainer):
r"""PGExplainer from `Parameterized Explainer for Graph Neural Network
<https://arxiv.org/pdf/2011.04573>`__, adapted for heterogeneous graphs
PGExplainer adopts a deep neural network (explanation network) to parameterize the generation
process of explanations, which enables it to explain multiple instances
collectively. PGExplainer models the underlying structure as edge
distributions, from which the explanatory graph is sampled.
PGExplainer adopts a deep neural network (explanation network) to
parameterize the generation process of explanations, which enables it to
explain multiple instances collectively. PGExplainer models the underlying
structure as edge distributions, from which the explanatory graph is
sampled.
Parameters
----------
......@@ -628,8 +636,9 @@ class HeteroPGExplainer(PGExplainer):
in a sample than others. Default: 0.0.
"""
def train_step(self, graph, feat, tmp, **kwargs):
r"""Compute the loss of the explanation network
def train_step(self, graph, feat, temperature, **kwargs):
# pylint: disable=useless-super-delegation
r"""Compute the loss of the explanation network for graph classification
Parameters
----------
......@@ -637,10 +646,36 @@ class HeteroPGExplainer(PGExplainer):
Input batched heterogeneous graph.
feat : dict[str, Tensor]
A dict mapping node types (keys) to feature tensors (values).
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is the
number of nodes for node type :math:`t`, and :math:`D_t` is the feature
size for node type :math:`t`
tmp : float
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
the number of nodes for node type :math:`t`, and :math:`D_t` is the
feature size for node type :math:`t`
temperature : float
The temperature parameter fed to the sampling procedure.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
Tensor
A scalar tensor representing the loss.
"""
return super().train_step(graph, feat, temperature, **kwargs)
def train_step_node(self, nodes, graph, feat, temperature, **kwargs):
r"""Compute the loss of the explanation network for node classification
Parameters
----------
nodes : dict[str, Iterable[int]]
A dict mapping node types (keys) to an iterable set of node ids (values).
graph : DGLGraph
Input heterogeneous graph.
feat : dict[str, Tensor]
A dict mapping node types (keys) to feature tensors (values).
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
the number of nodes for node type :math:`t`, and :math:`D_t` is the
feature size for node type :math:`t`
temperature : float
The temperature parameter fed to the sampling procedure.
kwargs : dict
Additional arguments passed to the GNN model.
......@@ -650,9 +685,35 @@ class HeteroPGExplainer(PGExplainer):
Tensor
A scalar tensor representing the loss.
"""
return super().train_step(graph, feat, tmp=tmp, **kwargs)
assert (
not self.graph_explanation
), '"explain_graph" must be False when initializing the module.'
def explain_graph(self, graph, feat, tmp=1.0, training=False, **kwargs):
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
prob, _, batched_graph, inverse_indices = self.explain_node(
nodes, graph, feat, temperature, training=True, **kwargs
)
pred = self.model(
batched_graph, self.batched_feats, embed=False, **kwargs
)
pred = {ntype: pred[ntype].argmax(-1).data for ntype in pred.keys()}
loss = self.loss(
torch.cat(
[prob[ntype][nid] for ntype, nid in inverse_indices.items()]
),
torch.cat(
[pred[ntype][nid] for ntype, nid in inverse_indices.items()]
),
)
return loss
def explain_graph(
self, graph, feat, temperature=1.0, training=False, **kwargs
):
r"""Learn and return an edge mask that plays a crucial role to
explain the prediction made by the GNN for a graph. Also, return
the prediction made with the edges chosen based on the edge mask.
......@@ -663,10 +724,10 @@ class HeteroPGExplainer(PGExplainer):
A heterogeneous graph.
feat : dict[str, Tensor]
A dict mapping node types (keys) to feature tensors (values).
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is the
number of nodes for node type :math:`t`, and :math:`D_t` is the feature
size for node type :math:`t`
tmp : float
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
the number of nodes for node type :math:`t`, and :math:`D_t` is the
feature size for node type :math:`t`
temperature : float
The temperature parameter fed to the sampling procedure.
training : bool
Training the explanation network.
......@@ -676,13 +737,14 @@ class HeteroPGExplainer(PGExplainer):
Returns
-------
Tensor
Classification probabilities given the masked graph. It is a tensor of
shape :math:`(B, L)`, where :math:`L` is the different types of label
in the dataset, and :math:`B` is the batch size.
Classification probabilities given the masked graph. It is a tensor
of shape :math:`(B, L)`, where :math:`L` is the different types of
label in the dataset, and :math:`B` is the batch size.
dict[str, Tensor]
A dict mapping edge types (keys) to edge tensors (values) of shape :math:`(E_t)`,
where :math:`E_t` is the number of edges in the graph for edge type :math:`t`.
A higher weight suggests a larger contribution of the edge.
A dict mapping edge types (keys) to edge tensors (values) of shape
:math:`(E_t)`, where :math:`E_t` is the number of edges in the graph
for edge type :math:`t`. A higher weight suggests a larger
contribution of the edge.
Examples
--------
......@@ -761,6 +823,10 @@ class HeteroPGExplainer(PGExplainer):
>>> feat = g.ndata.pop("h")
>>> probs, edge_mask = explainer.explain_graph(g, feat)
"""
assert (
self.graph_explanation
), '"explain_graph" must be True when initializing the module.'
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
......@@ -770,16 +836,16 @@ class HeteroPGExplainer(PGExplainer):
homo_graph = to_homogeneous(graph, ndata=["emb"])
homo_embed = homo_graph.ndata["emb"]
edge_idx = homo_graph.edges()
col, row = edge_idx
col, row = homo_graph.edges()
col_emb = homo_embed[col.long()]
row_emb = homo_embed[row.long()]
emb = torch.cat([col_emb, row_emb], dim=-1)
emb = self.elayers(emb)
values = emb.reshape(-1)
values = self.concrete_sample(values, beta=tmp, training=training)
values = self.concrete_sample(
values, beta=temperature, training=training
)
self.sparse_mask_values = values
reverse_eids = homo_graph.edge_ids(row, col).long()
......@@ -788,14 +854,11 @@ class HeteroPGExplainer(PGExplainer):
self.set_masks(homo_graph, edge_mask)
# convert the edge mask back into heterogeneous format
hetero_edge_mask = {
etype: edge_mask[
(homo_graph.edata[ETYPE] == graph.get_etype_id(etype))
.nonzero()
.squeeze(1)
]
for etype in graph.canonical_etypes
}
hetero_edge_mask = self._edge_mask_to_heterogeneous(
edge_mask=edge_mask,
homograph=homo_graph,
heterograph=graph,
)
# the model prediction with the updated edge mask
logits = self.model(graph, feat, edge_weight=hetero_edge_mask, **kwargs)
......@@ -807,3 +870,270 @@ class HeteroPGExplainer(PGExplainer):
self.clear_masks()
return (probs, hetero_edge_mask)
def explain_node(
self, nodes, graph, feat, temperature=1.0, training=False, **kwargs
):
r"""Learn and return an edge mask that plays a crucial role to
explain the prediction made by the GNN for provided set of node IDs.
Also, return the prediction made with the batched graph and edge mask.
Parameters
----------
nodes : dict[str, Iterable[int]]
A dict mapping node types (keys) to an iterable set of node ids (values).
graph : DGLGraph
A heterogeneous graph.
feat : dict[str, Tensor]
A dict mapping node types (keys) to feature tensors (values).
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
the number of nodes for node type :math:`t`, and :math:`D_t` is the
feature size for node type :math:`t`
temperature : float
The temperature parameter fed to the sampling procedure.
training : bool
Training the explanation network.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
dict[str, Tensor]
A dict mapping node types (keys) to classification probabilities
for node labels (values). The values are tensors of shape
:math:`(N_t, L)`, where :math:`L` is the different types of node
labels in the dataset, and :math:`N_t` is the number of nodes in
the graph for node type :math:`t`.
dict[str, Tensor]
A dict mapping edge types (keys) to edge tensors (values) of shape
:math:`(E_t)`, where :math:`E_t` is the number of edges in the graph
for edge type :math:`t`. A higher weight suggests a larger
contribution of the edge.
DGLGraph
The batched set of subgraphs induced on the k-hop in-neighborhood
of the input center nodes.
dict[str, Tensor]
A dict mapping node types (keys) to a tensor of node IDs (values)
which correspond to the subgraph center nodes.
Examples
--------
>>> import dgl
>>> import torch as th
>>> import torch.nn as nn
>>> import numpy as np
>>> # Define the model
>>> class Model(nn.Module):
... def __init__(self, in_feats, hid_feats, out_feats, rel_names):
... super().__init__()
... self.conv = dgl.nn.HeteroGraphConv(
... {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names},
... aggregate="sum",
... )
... self.fc = nn.Linear(hid_feats, out_feats)
... nn.init.xavier_uniform_(self.fc.weight)
...
... def forward(self, g, h, embed=False, edge_weight=None):
... if edge_weight:
... mod_kwargs = {
... etype: {"edge_weight": mask} for etype, mask in edge_weight.items()
... }
... h = self.conv(g, h, mod_kwargs=mod_kwargs)
... else:
... h = self.conv(g, h)
...
... return h
>>> # Load dataset
>>> input_dim = 5
>>> hidden_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim)
>>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes)
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
... logits = model(g, g.ndata["h"])['user']
... loss = th.nn.functional.cross_entropy(logits, th.tensor([1,1,1]))
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Initialize the explainer
>>> explainer = dgl.nn.HeteroPGExplainer(
... model, hidden_dim, num_hops=2, explain_graph=False
... )
>>> # Train the explainer
>>> # Define explainer temperature parameter
>>> init_tmp, final_tmp = 5.0, 1.0
>>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)
>>> for epoch in range(20):
... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))
... loss = explainer.train_step_node(
... { ntype: g.nodes(ntype) for ntype in g.ntypes },
... g, g.ndata["h"], tmp
... )
... optimizer_exp.zero_grad()
... loss.backward()
... optimizer_exp.step()
>>> # Explain the graph
>>> feat = g.ndata.pop("h")
>>> probs, edge_mask, bg, inverse_indices = explainer.explain_node(
... { "user": [0] }, g, feat
... )
"""
assert (
not self.graph_explanation
), '"explain_graph" must be False when initializing the module.'
assert (
self.num_hops is not None
), '"num_hops" must be provided when initializing the module.'
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
batched_embed = []
batched_homo_graph = []
batched_hetero_graph = []
for target_ntype, target_nids in nodes.items():
if isinstance(target_nids, torch.Tensor):
target_nids = target_nids.tolist()
for target_nid in target_nids:
sg, inverse_indices = khop_in_subgraph(
graph, {target_ntype: target_nid}, self.num_hops
)
for sg_ntype in sg.ntypes:
sg_feat = feat[sg_ntype][sg.ndata[NID][sg_ntype].long()]
train_mask = [
sg_ntype in inverse_indices
and node_id in inverse_indices[sg_ntype]
for node_id in sg.nodes(sg_ntype)
]
sg.nodes[sg_ntype].data["feat"] = sg_feat
sg.nodes[sg_ntype].data["train"] = torch.tensor(
train_mask, device=sg.device
)
embed = self.model(sg, sg.ndata["feat"], embed=True, **kwargs)
for ntype in embed.keys():
sg.nodes[ntype].data["emb"] = embed[ntype].data
homo_sg = to_homogeneous(sg, ndata=["emb"])
homo_sg_embed = homo_sg.ndata["emb"]
col, row = homo_sg.edges()
col_emb = homo_sg_embed[col.long()]
row_emb = homo_sg_embed[row.long()]
self_emb = homo_sg_embed[
inverse_indices[target_ntype][0]
].repeat(sg.num_edges(), 1)
emb = torch.cat([col_emb, row_emb, self_emb], dim=-1)
batched_embed.append(emb)
batched_homo_graph.append(homo_sg)
batched_hetero_graph.append(sg)
batched_homo_graph = batch(batched_homo_graph)
batched_hetero_graph = batch(batched_hetero_graph)
batched_embed = torch.cat(batched_embed)
batched_embed = self.elayers(batched_embed)
values = batched_embed.reshape(-1)
values = self.concrete_sample(
values, beta=temperature, training=training
)
self.sparse_mask_values = values
col, row = batched_homo_graph.edges()
reverse_eids = batched_homo_graph.edge_ids(row, col).long()
edge_mask = (values + values[reverse_eids]) / 2
self.set_masks(batched_homo_graph, edge_mask)
# Convert the edge mask back into heterogeneous format.
hetero_edge_mask = self._edge_mask_to_heterogeneous(
edge_mask=edge_mask,
homograph=batched_homo_graph,
heterograph=batched_hetero_graph,
)
batched_feats = {
ntype: batched_hetero_graph.nodes[ntype].data["feat"]
for ntype in batched_hetero_graph.ntypes
}
# The model prediction with the updated edge mask.
logits = self.model(
batched_hetero_graph,
batched_feats,
edge_weight=hetero_edge_mask,
**kwargs,
)
probs = {
ntype: F.softmax(logits[ntype], dim=-1) for ntype in logits.keys()
}
batched_inverse_indices = {
ntype: batched_hetero_graph.nodes[ntype]
.data["train"]
.nonzero()
.squeeze(1)
for ntype in batched_hetero_graph.ntypes
}
if training:
self.batched_feats = batched_feats
probs = {ntype: probs[ntype].data for ntype in probs.keys()}
else:
self.clear_masks()
return (
probs,
hetero_edge_mask,
batched_hetero_graph,
batched_inverse_indices,
)
def _edge_mask_to_heterogeneous(self, edge_mask, homograph, heterograph):
r"""Convert an edge mask from homogeneous mappings built through
embeddings into heterogenous format by leveraging the context from
the source DGLGraphs in homogenous and heterogeneous form.
The `edge_mask` needs to have been built using the embedding of the
homogenous graph format for the mappings to work correctly.
Parameters
----------
edge_mask : dict[str, Tensor]
A dict mapping node types (keys) to a tensor of edge weights (values).
homograph : DGLGraph
The homogeneous form of the source graph.
heterograph : DGLGraph
The heterogeneous form of the source graph.
Returns
-------
dict[str, Tensor]
A dict mapping node types (keys) to tensors of node ids (values)
"""
return {
etype: edge_mask[
(homograph.edata[ETYPE] == heterograph.get_etype_id(etype))
.nonzero()
.squeeze(1)
]
for etype in heterograph.canonical_etypes
}
......@@ -1898,8 +1898,11 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes):
g = transform2(g)
class Model(th.nn.Module):
def __init__(self, in_feats, embed_dim, out_feats, canonical_etypes):
def __init__(
self, in_feats, embed_dim, out_feats, canonical_etypes, graph=True
):
super(Model, self).__init__()
self.graph = graph
self.conv = nn.HeteroGraphConv(
{
c_etype: nn.GraphConv(in_feats, embed_dim)
......@@ -1918,7 +1921,7 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes):
else:
h = self.conv(g, h)
if embed:
if not self.graph or embed:
return h
with g.local_scope():
......@@ -1931,13 +1934,33 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes):
embed_dim = input_dim
# graph explainer
model = Model(input_dim, embed_dim, n_classes, g.canonical_etypes)
model = Model(
input_dim, embed_dim, n_classes, g.canonical_etypes, graph=True
)
model = model.to(ctx)
explainer = nn.HeteroPGExplainer(model, embed_dim)
explainer.train_step(g, feat, 5.0)
probs, edge_weight = explainer.explain_graph(g, feat)
# node explainer
model = Model(
input_dim, embed_dim, n_classes, g.canonical_etypes, graph=False
)
model = model.to(ctx)
explainer = nn.HeteroPGExplainer(
model, embed_dim, num_hops=1, explain_graph=False
)
explainer.train_step_node({g.ntypes[0]: [0]}, g, feat, 5.0)
explainer.train_step_node({g.ntypes[0]: th.tensor([0, 1])}, g, feat, 5.0)
probs, edge_weight, bg, inverse_indices = explainer.explain_node(
{g.ntypes[0]: [0]}, g, feat
)
probs, edge_weight, bg, inverse_indices = explainer.explain_node(
{g.ntypes[0]: th.tensor([0, 1])}, g, feat
)
def test_jumping_knowledge():
ctx = F.ctx()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment