Unverified Commit dfa32ae0 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[NN] GNNExplainer (#3490)

* Update

* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* Fix

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* lint fix

* lint fix

* Fix lint

* Update

* Fix CI

* Fix CI

* Fix

* CI

* Fix

* Update

* Fix

* Fix

* Fix CI

* Fix CI
parent 55f2e872
...@@ -19,6 +19,7 @@ requirements: ...@@ -19,6 +19,7 @@ requirements:
- scipy - scipy
- networkx - networkx
- requests - requests
- tqdm
build: build:
script_env: script_env:
......
...@@ -44,7 +44,7 @@ GATConv ...@@ -44,7 +44,7 @@ GATConv
.. autoclass:: dgl.nn.pytorch.conv.GATConv .. autoclass:: dgl.nn.pytorch.conv.GATConv
:members: forward :members: forward
:show-inheritance: :show-inheritance:
GATv2Conv GATv2Conv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -163,7 +163,7 @@ TWIRLSUnfoldingAndAttention ...@@ -163,7 +163,7 @@ TWIRLSUnfoldingAndAttention
.. autoclass:: dgl.nn.pytorch.conv.TWIRLSUnfoldingAndAttention .. autoclass:: dgl.nn.pytorch.conv.TWIRLSUnfoldingAndAttention
:members: forward :members: forward
:show-inheritance: :show-inheritance:
GCN2Conv GCN2Conv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -319,3 +319,15 @@ NodeEmbedding ...@@ -319,3 +319,15 @@ NodeEmbedding
.. autoclass:: dgl.nn.pytorch.sparse_emb.NodeEmbedding .. autoclass:: dgl.nn.pytorch.sparse_emb.NodeEmbedding
:members: :members:
:show-inheritance: :show-inheritance:
Explainability Models
----------------------------------------
.. automodule:: dgl.nn.pytorch.explain
GNNExplainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.pytorch.explain.GNNExplainer
:members: explain_node, explain_graph
:show-inheritance:
...@@ -1148,18 +1148,23 @@ def count_nonzero(input): ...@@ -1148,18 +1148,23 @@ def count_nonzero(input):
# DGL should contain all the operations on index, so this set of operators # DGL should contain all the operations on index, so this set of operators
# should be gradually removed. # should be gradually removed.
def unique(input): def unique(input, return_inverse=False):
"""Returns the unique scalar elements in a tensor. """Returns the unique scalar elements in a tensor.
Parameters Parameters
---------- ----------
input : Tensor input : Tensor
Must be a 1-D tensor. Must be a 1-D tensor.
return_inverse : bool, optional
Whether to also return the indices for where elements in the original
input ended up in the returned unique list.
Returns Returns
------- -------
Tensor Tensor
A 1-D tensor containing unique elements. A 1-D tensor containing unique elements.
Tensor
A 1-D tensor containing the new positions of the elements in the input.
""" """
pass pass
......
...@@ -347,11 +347,17 @@ def count_nonzero(input): ...@@ -347,11 +347,17 @@ def count_nonzero(input):
tmp = input.asnumpy() tmp = input.asnumpy()
return np.count_nonzero(tmp) return np.count_nonzero(tmp)
def unique(input): def unique(input, return_inverse=False):
# TODO: fallback to numpy is unfortunate # TODO: fallback to numpy is unfortunate
tmp = input.asnumpy() tmp = input.asnumpy()
tmp = np.unique(tmp) if return_inverse:
return nd.array(tmp, ctx=input.context, dtype=input.dtype) tmp, inv = np.unique(tmp, return_inverse=True)
tmp = nd.array(tmp, ctx=input.context, dtype=input.dtype)
inv = nd.array(inv, ctx=input.context)
return tmp, inv
else:
tmp = np.unique(tmp)
return nd.array(tmp, ctx=input.context, dtype=input.dtype)
def full_1d(length, fill_value, dtype, ctx): def full_1d(length, fill_value, dtype, ctx):
return nd.full((length,), fill_value, dtype=dtype, ctx=ctx) return nd.full((length,), fill_value, dtype=dtype, ctx=ctx)
......
...@@ -295,10 +295,10 @@ def count_nonzero(input): ...@@ -295,10 +295,10 @@ def count_nonzero(input):
# TODO: fallback to numpy for backward compatibility # TODO: fallback to numpy for backward compatibility
return np.count_nonzero(input) return np.count_nonzero(input)
def unique(input): def unique(input, return_inverse=False):
if input.dtype == th.bool: if input.dtype == th.bool:
input = input.type(th.int8) input = input.type(th.int8)
return th.unique(input) return th.unique(input, return_inverse=return_inverse)
def full_1d(length, fill_value, dtype, ctx): def full_1d(length, fill_value, dtype, ctx):
return th.full((length,), fill_value, dtype=dtype, device=ctx) return th.full((length,), fill_value, dtype=dtype, device=ctx)
......
...@@ -413,8 +413,11 @@ def count_nonzero(input): ...@@ -413,8 +413,11 @@ def count_nonzero(input):
return int(tf.math.count_nonzero(input)) return int(tf.math.count_nonzero(input))
def unique(input): def unique(input, return_inverse=False):
return tf.unique(input).y if return_inverse:
return tf.unique(input)
else:
return tf.unique(input).y
def full_1d(length, fill_value, dtype, ctx): def full_1d(length, fill_value, dtype, ctx):
......
"""Package for pytorch-specific NN modules.""" """Package for pytorch-specific NN modules."""
from .conv import * from .conv import *
from .explain import *
from .glob import * from .glob import *
from .softmax import * from .softmax import *
from .factory import * from .factory import *
......
"""Torch modules for explanation models."""
# pylint: disable= no-member, arguments-differ, invalid-name
from .gnnexplainer import GNNExplainer
__all__ = ['GNNExplainer']
"""Torch Module for GNNExplainer"""
# pylint: disable= no-member, arguments-differ, invalid-name
from math import sqrt
import torch
from torch import nn
from tqdm import tqdm
from ....base import NID, EID
from ....subgraph import khop_in_subgraph
class GNNExplainer(nn.Module):
r"""
Description
-----------
GNNExplainer model from paper `GNNExplainer: Generating Explanations for
Graph Neural Networks <https://arxiv.org/abs/1903.03894>`__ for identifying
compact subgraph structures and small subsets of node features that play a
critical role in GNN-based node classification and graph classification.
Parameters
----------
model : nn.Module
The GNN model to explain.
* The required arguments of its forward function are graph and feat.
The latter one is for input node features.
* It should also optionally take an eweight argument for edge weights
and multiply the messages by it in message passing.
* The output of its forward function is the logits for the predicted
node/graph classes.
See also the example in :func:`explain_node` and :func:`explain_graph`.
num_hops : int
The number of hops for GNN information aggregation.
lr : float, optional
The learning rate to use, default to 0.01.
num_epochs : int, optional
The number of epochs to train.
log : bool, optional
If True, it will log the computation process, default to True.
"""
coeffs = {
'edge_size': 0.005,
'edge_ent': 1.0,
'node_feat_size': 1.0,
'node_feat_ent': 0.1
}
def __init__(self,
model,
num_hops,
lr=0.01,
num_epochs=100,
log=True):
super(GNNExplainer, self).__init__()
self.model = model
self.num_hops = num_hops
self.lr = lr
self.num_epochs = num_epochs
self.log = log
def _init_masks(self, graph, feat):
r"""Initialize learnable feature and edge mask.
Parameters
----------
graph : DGLGraph
Input graph.
feat : Tensor
Input node features.
Returns
-------
feat_mask : Tensor
Feature mask of shape :math:`(1, D)`, where :math:`D`
is the feature size.
edge_mask : Tensor
Edge mask of shape :math:`(E)`, where :math:`E` is the
number of edges.
"""
num_nodes, feat_size = feat.size()
num_edges = graph.num_edges()
device = feat.device
std = 0.1
feat_mask = nn.Parameter(torch.randn(1, feat_size, device=device) * std)
std = nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * num_nodes))
edge_mask = nn.Parameter(torch.randn(num_edges, device=device) * std)
return feat_mask, edge_mask
def _loss_regularize(self, loss, feat_mask, edge_mask):
r"""Add regularization terms to the loss.
Parameters
----------
loss : Tensor
Loss value.
feat_mask : Tensor
Feature mask of shape :math:`(1, D)`, where :math:`D`
is the feature size.
edge_mask : Tensor
Edge mask of shape :math:`(E)`, where :math:`E`
is the number of edges.
Returns
-------
Tensor
Loss value with regularization terms added.
"""
# epsilon for numerical stability
eps = 1e-15
edge_mask = edge_mask.sigmoid()
# Edge mask sparsity regularization
loss = loss + self.coeffs['edge_size'] * torch.sum(edge_mask)
# Edge mask entropy regularization
ent = - edge_mask * torch.log(edge_mask + eps) - \
(1 - edge_mask) * torch.log(1 - edge_mask + eps)
loss = loss + self.coeffs['edge_ent'] * ent.mean()
feat_mask = feat_mask.sigmoid()
# Feature mask sparsity regularization
loss = loss + self.coeffs['node_feat_size'] * torch.mean(feat_mask)
# Feature mask entropy regularization
ent = -feat_mask * torch.log(feat_mask + eps) - \
(1 - feat_mask) * torch.log(1 - feat_mask + eps)
loss = loss + self.coeffs['node_feat_ent'] * ent.mean()
return loss
def explain_node(self, node_id, graph, feat, **kwargs):
r"""Learn and return a node feature mask and subgraph that play a
crucial role to explain the prediction made by the GNN for node
:attr:`node_id`.
Parameters
----------
node_id : int
The node to explain.
graph : DGLGraph
A homogeneous graph.
feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size.
kwargs : dict
Additional arguments passed to the GNN model. Tensors whose
first dimension is the number of nodes or edges will be
assumed to be node/edge features.
Returns
-------
new_node_id : Tensor
The new ID of the input center node.
sg : DGLGraph
The subgraph induced on the k-hop in-neighborhood of :attr:`node_id`.
feat_mask : Tensor
Learned feature importance mask of shape :math:`(D)`, where :math:`D` is the
feature size. The values are within range :math:`(0, 1)`.
The higher, the more important.
edge_mask : Tensor
Learned importance mask of the edges in the subgraph, which is a tensor
of shape :math:`(E)`, where :math:`E` is the number of edges in the
subgraph. The values are within range :math:`(0, 1)`.
The higher, the more important.
Examples
--------
>>> import dgl
>>> import dgl.function as fn
>>> import torch
>>> import torch.nn as nn
>>> from dgl.data import CoraGraphDataset
>>> from dgl.nn import GNNExplainer
>>> # Load dataset
>>> data = CoraGraphDataset()
>>> g = data[0]
>>> features = g.ndata['feat']
>>> labels = g.ndata['label']
>>> train_mask = g.ndata['train_mask']
>>> # Define a model
>>> class Model(nn.Module):
... def __init__(self, in_feats, out_feats):
... super(Model, self).__init__()
... self.linear = nn.Linear(in_feats, out_feats)
...
... def forward(self, graph, feat, eweight=None):
... with graph.local_scope():
... feat = self.linear(feat)
... graph.ndata['h'] = feat
... if eweight is None:
... graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
... else:
... graph.edata['w'] = eweight
... graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
... return graph.ndata['h']
>>> # Train the model
>>> model = Model(features.shape[1], data.num_classes)
>>> criterion = nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for epoch in range(10):
... logits = model(g, features)
... loss = criterion(logits[train_mask], labels[train_mask])
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Explain the prediction for node 10
>>> explainer = GNNExplainer(model, num_hops=1)
>>> new_center, sg, feat_mask, edge_mask = explainer.explain_node(10, g, features)
>>> new_center
tensor([1])
>>> sg.num_edges()
26
>>> # Old IDs of the nodes in the subgraph
>>> sg.ndata[dgl.NID]
tensor([ 9, 10, 11, 12])
>>> # Old IDs of the edges in the subgraph
>>> sg.edata[dgl.EID]
tensor([51, 53, 56, 48, 52, 57, 47, 50, 55, 46, 49, 54])
>>> feat_mask
tensor([0.2638, 0.2738, 0.3039, ..., 0.2794, 0.2643, 0.2733])
>>> edge_mask
tensor([0.8291, 0.2065, 0.1379, 0.2265, 0.8618, 0.7038, 0.2094, 0.8847, 0.2157,
0.6595, 0.1906, 0.8184, 0.2033, 0.7211, 0.1279, 0.1668, 0.1441, 0.8571,
0.1903, 0.1125, 0.8235, 0.1913, 0.5834, 0.2248, 0.8345, 0.9270])
"""
self.model.eval()
num_nodes = graph.num_nodes()
num_edges = graph.num_edges()
# Extract node-centered k-hop subgraph and
# its associated node and edge features.
sg, inverse_indices = khop_in_subgraph(graph, node_id, self.num_hops)
sg_nodes = sg.ndata[NID].long()
sg_edges = sg.edata[EID].long()
feat = feat[sg_nodes]
for key, item in kwargs.items():
if torch.is_tensor(item) and item.size(0) == num_nodes:
item = item[sg_nodes]
elif torch.is_tensor(item) and item.size(0) == num_edges:
item = item[sg_edges]
kwargs[key] = item
# Get the initial prediction.
with torch.no_grad():
logits = self.model(graph=sg, feat=feat, **kwargs)
pred_label = logits.argmax(dim=-1)
feat_mask, edge_mask = self._init_masks(sg, feat)
params = [feat_mask, edge_mask]
optimizer = torch.optim.Adam(params, lr=self.lr)
if self.log:
pbar = tqdm(total=self.num_epochs)
pbar.set_description('Explain node {node_id}')
for _ in range(self.num_epochs):
optimizer.zero_grad()
h = feat * feat_mask.sigmoid()
logits = self.model(graph=sg, feat=h,
eweight=edge_mask.sigmoid(), **kwargs)
log_probs = logits.log_softmax(dim=-1)
loss = -log_probs[inverse_indices, pred_label[inverse_indices]]
loss = self._loss_regularize(loss, feat_mask, edge_mask)
loss.backward()
optimizer.step()
if self.log:
pbar.update(1)
if self.log:
pbar.close()
feat_mask = feat_mask.detach().sigmoid().squeeze()
edge_mask = edge_mask.detach().sigmoid()
return inverse_indices, sg, feat_mask, edge_mask
def explain_graph(self, graph, feat, **kwargs):
r"""Learn and return a node feature mask and an edge mask that play a
crucial role to explain the prediction made by the GNN for a graph.
Parameters
----------
graph : DGLGraph
A homogeneous graph.
feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size.
kwargs : dict
Additional arguments passed to the GNN model. Tensors whose
first dimension is the number of nodes or edges will be
assumed to be node/edge features.
Returns
-------
feat_mask : Tensor
Learned feature importance mask of shape :math:`(D)`, where :math:`D` is the
feature size. The values are within range :math:`(0, 1)`.
The higher, the more important.
edge_mask : Tensor
Learned importance mask of the edges in the graph, which is a tensor
of shape :math:`(E)`, where :math:`E` is the number of edges in the
graph. The values are within range :math:`(0, 1)`. The higher,
the more important.
Examples
--------
>>> import dgl.function as fn
>>> import torch
>>> import torch.nn as nn
>>> from dgl.data import GINDataset
>>> from dgl.dataloading import GraphDataLoader
>>> from dgl.nn import AvgPooling, GNNExplainer
>>> # Load dataset
>>> data = GINDataset('MUTAG', self_loop=True)
>>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Define a model
>>> class Model(nn.Module):
... def __init__(self, in_feats, out_feats):
... super(Model, self).__init__()
... self.linear = nn.Linear(in_feats, out_feats)
... self.pool = AvgPooling()
...
... def forward(self, graph, feat, eweight=None):
... with graph.local_scope():
... feat = self.linear(feat)
... graph.ndata['h'] = feat
... if eweight is None:
... graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
... else:
... graph.edata['w'] = eweight
... graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
... return self.pool(graph, graph.ndata['h'])
>>> # Train the model
>>> feat_size = data[0][0].ndata['attr'].shape[1]
>>> model = Model(feat_size, data.gclasses)
>>> criterion = nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for bg, labels in dataloader:
... logits = model(bg, bg.ndata['attr'])
... loss = criterion(logits, labels)
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Explain the prediction for graph 0
>>> explainer = GNNExplainer(model, num_hops=1)
>>> g, _ = data[0]
>>> features = g.ndata['attr']
>>> feat_mask, edge_mask = explainer.explain_graph(g, features)
>>> feat_mask
tensor([0.2362, 0.2497, 0.2622, 0.2675, 0.2649, 0.2962, 0.2533])
>>> edge_mask
tensor([0.2154, 0.2235, 0.8325, ..., 0.7787, 0.1735, 0.1847])
"""
self.model.eval()
# Get the initial prediction.
with torch.no_grad():
logits = self.model(graph=graph, feat=feat, **kwargs)
pred_label = logits.argmax(dim=-1)
feat_mask, edge_mask = self._init_masks(graph, feat)
params = [feat_mask, edge_mask]
optimizer = torch.optim.Adam(params, lr=self.lr)
if self.log:
pbar = tqdm(total=self.num_epochs)
pbar.set_description('Explain graph')
for _ in range(self.num_epochs):
optimizer.zero_grad()
h = feat * feat_mask.sigmoid()
logits = self.model(graph=graph, feat=h,
eweight=edge_mask.sigmoid(), **kwargs)
log_probs = logits.log_softmax(dim=-1)
loss = -log_probs[0, pred_label[0]]
loss = self._loss_regularize(loss, feat_mask, edge_mask)
loss.backward()
optimizer.step()
if self.log:
pbar.update(1)
if self.log:
pbar.close()
feat_mask = feat_mask.detach().sigmoid().squeeze()
edge_mask = edge_mask.detach().sigmoid()
return feat_mask, edge_mask
...@@ -594,8 +594,11 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): ...@@ -594,8 +594,11 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
Returns Returns
------- -------
G : DGLGraph DGLGraph
The subgraph. The subgraph.
Tensor or dict[str, Tensor], optional
The new IDs of the input :attr:`nodes` after node relabeling. This is returned
only when :attr:`relabel_nodes` is True. It is in the same form as :attr:`nodes`.
Notes Notes
----- -----
...@@ -615,7 +618,7 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): ...@@ -615,7 +618,7 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
>>> g = dgl.graph(([1, 1, 2, 3, 4], [0, 2, 0, 4, 2])) >>> g = dgl.graph(([1, 1, 2, 3, 4], [0, 2, 0, 4, 2]))
>>> g.edata['w'] = torch.arange(10).view(5, 2) >>> g.edata['w'] = torch.arange(10).view(5, 2)
>>> sg = dgl.khop_in_subgraph(g, 0, k=2) >>> sg, inverse_indices = dgl.khop_in_subgraph(g, 0, k=2)
>>> sg >>> sg
Graph(num_nodes=4, num_edges=4, Graph(num_nodes=4, num_edges=4,
ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)} ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
...@@ -630,17 +633,21 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): ...@@ -630,17 +633,21 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
[2, 3], [2, 3],
[4, 5], [4, 5],
[8, 9]]) [8, 9]])
>>> inverse_indices
tensor([0])
Extract a subgraph from a heterogeneous graph. Extract a subgraph from a heterogeneous graph.
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]), ... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),
... ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])}) ... ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])})
>>> sg = dgl.khop_in_subgraph(g, {'game': 0}, k=2) >>> sg, inverse_indices = dgl.khop_in_subgraph(g, {'game': 0}, k=2)
>>> sg >>> sg
Graph(num_nodes={'game': 1, 'user': 2}, Graph(num_nodes={'game': 1, 'user': 2},
num_edges={('user', 'follows', 'user'): 1, ('user', 'plays', 'game'): 2}, num_edges={('user', 'follows', 'user'): 1, ('user', 'plays', 'game'): 2},
metagraph=[('user', 'user', 'follows'), ('user', 'game', 'plays')]) metagraph=[('user', 'user', 'follows'), ('user', 'game', 'plays')])
>>> inverse_indices
{'game': tensor([0])}
See also See also
-------- --------
...@@ -649,7 +656,8 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): ...@@ -649,7 +656,8 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
if graph.is_block: if graph.is_block:
raise DGLError('Extracting subgraph of a block graph is not allowed.') raise DGLError('Extracting subgraph of a block graph is not allowed.')
if not isinstance(nodes, Mapping): is_mapping = isinstance(nodes, Mapping)
if not is_mapping:
assert len(graph.ntypes) == 1, \ assert len(graph.ntypes) == 1, \
'need a dict of node type and IDs for graph with multiple node types' 'need a dict of node type and IDs for graph with multiple node types'
nodes = {graph.ntypes[0]: nodes} nodes = {graph.ntypes[0]: nodes}
...@@ -675,12 +683,25 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): ...@@ -675,12 +683,25 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
last_hop_nodes = current_hop_nodes last_hop_nodes = current_hop_nodes
k_hop_nodes = dict() k_hop_nodes = dict()
inverse_indices = dict()
for nty in graph.ntypes: for nty in graph.ntypes:
k_hop_nodes[nty] = F.unique(F.cat([ k_hop_nodes[nty], inverse_indices[nty] = F.unique(F.cat([
hop_nodes.get(nty, place_holder) hop_nodes.get(nty, place_holder)
for hop_nodes in k_hop_nodes_], dim=0)) for hop_nodes in k_hop_nodes_], dim=0), return_inverse=True)
return node_subgraph(graph, k_hop_nodes, relabel_nodes=relabel_nodes, store_ids=store_ids) sub_g = node_subgraph(graph, k_hop_nodes, relabel_nodes=relabel_nodes, store_ids=store_ids)
if relabel_nodes:
if is_mapping:
seed_inverse_indices = dict()
for nty in nodes:
seed_inverse_indices[nty] = F.slice_axis(
inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty]))
else:
seed_inverse_indices = F.slice_axis(
inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty]))
return sub_g, seed_inverse_indices
else:
return sub_g
DGLHeteroGraph.khop_in_subgraph = utils.alias_func(khop_in_subgraph) DGLHeteroGraph.khop_in_subgraph = utils.alias_func(khop_in_subgraph)
...@@ -726,8 +747,11 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): ...@@ -726,8 +747,11 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
Returns Returns
------- -------
G : DGLGraph DGLGraph
The subgraph. The subgraph.
Tensor or dict[str, Tensor], optional
The new IDs of the input :attr:`nodes` after node relabeling. This is returned
only when :attr:`relabel_nodes` is True. It is in the same form as :attr:`nodes`.
Notes Notes
----- -----
...@@ -747,7 +771,7 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): ...@@ -747,7 +771,7 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
>>> g = dgl.graph(([0, 2, 0, 4, 2], [1, 1, 2, 3, 4])) >>> g = dgl.graph(([0, 2, 0, 4, 2], [1, 1, 2, 3, 4]))
>>> g.edata['w'] = torch.arange(10).view(5, 2) >>> g.edata['w'] = torch.arange(10).view(5, 2)
>>> sg = dgl.khop_out_subgraph(g, 0, k=2) >>> sg, inverse_indices = dgl.khop_out_subgraph(g, 0, k=2)
>>> sg >>> sg
Graph(num_nodes=4, num_edges=4, Graph(num_nodes=4, num_edges=4,
ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)} ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
...@@ -762,17 +786,21 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): ...@@ -762,17 +786,21 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
[4, 5], [4, 5],
[2, 3], [2, 3],
[8, 9]]) [8, 9]])
>>> inverse_indices
tensor([0])
Extract a subgraph from a heterogeneous graph. Extract a subgraph from a heterogeneous graph.
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]), ... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),
... ('user', 'follows', 'user'): ([0, 1], [1, 3])}) ... ('user', 'follows', 'user'): ([0, 1], [1, 3])})
>>> sg = dgl.khop_out_subgraph(g, {'user': 0}, k=2) >>> sg, inverse_indices = dgl.khop_out_subgraph(g, {'user': 0}, k=2)
>>> sg >>> sg
Graph(num_nodes={'game': 2, 'user': 3}, Graph(num_nodes={'game': 2, 'user': 3},
num_edges={('user', 'follows', 'user'): 2, ('user', 'plays', 'game'): 2}, num_edges={('user', 'follows', 'user'): 2, ('user', 'plays', 'game'): 2},
metagraph=[('user', 'user', 'follows'), ('user', 'game', 'plays')]) metagraph=[('user', 'user', 'follows'), ('user', 'game', 'plays')])
>>> inverse_indices
{'user': tensor([0])}
See also See also
-------- --------
...@@ -781,7 +809,8 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): ...@@ -781,7 +809,8 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
if graph.is_block: if graph.is_block:
raise DGLError('Extracting subgraph of a block graph is not allowed.') raise DGLError('Extracting subgraph of a block graph is not allowed.')
if not isinstance(nodes, Mapping): is_mapping = isinstance(nodes, Mapping)
if not is_mapping:
assert len(graph.ntypes) == 1, \ assert len(graph.ntypes) == 1, \
'need a dict of node type and IDs for graph with multiple node types' 'need a dict of node type and IDs for graph with multiple node types'
nodes = {graph.ntypes[0]: nodes} nodes = {graph.ntypes[0]: nodes}
...@@ -808,12 +837,25 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): ...@@ -808,12 +837,25 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True):
last_hop_nodes = current_hop_nodes last_hop_nodes = current_hop_nodes
k_hop_nodes = dict() k_hop_nodes = dict()
inverse_indices = dict()
for nty in graph.ntypes: for nty in graph.ntypes:
k_hop_nodes[nty] = F.unique(F.cat([ k_hop_nodes[nty], inverse_indices[nty] = F.unique(F.cat([
hop_nodes.get(nty, place_holder) hop_nodes.get(nty, place_holder)
for hop_nodes in k_hop_nodes_], dim=0)) for hop_nodes in k_hop_nodes_], dim=0), return_inverse=True)
return node_subgraph(graph, k_hop_nodes, relabel_nodes=relabel_nodes, store_ids=store_ids) sub_g = node_subgraph(graph, k_hop_nodes, relabel_nodes=relabel_nodes, store_ids=store_ids)
if relabel_nodes:
if is_mapping:
seed_inverse_indices = dict()
for nty in nodes:
seed_inverse_indices[nty] = F.slice_axis(
inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty]))
else:
seed_inverse_indices = F.slice_axis(
inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty]))
return sub_g, seed_inverse_indices
else:
return sub_g
DGLHeteroGraph.khop_out_subgraph = utils.alias_func(khop_out_subgraph) DGLHeteroGraph.khop_out_subgraph = utils.alias_func(khop_out_subgraph)
......
...@@ -172,6 +172,7 @@ setup( ...@@ -172,6 +172,7 @@ setup(
'scipy>=1.1.0', 'scipy>=1.1.0',
'networkx>=2.1', 'networkx>=2.1',
'requests>=2.19.0', 'requests>=2.19.0',
'tqdm'
], ],
url='https://github.com/dmlc/dgl', url='https://github.com/dmlc/dgl',
distclass=BinaryDistribution, distclass=BinaryDistribution,
......
...@@ -455,7 +455,7 @@ def test_khop_in_subgraph(idtype): ...@@ -455,7 +455,7 @@ def test_khop_in_subgraph(idtype):
[6, 7], [6, 7],
[8, 9] [8, 9]
]) ])
sg = dgl.khop_in_subgraph(g, 0, k=2) sg, inv = dgl.khop_in_subgraph(g, 0, k=2)
assert sg.idtype == g.idtype assert sg.idtype == g.idtype
u, v = sg.edges() u, v = sg.edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
...@@ -467,25 +467,27 @@ def test_khop_in_subgraph(idtype): ...@@ -467,25 +467,27 @@ def test_khop_in_subgraph(idtype):
[4, 5], [4, 5],
[8, 9] [8, 9]
])) ]))
assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))
# Test multiple nodes # Test multiple nodes
sg = dgl.khop_in_subgraph(g, [0, 2], k=1) sg, inv = dgl.khop_in_subgraph(g, [0, 2], k=1)
assert sg.num_edges() == 4 assert sg.num_edges() == 4
sg = dgl.khop_in_subgraph(g, F.tensor([0, 2], idtype), k=1) sg, inv = dgl.khop_in_subgraph(g, F.tensor([0, 2], idtype), k=1)
assert sg.num_edges() == 4 assert sg.num_edges() == 4
# Test isolated node # Test isolated node
sg = dgl.khop_in_subgraph(g, 1, k=2) sg, inv = dgl.khop_in_subgraph(g, 1, k=2)
assert sg.idtype == g.idtype assert sg.idtype == g.idtype
assert sg.num_nodes() == 1 assert sg.num_nodes() == 1
assert sg.num_edges() == 0 assert sg.num_edges() == 0
assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))
g = dgl.heterograph({ g = dgl.heterograph({
('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]), ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),
('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2]), ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2]),
}, idtype=idtype, device=F.ctx()) }, idtype=idtype, device=F.ctx())
sg = dgl.khop_in_subgraph(g, {'game': 0}, k=2) sg, inv = dgl.khop_in_subgraph(g, {'game': 0}, k=2)
assert sg.idtype == idtype assert sg.idtype == idtype
assert sg.num_nodes('game') == 1 assert sg.num_nodes('game') == 1
assert sg.num_nodes('user') == 2 assert sg.num_nodes('user') == 2
...@@ -497,23 +499,27 @@ def test_khop_in_subgraph(idtype): ...@@ -497,23 +499,27 @@ def test_khop_in_subgraph(idtype):
u, v = sg['plays'].edges() u, v = sg['plays'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert edge_set == {(0, 0), (1, 0)} assert edge_set == {(0, 0), (1, 0)}
assert F.array_equal(F.astype(inv['game'], idtype), F.tensor([0], idtype))
# Test isolated node # Test isolated node
sg = dgl.khop_in_subgraph(g, {'user': 0}, k=2) sg, inv = dgl.khop_in_subgraph(g, {'user': 0}, k=2)
assert sg.idtype == idtype assert sg.idtype == idtype
assert sg.num_nodes('game') == 0 assert sg.num_nodes('game') == 0
assert sg.num_nodes('user') == 1 assert sg.num_nodes('user') == 1
assert sg.num_edges('follows') == 0 assert sg.num_edges('follows') == 0
assert sg.num_edges('plays') == 0 assert sg.num_edges('plays') == 0
assert F.array_equal(F.astype(inv['user'], idtype), F.tensor([0], idtype))
# Test multiple nodes # Test multiple nodes
sg = dgl.khop_in_subgraph(g, {'user': F.tensor([0, 1], idtype), 'game': 0}, k=1) sg, inv = dgl.khop_in_subgraph(g, {'user': F.tensor([0, 1], idtype), 'game': 0}, k=1)
u, v = sg['follows'].edges() u, v = sg['follows'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert edge_set == {(0, 1)} assert edge_set == {(0, 1)}
u, v = sg['plays'].edges() u, v = sg['plays'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert edge_set == {(0, 0), (1, 0)} assert edge_set == {(0, 0), (1, 0)}
assert F.array_equal(F.astype(inv['user'], idtype), F.tensor([0, 1], idtype))
assert F.array_equal(F.astype(inv['game'], idtype), F.tensor([0], idtype))
@parametrize_dtype @parametrize_dtype
def test_khop_out_subgraph(idtype): def test_khop_out_subgraph(idtype):
...@@ -525,7 +531,7 @@ def test_khop_out_subgraph(idtype): ...@@ -525,7 +531,7 @@ def test_khop_out_subgraph(idtype):
[6, 7], [6, 7],
[8, 9] [8, 9]
]) ])
sg = dgl.khop_out_subgraph(g, 0, k=2) sg, inv = dgl.khop_out_subgraph(g, 0, k=2)
assert sg.idtype == g.idtype assert sg.idtype == g.idtype
u, v = sg.edges() u, v = sg.edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
...@@ -537,25 +543,27 @@ def test_khop_out_subgraph(idtype): ...@@ -537,25 +543,27 @@ def test_khop_out_subgraph(idtype):
[2, 3], [2, 3],
[8, 9] [8, 9]
])) ]))
assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))
# Test multiple nodes # Test multiple nodes
sg = dgl.khop_out_subgraph(g, [0, 2], k=1) sg, inv = dgl.khop_out_subgraph(g, [0, 2], k=1)
assert sg.num_edges() == 4 assert sg.num_edges() == 4
sg = dgl.khop_out_subgraph(g, F.tensor([0, 2], idtype), k=1) sg, inv = dgl.khop_out_subgraph(g, F.tensor([0, 2], idtype), k=1)
assert sg.num_edges() == 4 assert sg.num_edges() == 4
# Test isolated node # Test isolated node
sg = dgl.khop_out_subgraph(g, 1, k=2) sg, inv = dgl.khop_out_subgraph(g, 1, k=2)
assert sg.idtype == g.idtype assert sg.idtype == g.idtype
assert sg.num_nodes() == 1 assert sg.num_nodes() == 1
assert sg.num_edges() == 0 assert sg.num_edges() == 0
assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))
g = dgl.heterograph({ g = dgl.heterograph({
('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]), ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),
('user', 'follows', 'user'): ([0, 1], [1, 3]), ('user', 'follows', 'user'): ([0, 1], [1, 3]),
}, idtype=idtype, device=F.ctx()) }, idtype=idtype, device=F.ctx())
sg = dgl.khop_out_subgraph(g, {'user': 0}, k=2) sg, inv = dgl.khop_out_subgraph(g, {'user': 0}, k=2)
assert sg.idtype == idtype assert sg.idtype == idtype
assert sg.num_nodes('game') == 2 assert sg.num_nodes('game') == 2
assert sg.num_nodes('user') == 3 assert sg.num_nodes('user') == 3
...@@ -567,18 +575,22 @@ def test_khop_out_subgraph(idtype): ...@@ -567,18 +575,22 @@ def test_khop_out_subgraph(idtype):
u, v = sg['plays'].edges() u, v = sg['plays'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert edge_set == {(0,0), (1,0), (1,1)} assert edge_set == {(0,0), (1,0), (1,1)}
assert F.array_equal(F.astype(inv['user'], idtype), F.tensor([0], idtype))
# Test isolated node # Test isolated node
sg = dgl.khop_out_subgraph(g, {'user': 3}, k=2) sg, inv = dgl.khop_out_subgraph(g, {'user': 3}, k=2)
assert sg.idtype == idtype assert sg.idtype == idtype
assert sg.num_nodes('game') == 0 assert sg.num_nodes('game') == 0
assert sg.num_nodes('user') == 1 assert sg.num_nodes('user') == 1
assert sg.num_edges('follows') == 0 assert sg.num_edges('follows') == 0
assert sg.num_edges('plays') == 0 assert sg.num_edges('plays') == 0
assert F.array_equal(F.astype(inv['user'], idtype), F.tensor([0], idtype))
# Test multiple nodes # Test multiple nodes
sg = dgl.khop_out_subgraph(g, {'user': F.tensor([2], idtype), 'game': 0}, k=1) sg, inv = dgl.khop_out_subgraph(g, {'user': F.tensor([2], idtype), 'game': 0}, k=1)
assert sg.num_edges('follows') == 0 assert sg.num_edges('follows') == 0
u, v = sg['plays'].edges() u, v = sg['plays'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert edge_set == {(0, 1)} assert edge_set == {(0, 1)}
assert F.array_equal(F.astype(inv['user'], idtype), F.tensor([0], idtype))
assert F.array_equal(F.astype(inv['game'], idtype), F.tensor([0], idtype))
...@@ -154,7 +154,7 @@ def test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim): ...@@ -154,7 +154,7 @@ def test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim):
# Test a pair of tensor inputs # Test a pair of tensor inputs
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx()) conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())
# test pickle # test pickle
th.save(conv, tmp_buffer) th.save(conv, tmp_buffer)
...@@ -192,7 +192,7 @@ def test_tagconv(out_dim): ...@@ -192,7 +192,7 @@ def test_tagconv(out_dim):
conv = nn.TAGConv(5, out_dim, bias=True) conv = nn.TAGConv(5, out_dim, bias=True)
conv = conv.to(ctx) conv = conv.to(ctx)
print(conv) print(conv)
# test pickle # test pickle
th.save(conv, tmp_buffer) th.save(conv, tmp_buffer)
...@@ -610,7 +610,7 @@ def test_gatv2_conv_bi(g, idtype, out_dim, num_heads): ...@@ -610,7 +610,7 @@ def test_gatv2_conv_bi(g, idtype, out_dim, num_heads):
@pytest.mark.parametrize('num_heads', [1, 4]) @pytest.mark.parametrize('num_heads', [1, 4])
def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads): def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
egat = nn.EGATConv(in_node_feats=10, egat = nn.EGATConv(in_node_feats=10,
in_edge_feats=5, in_edge_feats=5,
out_node_feats=out_node_feats, out_node_feats=out_node_feats,
...@@ -618,12 +618,12 @@ def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads): ...@@ -618,12 +618,12 @@ def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
num_heads=num_heads) num_heads=num_heads)
nfeat = F.randn((g.number_of_nodes(), 10)) nfeat = F.randn((g.number_of_nodes(), 10))
efeat = F.randn((g.number_of_edges(), 5)) efeat = F.randn((g.number_of_edges(), 5))
egat = egat.to(ctx) egat = egat.to(ctx)
h, f = egat(g, nfeat, efeat) h, f = egat(g, nfeat, efeat)
h, f, attn = egat(g, nfeat, efeat, True) h, f, attn = egat(g, nfeat, efeat, True)
th.save(egat, tmp_buffer) th.save(egat, tmp_buffer)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
...@@ -708,7 +708,7 @@ def test_appnp_conv(g, idtype): ...@@ -708,7 +708,7 @@ def test_appnp_conv(g, idtype):
appnp = nn.APPNPConv(10, 0.1) appnp = nn.APPNPConv(10, 0.1)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_nodes(), 5))
appnp = appnp.to(ctx) appnp = appnp.to(ctx)
# test pickle # test pickle
th.save(appnp, tmp_buffer) th.save(appnp, tmp_buffer)
...@@ -731,7 +731,7 @@ def test_gin_conv(g, idtype, aggregator_type): ...@@ -731,7 +731,7 @@ def test_gin_conv(g, idtype, aggregator_type):
# test pickle # test pickle
th.save(h, tmp_buffer) th.save(h, tmp_buffer)
assert h.shape == (g.number_of_dst_nodes(), 12) assert h.shape == (g.number_of_dst_nodes(), 12)
@parametrize_dtype @parametrize_dtype
...@@ -914,7 +914,7 @@ def test_edge_conv(g, idtype, out_dim): ...@@ -914,7 +914,7 @@ def test_edge_conv(g, idtype, out_dim):
# test pickle # test pickle
th.save(edge_conv, tmp_buffer) th.save(edge_conv, tmp_buffer)
h0 = F.randn((g.number_of_src_nodes(), 5)) h0 = F.randn((g.number_of_src_nodes(), 5))
h1 = edge_conv(g, h0) h1 = edge_conv(g, h0)
assert h1.shape == (g.number_of_dst_nodes(), out_dim) assert h1.shape == (g.number_of_dst_nodes(), out_dim)
...@@ -931,7 +931,7 @@ def test_edge_conv_bi(g, idtype, out_dim): ...@@ -931,7 +931,7 @@ def test_edge_conv_bi(g, idtype, out_dim):
x0 = F.randn((g.number_of_dst_nodes(), 5)) x0 = F.randn((g.number_of_dst_nodes(), 5))
h1 = edge_conv(g, (h0, x0)) h1 = edge_conv(g, (h0, x0))
assert h1.shape == (g.number_of_dst_nodes(), out_dim) assert h1.shape == (g.number_of_dst_nodes(), out_dim)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2]) @pytest.mark.parametrize('out_dim', [1, 2])
...@@ -942,10 +942,10 @@ def test_dotgat_conv(g, idtype, out_dim, num_heads): ...@@ -942,10 +942,10 @@ def test_dotgat_conv(g, idtype, out_dim, num_heads):
dotgat = nn.DotGatConv(5, out_dim, num_heads) dotgat = nn.DotGatConv(5, out_dim, num_heads)
feat = F.randn((g.number_of_src_nodes(), 5)) feat = F.randn((g.number_of_src_nodes(), 5))
dotgat = dotgat.to(ctx) dotgat = dotgat.to(ctx)
# test pickle # test pickle
th.save(dotgat, tmp_buffer) th.save(dotgat, tmp_buffer)
h = dotgat(g, feat) h = dotgat(g, feat)
assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim) assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
_, a = dotgat(g, feat, get_attention=True) _, a = dotgat(g, feat, get_attention=True)
...@@ -1186,6 +1186,49 @@ def test_hetero_conv(agg, idtype): ...@@ -1186,6 +1186,49 @@ def test_hetero_conv(agg, idtype):
{'user': uf, 'game': gf, 'store': sf[0:0]})) {'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'} assert set(h.keys()) == {'user', 'game'}
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
def test_gnnexplainer(g, idtype, out_dim):
g = g.astype(idtype).to(F.ctx())
feat = F.randn((g.num_nodes(), 5))
class Model(th.nn.Module):
def __init__(self, in_feats, out_feats, graph=False):
super(Model, self).__init__()
self.linear = th.nn.Linear(in_feats, out_feats)
if graph:
self.pool = nn.AvgPooling()
else:
self.pool = None
def forward(self, graph, feat, eweight=None):
with graph.local_scope():
feat = self.linear(feat)
graph.ndata['h'] = feat
if eweight is None:
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
else:
graph.edata['w'] = eweight
graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
if self.pool:
return self.pool(graph, graph.ndata['h'])
else:
return graph.ndata['h']
# Explain node prediction
model = Model(5, out_dim)
model = model.to(F.ctx())
explainer = nn.GNNExplainer(model, num_hops=1)
new_center, sg, feat_mask, edge_mask = explainer.explain_node(0, g, feat)
# Explain graph prediction
model = Model(5, out_dim, graph=True)
model = model.to(F.ctx())
explainer = nn.GNNExplainer(model, num_hops=1)
feat_mask, edge_mask = explainer.explain_graph(g, feat)
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_graph_conv_e_weight() test_graph_conv_e_weight()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment