"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a7361dccdc581147620bbd74a6d295cd92daf616"
Unverified Commit 226d1159 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Auto fix by black. (#4952)



* black on explain_main

* isort

* add dot
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent f118ea95
import argparse import argparse
import os import os
import dgl
from gnnlens import Writer
import torch as th import torch as th
from gnnlens import Writer
from models import Model
import dgl
from dgl import load_graphs from dgl import load_graphs
from dgl.data import (BACommunityDataset, BAShapeDataset, TreeCycleDataset,
TreeGridDataset)
from dgl.nn import GNNExplainer from dgl.nn import GNNExplainer
from models import Model
from dgl.data import BAShapeDataset, BACommunityDataset, TreeCycleDataset, TreeGridDataset
def main(args): def main(args):
if args.dataset == 'BAShape': if args.dataset == "BAShape":
dataset = BAShapeDataset(seed=0) dataset = BAShapeDataset(seed=0)
elif args.dataset == 'BACommunity': elif args.dataset == "BACommunity":
dataset = BACommunityDataset(seed=0) dataset = BACommunityDataset(seed=0)
elif args.dataset == 'TreeCycle': elif args.dataset == "TreeCycle":
dataset = TreeCycleDataset(seed=0) dataset = TreeCycleDataset(seed=0)
elif args.dataset == 'TreeGrid': elif args.dataset == "TreeGrid":
dataset = TreeGridDataset(seed=0) dataset = TreeGridDataset(seed=0)
graph = dataset[0] graph = dataset[0]
labels = graph.ndata['label'] labels = graph.ndata["label"]
feats = graph.ndata['feat'] feats = graph.ndata["feat"]
num_classes = dataset.num_classes num_classes = dataset.num_classes
# load an existing model # load an existing model
model_path = os.path.join('./', f'model_{args.dataset}.pth') model_path = os.path.join("./", f"model_{args.dataset}.pth")
model_stat_dict = th.load(model_path) model_stat_dict = th.load(model_path)
model = Model(feats.shape[-1], num_classes) model = Model(feats.shape[-1], num_classes)
model.load_state_dict(model_stat_dict) model.load_state_dict(model_stat_dict)
...@@ -38,29 +40,41 @@ def main(args): ...@@ -38,29 +40,41 @@ def main(args):
break break
explainer = GNNExplainer(model, num_hops=3) explainer = GNNExplainer(model, num_hops=3)
new_center, sub_graph, feat_mask, edge_mask = explainer.explain_node(n_idx, graph, feats) new_center, sub_graph, feat_mask, edge_mask = explainer.explain_node(
n_idx, graph, feats
)
# gnnlens2 # gnnlens2
# Specify the path to create a new directory for dumping data files. # Specify the path to create a new directory for dumping data files.
writer = Writer('gnn_subgraph') writer = Writer("gnn_subgraph")
writer.add_graph(name=args.dataset, graph=graph, writer.add_graph(
nlabels=labels, num_nlabel_types=num_classes) name=args.dataset,
writer.add_subgraph(graph_name=args.dataset, graph=graph,
subgraph_name='GNNExplainer', nlabels=labels,
node_id=n_idx, num_nlabel_types=num_classes,
subgraph_nids=sub_graph.ndata[dgl.NID], )
subgraph_eids=sub_graph.edata[dgl.EID], writer.add_subgraph(
subgraph_eweights=edge_mask) graph_name=args.dataset,
subgraph_name="GNNExplainer",
node_id=n_idx,
subgraph_nids=sub_graph.ndata[dgl.NID],
subgraph_eids=sub_graph.edata[dgl.EID],
subgraph_eweights=edge_mask,
)
# Finish dumping # Finish dumping.
writer.close() writer.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Demo of GNN explainer in DGL') if __name__ == "__main__":
parser.add_argument('--dataset', type=str, default='BAShape', parser = argparse.ArgumentParser(description="Demo of GNN explainer in DGL")
choices=['BAShape', 'BACommunity', 'TreeCycle', 'TreeGrid']) parser.add_argument(
"--dataset",
type=str,
default="BAShape",
choices=["BAShape", "BACommunity", "TreeCycle", "TreeGrid"],
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment