Unverified Commit ec4271bf authored by YanJun-Zhao's avatar YanJun-Zhao Committed by GitHub
Browse files

[Example] Refactor GNNExplainer Example (#4560)



* debug

* debug

* readme

* fix readme

* fix readme

* Update

* Update

* update

* fix bug of syn2
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-9-26.ap-northeast-1.compute.internal>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 880b3b1f
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import math
class NodeExplainerModule(nn.Module):
"""
A Pytorch module for explaining a node's prediction based on its computational graph and node features.
Use two masks: One mask on edges, and another on nodes' features.
So far due to the limit of DGL on edge mask operation, this explainer need the to-be-explained models to
accept an additional input argument, edge mask, and apply this mask in their inner message parse operation.
This is current walk_around to use edge masks.
"""
# Class inner variables
loss_coef = {
"g_size": 0.05,
"feat_size": 1.0,
"g_ent": 0.1,
"feat_ent": 0.1
}
def __init__(self,
model,
num_edges,
node_feat_dim,
activation='sigmoid',
agg_fn='sum',
mask_bias=False):
super(NodeExplainerModule, self).__init__()
self.model = model
self.model.eval()
self.num_edges = num_edges
self.node_feat_dim = node_feat_dim
self.activation = activation
self.agg_fn=agg_fn
self.mask_bias = mask_bias
# Initialize parameters on masks
self.edge_mask, self.edge_mask_bias = self.create_edge_mask(self.num_edges)
self.node_feat_mask = self.create_node_feat_mask(self.node_feat_dim)
def create_edge_mask(self, num_edges, init_strategy='normal', const=1.):
"""
Based on the number of nodes in the computational graph, create a learnable mask of edges.
To adopt to DGL, change this mask from N*N adjacency matrix to the No. of edges
Parameters
----------
num_edges: Integer N, specify the number of edges.
init_strategy: String, specify the parameter initialization method
const: Float, a value for constant initialization
Returns
-------
mask and mask bias: Tensor, all in shape of N*1
"""
mask = nn.Parameter(th.Tensor(num_edges, 1))
if init_strategy == 'normal':
std = nn.init.calculate_gain("relu") * math.sqrt(
1.0 / num_edges
)
with th.no_grad():
mask.normal_(1.0, std)
elif init_strategy == "const":
nn.init.constant_(mask, const)
if self.mask_bias:
mask_bias = nn.Parameter(th.Tensor(num_edges, 1))
nn.init.constant_(mask_bias, 0.0)
else:
mask_bias = None
return mask, mask_bias
def create_node_feat_mask(self, node_feat_dim, init_strategy="normal"):
"""
Based on the dimensions of node feature in the computational graph, create a learnable mask of features.
Parameters
----------
node_feat_dim: Integer N, dimensions of node feature
init_strategy: String, specify the parameter initialization method
Returns
-------
mask: Tensor, in shape of N
"""
mask = nn.Parameter(th.Tensor(node_feat_dim))
if init_strategy == "normal":
std = 0.1
with th.no_grad():
mask.normal_(1.0, std)
elif init_strategy == "constant":
with th.no_grad():
nn.init.constant_(mask, 0.0)
return mask
def forward(self, graph, n_feats):
"""
Calculate prediction results after masking input of the given model.
Parameters
----------
graph: DGLGraph, Should be a sub_graph of the target node to be explained.
n_idx: Tensor, an integer, index of the node to be explained.
Returns
-------
new_logits: Tensor, in shape of N * Num_Classes
"""
# Step 1: Mask node feature with the inner feature mask
new_n_feats = n_feats * self.node_feat_mask.sigmoid()
edge_mask = self.edge_mask.sigmoid()
# Step 2: Add compute logits after mask node features and edges
new_logits = self.model(graph, new_n_feats, edge_mask)
return new_logits
def _loss(self, pred_logits, pred_label):
"""
Compute the losses of this explainer, which include 6 parts in author's codes:
1. The prediction loss between predict logits before and after node and edge masking;
2. Loss of edge mask itself, which tries to put the mask value to either 0 or 1;
3. Loss of node feature mask itself, which tries to put the mask value to either 0 or 1;
4. L2 loss of edge mask weights, but in sum not in mean;
5. L2 loss of node feature mask weights, which is NOT used in the author's codes;
6. Laplacian loss of the adj matrix.
In the PyG implementation, there are 5 types of losses:
1. The prediction loss between logits before and after node and edge masking;
2. Sum loss of edge mask weights;
3. Loss of edge mask entropy, which tries to put the mask value to either 0 or 1;
4. Sum loss of node feature mask weights;
5. Loss of node feature mask entropy, which tries to put the mask value to either 0 or 1;
Parameters
----------
pred_logits:Tensor, N-dim logits output of model
pred_label: Tensor, N-dim one-hot label of the label
Returns
-------
loss: Scalar, the overall loss of this explainer.
"""
# 1. prediction loss
log_logit = - F.log_softmax(pred_logits, dim=-1)
pred_loss = th.sum(log_logit * pred_label)
# 2. edge mask loss
if self.activation == 'sigmoid':
edge_mask = th.sigmoid(self.edge_mask)
elif self.activation == 'relu':
edge_mask = F.relu(self.edge_mask)
else:
raise ValueError()
edge_mask_loss = self.loss_coef['g_size'] * th.sum(edge_mask)
# 3. edge mask entropy loss
edge_ent = -edge_mask * \
th.log(edge_mask + 1e-8) - \
(1 - edge_mask) * \
th.log(1 - edge_mask + 1e-8)
edge_ent_loss = self.loss_coef['g_ent'] * th.mean(edge_ent)
# 4. node feature mask loss
if self.activation == 'sigmoid':
node_feat_mask = th.sigmoid(self.node_feat_mask)
elif self.activation == 'relu':
node_feat_mask = F.relu(self.node_feat_mask)
else:
raise ValueError()
node_feat_mask_loss = self.loss_coef['feat_size'] * th.sum(node_feat_mask)
# 5. node feature mask entry loss
node_feat_ent = -node_feat_mask * \
th.log(node_feat_mask + 1e-8) - \
(1 - node_feat_mask) * \
th.log( 1 - node_feat_mask + 1e-8)
node_feat_ent_loss = self.loss_coef['feat_ent'] * th.mean(node_feat_ent)
total_loss = pred_loss + edge_mask_loss + edge_ent_loss + node_feat_mask_loss + node_feat_ent_loss
return total_loss
\ No newline at end of file
# DGL Implementation of the GNN Explainer
# DGL Implementation of GNNExplainer
This DGL example implements the GNN Explainer model proposed in the paper [GNNExplainer: Generating Explanations for Graph Neural Networks](https://arxiv.org/abs/1903.03894).
The author's codes of implementation is in [here](https://github.com/RexYing/gnn-model-explainer).
This is a DGL example for [GNNExplainer: Generating Explanations for Graph Neural Networks](https://arxiv.org/abs/1903.03894). For the authors' original implementation,
see [here](https://github.com/RexYing/gnn-model-explainer).
The author's implementation is kind of experimental with experimental codes. So this implementation focuses on a subset of
GNN Explainer's functions, node classification, and later on extend to edge classification.
Contributors:
- [Jian Zhang](https://github.com/zhjwy9343)
- [Kounianhua Du](https://github.com/KounianhuaDu)
- [Yanjun Zhao](https://github.com/zyj-111)
Example implementor
Datasets
----------------------
This example was implemented by [Jian Zhang](https://github.com/zhjwy9343) and [Kounianhua Du](https://github.com/KounianhuaDu) at the AWS Shanghai AI Lab.
Dependencies
----------------------
- numpy 1.19.4
- pytorch 1.7.1
- dgl 0.5.3
- networkx 2.5
- matplotlib 3.3.4
Four built-in synthetic datasets are used in this example.
Datasets
----------------------
Five synthetic datasets used in the paper are used in this example. The generation codes are referenced from the author implementation.
- Syn1 (BA-SHAPES): Start with a base Barabasi-Albert (BA) graph on 300 nodes and a set of 80 five-node “house”-structured network motifs, which are attached to randomly selected nodes of the base graph. The resulting graph is further perturbed by adding 0.01N random edges. Nodes are assigned to 4 classes based on their structural roles. In a house-structured motif, there are 3 types of roles: the top, middle, and bottom node of the house. Therefore there are 4 different classes, corresponding to nodes at the top, middle, bottom of houses, and nodes that do not belong to a house.
- Syn2 (BA-COMMUNITY): A union of two BA-SHAPES graphs. Nodes have normally distributed feature vectors and are assigned to one of 8 classes based on their structural roles and community memberships.
- Syn3 (BA-GRID): The same as BA-SHAPES except that 3-by-3 grid motifs are attached to the base graph in place of house motifs.
- Syn4 (TREE-CYCLE): Start with a base 8-level balanced binary tree and 60 six-node cycle motifs, which are attached to random nodes of the base graph. Perturbed by adding 0.01N random edges.
- Syn5 (TREE-GRID): Start with a base 8-level balanced binary tree and 80 3-by-3 grid motifs, which are attached to random nodes of the base graph. Perturbed by adding 0.1N random edges.
Demo Usage
- [BA-SHAPES](https://docs.dgl.ai/generated/dgl.data.BAShapeDataset.html#dgl.data.BAShapeDataset)
- [BA-COMMUNITY](https://docs.dgl.ai/generated/dgl.data.BACommunityDataset.html#dgl.data.BACommunityDataset)
- [TREE-CYCLE](https://docs.dgl.ai/generated/dgl.data.TreeCycleDataset.html#dgl.data.TreeCycleDataset)
- [TREE-GRID](https://docs.dgl.ai/generated/dgl.data.TreeGridDataset.html#dgl.data.TreeGridDataset)
Usage
----------------------
**First**, train a demo GNN model by using a synthetic dataset.
``` python
python train_main.py --dataset syn1
```
Replace the argument of the --dataset, available options: syn1, syn2, syn3, syn4, syn5
This command trains a GNN model and save it to the "dummy_model_syn1.pth" file.
**First**, train a GNN model on a dataset.
**Second**, explain the trained model with the same data
``` python
python explain_main.py --dataset syn1 --target_class 1 --hop 2
```bash
python train_main.py --dataset $DATASET
```
Replace the dataset argument value and the target class you want to explain. The code will pick the first node in the specified class to explain. The --hop argument corresponds to the maximum hop number of the computation sub-graph. (For syn1 and syn2, hop=2. For syn3, syn4, and syn5, hop=4.)
Notice
----------------------
Because DGL does not support masked adjacency matrix as an input to the forward function of a module.
To use this Explainer, you need to add an edge_weight as the **edge mask** argument to your forward function just like
the dummy model in the models.py file. And you need to change your forward function whenever uses `.update_all` function.
Please use `dgl.function.u_mul_e` to compute the src nodes' features to the edge_weights as the mask method proposed by the
GNN Explainer paper. Check the models.py for details.
Valid options for `$DATASET`: `BAShape`, `BACommunity`, `TreeCycle`, `TreeGrid`
Results
----------------------
For all the datasets, the first node of target class 1 is picked to be explained. The hop-k computation sub-graph (a compact of 0-hop, 1-hop, ..., k-hop subgraphs) is first extracted and then fed to the models. Followings are the visualization results. Instead of cutting edges that are below the threshold. We use the depth of color of the edges to represent the edge mask weights. The deeper the color of an edge is, the more important the edge is.
The trained model weights will be saved to `model_{dataset}.pth`
NOTE: We do not perform grid search or finetune here, the visualization results are just for reference.
**Second**, install [GNNLens2](https://github.com/dmlc/GNNLens2) with
```bash
pip install -U flask-cors
pip install Flask==2.0.3
pip install gnnlens
```
**Syn1 (BA-SHAPES)**
<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn1.png" width="600">
<br>
<b>Figure</b>: Visualization of syn1 dataset (hop=2).
</p>
**Third**, explain the trained model with the same dataset
**Syn2 (BA-COMMUNITY)**
<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn2.png" width="600">
<br>
<b>Figure</b>: Visualization of syn2 dataset (hop=2).
</p>
```bash
python explain_main.py --dataset $DATASET
```
**Syn3 (BA-GRID)**
**Finally**, launch `GNNLens2` to visualize the explanations
For a more explict view, we conduct explaination on both the hop-3 computation sub-graph and the hop-4 computation sub-graph in Syn3 task.
```bash
gnnlens --logdir gnn_subgraph
```
<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn3_3hop.png" width="600">
<br>
<b>Figure</b>: Visualization of syn3 dataset with hop=3.
</p>
By entering `localhost:7777` in your web browser address bar, you can see the GNNLens2 interface. `7777` is the default port GNNLens2 uses. You can specify an alternative one by adding `--port xxxx` after the command line and change the address in the web browser accordingly.
<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn3_4hop.png" width="600">
<br>
<b>Figure</b>: Visualization of syn3 dataset with hop=4.
</p>
**Syn4 (TREE-CYCLE)**
<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn4.png" width="600">
<br>
<b>Figure</b>: Visualization of syn4 dataset (hop=4).
</p>
A sample visualization is available below. For more details of using `GNNLens2`, check its [tutorials](https://github.com/dmlc/GNNLens2#tutorials).
**Syn5 (TREE-GRID)**
<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn5.png" width="600">
<img src="https://data.dgl.ai/asset/image/explain_BAShape.png" width="600">
<br>
<b>Figure</b>: Visualization of syn5 dataset (hop=4).
<b>Figure</b>: Explanation for node 41 of BAShape
</p>
# The major idea of the overall GNN model explanation
import argparse
import os
import dgl
from gnnlens import Writer
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import load_graphs
from models import dummy_gnn_model
from NodeExplainerModule import NodeExplainerModule
from utils_graph import extract_subgraph, visualize_sub_graph
from dgl.nn import GNNExplainer
from models import Model
from dgl.data import BAShapeDataset, BACommunityDataset, TreeCycleDataset, TreeGridDataset
def main(args):
# load an exisitng model or ask for training a model
model_path = os.path.join('./', 'dummy_model_{}.pth'.format(args.dataset))
if os.path.exists(model_path):
model_stat_dict = th.load(model_path)
else:
raise FileExistsError('No Saved Model file. Please train a GNN model first...')
# load graph, feat, and label
g_list, label_dict = load_graphs('./'+args.dataset+'.bin')
graph = g_list[0]
if args.dataset == 'BAShape':
dataset = BAShapeDataset(seed=0)
elif args.dataset == 'BACommunity':
dataset = BACommunityDataset(seed=0)
elif args.dataset == 'TreeCycle':
dataset = TreeCycleDataset(seed=0)
elif args.dataset == 'TreeGrid':
dataset = TreeGridDataset(seed=0)
graph = dataset[0]
labels = graph.ndata['label']
feats = graph.ndata['feat']
num_classes = max(labels).item() + 1
feat_dim = feats.shape[1]
hid_dim = label_dict['hid_dim'].item()
# create a model and load from state_dict
dummy_model = dummy_gnn_model(feat_dim, hid_dim, num_classes)
dummy_model.load_state_dict(model_stat_dict)
# Choose a node of the target class to be explained and extract its subgraph.
# Here just pick the first one of the target class.
target_list = [i for i, e in enumerate(labels) if e==args.target_class]
n_idx = th.tensor([target_list[0]])
# Extract the computation graph within k-hop of target node and use it for explainability
sub_graph, ori_n_idxes, new_n_idx = extract_subgraph(graph, n_idx, hops=args.hop)
#Sub-graph features.
sub_feats = feats[ori_n_idxes,:]
# create an explainer
explainer = NodeExplainerModule(model=dummy_model,
num_edges=sub_graph.number_of_edges(),
node_feat_dim=feat_dim)
# define optimizer
optim = th.optim.Adam([explainer.edge_mask, explainer.node_feat_mask], lr=args.lr, weight_decay=args.wd)
# train the explainer for the given node
dummy_model.eval()
model_logits = dummy_model(sub_graph, sub_feats)
model_predict = F.one_hot(th.argmax(model_logits, dim=-1), num_classes)
for epoch in range(args.epochs):
explainer.train()
exp_logits = explainer(sub_graph, sub_feats)
loss = explainer._loss(exp_logits[new_n_idx], model_predict[new_n_idx])
optim.zero_grad()
loss.backward()
optim.step()
# visualize the importance of edges
edge_weights = explainer.edge_mask.sigmoid().detach()
visualize_sub_graph(sub_graph, edge_weights.numpy(), ori_n_idxes, n_idx)
num_classes = dataset.num_classes
# load an existing model
model_path = os.path.join('./', f'model_{args.dataset}.pth')
model_stat_dict = th.load(model_path)
model = Model(feats.shape[-1], num_classes)
model.load_state_dict(model_stat_dict)
# Choose the first node of the class 1 for explaining prediction
target_class = 1
for n_idx, n_label in enumerate(labels):
if n_label == target_class:
break
explainer = GNNExplainer(model, num_hops=3)
new_center, sub_graph, feat_mask, edge_mask = explainer.explain_node(n_idx, graph, feats)
# gnnlens2
# Specify the path to create a new directory for dumping data files.
writer = Writer('gnn_subgraph')
writer.add_graph(name=args.dataset, graph=graph,
nlabels=labels, num_nlabel_types=num_classes)
writer.add_subgraph(graph_name=args.dataset,
subgraph_name='GNNExplainer',
node_id=n_idx,
subgraph_nids=sub_graph.ndata[dgl.NID],
subgraph_eids=sub_graph.edata[dgl.EID],
subgraph_eweights=edge_mask)
# Finish dumping
writer.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Demo of GNN explainer in DGL')
parser.add_argument('--dataset', type=str, default='syn1',
help='The dataset to be explained.')
parser.add_argument('--target_class', type=int, default='1',
help='The class to be explained. In the synthetic 1 dataset, Valid option is from 0 to 4'
'Will choose the first node in this class to explain')
parser.add_argument('--hop', type=int, default='2',
help='The hop number of the computation sub-graph. For syn1 and syn2, k=2. For syn3, syn4, and syn5, k=4.')
parser.add_argument('--epochs', type=int, default=200, help='The number of epochs.')
parser.add_argument('--lr', type=float, default=0.01, help='The learning rate.')
parser.add_argument('--wd', type=float, default=0.0, help='Weight decay.')
parser.add_argument('--dataset', type=str, default='BAShape',
choices=['BAShape', 'BACommunity', 'TreeCycle', 'TreeGrid'])
args = parser.parse_args()
print(args)
......
# This file is copied from the author's implementation.
# <https://github.com/RexYing/gnn-model-explainer/blob/master/utils/featgen.py>.
""" featgen.py
Node feature generators.
"""
import networkx as nx
import numpy as np
import random
import abc
class FeatureGen(metaclass=abc.ABCMeta):
"""Feature Generator base class."""
@abc.abstractmethod
def gen_node_features(self, G):
pass
class ConstFeatureGen(FeatureGen):
"""Constant Feature class."""
def __init__(self, val):
self.val = val
def gen_node_features(self, G):
feat_dict = {i:{'feat': np.array(self.val, dtype=np.float32)} for i in G.nodes()}
print ('feat_dict[0]["feat"]:', feat_dict[0]['feat'].dtype)
nx.set_node_attributes(G, feat_dict)
print ('G.nodes[0]["feat"]:', G.nodes[0]['feat'].dtype)
class GaussianFeatureGen(FeatureGen):
"""Gaussian Feature class."""
def __init__(self, mu, sigma):
self.mu = mu
if sigma.ndim < 2:
self.sigma = np.diag(sigma)
else:
self.sigma = sigma
def gen_node_features(self, G):
feat = np.random.multivariate_normal(self.mu, self.sigma, G.number_of_nodes())
feat_dict = {
i: {"feat": feat[i]} for i in range(feat.shape[0])
}
nx.set_node_attributes(G, feat_dict)
class GridFeatureGen(FeatureGen):
"""Grid Feature class."""
def __init__(self, mu, sigma, com_choices):
self.mu = mu # Mean
self.sigma = sigma # Variance
self.com_choices = com_choices # List of possible community labels
def gen_node_features(self, G):
# Generate community assignment
community_dict = {
n: self.com_choices[0] if G.degree(n) < 4 else self.com_choices[1]
for n in G.nodes()
}
# Generate random variable
s = np.random.normal(self.mu, self.sigma, G.number_of_nodes())
# Generate features
feat_dict = {
n: {"feat": np.asarray([community_dict[n], s[i]])}
for i, n in enumerate(G.nodes())
}
nx.set_node_attributes(G, feat_dict)
return community_dict
# This file is copied from the author's implementation.
# <https://github.com/RexYing/gnn-model-explainer/blob/master/gengraph.py>.
"""gengraph.py
Generating and manipulaton the synthetic graphs needed for the paper's experiments.
"""
import os
from matplotlib import pyplot as plt
import numpy as np
import networkx as nx
# Set matplotlib backend to file writing
plt.switch_backend("agg")
from synthetic_structsim import *
from featgen import *
def perturb(graph_list, p):
""" Perturb the list of (sparse) graphs by adding/removing edges.
Args:
p: proportion of added edges based on current number of edges.
Returns:
A list of graphs that are perturbed from the original graphs.
"""
perturbed_graph_list = []
for G_original in graph_list:
G = G_original.copy()
edge_count = int(G.number_of_edges() * p)
# randomly add the edges between a pair of nodes without an edge.
for _ in range(edge_count):
while True:
u = np.random.randint(0, G.number_of_nodes())
v = np.random.randint(0, G.number_of_nodes())
if (not G.has_edge(u, v)) and (u != v):
break
G.add_edge(u, v)
perturbed_graph_list.append(G)
return perturbed_graph_list
def join_graph(G1, G2, n_pert_edges):
""" Join two graphs along matching nodes, then perturb the resulting graph.
Args:
G1, G2: Networkx graphs to be joined.
n_pert_edges: number of perturbed edges.
Returns:
A new graph, result of merging and perturbing G1 and G2.
"""
assert n_pert_edges > 0
F = nx.compose(G1, G2)
edge_cnt = 0
while edge_cnt < n_pert_edges:
node_1 = np.random.choice(G1.nodes())
node_2 = np.random.choice(G2.nodes())
F.add_edge(node_1, node_2)
edge_cnt += 1
return F
# Generating synthetic graphs
def gen_syn1(nb_shapes=80, width_basis=300, feature_generator=None, m=5):
""" Synthetic Graph #1:
Start with Barabasi-Albert graph and attach house-shaped subgraphs.
Args:
nb_shapes : The number of shapes (here 'houses') that should be added to the base graph.
width_basis : The width of the basis graph (here 'Barabasi-Albert' random graph).
feature_generator : A `FeatureGenerator` for node features. If `None`, add constant features to nodes.
m : number of edges to attach to existing node (for BA graph)
Returns:
G : A networkx graph
role_id : A list with length equal to number of nodes in the entire graph (basis
: + shapes). role_id[i] is the ID of the role of node i. It is the label.
name : A graph identifier
"""
basis_type = "ba"
list_shapes = [["house"]] * nb_shapes
plt.figure(figsize=(8, 6), dpi=300)
G, role_id, _ = build_graph(
width_basis, basis_type, list_shapes, start=0, m=5
)
G = perturb([G], 0.01)[0]
if feature_generator is None:
feature_generator = ConstFeatureGen(1)
feature_generator.gen_node_features(G)
name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes)
return G, role_id, name
def gen_syn2(nb_shapes=100, width_basis=350):
""" Synthetic Graph #2:
Start with Barabasi-Albert graph and add node features indicative of a community label.
Args:
nb_shapes : The number of shapes (here 'houses') that should be added to the base graph.
width_basis : The width of the basis graph (here 'Barabasi-Albert' random graph).
Returns:
G : A networkx graph
label : Label of the nodes (determined by role_id and community)
name : A graph identifier
"""
basis_type = "ba"
random_mu = [0.0] * 8
random_sigma = [1.0] * 8
# Create two grids
mu_1, sigma_1 = np.array([-1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma)
mu_2, sigma_2 = np.array([1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma)
feat_gen_G1 = GaussianFeatureGen(mu=mu_1, sigma=sigma_1)
feat_gen_G2 = GaussianFeatureGen(mu=mu_2, sigma=sigma_2)
G1, role_id1, name = gen_syn1(feature_generator=feat_gen_G1, m=4)
G2, role_id2, name = gen_syn1(feature_generator=feat_gen_G2, m=4)
G1_size = G1.number_of_nodes()
num_roles = max(role_id1) + 1
role_id2 = [r + num_roles for r in role_id2]
label = role_id1 + role_id2
# Edit node ids to avoid collisions on join
g1_map = {n: i for i, n in enumerate(G1.nodes())}
G1 = nx.relabel_nodes(G1, g1_map)
g2_map = {n: i + G1_size for i, n in enumerate(G2.nodes())}
G2 = nx.relabel_nodes(G2, g2_map)
# Join
n_pert_edges = width_basis
G = join_graph(G1, G2, n_pert_edges)
name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes) + "_2comm"
return G, label, name
def gen_syn3(nb_shapes=80, width_basis=300, feature_generator=None, m=5):
""" Synthetic Graph #3:
Start with Barabasi-Albert graph and attach grid-shaped subgraphs.
Args:
nb_shapes : The number of shapes (here 'grid') that should be added to the base graph.
width_basis : The width of the basis graph (here 'Barabasi-Albert' random graph).
feature_generator : A `FeatureGenerator` for node features. If `None`, add constant features to nodes.
m : number of edges to attach to existing node (for BA graph)
Returns:
G : A networkx graph
role_id : Role ID for each node in synthetic graph.
name : A graph identifier
"""
basis_type = "ba"
list_shapes = [["grid", 3]] * nb_shapes
plt.figure(figsize=(8, 6), dpi=300)
G, role_id, _ = build_graph(
width_basis, basis_type, list_shapes, start=0, m=5
)
G = perturb([G], 0.01)[0]
if feature_generator is None:
feature_generator = ConstFeatureGen(1)
feature_generator.gen_node_features(G)
name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes)
return G, role_id, name
def gen_syn4(nb_shapes=60, width_basis=8, feature_generator=None, m=4):
""" Synthetic Graph #4:
Start with a tree and attach cycle-shaped subgraphs.
Args:
nb_shapes : The number of shapes (here 'houses') that should be added to the base graph.
width_basis : The width of the basis graph (here a random 'Tree').
feature_generator : A `FeatureGenerator` for node features. If `None`, add constant features to nodes.
m : The tree depth.
Returns:
G : A networkx graph
role_id : Role ID for each node in synthetic graph
name : A graph identifier
"""
basis_type = "tree"
list_shapes = [["cycle", 6]] * nb_shapes
fig = plt.figure(figsize=(8, 6), dpi=300)
G, role_id, plugins = build_graph(
width_basis, basis_type, list_shapes, start=0
)
G = perturb([G], 0.01)[0]
if feature_generator is None:
feature_generator = ConstFeatureGen(1)
feature_generator.gen_node_features(G)
name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes)
return G, role_id, name
def gen_syn5(nb_shapes=80, width_basis=8, feature_generator=None, m=3):
""" Synthetic Graph #5:
Start with a tree and attach grid-shaped subgraphs.
Args:
nb_shapes : The number of shapes (here 'houses') that should be added to the base graph.
width_basis : The width of the basis graph (here a random 'grid').
feature_generator : A `FeatureGenerator` for node features. If `None`, add constant features to nodes.
m : The tree depth.
Returns:
G : A networkx graph
role_id : Role ID for each node in synthetic graph
name : A graph identifier
"""
basis_type = "tree"
list_shapes = [["grid", m]] * nb_shapes
plt.figure(figsize=(8, 6), dpi=300)
G, role_id, _ = build_graph(
width_basis, basis_type, list_shapes, start=0
)
G = perturb([G], 0.1)[0]
if feature_generator is None:
feature_generator = ConstFeatureGen(1)
feature_generator.gen_node_features(G)
name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes)
return G, role_id, name
This diff is collapsed.
{"models": [], "success": true}
\ No newline at end of file
This diff is collapsed.
{"subgraphs": [{"id": 1, "name": "GNNExplainer"}], "success": true}
\ No newline at end of file
{"datasets": [{"id": 1, "name": "BAShape"}], "success": true}
\ No newline at end of file
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
class dummy_layer(nn.Module):
class Layer(nn.Module):
def __init__(self, in_dim, out_dim):
super(dummy_layer, self).__init__()
super().__init__()
self.layer = nn.Linear(in_dim * 2, out_dim, bias=True)
def forward(self, graph, n_feats, e_weights=None):
graph.ndata['h'] = n_feats
def forward(self, graph, feat, eweight=None):
with graph.local_scope():
graph.ndata['h'] = feat
if e_weights == None:
if eweight is None:
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h'))
else:
graph.edata['ew'] = e_weights
graph.edata['ew'] = eweight
graph.update_all(fn.u_mul_e('h', 'ew', 'm'), fn.mean('m', 'h'))
graph.ndata['h'] = self.layer(th.cat([graph.ndata['h'], n_feats], dim=-1))
output = graph.ndata['h']
return output
class dummy_gnn_model(nn.Module):
h = self.layer(th.cat([graph.ndata['h'], feat], dim=-1))
"""
A dummy gnn model, which is same as graph sage, but could adopt edge mask in forward
"""
def __init__(self,
in_dim,
hid_dim,
out_dim):
super(dummy_gnn_model, self).__init__()
self.in_dim = in_dim
self.hid_dim = hid_dim
self.out_dim = out_dim
self.in_layer = dummy_layer(self.in_dim, self.hid_dim)
self.hid_layer = dummy_layer(self.hid_dim, self.hid_dim)
self.out_layer = dummy_layer(self.hid_dim, self.out_dim)
return h
def forward(self, graph, n_feat, edge_weights=None):
class Model(nn.Module):
def __init__(self, in_dim, out_dim, hid_dim=40):
super().__init__()
self.in_layer = Layer(in_dim, hid_dim)
self.hid_layer = Layer(hid_dim, hid_dim)
self.out_layer = Layer(hid_dim, out_dim)
h = self.in_layer(graph, n_feat, edge_weights)
def forward(self, graph, feat, eweight=None):
h = self.in_layer(graph, feat.float(), eweight)
h = F.relu(h)
h = self.hid_layer(graph, h, edge_weights)
h = self.hid_layer(graph, h, eweight)
h = F.relu(h)
h = self.out_layer(graph, h, edge_weights)
h = self.out_layer(graph, h, eweight)
return h
\ No newline at end of file
# This file is copied from the author's implementation.
# <https://github.com/RexYing/gnn-model-explainer/blob/master/utils/synthetic_structsim.py>.
"""synthetic_structsim.py
Utilities for generating certain graph shapes.
"""
import math
import networkx as nx
import numpy as np
# Following GraphWave's representation of structural similarity
def clique(start, nb_nodes, nb_to_remove=0, role_start=0):
""" Defines a clique (complete graph on nb_nodes nodes,
with nb_to_remove edges that will have to be removed),
index of nodes starting at start
and role_ids at role_start
INPUT:
-------------
start : starting index for the shape
nb_nodes : int correspondingraph to the nb of nodes in the clique
role_start : starting index for the roles
nb_to_remove: int-- numb of edges to remove (unif at RDM)
OUTPUT:
-------------
graph : a house shape graph, with ids beginning at start
roles : list of the roles of the nodes (indexed starting at
role_start)
"""
a = np.ones((nb_nodes, nb_nodes))
np.fill_diagonal(a, 0)
graph = nx.from_numpy_matrix(a)
edge_list = graph.edges().keys()
roles = [role_start] * nb_nodes
if nb_to_remove > 0:
lst = np.random.choice(len(edge_list), nb_to_remove, replace=False)
print(edge_list, lst)
to_delete = [edge_list[e] for e in lst]
graph.remove_edges_from(to_delete)
for e in lst:
print(edge_list[e][0])
print(len(roles))
roles[edge_list[e][0]] += 1
roles[edge_list[e][1]] += 1
mapping_graph = {k: (k + start) for k in range(nb_nodes)}
graph = nx.relabel_nodes(graph, mapping_graph)
return graph, roles
def cycle(start, len_cycle, role_start=0):
"""Builds a cycle graph, with index of nodes starting at start
and role_ids at role_start
INPUT:
-------------
start : starting index for the shape
role_start : starting index for the roles
OUTPUT:
-------------
graph : a house shape graph, with ids beginning at start
roles : list of the roles of the nodes (indexed starting at
role_start)
"""
graph = nx.Graph()
graph.add_nodes_from(range(start, start + len_cycle))
for i in range(len_cycle - 1):
graph.add_edges_from([(start + i, start + i + 1)])
graph.add_edges_from([(start + len_cycle - 1, start)])
roles = [role_start] * len_cycle
return graph, roles
def diamond(start, role_start=0):
"""Builds a diamond graph, with index of nodes starting at start
and role_ids at role_start
INPUT:
-------------
start : starting index for the shape
role_start : starting index for the roles
OUTPUT:
-------------
graph : a house shape graph, with ids beginning at start
roles : list of the roles of the nodes (indexed starting at
role_start)
"""
graph = nx.Graph()
graph.add_nodes_from(range(start, start + 6))
graph.add_edges_from(
[
(start, start + 1),
(start + 1, start + 2),
(start + 2, start + 3),
(start + 3, start),
]
)
graph.add_edges_from(
[
(start + 4, start),
(start + 4, start + 1),
(start + 4, start + 2),
(start + 4, start + 3),
]
)
graph.add_edges_from(
[
(start + 5, start),
(start + 5, start + 1),
(start + 5, start + 2),
(start + 5, start + 3),
]
)
roles = [role_start] * 6
return graph, roles
def tree(start, height, r=2, role_start=0):
"""Builds a balanced r-tree of height h
INPUT:
-------------
start : starting index for the shape
height : int height of the tree
r : int number of branches per node
role_start : starting index for the roles
OUTPUT:
-------------
graph : a tree shape graph, with ids beginning at start
roles : list of the roles of the nodes (indexed starting at role_start)
"""
graph = nx.balanced_tree(r, height)
roles = [0] * graph.number_of_nodes()
return graph, roles
def fan(start, nb_branches, role_start=0):
"""Builds a fan-like graph, with index of nodes starting at start
and role_ids at role_start
INPUT:
-------------
nb_branches : int correspondingraph to the nb of fan branches
start : starting index for the shape
role_start : starting index for the roles
OUTPUT:
-------------
graph : a house shape graph, with ids beginning at start
roles : list of the roles of the nodes (indexed starting at
role_start)
"""
graph, roles = star(start, nb_branches, role_start=role_start)
for k in range(1, nb_branches - 1):
roles[k] += 1
roles[k + 1] += 1
graph.add_edges_from([(start + k, start + k + 1)])
return graph, roles
def ba(start, width, role_start=0, m=5):
"""Builds a BA preferential attachment graph, with index of nodes starting at start
and role_ids at role_start
INPUT:
-------------
start : starting index for the shape
width : int size of the graph
role_start : starting index for the roles
OUTPUT:
-------------
graph : a house shape graph, with ids beginning at start
roles : list of the roles of the nodes (indexed starting at
role_start)
"""
graph = nx.barabasi_albert_graph(width, m)
graph.add_nodes_from(range(start, start + width))
nids = sorted(graph)
mapping = {nid: start + i for i, nid in enumerate(nids)}
graph = nx.relabel_nodes(graph, mapping)
roles = [role_start for i in range(width)]
return graph, roles
def house(start, role_start=0):
"""Builds a house-like graph, with index of nodes starting at start
and role_ids at role_start
INPUT:
-------------
start : starting index for the shape
role_start : starting index for the roles
OUTPUT:
-------------
graph : a house shape graph, with ids beginning at start
roles : list of the roles of the nodes (indexed starting at
role_start)
"""
graph = nx.Graph()
graph.add_nodes_from(range(start, start + 5))
graph.add_edges_from(
[
(start, start + 1),
(start + 1, start + 2),
(start + 2, start + 3),
(start + 3, start),
]
)
# graph.add_edges_from([(start, start + 2), (start + 1, start + 3)])
graph.add_edges_from([(start + 4, start), (start + 4, start + 1)])
roles = [role_start, role_start, role_start + 1, role_start + 1, role_start + 2]
return graph, roles
def grid(start, dim=2, role_start=0):
""" Builds a 2by2 grid
"""
grid_G = nx.grid_graph([dim, dim])
grid_G = nx.convert_node_labels_to_integers(grid_G, first_label=start)
roles = [role_start for i in grid_G.nodes()]
return grid_G, roles
def star(start, nb_branches, role_start=0):
"""Builds a star graph, with index of nodes starting at start
and role_ids at role_start
INPUT:
-------------
nb_branches : int correspondingraph to the nb of star branches
start : starting index for the shape
role_start : starting index for the roles
OUTPUT:
-------------
graph : a house shape graph, with ids beginning at start
roles : list of the roles of the nodes (indexed starting at
role_start)
"""
graph = nx.Graph()
graph.add_nodes_from(range(start, start + nb_branches + 1))
for k in range(1, nb_branches + 1):
graph.add_edges_from([(start, start + k)])
roles = [role_start + 1] * (nb_branches + 1)
roles[0] = role_start
return graph, roles
def path(start, width, role_start=0):
"""Builds a path graph, with index of nodes starting at start
and role_ids at role_start
INPUT:
-------------
start : starting index for the shape
width : int length of the path
role_start : starting index for the roles
OUTPUT:
-------------
graph : a house shape graph, with ids beginning at start
roles : list of the roles of the nodes (indexed starting at
role_start)
"""
graph = nx.Graph()
graph.add_nodes_from(range(start, start + width))
for i in range(width - 1):
graph.add_edges_from([(start + i, start + i + 1)])
roles = [role_start] * width
roles[0] = role_start + 1
roles[-1] = role_start + 1
return graph, roles
def build_graph(
width_basis,
basis_type,
list_shapes,
start=0,
rdm_basis_plugins=False,
add_random_edges=0,
m=5,
):
"""This function creates a basis (scale-free, path, or cycle)
and attaches elements of the type in the list randomly along the basis.
Possibility to add random edges afterwards.
INPUT:
--------------------------------------------------------------------------------------
width_basis : width (in terms of number of nodes) of the basis
basis_type : (torus, string, or cycle)
shapes : list of shape list (1st arg: type of shape,
next args:args for building the shape,
except for the start)
start : initial nb for the first node
rdm_basis_plugins: boolean. Should the shapes be randomly placed
along the basis (True) or regularly (False)?
add_random_edges : nb of edges to randomly add on the structure
m : number of edges to attach to existing node (for BA graph)
OUTPUT:
--------------------------------------------------------------------------------------
basis : a nx graph with the particular shape
role_ids : labels for each role
plugins : node ids with the attached shapes
"""
if basis_type == "ba":
basis, role_id = eval(basis_type)(start, width_basis, m=m)
else:
basis, role_id = eval(basis_type)(start, width_basis)
n_basis, n_shapes = nx.number_of_nodes(basis), len(list_shapes)
start += n_basis # indicator of the id of the next node
# Sample (with replacement) where to attach the new motifs
if rdm_basis_plugins is True:
plugins = np.random.choice(n_basis, n_shapes, replace=False)
else:
spacing = math.floor(n_basis / n_shapes)
plugins = [int(k * spacing) for k in range(n_shapes)]
seen_shapes = {"basis": [0, n_basis]}
for shape_id, shape in enumerate(list_shapes):
shape_type = shape[0]
args = [start]
if len(shape) > 1:
args += shape[1:]
args += [0]
graph_s, roles_graph_s = eval(shape_type)(*args)
n_s = nx.number_of_nodes(graph_s)
try:
col_start = seen_shapes[shape_type][0]
except:
col_start = np.max(role_id) + 1
seen_shapes[shape_type] = [col_start, n_s]
# Attach the shape to the basis
basis.add_nodes_from(graph_s.nodes())
basis.add_edges_from(graph_s.edges())
basis.add_edges_from([(start, plugins[shape_id])])
if shape_type == "cycle":
if np.random.random() > 0.5:
a = np.random.randint(1, 4)
b = np.random.randint(1, 4)
basis.add_edges_from([(a + start, b + plugins[shape_id])])
temp_labels = [r + col_start for r in roles_graph_s]
# temp_labels[0] += 100 * seen_shapes[shape_type][0]
role_id += temp_labels
start += n_s
if add_random_edges > 0:
# add random edges between nodes:
for p in range(add_random_edges):
src, dest = np.random.choice(nx.number_of_nodes(basis), 2, replace=False)
print(src, dest)
basis.add_edges_from([(src, dest)])
return basis, role_id, plugins
# The training codes of the dummy model
import os
import argparse
import dgl
import torch as th
import torch.nn as nn
from dgl import save_graphs
from models import dummy_gnn_model
from gengraph import gen_syn1, gen_syn2, gen_syn3, gen_syn4, gen_syn5
import numpy as np
def main(args):
# load dataset
if args.dataset == 'syn1':
g, labels, name = gen_syn1()
elif args.dataset == 'syn2':
g, labels, name = gen_syn2()
elif args.dataset == 'syn3':
g, labels, name = gen_syn3()
elif args.dataset == 'syn4':
g, labels, name = gen_syn4()
elif args.dataset == 'syn5':
g, labels, name = gen_syn5()
else:
raise NotImplementedError
#Transform to dgl graph.
graph = dgl.from_networkx(g)
labels = th.tensor(labels, dtype=th.long)
graph.ndata['label'] = labels
graph.ndata['feat'] = th.randn(graph.number_of_nodes(), args.feat_dim)
hid_dim = th.tensor(args.hidden_dim, dtype=th.long)
label_dict = {'hid_dim':hid_dim}
from models import Model
# save graph for later use
save_graphs(filename='./'+args.dataset+'.bin', g_list=[graph], labels=label_dict)
from dgl.data import BAShapeDataset, BACommunityDataset, TreeCycleDataset, TreeGridDataset
num_classes = max(graph.ndata['label']).item() + 1
def main(args):
if args.dataset == 'BAShape':
dataset = BAShapeDataset(seed=0)
elif args.dataset == 'BACommunity':
dataset = BACommunityDataset(seed=0)
elif args.dataset == 'TreeCycle':
dataset = TreeCycleDataset(seed=0)
elif args.dataset == 'TreeGrid':
dataset = TreeGridDataset(seed=0)
graph = dataset[0]
labels = graph.ndata['label']
n_feats = graph.ndata['feat']
num_classes = dataset.num_classes
#create model
dummy_model = dummy_gnn_model(args.feat_dim, args.hidden_dim, num_classes)
model = Model(n_feats.shape[-1], num_classes)
loss_fn = nn.CrossEntropyLoss()
optim = th.optim.Adam(dummy_model.parameters(), lr=args.lr, weight_decay=args.wd)
optim = th.optim.Adam(model.parameters(), lr=0.001)
# train and output
for epoch in range(args.epochs):
dummy_model.train()
logits = dummy_model(graph, n_feats)
for epoch in range(500):
model.train()
# For demo purpose, we train the model on all datapoints
# In practice, you should train only on the training datapoints
logits = model(graph, n_feats)
loss = loss_fn(logits, labels)
acc = th.sum(logits.argmax(dim=1) == labels).item() / len(labels)
......@@ -60,25 +40,17 @@ def main(args):
loss.backward()
optim.step()
print('In Epoch: {:03d}; Acc: {:.4f}; Loss: {:.6f}'.format(epoch, acc, loss.item()))
print(f'In Epoch: {epoch}; Acc: {acc}; Loss: {loss.item()}')
# save model
model_stat_dict = dummy_model.state_dict()
model_path = os.path.join('./', 'dummy_model_{}.pth'.format(args.dataset))
model_stat_dict = model.state_dict()
model_path = os.path.join('./', f'model_{args.dataset}.pth')
th.save(model_stat_dict, model_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Dummy model training')
parser.add_argument('--dataset', type=str, default='syn1', help='The dataset used for training the model.')
parser.add_argument('--feat_dim', type=int, default=10, help='The feature dimension.')
parser.add_argument('--hidden_dim', type=int, default=40, help='The hidden dimension.')
parser.add_argument('--epochs', type=int, default=500, help='The number of epochs.')
parser.add_argument('--lr', type=float, default=0.001, help='The learning rate.')
parser.add_argument('--wd', type=float, default=0.0, help='Weight decay.')
parser.add_argument('--dataset', type=str, default='BAShape',
choices=['BAShape', 'BACommunity', 'TreeCycle', 'TreeGrid'])
args = parser.parse_args()
print(args)
main(args)
# Utility file for graph queries
import tkinter
import matplotlib
matplotlib.use('TkAgg')
import networkx as nx
import matplotlib.pylab as plt
import torch as th
import dgl
from dgl.sampling import sample_neighbors
def extract_subgraph(graph, seed_nodes, hops=2):
"""
For the explainability, extract the subgraph of a seed node with the hops specified.
Parameters
----------
graph: DGLGraph, the full graph to extract from. This time, assume it is a homograph
seed_nodes: Tensor, index of a node in the graph
hops: Integer, the number of hops to extract
Returns
-------
sub_graph: DGLGraph, a sub graph
origin_nodes: List, list of node ids in the origin graph, sorted from small to large, whose order is the new id. e.g
[2, 51, 53, 79] means in the new sug_graph, their new node id is [0,1,2,3], the mapping is 2<>0, 51<>1, 53<>2,
and 79 <> 3.
new_seed_node: Scalar, the node index of seed_nodes
"""
seeds=seed_nodes
for i in range(hops):
i_hop = sample_neighbors(graph, seeds, -1)
seeds = th.cat([seeds, i_hop.edges()[0]])
ori_src, ori_dst = i_hop.edges()
edge_all = th.cat([ori_src, ori_dst])
origin_nodes, new_edges_all = th.unique(edge_all, return_inverse=True)
n = int(new_edges_all.shape[0] / 2)
new_src = new_edges_all[:n]
new_dst = new_edges_all[n:]
sub_graph = dgl.DGLGraph((new_src, new_dst))
new_seed_node = th.nonzero(origin_nodes==seed_nodes, as_tuple=True)[0][0]
return sub_graph, origin_nodes, new_seed_node
def visualize_sub_graph(sub_graph, edge_weights=None, origin_nodes=None, center_node=None):
"""
Use networkx to visualize the sub_graph and,
if edge weights are given, set edges with different fading of blue.
Parameters
----------
sub_graph: DGLGraph, the sub_graph to be visualized.
edge_weights: Tensor, the same number of edges. Values are (0,1), default is None
origin_nodes: List, list of node ids that will be used to replace the node ids in the subgraph in visualization
center_node: Tensor, the node id in origin node list to be highlighted with different color
Returns
show the sub_graph
-------
"""
# Extract original idx and map to the new networkx graph
# Convert to networkx graph
g = dgl.to_networkx(sub_graph)
nx_edges = g.edges(data=True)
if not (origin_nodes is None):
n_mapping = {new_id: old_id for new_id, old_id in enumerate(origin_nodes.tolist())}
g = nx.relabel_nodes(g, mapping=n_mapping)
pos = nx.spring_layout(g)
if edge_weights is None:
options = {"node_size": 1000,
"alpha": 0.9,
"font_size":24,
"width": 4,
}
else:
ec = [edge_weights[e[2]['id']][0] for e in nx_edges]
options = {"node_size": 1000,
"alpha": 0.3,
"font_size": 12,
"edge_color": ec,
"width": 4,
"edge_cmap": plt.cm.Reds,
"edge_vmin": 0,
"edge_vmax": 1,
"connectionstyle":"arc3,rad=0.1"}
nx.draw(g, pos, with_labels=True, node_color='b', **options)
if not (center_node is None):
nx.draw(g, pos, nodelist=center_node.tolist(), with_labels=True, node_color='r', **options)
plt.show()
......@@ -282,7 +282,7 @@ class GNNExplainer(nn.Module):
if self.log:
pbar = tqdm(total=self.num_epochs)
pbar.set_description('Explain node {node_id}')
pbar.set_description(f'Explain node {node_id}')
for _ in range(self.num_epochs):
optimizer.zero_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment