Unverified Commit 70ad5083 authored by Nick Baker's avatar Nick Baker Committed by GitHub
Browse files

[Model] Add Node explanation for Homogenous PGExplainer Impl. (#5839)


Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent e2d35f62
...@@ -7,4 +7,4 @@ ...@@ -7,4 +7,4 @@
.. autoclass:: {{ name }} .. autoclass:: {{ name }}
:show-inheritance: :show-inheritance:
:members: __getitem__, __len__, collate_fn, forward, reset_parameters, rel_emb, rel_project, explain_node, explain_graph, train_step :members: __getitem__, __len__, collate_fn, forward, reset_parameters, rel_emb, rel_project, explain_node, explain_graph, train_step, train_step_node
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .... import ETYPE, to_homogeneous from .... import batch, ETYPE, khop_in_subgraph, NID, to_homogeneous
__all__ = ["PGExplainer", "HeteroPGExplainer"] __all__ = ["PGExplainer", "HeteroPGExplainer"]
...@@ -30,6 +30,11 @@ class PGExplainer(nn.Module): ...@@ -30,6 +30,11 @@ class PGExplainer(nn.Module):
the intermediate node embeddings. the intermediate node embeddings.
num_features : int num_features : int
Node embedding size used by :attr:`model`. Node embedding size used by :attr:`model`.
num_hops : int, optional
The number of hops for GNN information aggregation, which must match the
number of message passing layers employed by the GNN to be explained.
explain_graph : bool, optional
Whether to initialize the model for graph-level or node-level predictions.
coff_budget : float, optional coff_budget : float, optional
Size regularization to constrain the explanation size. Default: 0.01. Size regularization to constrain the explanation size. Default: 0.01.
coff_connect : float, optional coff_connect : float, optional
...@@ -43,6 +48,8 @@ class PGExplainer(nn.Module): ...@@ -43,6 +48,8 @@ class PGExplainer(nn.Module):
self, self,
model, model,
num_features, num_features,
num_hops=None,
explain_graph=True,
coff_budget=0.01, coff_budget=0.01,
coff_connect=5e-4, coff_connect=5e-4,
sample_bias=0.0, sample_bias=0.0,
...@@ -50,7 +57,9 @@ class PGExplainer(nn.Module): ...@@ -50,7 +57,9 @@ class PGExplainer(nn.Module):
super(PGExplainer, self).__init__() super(PGExplainer, self).__init__()
self.model = model self.model = model
self.num_features = num_features * 2 self.graph_explanation = explain_graph
self.num_features = num_features * (2 if self.graph_explanation else 3)
self.num_hops = num_hops
# training hyperparameters for PGExplainer # training hyperparameters for PGExplainer
self.coff_budget = coff_budget self.coff_budget = coff_budget
...@@ -79,13 +88,14 @@ class PGExplainer(nn.Module): ...@@ -79,13 +88,14 @@ class PGExplainer(nn.Module):
graph. The values are within range :math:`(0, 1)`. The higher, graph. The values are within range :math:`(0, 1)`. The higher,
the more important. Default: None. the more important. Default: None.
""" """
num_nodes = graph.num_nodes()
num_edges = graph.num_edges()
init_bias = self.init_bias
std = nn.init.calculate_gain("relu") * math.sqrt(2.0 / (2 * num_nodes))
if edge_mask is None: if edge_mask is None:
num_nodes = graph.num_nodes()
num_edges = graph.num_edges()
init_bias = self.init_bias
std = nn.init.calculate_gain("relu") * math.sqrt(
2.0 / (2 * num_nodes)
)
self.edge_mask = torch.randn(num_edges) * std + init_bias self.edge_mask = torch.randn(num_edges) * std + init_bias
else: else:
self.edge_mask = edge_mask self.edge_mask = edge_mask
...@@ -126,7 +136,7 @@ class PGExplainer(nn.Module): ...@@ -126,7 +136,7 @@ class PGExplainer(nn.Module):
different types of label in the dataset and :math:`B` is different types of label in the dataset and :math:`B` is
the batch size. the batch size.
ori_pred: Tensor ori_pred: Tensor
Tensor of shape ::math:`(B, 1)`, representing the original prediction Tensor of shape :math:`(B, 1)`, representing the original prediction
for the graph, where :math:`B` is the batch size. for the graph, where :math:`B` is the batch size.
Returns Returns
...@@ -216,17 +226,69 @@ class PGExplainer(nn.Module): ...@@ -216,17 +226,69 @@ class PGExplainer(nn.Module):
Tensor Tensor
A scalar tensor representing the loss. A scalar tensor representing the loss.
""" """
assert (
self.graph_explanation
), '"explain_graph" must be True in initializing the module.'
self.model = self.model.to(graph.device) self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device) self.elayers = self.elayers.to(graph.device)
pred = self.model(graph, feat, embed=False, **kwargs).argmax(-1).data pred = self.model(graph, feat, embed=False, **kwargs)
pred = pred.argmax(-1).data
prob, _ = self.explain_graph( prob, _ = self.explain_graph(
graph, feat, tmp=tmp, training=True, **kwargs graph, feat, tmp=tmp, training=True, **kwargs
) )
loss_tmp = self.loss(prob, pred) loss = self.loss(prob, pred)
return loss_tmp return loss
def train_step_node(self, nodes, graph, feat, tmp, **kwargs):
r"""Compute the loss of the explanation network
Parameters
----------
nodes : int, iterable[int], tensor
The nodes from the graph used to train the explanation network, which cannot
have any duplicate value.
graph : DGLGraph
Input homogeneous graph.
feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size.
tmp : float
The temperature parameter fed to the sampling procedure.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
Tensor
A scalar tensor representing the loss.
"""
assert (
not self.graph_explanation
), '"explain_graph" must be False in initializing the module.'
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
if isinstance(nodes, torch.Tensor):
nodes = nodes.tolist()
if isinstance(nodes, int):
nodes = [nodes]
prob, _, batched_graph, inverse_indices = self.explain_node(
nodes, graph, feat, tmp=tmp, training=True, **kwargs
)
pred = self.model(
batched_graph, self.batched_feats, embed=False, **kwargs
)
pred = pred.argmax(-1).data
loss = self.loss(prob[inverse_indices], pred[inverse_indices])
return loss
def explain_graph(self, graph, feat, tmp=1.0, training=False, **kwargs): def explain_graph(self, graph, feat, tmp=1.0, training=False, **kwargs):
r"""Learn and return an edge mask that plays a crucial role to r"""Learn and return an edge mask that plays a crucial role to
...@@ -324,19 +386,20 @@ class PGExplainer(nn.Module): ...@@ -324,19 +386,20 @@ class PGExplainer(nn.Module):
>>> graph_feat = graph.ndata.pop("attr") >>> graph_feat = graph.ndata.pop("attr")
>>> probs, edge_weight = explainer.explain_graph(graph, graph_feat) >>> probs, edge_weight = explainer.explain_graph(graph, graph_feat)
""" """
assert (
self.graph_explanation
), '"explain_graph" must be True in initializing the module.'
self.model = self.model.to(graph.device) self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device) self.elayers = self.elayers.to(graph.device)
embed = self.model(graph, feat, embed=True, **kwargs) embed = self.model(graph, feat, embed=True, **kwargs)
embed = embed.data embed = embed.data
edge_idx = graph.edges() col, row = graph.edges()
col, row = edge_idx
col_emb = embed[col.long()] col_emb = embed[col.long()]
row_emb = embed[row.long()] row_emb = embed[row.long()]
emb = torch.cat([col_emb, row_emb], dim=-1) emb = torch.cat([col_emb, row_emb], dim=-1)
emb = self.elayers(emb) emb = self.elayers(emb)
values = emb.reshape(-1) values = emb.reshape(-1)
...@@ -352,10 +415,188 @@ class PGExplainer(nn.Module): ...@@ -352,10 +415,188 @@ class PGExplainer(nn.Module):
logits = self.model(graph, feat, edge_weight=self.edge_mask, **kwargs) logits = self.model(graph, feat, edge_weight=self.edge_mask, **kwargs)
probs = F.softmax(logits, dim=-1) probs = F.softmax(logits, dim=-1)
if not training: if training:
probs = probs.data
else:
self.clear_masks()
return (probs, edge_mask)
def explain_node(
self, nodes, graph, feat, tmp=1.0, training=False, **kwargs
):
r"""Learn and return an edge mask that plays a crucial role to
explain the prediction made by the GNN for node :attr:`node_id`.
Also, return the prediction made with the edges chosen based on
the edge mask.
Parameters
----------
nodes : int, iterable[int], tensor
The nodes from the graph, which cannot have any duplicate value.
graph : DGLGraph
A homogeneous graph.
feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size.
tmp : float
The temperature parameter fed to the sampling procedure.
training : bool
Training the explanation network.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
Tensor
Classification probabilities given the masked graph. It is a tensor of
shape :math:`(B, L)`, where :math:`L` is the different types of label
in the dataset, and :math:`B` is the batch size.
Tensor
Edge weights which is a tensor of shape :math:`(E)`, where :math:`E`
is the number of edges in the graph. A higher weight suggests a larger
contribution of the edge.
DGLGraph
The batched set of subgraphs induced on the k-hop in-neighborhood
of the input center nodes.
Tensor
The new IDs of the subgraph center nodes.
Examples
--------
>>> import dgl
>>> import numpy as np
>>> import torch
>>> # Define the model
>>> class Model(torch.nn.Module):
... def __init__(self, in_feats, out_feats):
... super().__init__()
... self.conv1 = dgl.nn.GraphConv(in_feats, out_feats)
... self.conv2 = dgl.nn.GraphConv(out_feats, out_feats)
...
... def forward(self, g, h, embed=False, edge_weight=None):
... h = self.conv1(g, h, edge_weight=edge_weight)
... if embed:
... return h
... return self.conv2(g, h)
>>> # Load dataset
>>> data = dgl.data.CoraGraphDataset(verbose=False)
>>> g = data[0]
>>> features = g.ndata["feat"]
>>> labels = g.ndata["label"]
>>> # Train the model
>>> model = Model(features.shape[1], data.num_classes)
>>> criterion = torch.nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for epoch in range(20):
... logits = model(g, features)
... loss = criterion(logits, labels)
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Initialize the explainer
>>> explainer = dgl.nn.PGExplainer(
... model, data.num_classes, num_hops=2, explain_graph=False
... )
>>> # Train the explainer
>>> # Define explainer temperature parameter
>>> init_tmp, final_tmp = 5.0, 1.0
>>> optimizer_exp = torch.optim.Adam(explainer.parameters(), lr=0.01)
>>> epochs = 10
>>> for epoch in range(epochs):
... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / epochs))
... loss = explainer.train_step_node(g.nodes(), g, features, tmp)
... optimizer_exp.zero_grad()
... loss.backward()
... optimizer_exp.step()
>>> # Explain the prediction for graph 0
>>> probs, edge_weight, bg, inverse_indices = explainer.explain_node(
... 0, g, features
... )
"""
assert (
not self.graph_explanation
), '"explain_graph" must be False in initializing the module.'
assert (
self.num_hops is not None
), '"num_hops" must be provided in initializing the module.'
if isinstance(nodes, torch.Tensor):
nodes = nodes.tolist()
if isinstance(nodes, int):
nodes = [nodes]
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
batched_graph = []
batched_feats = []
batched_embed = []
batched_inverse_indices = []
node_idx = 0
for node_id in nodes:
sg, inverse_indices = khop_in_subgraph(
graph, node_id, self.num_hops
)
sg_feat = feat[sg.ndata[NID].long()]
embed = self.model(sg, sg_feat, embed=True, **kwargs)
embed = embed.data
col, row = sg.edges()
col_emb = embed[col.long()]
row_emb = embed[row.long()]
self_emb = embed[inverse_indices[0]].repeat(sg.num_edges(), 1)
emb = torch.cat([col_emb, row_emb, self_emb], dim=-1)
batched_embed.append(emb)
batched_graph.append(sg)
batched_feats.append(sg_feat)
# node id's of subgraph mapped to batch:
# https://docs.dgl.ai/en/latest/generated/dgl.batch.html#dgl.batch
batched_inverse_indices.append(inverse_indices[0].item() + node_idx)
node_idx += sg.num_nodes()
batched_graph = batch(batched_graph)
batched_feats = torch.cat(batched_feats)
batched_embed = torch.cat(batched_embed)
batched_embed = self.elayers(batched_embed)
values = batched_embed.reshape(-1)
values = self.concrete_sample(values, beta=tmp, training=training)
self.sparse_mask_values = values
col, row = batched_graph.edges()
reverse_eids = batched_graph.edge_ids(row, col).long()
edge_mask = (values + values[reverse_eids]) / 2
self.set_masks(batched_graph, edge_mask)
# the model prediction with the updated edge mask
logits = self.model(
batched_graph, batched_feats, edge_weight=self.edge_mask, **kwargs
)
probs = F.softmax(logits, dim=-1)
if training:
self.batched_feats = batched_feats
probs = probs.data
else:
self.clear_masks() self.clear_masks()
return (probs, edge_mask) if training else (probs.data, edge_mask) return (
probs.data,
edge_mask,
batched_graph,
batched_inverse_indices,
)
class HeteroPGExplainer(PGExplainer): class HeteroPGExplainer(PGExplainer):
...@@ -560,11 +801,9 @@ class HeteroPGExplainer(PGExplainer): ...@@ -560,11 +801,9 @@ class HeteroPGExplainer(PGExplainer):
logits = self.model(graph, feat, edge_weight=hetero_edge_mask, **kwargs) logits = self.model(graph, feat, edge_weight=hetero_edge_mask, **kwargs)
probs = F.softmax(logits, dim=-1) probs = F.softmax(logits, dim=-1)
if not training: if training:
probs = probs.data
else:
self.clear_masks() self.clear_masks()
return ( return (probs, hetero_edge_mask)
(probs, hetero_edge_mask)
if training
else (probs.data, hetero_edge_mask)
)
...@@ -1826,8 +1826,9 @@ def test_pgexplainer(g, idtype, n_classes): ...@@ -1826,8 +1826,9 @@ def test_pgexplainer(g, idtype, n_classes):
g = transform(g) g = transform(g)
class Model(th.nn.Module): class Model(th.nn.Module):
def __init__(self, in_feats, out_feats): def __init__(self, in_feats, out_feats, graph=False):
super(Model, self).__init__() super(Model, self).__init__()
self.graph = graph
self.conv = nn.GraphConv(in_feats, out_feats) self.conv = nn.GraphConv(in_feats, out_feats)
self.fc = th.nn.Linear(out_feats, out_feats) self.fc = th.nn.Linear(out_feats, out_feats)
th.nn.init.xavier_uniform_(self.fc.weight) th.nn.init.xavier_uniform_(self.fc.weight)
...@@ -1835,7 +1836,7 @@ def test_pgexplainer(g, idtype, n_classes): ...@@ -1835,7 +1836,7 @@ def test_pgexplainer(g, idtype, n_classes):
def forward(self, g, h, embed=False, edge_weight=None): def forward(self, g, h, embed=False, edge_weight=None):
h = self.conv(g, h, edge_weight=edge_weight) h = self.conv(g, h, edge_weight=edge_weight)
if embed: if not self.graph or embed:
return h return h
with g.local_scope(): with g.local_scope():
...@@ -1843,14 +1844,36 @@ def test_pgexplainer(g, idtype, n_classes): ...@@ -1843,14 +1844,36 @@ def test_pgexplainer(g, idtype, n_classes):
hg = dgl.mean_nodes(g, "h") hg = dgl.mean_nodes(g, "h")
return self.fc(hg) return self.fc(hg)
model = Model(feat.shape[1], n_classes) # graph explainer
model = Model(feat.shape[1], n_classes, graph=True)
model = model.to(ctx) model = model.to(ctx)
explainer = nn.PGExplainer(model, n_classes) explainer = nn.PGExplainer(model, n_classes)
explainer.train_step(g, g.ndata["attr"], 5.0) explainer.train_step(g, g.ndata["attr"], 5.0)
probs, edge_weight = explainer.explain_graph(g, feat) probs, edge_weight = explainer.explain_graph(g, feat)
# node explainer
model = Model(feat.shape[1], n_classes, graph=False)
model = model.to(ctx)
explainer = nn.PGExplainer(
model, n_classes, num_hops=1, explain_graph=False
)
explainer.train_step_node(0, g, g.ndata["attr"], 5.0)
explainer.train_step_node([0, 1], g, g.ndata["attr"], 5.0)
explainer.train_step_node(th.tensor(0), g, g.ndata["attr"], 5.0)
explainer.train_step_node(th.tensor([0, 1]), g, g.ndata["attr"], 5.0)
probs, edge_weight, bg, inverse_indices = explainer.explain_node(0, g, feat)
probs, edge_weight, bg, inverse_indices = explainer.explain_node(
[0, 1], g, feat
)
probs, edge_weight, bg, inverse_indices = explainer.explain_node(
th.tensor(0), g, feat
)
probs, edge_weight, bg, inverse_indices = explainer.explain_node(
th.tensor([0, 1]), g, feat
)
@pytest.mark.parametrize("g", get_cases(["hetero"])) @pytest.mark.parametrize("g", get_cases(["hetero"]))
@pytest.mark.parametrize("idtype", [F.int64]) @pytest.mark.parametrize("idtype", [F.int64])
...@@ -1901,9 +1924,10 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes): ...@@ -1901,9 +1924,10 @@ def test_heteropgexplainer(g, idtype, input_dim, n_classes):
return self.fc(hg) return self.fc(hg)
embed_dim = input_dim embed_dim = input_dim
# graph explainer
model = Model(input_dim, embed_dim, n_classes, g.canonical_etypes) model = Model(input_dim, embed_dim, n_classes, g.canonical_etypes)
model = model.to(ctx) model = model.to(ctx)
explainer = nn.HeteroPGExplainer(model, embed_dim) explainer = nn.HeteroPGExplainer(model, embed_dim)
explainer.train_step(g, feat, 5.0) explainer.train_step(g, feat, 5.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment