Unverified Commit b0308c85 authored by Nick Baker's avatar Nick Baker Committed by GitHub
Browse files

[Model] Implemented PGExplainer for Heterogeneous graph (#5739)


Co-authored-by: default avatarkxm180046 <kxm180046@utdallas.edu>
Co-authored-by: default avatarKunal Mukherjee <kunmukh@gmail.com>
parent 3a33c8b5
......@@ -130,6 +130,7 @@ Utility Modules
~dgl.nn.pytorch.explain.SubgraphX
~dgl.nn.pytorch.explain.HeteroSubgraphX
~dgl.nn.pytorch.explain.PGExplainer
~dgl.nn.pytorch.explain.HeteroPGExplainer
~dgl.nn.pytorch.utils.LabelPropagation
~dgl.nn.pytorch.graph_transformer.DegreeEncoder
~dgl.nn.pytorch.utils.LaplacianPosEnc
......
......@@ -5,7 +5,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ["PGExplainer"]
from .... import ETYPE, to_homogeneous
__all__ = ["PGExplainer", "HeteroPGExplainer"]
class PGExplainer(nn.Module):
......@@ -62,7 +64,7 @@ class PGExplainer(nn.Module):
nn.Linear(self.num_features, 64), nn.ReLU(), nn.Linear(64, 1)
)
def set_masks(self, graph, feat, edge_mask=None):
def set_masks(self, graph, edge_mask=None):
r"""Set the edge mask that plays a crucial role to explain the
prediction made by the GNN for a graph. Initialize learnable edge
mask if it is None.
......@@ -71,16 +73,13 @@ class PGExplainer(nn.Module):
----------
graph : DGLGraph
A homogeneous graph.
feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size.
edge_mask : Tensor, optional
Learned importance mask of the edges in the graph, which is a tensor
of shape :math:`(E)`, where :math:`E` is the number of edges in the
graph. The values are within range :math:`(0, 1)`. The higher,
the more important. Default: None.
"""
num_nodes, _ = feat.shape
num_nodes = graph.num_nodes()
num_edges = graph.num_edges()
init_bias = self.init_bias
......@@ -198,8 +197,7 @@ class PGExplainer(nn.Module):
return gate_inputs
def train_step(self, graph, feat, tmp, **kwargs):
r"""Training the explanation network by gradient descent(GD)
using Adam optimizer
r"""Compute the loss of the explanation network
Parameters
----------
......@@ -223,12 +221,10 @@ class PGExplainer(nn.Module):
pred = self.model(graph, feat, embed=False, **kwargs).argmax(-1).data
prob, edge_mask = self.explain_graph(
prob, _ = self.explain_graph(
graph, feat, tmp=tmp, training=True, **kwargs
)
self.edge_mask = edge_mask
loss_tmp = self.loss(prob, pred)
return loss_tmp
......@@ -283,12 +279,14 @@ class PGExplainer(nn.Module):
...
... def forward(self, g, h, embed=False, edge_weight=None):
... h = self.conv(g, h, edge_weight=edge_weight)
... if not embed:
...
... if embed:
... return h
...
... with g.local_scope():
... g.ndata['h'] = h
... hg = dgl.mean_nodes(g, 'h')
... return self.fc(hg)
... else:
... return h
>>> # Load dataset
>>> data = GINDataset('MUTAG', self_loop=True)
......@@ -348,13 +346,225 @@ class PGExplainer(nn.Module):
reverse_eids = graph.edge_ids(row, col).long()
edge_mask = (values + values[reverse_eids]) / 2
self.clear_masks()
self.set_masks(graph, feat, edge_mask)
self.set_masks(graph, edge_mask)
# the model prediction with the updated edge mask
logits = self.model(graph, feat, edge_weight=self.edge_mask, **kwargs)
probs = F.softmax(logits, dim=-1)
if not training:
self.clear_masks()
return (probs, edge_mask) if training else (probs.data, edge_mask)
class HeteroPGExplainer(PGExplainer):
r"""PGExplainer from `Parameterized Explainer for Graph Neural Network
<https://arxiv.org/pdf/2011.04573>`__, adapted for heterogeneous graphs
PGExplainer adopts a deep neural network (explanation network) to parameterize the generation
process of explanations, which enables it to explain multiple instances
collectively. PGExplainer models the underlying structure as edge
distributions, from which the explanatory graph is sampled.
Parameters
----------
model : nn.Module
The GNN model to explain that tackles multiclass graph classification
* Its forward function must have the form
:attr:`forward(self, graph, nfeat, embed, edge_weight)`.
* The output of its forward function is the logits if embed=False else
the intermediate node embeddings.
num_features : int
Node embedding size used by :attr:`model`.
coff_budget : float, optional
Size regularization to constrain the explanation size. Default: 0.01.
coff_connect : float, optional
Entropy regularization to constrain the connectivity of explanation. Default: 5e-4.
sample_bias : float, optional
Some members of a population are systematically more likely to be selected
in a sample than others. Default: 0.0.
"""
def train_step(self, graph, feat, tmp, **kwargs):
r"""Compute the loss of the explanation network
Parameters
----------
graph : DGLGraph
Input batched heterogeneous graph.
feat : dict[str, Tensor]
A dict mapping node types (keys) to feature tensors (values).
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is the
number of nodes for node type :math:`t`, and :math:`D_t` is the feature
size for node type :math:`t`
tmp : float
The temperature parameter fed to the sampling procedure.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
Tensor
A scalar tensor representing the loss.
"""
return super().train_step(graph, feat, tmp=tmp, **kwargs)
def explain_graph(self, graph, feat, tmp=1.0, training=False, **kwargs):
r"""Learn and return an edge mask that plays a crucial role to
explain the prediction made by the GNN for a graph. Also, return
the prediction made with the edges chosen based on the edge mask.
Parameters
----------
graph : DGLGraph
A heterogeneous graph.
feat : dict[str, Tensor]
A dict mapping node types (keys) to feature tensors (values).
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is the
number of nodes for node type :math:`t`, and :math:`D_t` is the feature
size for node type :math:`t`
tmp : float
The temperature parameter fed to the sampling procedure.
training : bool
Training the explanation network.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
Tensor
Classification probabilities given the masked graph. It is a tensor of
shape :math:`(B, L)`, where :math:`L` is the different types of label
in the dataset, and :math:`B` is the batch size.
dict[str, Tensor]
A dict mapping edge types (keys) to edge tensors (values) of shape :math:`(E_t)`,
where :math:`E_t` is the number of edges in the graph for edge type :math:`t`.
A higher weight suggests a larger contribution of the edge.
Examples
--------
>>> import dgl
>>> import torch as th
>>> import torch.nn as nn
>>> import numpy as np
>>> # Define the model
>>> class Model(nn.Module):
... def __init__(self, in_feats, hid_feats, out_feats, rel_names):
... super().__init__()
... self.conv = dgl.nn.HeteroGraphConv(
... {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names},
... aggregate="sum",
... )
... self.fc = nn.Linear(hid_feats, out_feats)
... nn.init.xavier_uniform_(self.fc.weight)
...
... def forward(self, g, h, embed=False, edge_weight=None):
... if edge_weight:
... mod_kwargs = {
... etype: {"edge_weight": mask} for etype, mask in edge_weight.items()
... }
... h = self.conv(g, h, mod_kwargs=mod_kwargs)
... else:
... h = self.conv(g, h)
...
... if embed:
... return h
...
... with g.local_scope():
... g.ndata["h"] = h
... hg = 0
... for ntype in g.ntypes:
... hg = hg + dgl.mean_nodes(g, "h", ntype=ntype)
... return self.fc(hg)
>>> # Load dataset
>>> input_dim = 5
>>> hidden_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim)
>>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes)
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
... logits = model(g, g.ndata["h"])
... loss = th.nn.functional.cross_entropy(logits, th.tensor([1]))
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Initialize the explainer
>>> explainer = dgl.nn.HeteroPGExplainer(model, hidden_dim)
>>> # Train the explainer
>>> # Define explainer temperature parameter
>>> init_tmp, final_tmp = 5.0, 1.0
>>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)
>>> for epoch in range(20):
... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))
... loss = explainer.train_step(g, g.ndata["h"], tmp)
... optimizer_exp.zero_grad()
... loss.backward()
... optimizer_exp.step()
>>> # Explain the graph
>>> feat = g.ndata.pop("h")
>>> probs, edge_mask = explainer.explain_graph(g, feat)
"""
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
embed = self.model(graph, feat, embed=True, **kwargs)
for ntype, emb in embed.items():
graph.nodes[ntype].data["emb"] = emb.data
homo_graph = to_homogeneous(graph, ndata=["emb"])
homo_embed = homo_graph.ndata["emb"]
edge_idx = homo_graph.edges()
col, row = edge_idx
col_emb = homo_embed[col.long()]
row_emb = homo_embed[row.long()]
emb = torch.cat([col_emb, row_emb], dim=-1)
emb = self.elayers(emb)
values = emb.reshape(-1)
values = self.concrete_sample(values, beta=tmp, training=training)
self.sparse_mask_values = values
reverse_eids = homo_graph.edge_ids(row, col).long()
edge_mask = (values + values[reverse_eids]) / 2
self.set_masks(homo_graph, edge_mask)
# convert the edge mask back into heterogeneous format
hetero_edge_mask = {
etype: edge_mask[
(homo_graph.edata[ETYPE] == graph.get_etype_id(etype))
.nonzero()
.squeeze(1)
]
for etype in graph.canonical_etypes
}
# the model prediction with the updated edge mask
logits = self.model(graph, feat, edge_weight=hetero_edge_mask, **kwargs)
probs = F.softmax(logits, dim=-1)
if not training:
self.clear_masks()
return (
(probs, hetero_edge_mask)
if training
else (probs.data, hetero_edge_mask)
)
......@@ -1823,12 +1823,14 @@ def test_pgexplainer(g, idtype, n_classes):
def forward(self, g, h, embed=False, edge_weight=None):
h = self.conv(g, h, edge_weight=edge_weight)
if not embed:
if embed:
return h
with g.local_scope():
g.ndata["h"] = h
hg = dgl.mean_nodes(g, "h")
return self.fc(hg)
else:
return h
model = Model(feat.shape[1], n_classes)
model = model.to(ctx)
......@@ -1839,6 +1841,64 @@ def test_pgexplainer(g, idtype, n_classes):
probs, edge_weight = explainer.explain_graph(g, feat)
@pytest.mark.parametrize("g", get_cases(["hetero"]))
@pytest.mark.parametrize("idtype", [F.int64])
@pytest.mark.parametrize("input_dim", [5])
@pytest.mark.parametrize("n_classes", [2])
def test_heteropgexplainer(g, idtype, input_dim, n_classes):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
feat = {
ntype: F.randn((g.num_nodes(ntype), input_dim)) for ntype in g.ntypes
}
# add self-loop and reverse edges
transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
g = transform1(g)
transform2 = dgl.transforms.AddReverse(copy_edata=True)
g = transform2(g)
class Model(th.nn.Module):
def __init__(self, in_feats, embed_dim, out_feats, canonical_etypes):
super(Model, self).__init__()
self.conv = nn.HeteroGraphConv(
{
c_etype: nn.GraphConv(in_feats, embed_dim)
for c_etype in canonical_etypes
}
)
self.fc = th.nn.Linear(embed_dim, out_feats)
def forward(self, g, h, embed=False, edge_weight=None):
if edge_weight is not None:
mod_kwargs = {
etype: {"edge_weight": mask}
for etype, mask in edge_weight.items()
}
h = self.conv(g, h, mod_kwargs=mod_kwargs)
else:
h = self.conv(g, h)
if embed:
return h
with g.local_scope():
g.ndata["h"] = h
hg = 0
for ntype in g.ntypes:
hg = hg + dgl.mean_nodes(g, "h", ntype=ntype)
return self.fc(hg)
embed_dim = input_dim
model = Model(input_dim, embed_dim, n_classes, g.canonical_etypes)
model = model.to(ctx)
explainer = nn.HeteroPGExplainer(model, embed_dim)
explainer.train_step(g, feat, 5.0)
probs, edge_weight = explainer.explain_graph(g, feat)
def test_jumping_knowledge():
ctx = F.ctx()
num_layers = 2
......
......@@ -238,3 +238,29 @@ def two_hetero_batch_with_isolated_ntypes():
num_nodes_dict={"user": 3, "game": 2, "developer": 3, "platform": 3},
)
return [g1, g2]
@register_case(["batched", "hetero"])
def batched_heterograph0():
g1 = dgl.heterograph(
{
("user", "follows", "user"): ([0, 1], [1, 2]),
("user", "follows", "developer"): ([0, 1], [1, 2]),
("user", "plays", "game"): ([0, 1, 2, 3], [0, 0, 1, 1]),
}
)
g2 = dgl.heterograph(
{
("user", "follows", "user"): ([0, 1], [1, 2]),
("user", "follows", "developer"): ([0, 1], [1, 2]),
("user", "plays", "game"): ([0, 1, 2], [0, 0, 1]),
}
)
g3 = dgl.heterograph(
{
("user", "follows", "user"): ([1], [2]),
("user", "follows", "developer"): ([0, 1, 2], [0, 2, 2]),
("user", "plays", "game"): ([0, 1], [0, 0]),
}
)
return dgl.batch([g1, g2, g3])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment