Unverified Commit eb40ed55 authored by Nick Baker's avatar Nick Baker Committed by GitHub
Browse files

[Model] Add Node explanation for Heterogenous PGExplainer Impl. (#6050)


Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent a848aa3e
...@@ -14,10 +14,11 @@ class PGExplainer(nn.Module): ...@@ -14,10 +14,11 @@ class PGExplainer(nn.Module):
r"""PGExplainer from `Parameterized Explainer for Graph Neural Network r"""PGExplainer from `Parameterized Explainer for Graph Neural Network
<https://arxiv.org/pdf/2011.04573>` <https://arxiv.org/pdf/2011.04573>`
PGExplainer adopts a deep neural network (explanation network) to parameterize the generation PGExplainer adopts a deep neural network (explanation network) to
process of explanations, which enables it to explain multiple instances parameterize the generation process of explanations, which enables it to
collectively. PGExplainer models the underlying structure as edge explain multiple instances collectively. PGExplainer models the underlying
distributions, from which the explanatory graph is sampled. structure as edge distributions, from which the explanatory graph is
sampled.
Parameters Parameters
---------- ----------
...@@ -58,6 +59,7 @@ class PGExplainer(nn.Module): ...@@ -58,6 +59,7 @@ class PGExplainer(nn.Module):
self.model = model self.model = model
self.graph_explanation = explain_graph self.graph_explanation = explain_graph
# Node explanation requires additional self-embedding data.
self.num_features = num_features * (2 if self.graph_explanation else 3) self.num_features = num_features * (2 if self.graph_explanation else 3)
self.num_hops = num_hops self.num_hops = num_hops
...@@ -206,8 +208,8 @@ class PGExplainer(nn.Module): ...@@ -206,8 +208,8 @@ class PGExplainer(nn.Module):
return gate_inputs return gate_inputs
def train_step(self, graph, feat, tmp, **kwargs): def train_step(self, graph, feat, temperature, **kwargs):
r"""Compute the loss of the explanation network r"""Compute the loss of the explanation network for graph classification
Parameters Parameters
---------- ----------
...@@ -216,7 +218,7 @@ class PGExplainer(nn.Module): ...@@ -216,7 +218,7 @@ class PGExplainer(nn.Module):
feat : Tensor feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size. number of nodes, and :math:`D` is the feature size.
tmp : float temperature : float
The temperature parameter fed to the sampling procedure. The temperature parameter fed to the sampling procedure.
kwargs : dict kwargs : dict
Additional arguments passed to the GNN model. Additional arguments passed to the GNN model.
...@@ -228,7 +230,7 @@ class PGExplainer(nn.Module): ...@@ -228,7 +230,7 @@ class PGExplainer(nn.Module):
""" """
assert ( assert (
self.graph_explanation self.graph_explanation
), '"explain_graph" must be True in initializing the module.' ), '"explain_graph" must be True when initializing the module.'
self.model = self.model.to(graph.device) self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device) self.elayers = self.elayers.to(graph.device)
...@@ -237,26 +239,26 @@ class PGExplainer(nn.Module): ...@@ -237,26 +239,26 @@ class PGExplainer(nn.Module):
pred = pred.argmax(-1).data pred = pred.argmax(-1).data
prob, _ = self.explain_graph( prob, _ = self.explain_graph(
graph, feat, tmp=tmp, training=True, **kwargs graph, feat, temperature, training=True, **kwargs
) )
loss = self.loss(prob, pred) loss = self.loss(prob, pred)
return loss return loss
def train_step_node(self, nodes, graph, feat, tmp, **kwargs): def train_step_node(self, nodes, graph, feat, temperature, **kwargs):
r"""Compute the loss of the explanation network r"""Compute the loss of the explanation network for node classification
Parameters Parameters
---------- ----------
nodes : int, iterable[int], tensor nodes : int, iterable[int], tensor
The nodes from the graph used to train the explanation network, which cannot The nodes from the graph used to train the explanation network,
have any duplicate value. which cannot have any duplicate value.
graph : DGLGraph graph : DGLGraph
Input homogeneous graph. Input homogeneous graph.
feat : Tensor feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size. number of nodes, and :math:`D` is the feature size.
tmp : float temperature : float
The temperature parameter fed to the sampling procedure. The temperature parameter fed to the sampling procedure.
kwargs : dict kwargs : dict
Additional arguments passed to the GNN model. Additional arguments passed to the GNN model.
...@@ -268,7 +270,7 @@ class PGExplainer(nn.Module): ...@@ -268,7 +270,7 @@ class PGExplainer(nn.Module):
""" """
assert ( assert (
not self.graph_explanation not self.graph_explanation
), '"explain_graph" must be False in initializing the module.' ), '"explain_graph" must be False when initializing the module.'
self.model = self.model.to(graph.device) self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device) self.elayers = self.elayers.to(graph.device)
...@@ -279,7 +281,7 @@ class PGExplainer(nn.Module): ...@@ -279,7 +281,7 @@ class PGExplainer(nn.Module):
nodes = [nodes] nodes = [nodes]
prob, _, batched_graph, inverse_indices = self.explain_node( prob, _, batched_graph, inverse_indices = self.explain_node(
nodes, graph, feat, tmp=tmp, training=True, **kwargs nodes, graph, feat, temperature, training=True, **kwargs
) )
pred = self.model( pred = self.model(
...@@ -290,7 +292,9 @@ class PGExplainer(nn.Module): ...@@ -290,7 +292,9 @@ class PGExplainer(nn.Module):
loss = self.loss(prob[inverse_indices], pred[inverse_indices]) loss = self.loss(prob[inverse_indices], pred[inverse_indices])
return loss return loss
def explain_graph(self, graph, feat, tmp=1.0, training=False, **kwargs): def explain_graph(
self, graph, feat, temperature=1.0, training=False, **kwargs
):
r"""Learn and return an edge mask that plays a crucial role to r"""Learn and return an edge mask that plays a crucial role to
explain the prediction made by the GNN for a graph. Also, return explain the prediction made by the GNN for a graph. Also, return
the prediction made with the edges chosen based on the edge mask. the prediction made with the edges chosen based on the edge mask.
...@@ -302,7 +306,7 @@ class PGExplainer(nn.Module): ...@@ -302,7 +306,7 @@ class PGExplainer(nn.Module):
feat : Tensor feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size. number of nodes, and :math:`D` is the feature size.
tmp : float temperature : float
The temperature parameter fed to the sampling procedure. The temperature parameter fed to the sampling procedure.
training : bool training : bool
Training the explanation network. Training the explanation network.
...@@ -312,13 +316,13 @@ class PGExplainer(nn.Module): ...@@ -312,13 +316,13 @@ class PGExplainer(nn.Module):
Returns Returns
------- -------
Tensor Tensor
Classification probabilities given the masked graph. It is a tensor of Classification probabilities given the masked graph. It is a tensor
shape :math:`(B, L)`, where :math:`L` is the different types of label of shape :math:`(B, L)`, where :math:`L` is the different types of
in the dataset, and :math:`B` is the batch size. label in the dataset, and :math:`B` is the batch size.
Tensor Tensor
Edge weights which is a tensor of shape :math:`(E)`, where :math:`E` Edge weights which is a tensor of shape :math:`(E)`, where :math:`E`
is the number of edges in the graph. A higher weight suggests a larger is the number of edges in the graph. A higher weight suggests a
contribution of the edge. larger contribution of the edge.
Examples Examples
-------- --------
...@@ -388,7 +392,7 @@ class PGExplainer(nn.Module): ...@@ -388,7 +392,7 @@ class PGExplainer(nn.Module):
""" """
assert ( assert (
self.graph_explanation self.graph_explanation
), '"explain_graph" must be True in initializing the module.' ), '"explain_graph" must be True when initializing the module.'
self.model = self.model.to(graph.device) self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device) self.elayers = self.elayers.to(graph.device)
...@@ -403,7 +407,9 @@ class PGExplainer(nn.Module): ...@@ -403,7 +407,9 @@ class PGExplainer(nn.Module):
emb = self.elayers(emb) emb = self.elayers(emb)
values = emb.reshape(-1) values = emb.reshape(-1)
values = self.concrete_sample(values, beta=tmp, training=training) values = self.concrete_sample(
values, beta=temperature, training=training
)
self.sparse_mask_values = values self.sparse_mask_values = values
reverse_eids = graph.edge_ids(row, col).long() reverse_eids = graph.edge_ids(row, col).long()
...@@ -423,12 +429,11 @@ class PGExplainer(nn.Module): ...@@ -423,12 +429,11 @@ class PGExplainer(nn.Module):
return (probs, edge_mask) return (probs, edge_mask)
def explain_node( def explain_node(
self, nodes, graph, feat, tmp=1.0, training=False, **kwargs self, nodes, graph, feat, temperature=1.0, training=False, **kwargs
): ):
r"""Learn and return an edge mask that plays a crucial role to r"""Learn and return an edge mask that plays a crucial role to
explain the prediction made by the GNN for node :attr:`node_id`. explain the prediction made by the GNN for provided set of node IDs.
Also, return the prediction made with the edges chosen based on Also, return the prediction made with the graph and edge mask.
the edge mask.
Parameters Parameters
---------- ----------
...@@ -439,7 +444,7 @@ class PGExplainer(nn.Module): ...@@ -439,7 +444,7 @@ class PGExplainer(nn.Module):
feat : Tensor feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size. number of nodes, and :math:`D` is the feature size.
tmp : float temperature : float
The temperature parameter fed to the sampling procedure. The temperature parameter fed to the sampling procedure.
training : bool training : bool
Training the explanation network. Training the explanation network.
...@@ -449,13 +454,14 @@ class PGExplainer(nn.Module): ...@@ -449,13 +454,14 @@ class PGExplainer(nn.Module):
Returns Returns
------- -------
Tensor Tensor
Classification probabilities given the masked graph. It is a tensor of Classification probabilities given the masked graph. It is a tensor
shape :math:`(B, L)`, where :math:`L` is the different types of label of shape :math:`(N, L)`, where :math:`L` is the different types
in the dataset, and :math:`B` is the batch size. of node labels in the dataset, and :math:`N` is the number of nodes
in the graph.
Tensor Tensor
Edge weights which is a tensor of shape :math:`(E)`, where :math:`E` Edge weights which is a tensor of shape :math:`(E)`, where :math:`E`
is the number of edges in the graph. A higher weight suggests a larger is the number of edges in the graph. A higher weight suggests a
contribution of the edge. larger contribution of the edge.
DGLGraph DGLGraph
The batched set of subgraphs induced on the k-hop in-neighborhood The batched set of subgraphs induced on the k-hop in-neighborhood
of the input center nodes. of the input center nodes.
...@@ -523,10 +529,10 @@ class PGExplainer(nn.Module): ...@@ -523,10 +529,10 @@ class PGExplainer(nn.Module):
""" """
assert ( assert (
not self.graph_explanation not self.graph_explanation
), '"explain_graph" must be False in initializing the module.' ), '"explain_graph" must be False when initializing the module.'
assert ( assert (
self.num_hops is not None self.num_hops is not None
), '"num_hops" must be provided in initializing the module.' ), '"num_hops" must be provided when initializing the module.'
if isinstance(nodes, torch.Tensor): if isinstance(nodes, torch.Tensor):
nodes = nodes.tolist() nodes = nodes.tolist()
...@@ -537,17 +543,17 @@ class PGExplainer(nn.Module): ...@@ -537,17 +543,17 @@ class PGExplainer(nn.Module):
self.elayers = self.elayers.to(graph.device) self.elayers = self.elayers.to(graph.device)
batched_graph = [] batched_graph = []
batched_feats = []
batched_embed = [] batched_embed = []
batched_inverse_indices = []
node_idx = 0
for node_id in nodes: for node_id in nodes:
sg, inverse_indices = khop_in_subgraph( sg, inverse_indices = khop_in_subgraph(
graph, node_id, self.num_hops graph, node_id, self.num_hops
) )
sg_feat = feat[sg.ndata[NID].long()] sg.ndata["feat"] = feat[sg.ndata[NID].long()]
sg.ndata["train"] = torch.tensor(
[nid in inverse_indices for nid in sg.nodes()], device=sg.device
)
embed = self.model(sg, sg_feat, embed=True, **kwargs) embed = self.model(sg, sg.ndata["feat"], embed=True, **kwargs)
embed = embed.data embed = embed.data
col, row = sg.edges() col, row = sg.edges()
...@@ -557,20 +563,16 @@ class PGExplainer(nn.Module): ...@@ -557,20 +563,16 @@ class PGExplainer(nn.Module):
emb = torch.cat([col_emb, row_emb, self_emb], dim=-1) emb = torch.cat([col_emb, row_emb, self_emb], dim=-1)
batched_embed.append(emb) batched_embed.append(emb)
batched_graph.append(sg) batched_graph.append(sg)
batched_feats.append(sg_feat)
# node id's of subgraph mapped to batch:
# https://docs.dgl.ai/en/latest/generated/dgl.batch.html#dgl.batch
batched_inverse_indices.append(inverse_indices[0].item() + node_idx)
node_idx += sg.num_nodes()
batched_graph = batch(batched_graph) batched_graph = batch(batched_graph)
batched_feats = torch.cat(batched_feats)
batched_embed = torch.cat(batched_embed)
batched_embed = torch.cat(batched_embed)
batched_embed = self.elayers(batched_embed) batched_embed = self.elayers(batched_embed)
values = batched_embed.reshape(-1) values = batched_embed.reshape(-1)
values = self.concrete_sample(values, beta=tmp, training=training) values = self.concrete_sample(
values, beta=temperature, training=training
)
self.sparse_mask_values = values self.sparse_mask_values = values
col, row = batched_graph.edges() col, row = batched_graph.edges()
...@@ -579,12 +581,17 @@ class PGExplainer(nn.Module): ...@@ -579,12 +581,17 @@ class PGExplainer(nn.Module):
self.set_masks(batched_graph, edge_mask) self.set_masks(batched_graph, edge_mask)
batched_feats = batched_graph.ndata["feat"]
# the model prediction with the updated edge mask # the model prediction with the updated edge mask
logits = self.model( logits = self.model(
batched_graph, batched_feats, edge_weight=self.edge_mask, **kwargs batched_graph, batched_feats, edge_weight=self.edge_mask, **kwargs
) )
probs = F.softmax(logits, dim=-1) probs = F.softmax(logits, dim=-1)
batched_inverse_indices = (
batched_graph.ndata["train"].nonzero().squeeze(1)
)
if training: if training:
self.batched_feats = batched_feats self.batched_feats = batched_feats
probs = probs.data probs = probs.data
...@@ -592,7 +599,7 @@ class PGExplainer(nn.Module): ...@@ -592,7 +599,7 @@ class PGExplainer(nn.Module):
self.clear_masks() self.clear_masks()
return ( return (
probs.data, probs,
edge_mask, edge_mask,
batched_graph, batched_graph,
batched_inverse_indices, batched_inverse_indices,
...@@ -603,10 +610,11 @@ class HeteroPGExplainer(PGExplainer): ...@@ -603,10 +610,11 @@ class HeteroPGExplainer(PGExplainer):
r"""PGExplainer from `Parameterized Explainer for Graph Neural Network r"""PGExplainer from `Parameterized Explainer for Graph Neural Network
<https://arxiv.org/pdf/2011.04573>`__, adapted for heterogeneous graphs <https://arxiv.org/pdf/2011.04573>`__, adapted for heterogeneous graphs
PGExplainer adopts a deep neural network (explanation network) to parameterize the generation PGExplainer adopts a deep neural network (explanation network) to
process of explanations, which enables it to explain multiple instances parameterize the generation process of explanations, which enables it to
collectively. PGExplainer models the underlying structure as edge explain multiple instances collectively. PGExplainer models the underlying
distributions, from which the explanatory graph is sampled. structure as edge distributions, from which the explanatory graph is
sampled.
Parameters Parameters
---------- ----------
...@@ -628,8 +636,9 @@ class HeteroPGExplainer(PGExplainer): ...@@ -628,8 +636,9 @@ class HeteroPGExplainer(PGExplainer):
in a sample than others. Default: 0.0. in a sample than others. Default: 0.0.
""" """
def train_step(self, graph, feat, tmp, **kwargs): def train_step(self, graph, feat, temperature, **kwargs):
r"""Compute the loss of the explanation network # pylint: disable=useless-super-delegation
r"""Compute the loss of the explanation network for graph classification
Parameters Parameters
---------- ----------
...@@ -637,10 +646,36 @@ class HeteroPGExplainer(PGExplainer): ...@@ -637,10 +646,36 @@ class HeteroPGExplainer(PGExplainer):
Input batched heterogeneous graph. Input batched heterogeneous graph.
feat : dict[str, Tensor] feat : dict[str, Tensor]
A dict mapping node types (keys) to feature tensors (values). A dict mapping node types (keys) to feature tensors (values).
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is the The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
number of nodes for node type :math:`t`, and :math:`D_t` is the feature the number of nodes for node type :math:`t`, and :math:`D_t` is the
size for node type :math:`t` feature size for node type :math:`t`
tmp : float temperature : float
The temperature parameter fed to the sampling procedure.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
Tensor
A scalar tensor representing the loss.
"""
return super().train_step(graph, feat, temperature, **kwargs)
def train_step_node(self, nodes, graph, feat, temperature, **kwargs):
r"""Compute the loss of the explanation network for node classification
Parameters
----------
nodes : dict[str, Iterable[int]]
A dict mapping node types (keys) to an iterable set of node ids (values).
graph : DGLGraph
Input heterogeneous graph.
feat : dict[str, Tensor]
A dict mapping node types (keys) to feature tensors (values).
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
the number of nodes for node type :math:`t`, and :math:`D_t` is the
feature size for node type :math:`t`
temperature : float
The temperature parameter fed to the sampling procedure. The temperature parameter fed to the sampling procedure.
kwargs : dict kwargs : dict
Additional arguments passed to the GNN model. Additional arguments passed to the GNN model.
...@@ -650,9 +685,35 @@ class HeteroPGExplainer(PGExplainer): ...@@ -650,9 +685,35 @@ class HeteroPGExplainer(PGExplainer):
Tensor Tensor
A scalar tensor representing the loss. A scalar tensor representing the loss.
""" """
return super().train_step(graph, feat, tmp=tmp, **kwargs) assert (
not self.graph_explanation
), '"explain_graph" must be False when initializing the module.'
def explain_graph(self, graph, feat, tmp=1.0, training=False, **kwargs): self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
prob, _, batched_graph, inverse_indices = self.explain_node(
nodes, graph, feat, temperature, training=True, **kwargs
)
pred = self.model(
batched_graph, self.batched_feats, embed=False, **kwargs
)
pred = {ntype: pred[ntype].argmax(-1).data for ntype in pred.keys()}
loss = self.loss(
torch.cat(
[prob[ntype][nid] for ntype, nid in inverse_indices.items()]
),
torch.cat(
[pred[ntype][nid] for ntype, nid in inverse_indices.items()]
),
)
return loss
def explain_graph(
self, graph, feat, temperature=1.0, training=False, **kwargs
):
r"""Learn and return an edge mask that plays a crucial role to r"""Learn and return an edge mask that plays a crucial role to
explain the prediction made by the GNN for a graph. Also, return explain the prediction made by the GNN for a graph. Also, return
the prediction made with the edges chosen based on the edge mask. the prediction made with the edges chosen based on the edge mask.
...@@ -663,10 +724,10 @@ class HeteroPGExplainer(PGExplainer): ...@@ -663,10 +724,10 @@ class HeteroPGExplainer(PGExplainer):
A heterogeneous graph. A heterogeneous graph.
feat : dict[str, Tensor] feat : dict[str, Tensor]
A dict mapping node types (keys) to feature tensors (values). A dict mapping node types (keys) to feature tensors (values).
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is the The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
number of nodes for node type :math:`t`, and :math:`D_t` is the feature the number of nodes for node type :math:`t`, and :math:`D_t` is the
size for node type :math:`t` feature size for node type :math:`t`
tmp : float temperature : float
The temperature parameter fed to the sampling procedure. The temperature parameter fed to the sampling procedure.
training : bool training : bool
Training the explanation network. Training the explanation network.
...@@ -676,13 +737,14 @@ class HeteroPGExplainer(PGExplainer): ...@@ -676,13 +737,14 @@ class HeteroPGExplainer(PGExplainer):
Returns Returns
------- -------
Tensor Tensor
Classification probabilities given the masked graph. It is a tensor of Classification probabilities given the masked graph. It is a tensor
shape :math:`(B, L)`, where :math:`L` is the different types of label of shape :math:`(B, L)`, where :math:`L` is the different types of
in the dataset, and :math:`B` is the batch size. label in the dataset, and :math:`B` is the batch size.
dict[str, Tensor] dict[str, Tensor]
A dict mapping edge types (keys) to edge tensors (values) of shape :math:`(E_t)`, A dict mapping edge types (keys) to edge tensors (values) of shape
where :math:`E_t` is the number of edges in the graph for edge type :math:`t`. :math:`(E_t)`, where :math:`E_t` is the number of edges in the graph
A higher weight suggests a larger contribution of the edge. for edge type :math:`t`. A higher weight suggests a larger
contribution of the edge.
Examples Examples
-------- --------
...@@ -761,6 +823,10 @@ class HeteroPGExplainer(PGExplainer): ...@@ -761,6 +823,10 @@ class HeteroPGExplainer(PGExplainer):
>>> feat = g.ndata.pop("h") >>> feat = g.ndata.pop("h")
>>> probs, edge_mask = explainer.explain_graph(g, feat) >>> probs, edge_mask = explainer.explain_graph(g, feat)
""" """
assert (
self.graph_explanation
), '"explain_graph" must be True when initializing the module.'
self.model = self.model.to(graph.device) self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device) self.elayers = self.elayers.to(graph.device)
...@@ -770,16 +836,16 @@ class HeteroPGExplainer(PGExplainer): ...@@ -770,16 +836,16 @@ class HeteroPGExplainer(PGExplainer):
homo_graph = to_homogeneous(graph, ndata=["emb"]) homo_graph = to_homogeneous(graph, ndata=["emb"])
homo_embed = homo_graph.ndata["emb"] homo_embed = homo_graph.ndata["emb"]
edge_idx = homo_graph.edges() col, row = homo_graph.edges()
col, row = edge_idx
col_emb = homo_embed[col.long()] col_emb = homo_embed[col.long()]
row_emb = homo_embed[row.long()] row_emb = homo_embed[row.long()]
emb = torch.cat([col_emb, row_emb], dim=-1) emb = torch.cat([col_emb, row_emb], dim=-1)
emb = self.elayers(emb) emb = self.elayers(emb)
values = emb.reshape(-1) values = emb.reshape(-1)
values = self.concrete_sample(values, beta=tmp, training=training) values = self.concrete_sample(
values, beta=temperature, training=training
)
self.sparse_mask_values = values self.sparse_mask_values = values
reverse_eids = homo_graph.edge_ids(row, col).long() reverse_eids = homo_graph.edge_ids(row, col).long()
...@@ -788,14 +854,11 @@ class HeteroPGExplainer(PGExplainer): ...@@ -788,14 +854,11 @@ class HeteroPGExplainer(PGExplainer):
self.set_masks(homo_graph, edge_mask) self.set_masks(homo_graph, edge_mask)
# convert the edge mask back into heterogeneous format # convert the edge mask back into heterogeneous format
hetero_edge_mask = { hetero_edge_mask = self._edge_mask_to_heterogeneous(
etype: edge_mask[ edge_mask=edge_mask,
(homo_graph.edata[ETYPE] == graph.get_etype_id(etype)) homograph=homo_graph,
.nonzero() heterograph=graph,
.squeeze(1) )
]
for etype in graph.canonical_etypes
}
# the model prediction with the updated edge mask # the model prediction with the updated edge mask
logits = self.model(graph, feat, edge_weight=hetero_edge_mask, **kwargs) logits = self.model(graph, feat, edge_weight=hetero_edge_mask, **kwargs)
...@@ -807,3 +870,270 @@ class HeteroPGExplainer(PGExplainer): ...@@ -807,3 +870,270 @@ class HeteroPGExplainer(PGExplainer):
self.clear_masks() self.clear_masks()
return (probs, hetero_edge_mask) return (probs, hetero_edge_mask)
def explain_node(
self, nodes, graph, feat, temperature=1.0, training=False, **kwargs
):
r"""Learn and return an edge mask that plays a crucial role to
explain the prediction made by the GNN for provided set of node IDs.
Also, return the prediction made with the batched graph and edge mask.
Parameters
----------
nodes : dict[str, Iterable[int]]
A dict mapping node types (keys) to an iterable set of node ids (values).
graph : DGLGraph
A heterogeneous graph.
feat : dict[str, Tensor]
A dict mapping node types (keys) to feature tensors (values).
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
the number of nodes for node type :math:`t`, and :math:`D_t` is the
feature size for node type :math:`t`
temperature : float
The temperature parameter fed to the sampling procedure.
training : bool
Training the explanation network.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
dict[str, Tensor]
A dict mapping node types (keys) to classification probabilities
for node labels (values). The values are tensors of shape
:math:`(N_t, L)`, where :math:`L` is the different types of node
labels in the dataset, and :math:`N_t` is the number of nodes in
the graph for node type :math:`t`.
dict[str, Tensor]
A dict mapping edge types (keys) to edge tensors (values) of shape
:math:`(E_t)`, where :math:`E_t` is the number of edges in the graph
for edge type :math:`t`. A higher weight suggests a larger
contribution of the edge.
DGLGraph
The batched set of subgraphs induced on the k-hop in-neighborhood
of the input center nodes.
dict[str, Tensor]
A dict mapping node types (keys) to a tensor of node IDs (values)
which correspond to the subgraph center nodes.
Examples
--------
>>> import dgl
>>> import torch as th
>>> import torch.nn as nn
>>> import numpy as np
>>> # Define the model
>>> class Model(nn.Module):
... def __init__(self, in_feats, hid_feats, out_feats, rel_names):
... super().__init__()
... self.conv = dgl.nn.HeteroGraphConv(
... {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names},
... aggregate="sum",
... )
... self.fc = nn.Linear(hid_feats, out_feats)
... nn.init.xavier_uniform_(self.fc.weight)
...
... def forward(self, g, h, embed=False, edge_weight=None):
... if edge_weight:
... mod_kwargs = {
... etype: {"edge_weight": mask} for etype, mask in edge_weight.items()
... }
... h = self.conv(g, h, mod_kwargs=mod_kwargs)
... else:
... h = self.conv(g, h)
...
... return h
>>> # Load dataset
>>> input_dim = 5
>>> hidden_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim)
>>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes)
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
... logits = model(g, g.ndata["h"])['user']
... loss = th.nn.functional.cross_entropy(logits, th.tensor([1,1,1]))
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Initialize the explainer
>>> explainer = dgl.nn.HeteroPGExplainer(
... model, hidden_dim, num_hops=2, explain_graph=False
... )
>>> # Train the explainer
>>> # Define explainer temperature parameter
>>> init_tmp, final_tmp = 5.0, 1.0
>>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)
>>> for epoch in range(20):
... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))
... loss = explainer.train_step_node(
... { ntype: g.nodes(ntype) for ntype in g.ntypes },
... g, g.ndata["h"], tmp
... )
... optimizer_exp.zero_grad()
... loss.backward()
... optimizer_exp.step()
>>> # Explain the graph
>>> feat = g.ndata.pop("h")
>>> probs, edge_mask, bg, inverse_indices = explainer.explain_node(
... { "user": [0] }, g, feat
... )
"""
assert (
not self.graph_explanation
), '"explain_graph" must be False when initializing the module.'
assert (
self.num_hops is not None
), '"num_hops" must be provided when initializing the module.'
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
batched_embed = []
batched_homo_graph = []
batched_hetero_graph = []
for target_ntype, target_nids in nodes.items():
if isinstance(target_nids, torch.Tensor):
target_nids = target_nids.tolist()
for target_nid in target_nids:
sg, inverse_indices = khop_in_subgraph(
graph, {target_ntype: target_nid}, self.num_hops
)
for sg_ntype in sg.ntypes:
sg_feat = feat[sg_ntype][sg.ndata[NID][sg_ntype].long()]
train_mask = [
sg_ntype in inverse_indices
and node_id in inverse_indices[sg_ntype]
for node_id in sg.nodes(sg_ntype)
]
sg.nodes[sg_ntype].data["feat"] = sg_feat
sg.nodes[sg_ntype].data["train"] = torch.tensor(
train_mask, device=sg.device
)
embed = self.model(sg, sg.ndata["feat"], embed=True, **kwargs)
for ntype in embed.keys():
sg.nodes[ntype].data["emb"] = embed[ntype].data
homo_sg = to_homogeneous(sg, ndata=["emb"])
homo_sg_embed = homo_sg.ndata["emb"]
col, row = homo_sg.edges()
col_emb = homo_sg_embed[col.long()]
row_emb = homo_sg_embed[row.long()]
self_emb = homo_sg_embed[
inverse_indices[target_ntype][0]
].repeat(sg.num_edges(), 1)
emb = torch.cat([col_emb, row_emb, self_emb], dim=-1)
batched_embed.append(emb)
batched_homo_graph.append(homo_sg)
batched_hetero_graph.append(sg)
batched_homo_graph = batch(batched_homo_graph)
batched_hetero_graph = batch(batched_hetero_graph)
batched_embed = torch.cat(batched_embed)
batched_embed = self.elayers(batched_embed)
values = batched_embed.reshape(-1)
values = self.concrete_sample(
values, beta=temperature, training=training
)
self.sparse_mask_values = values
col, row = batched_homo_graph.edges()
reverse_eids = batched_homo_graph.edge_ids(row, col).long()
edge_mask = (values + values[reverse_eids]) / 2
self.set_masks(batched_homo_graph, edge_mask)
# Convert the edge mask back into heterogeneous format.
hetero_edge_mask = self._edge_mask_to_heterogeneous(
edge_mask=edge_mask,
homograph=batched_homo_graph,
heterograph=batched_hetero_graph,
)
batched_feats = {
ntype: batched_hetero_graph.nodes[ntype].data["feat"]
for ntype in batched_hetero_graph.ntypes
}
# The model prediction with the updated edge mask.
logits = self.model(
batched_hetero_graph,
batched_feats,
edge_weight=hetero_edge_mask,
**kwargs,
)
probs = {
ntype: F.softmax(logits[ntype], dim=-1) for ntype in logits.keys()
}
batched_inverse_indices = {
ntype: batched_hetero_graph.nodes[ntype]
.data["train"]
.nonzero()
.squeeze(1)
for ntype in batched_hetero_graph.ntypes
}
if training:
self.batched_feats = batched_feats
probs = {ntype: probs[ntype].data for ntype in probs.keys()}
else:
self.clear_masks()
return (
probs,
hetero_edge_mask,
batched_hetero_graph,
batched_inverse_indices,
)
def _edge_mask_to_heterogeneous(self, edge_mask, homograph, heterograph):
r"""Convert an edge mask from homogeneous mappings built through
embeddings into heterogenous format by leveraging the context from
the source DGLGraphs in homogenous and heterogeneous form.
The `edge_mask` needs to have been built using the embedding of the
homogenous graph format for the mappings to work correctly.
Parameters
----------
edge_mask : dict[str, Tensor]
A dict mapping node types (keys) to a tensor of edge weights (values).
homograph : DGLGraph
The homogeneous form of the source graph.
heterograph : DGLGraph
The heterogeneous form of the source graph.
Returns
-------
dict[str, Tensor]
A dict mapping node types (keys) to tensors of node ids (values)
"""
return {
etype: edge_mask[
(homograph.edata[ETYPE] == heterograph.get_etype_id(etype))
.nonzero()
.squeeze(1)
]
for etype in heterograph.canonical_etypes
}
...@@ -1898,8 +1898,11 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes): ...@@ -1898,8 +1898,11 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes):
g = transform2(g) g = transform2(g)
class Model(th.nn.Module): class Model(th.nn.Module):
def __init__(self, in_feats, embed_dim, out_feats, canonical_etypes): def __init__(
self, in_feats, embed_dim, out_feats, canonical_etypes, graph=True
):
super(Model, self).__init__() super(Model, self).__init__()
self.graph = graph
self.conv = nn.HeteroGraphConv( self.conv = nn.HeteroGraphConv(
{ {
c_etype: nn.GraphConv(in_feats, embed_dim) c_etype: nn.GraphConv(in_feats, embed_dim)
...@@ -1918,7 +1921,7 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes): ...@@ -1918,7 +1921,7 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes):
else: else:
h = self.conv(g, h) h = self.conv(g, h)
if embed: if not self.graph or embed:
return h return h
with g.local_scope(): with g.local_scope():
...@@ -1931,13 +1934,33 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes): ...@@ -1931,13 +1934,33 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes):
embed_dim = input_dim embed_dim = input_dim
# graph explainer # graph explainer
model = Model(input_dim, embed_dim, n_classes, g.canonical_etypes) model = Model(
input_dim, embed_dim, n_classes, g.canonical_etypes, graph=True
)
model = model.to(ctx) model = model.to(ctx)
explainer = nn.HeteroPGExplainer(model, embed_dim) explainer = nn.HeteroPGExplainer(model, embed_dim)
explainer.train_step(g, feat, 5.0) explainer.train_step(g, feat, 5.0)
probs, edge_weight = explainer.explain_graph(g, feat) probs, edge_weight = explainer.explain_graph(g, feat)
# node explainer
model = Model(
input_dim, embed_dim, n_classes, g.canonical_etypes, graph=False
)
model = model.to(ctx)
explainer = nn.HeteroPGExplainer(
model, embed_dim, num_hops=1, explain_graph=False
)
explainer.train_step_node({g.ntypes[0]: [0]}, g, feat, 5.0)
explainer.train_step_node({g.ntypes[0]: th.tensor([0, 1])}, g, feat, 5.0)
probs, edge_weight, bg, inverse_indices = explainer.explain_node(
{g.ntypes[0]: [0]}, g, feat
)
probs, edge_weight, bg, inverse_indices = explainer.explain_node(
{g.ntypes[0]: th.tensor([0, 1])}, g, feat
)
def test_jumping_knowledge(): def test_jumping_knowledge():
ctx = F.ctx() ctx = F.ctx()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment