Unverified Commit b36b6c26 authored by 张天启's avatar 张天启 Committed by GitHub
Browse files

[Example] Add HGP-SL example for pytorch backend (#2515)



* add sagpool example for pytorch backend

* polish sagpool example for pytorch backend

* [Example] SAGPool: use std variance

* [Example] SAGPool: change to std

* add sagpool example to index page

* add graph property prediction tag to sagpool

* [Example] add graph classification example HGP-SL

* [Example] fix sagpool

* fix bug

* [Example] change tab to space in README of hgp-sl

* remove redundant files

* remote redundant network

* [Example]: change link from code to doc in HGP-SL

* [Example] in HGP-SL, change to meaningful name

* [Example] Fix path mistake for 'hardgat'
Co-authored-by: default avatarzhangtianqi <tianqizh@amazon.com>
parent 1caf01d0
......@@ -43,6 +43,7 @@ The folder contains example implementations of selected research papers related
| [Self-Attention Graph Pooling](#sagpool) | | | :heavy_check_mark: | | |
| [Convolutional Networks on Graphs for Learning Molecular Fingerprints](#nf) | | | :heavy_check_mark: | | |
| [GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation](#gnnfilm) | :heavy_check_mark: | | | | |
| [Hierarchical Graph Pooling with Structure Learning](#hgp-sl) | | | :heavy_check_mark: | | |
| [Graph Representation Learning via Hard and Channel-Wise Attention Networks](#hardgat) |:heavy_check_mark: | | | | |
## 2020
......@@ -144,8 +145,12 @@ The folder contains example implementations of selected research papers related
- Example code: [PyTorch](../examples/pytorch/sagpool)
- Tags: graph classification, pooling
- <a name="hgp-sl"></a> Zhang, Zhen, et al. Hierarchical Graph Pooling with Structure Learning. [Paper link](https://arxiv.org/abs/1911.05954).
- Example code: [PyTorch](../examples/pytorch/hgp_sl)
- Tags: graph classification, pooling
- <a name='hardgat'></a> Gao, Hongyang, et al. Graph Representation Learning via Hard and Channel-Wise Attention Networks [Paper link](https://arxiv.org/abs/1907.04652).
- Example code: [Pytorch](../examples/pytorch/hgat)
- Example code: [Pytorch](../examples/pytorch/hardgat)
- Tags: node classification, graph attention
## 2018
......
# DGL Implementation of the HGP-SL Paper
This DGL example implements the GNN model proposed in the paper [Hierarchical Graph Pooling with Structure Learning](https://arxiv.org/pdf/1911.05954.pdf).
The author's codes of implementation is in [here](https://github.com/cszhangzhen/HGP-SL)
Example implementor
----------------------
This example was implemented by [Tianqi Zhang](https://github.com/lygztq) during his Applied Scientist Intern work at the AWS Shanghai AI Lab.
The graph dataset used in this example
---------------------------------------
The DGL's built-in [LegacyTUDataset](https://docs.dgl.ai/api/python/dgl.data.html?highlight=tudataset#dgl.data.LegacyTUDataset). This is a serial of graph kernel datasets for graph classification. We use 'DD', 'PROTEINS', 'NCI1', 'NCI109', 'Mutagenicity' and 'ENZYMES' in this HGP-SL implementation. All these datasets are randomly splited to train, validation and test set with ratio 0.8, 0.1 and 0.1.
NOTE: Since there is no data attributes in some of these datasets, we use node_id (in one-hot vector whose length is the max number of nodes across all graphs) as the node feature. Also note that the node_id in some datasets is not unique (e.g. a graph may has two nodes with the same id).
DD
- NumGraphs: 1178
- AvgNodesPerGraph: 284.32
- AvgEdgesPerGraph: 715.66
- NumFeats: 89
- NumClasses: 2
PROTEINS
- NumGraphs: 1113
- AvgNodesPerGraph: 39.06
- AvgEdgesPerGraph: 72.82
- NumFeats: 1
- NumClasses: 2
NCI1
- NumGraphs: 4110
- AvgNodesPerGraph: 29.87
- AvgEdgesPerGraph: 32.30
- NumFeats: 37
- NumClasses: 2
NCI109
- NumGraphs: 4127
- AvgNodesPerGraph: 29.68
- AvgEdgesPerGraph: 32.13
- NumFeats: 38
- NumClasses: 2
Mutagenicity
- NumGraphs: 4337
- AvgNodesPerGraph: 30.32
- AvgEdgesPerGraph: 30.77
- NumFeats: 14
- NumClasses: 2
ENZYMES
- NumGraphs: 600
- AvgNodesPerGraph: 32.63
- AvgEdgesPerGraph: 62.14
- NumFeats: 18
- NumClasses: 6
How to run example files
--------------------------------
In the HGP-SL-DGL folder, run
```bash
python main.py --dataset ${your_dataset_name_here}
```
If want to use a GPU, run
```bash
python main.py --device ${your_device_id_here} --dataset ${your_dataset_name_here}
```
Performance
-------------------------
**Hyper-parameters**
This part is directly from [author's implementation](https://github.com/cszhangzhen/HGP-SL)
| Datasets | lr | weight_decay | batch_size | pool_ratio | dropout | net_layers |
| ------------- | --------- | -------------- | --------------- | -------------- | -------- | ---------- |
| PROTEINS | 0.001 | 0.001 | 512 | 0.5 | 0.0 | 3 |
| Mutagenicity | 0.001 | 0.001 | 512 | 0.8 | 0.0 | 3 |
| NCI109 | 0.001 | 0.001 | 512 | 0.8 | 0.0 | 3 |
| NCI1 | 0.001 | 0.001 | 512 | 0.8 | 0.0 | 3 |
| DD | 0.0001 | 0.001 | 64 | 0.3 | 0.5 | 2 |
| ENZYMES | 0.001 | 0.001 | 128 | 0.8 | 0.0 | 2 |
**Accuracy**
**NOTE**: We find that there is a gap between accuracy obtained via author's code and the one reported in the [paper]((https://arxiv.org/pdf/1911.05954.pdf)). An issue has been proposed in the author's repo (see [here](https://github.com/cszhangzhen/HGP-SL/issues/8)).
| | Mutagenicity | NCI109 | NCI1 | DD |
| -------------------------- | ------------ | ----------- | ----------- | ----------- |
| Reported in Paper | 82.15(0.58) | 80.67(1.16) | 78.45(0.77) | 80.96(1.26) |
| Author's Code (full graph) | 78.44(2.10) | 74.44(2.05) | 77.37(2.09) | OOM |
| Author's Code (sample) | 79.68(1.68) | 73.86(1.72) | 76.29(2.14) | 75.46(3.86) |
| DGL (full graph) | 79.52(2.21) | 74.86(1.99) | 74.62(2.22) | OOM |
| DGL (sample) | 79.15(1.62) | 75.39(1.86) | 73.77(2.04) | 76.47(2.14) |
**Speed**
Device: Tesla V100-SXM2 16GB
In seconds
| | DD(batchsize=64), large graph | Mutagenicity(batchsize=512), small graph |
| ----------------------------- | ----------------------------- | ---------------------------------------- |
| Author's code (sample) | 9.96 | 12.91 |
| Author's code (full graph) | OOM | 13.03 |
| DGL (sample) | 9.50 | 3.59 |
| DGL (full graph) | OOM | 3.56 |
"""
An original implementation of sparsemax (Martins & Astudillo, 2016) is available at
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/sparse_activations.py.
See `From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification, ICML 2016`
for detailed description.
Here we implement a graph-edge version of sparsemax where we perform sparsemax for all edges
with the same node as end-node in graphs.
"""
import dgl
import torch
from dgl.backend import astype
from dgl.base import ALL, is_all
from dgl.heterograph_index import HeteroGraphIndex
from dgl.sparse import _gsddmm, _gspmm
from torch import Tensor
from torch.autograd import Function
def _neighbor_sort(scores:Tensor, end_n_ids:Tensor, in_degrees:Tensor, cum_in_degrees:Tensor):
"""Sort edge scores for each node"""
num_nodes, max_in_degree = in_degrees.size(0), int(in_degrees.max().item())
# Compute the index for dense score matrix with size (N x D_{max})
# Note that the end_n_ids here is the end_node tensor in dgl graph,
# which is not grouped by its node id (i.e. in this form: 0,0,1,1,1,...,N,N).
# Thus here we first sort the end_node tensor to make it easier to compute
# indexs in dense edge score matrix. Since we will need the original order
# for following gspmm and gsddmm operations, we also keep the reverse mapping
# (the reverse_perm) here.
end_n_ids, perm = torch.sort(end_n_ids)
scores = scores[perm]
_, reverse_perm = torch.sort(perm)
index = torch.arange(end_n_ids.size(0), dtype=torch.long, device=scores.device)
index = (index - cum_in_degrees[end_n_ids]) + (end_n_ids * max_in_degree)
index = index.long()
dense_scores = scores.new_full((num_nodes * max_in_degree, ), torch.finfo(scores.dtype).min)
dense_scores[index] = scores
dense_scores = dense_scores.view(num_nodes, max_in_degree)
sorted_dense_scores, dense_reverse_perm = dense_scores.sort(dim=-1, descending=True)
_, dense_reverse_perm = torch.sort(dense_reverse_perm, dim=-1)
dense_reverse_perm = dense_reverse_perm + cum_in_degrees.view(-1, 1)
dense_reverse_perm = dense_reverse_perm.view(-1)
cumsum_sorted_dense_scores = sorted_dense_scores.cumsum(dim=-1).view(-1)
sorted_dense_scores = sorted_dense_scores.view(-1)
arange_vec = torch.arange(1, max_in_degree + 1, dtype=torch.long, device=end_n_ids.device)
arange_vec = torch.repeat_interleave(arange_vec.view(1, -1), num_nodes, dim=0).view(-1)
valid_mask = (sorted_dense_scores != torch.finfo(scores.dtype).min)
sorted_scores = sorted_dense_scores[valid_mask]
cumsum_sorted_scores = cumsum_sorted_dense_scores[valid_mask]
arange_vec = arange_vec[valid_mask]
dense_reverse_perm = dense_reverse_perm[valid_mask].long()
return sorted_scores, cumsum_sorted_scores, arange_vec, reverse_perm, dense_reverse_perm
def _threshold_and_support_graph(gidx:HeteroGraphIndex, scores:Tensor, end_n_ids:Tensor):
"""Find the threshold for each node and its edges"""
in_degrees = _gspmm(gidx, "copy_rhs", "sum", None, torch.ones_like(scores))[0]
cum_in_degrees = torch.cat([in_degrees.new_zeros(1), in_degrees.cumsum(dim=0)[:-1]], dim=0)
# perform sort on edges for each node
sorted_scores, cumsum_scores, rhos, reverse_perm, dense_reverse_perm = _neighbor_sort(scores, end_n_ids,
in_degrees, cum_in_degrees)
cumsum_scores = cumsum_scores - 1.
support = rhos * sorted_scores > cumsum_scores
support = support[dense_reverse_perm] # from sorted order to unsorted order
support = support[reverse_perm] # from src-dst order to eid order
support_size = _gspmm(gidx, "copy_rhs", "sum", None, support.float())[0]
support_size = support_size.long()
idx = support_size + cum_in_degrees - 1
# mask invalid index, for example, if batch is not start from 0 or not continuous, it may result in negative index
mask = idx < 0
idx[mask] = 0
tau = cumsum_scores.gather(0, idx.long())
tau /= support_size.to(scores.dtype)
return tau, support_size
class EdgeSparsemaxFunction(Function):
r"""
Description
-----------
Pytorch Auto-Grad Function for edge sparsemax.
We define this auto-grad function here since
sparsemax involves sort and select, which are
not derivative.
"""
@staticmethod
def forward(ctx, gidx:HeteroGraphIndex, scores:Tensor,
eids:Tensor, end_n_ids:Tensor, norm_by:str):
if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == "src":
gidx = gidx.reverse()
# use feat - max(feat) for numerical stability.
scores = scores.float()
scores_max = _gspmm(gidx, "copy_rhs", "max", None, scores)[0]
scores = _gsddmm(gidx, "sub", scores, scores_max, "e", "v")
# find threshold for each node and perform ReLU(u-t(u)) operation.
tau, supp_size = _threshold_and_support_graph(gidx, scores, end_n_ids)
out = torch.clamp(_gsddmm(gidx, "sub", scores, tau, "e", "v"), min=0)
ctx.backward_cache = gidx
ctx.save_for_backward(supp_size, out)
torch.cuda.empty_cache()
return out
@staticmethod
def backward(ctx, grad_out):
gidx = ctx.backward_cache
supp_size, out = ctx.saved_tensors
grad_in = grad_out.clone()
# grad for ReLU
grad_in[out == 0] = 0
# dL/dv_i = dL/do_i - 1/k \sum_{j=1}^k dL/do_j
v_hat = _gspmm(gidx, "copy_rhs", "sum", None, grad_in)[0] / supp_size.to(out.dtype)
grad_in_modify = _gsddmm(gidx, "sub", grad_in, v_hat, "e", "v")
grad_in = torch.where(out != 0, grad_in_modify, grad_in)
del gidx
torch.cuda.empty_cache()
return None, grad_in, None, None, None
def edge_sparsemax(graph:dgl.DGLGraph, logits, eids=ALL, norm_by="dst"):
r"""
Description
-----------
Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes
.. math::
a_{ij} = \text{ReLU}(z_{ij} - \tau(\z_{i,:}))
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of sparsemax. :math:`\tau` is a function
that can be found at the `From Softmax to Sparsemax <https://arxiv.org/pdf/1602.02068.pdf>`
paper.
NOTE: currently only homogeneous graphs are supported.
Parameters
----------
graph : DGLGraph
The graph to perform edge sparsemax on.
logits : torch.Tensor
The input edge feature.
eids : torch.Tensor or ALL, optional
A tensor of edge index on which to apply edge sparsemax. If ALL, apply edge
sparsemax on all edges in the graph. Default: ALL.
norm_by : str, could be 'src' or 'dst'
Normalized by source nodes of destination nodes. Default: `dst`.
Returns
-------
Tensor
Sparsemax value.
"""
# we get edge index tensors here since it is
# hard to get edge index with HeteroGraphIndex
# object without other information like edge_type.
row, col = graph.all_edges(order="eid")
assert norm_by in ["dst", "src"]
end_n_ids = col if norm_by == "dst" else row
if not is_all(eids):
eids = astype(eids, graph.idtype)
end_n_ids = end_n_ids[eids]
return EdgeSparsemaxFunction.apply(graph._graph, logits,
eids, end_n_ids, norm_by)
class EdgeSparsemax(torch.nn.Module):
r"""
Description
-----------
Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes
.. math::
a_{ij} = \text{ReLU}(z_{ij} - \tau(\z_{i,:}))
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of sparsemax. :math:`\tau` is a function
that can be found at the `From Softmax to Sparsemax <https://arxiv.org/pdf/1602.02068.pdf>`
paper.
Parameters
----------
graph : DGLGraph
The graph to perform edge sparsemax on.
logits : torch.Tensor
The input edge feature.
eids : torch.Tensor or ALL, optional
A tensor of edge index on which to apply edge sparsemax. If ALL, apply edge
sparsemax on all edges in the graph. Default: ALL.
norm_by : str, could be 'src' or 'dst'
Normalized by source nodes of destination nodes. Default: `dst`.
NOTE: currently only homogeneous graphs are supported.
Returns
-------
Tensor
Sparsemax value.
"""
def __init__(self):
super(EdgeSparsemax, self).__init__()
def forward(self, graph, logits, eids=ALL, norm_by="dst"):
return edge_sparsemax(graph, logits, eids, norm_by)
import dgl
import dgl.function as fn
import scipy.sparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.nn import AvgPooling, GraphConv, MaxPooling
from dgl.ops import edge_softmax
from torch import Tensor
from torch.nn import Parameter
from functions import edge_sparsemax
from utils import get_batch_id, topk
class WeightedGraphConv(GraphConv):
r"""
Description
-----------
GraphConv with edge weights on homogeneous graphs.
If edge weights are not given, directly call GraphConv instead.
Parameters
----------
graph : DGLGraph
The graph to perform this operation.
n_feat : torch.Tensor
The node features
e_feat : torch.Tensor, optional
The edge features. Default: :obj:`None`
"""
def forward(self, graph:DGLGraph, n_feat, e_feat=None):
if e_feat is None:
return super(WeightedGraphConv, self).forward(graph, n_feat)
with graph.local_scope():
if self.weight is not None:
n_feat = torch.matmul(n_feat, self.weight)
src_norm = torch.pow(graph.out_degrees().float().clamp(min=1), -0.5)
src_norm = src_norm.view(-1, 1)
dst_norm = torch.pow(graph.in_degrees().float().clamp(min=1), -0.5)
dst_norm = dst_norm.view(-1, 1)
n_feat = n_feat * src_norm
graph.ndata["h"] = n_feat
graph.edata["e"] = e_feat
graph.update_all(fn.src_mul_edge("h", "e", "m"),
fn.sum("m", "h"))
n_feat = graph.ndata.pop("h")
n_feat = n_feat * dst_norm
if self.bias is not None:
n_feat = n_feat + self.bias
if self._activation is not None:
n_feat = self._activation(n_feat)
return n_feat
class NodeInfoScoreLayer(nn.Module):
r"""
Description
-----------
Compute a score for each node for sort-pooling. The score of each node
is computed via the absolute difference of its first-order random walk
result and its features.
Arguments
---------
sym_norm : bool, optional
If true, use symmetric norm for adjacency.
Default: :obj:`True`
Parameters
----------
graph : DGLGraph
The graph to perform this operation.
feat : torch.Tensor
The node features
e_feat : torch.Tensor, optional
The edge features. Default: :obj:`None`
Returns
-------
Tensor
Score for each node.
"""
def __init__(self, sym_norm:bool=True):
super(NodeInfoScoreLayer, self).__init__()
self.sym_norm = sym_norm
def forward(self, graph:dgl.DGLGraph, feat:Tensor, e_feat:Tensor):
with graph.local_scope():
if self.sym_norm:
src_norm = torch.pow(graph.out_degrees().float().clamp(min=1), -0.5)
src_norm = src_norm.view(-1, 1).to(feat.device)
dst_norm = torch.pow(graph.in_degrees().float().clamp(min=1), -0.5)
dst_norm = dst_norm.view(-1, 1).to(feat.device)
src_feat = feat * src_norm
graph.ndata["h"] = src_feat
graph.edata["e"] = e_feat
graph = dgl.remove_self_loop(graph)
graph.update_all(fn.src_mul_edge("h", "e", "m"), fn.sum("m", "h"))
dst_feat = graph.ndata.pop("h") * dst_norm
feat = feat - dst_feat
else:
dst_norm = 1. / graph.in_degrees().float().clamp(min=1)
dst_norm = dst_norm.view(-1, 1)
graph.ndata["h"] = feat
graph.edata["e"] = e_feat
graph = dgl.remove_self_loop(graph)
graph.update_all(fn.src_mul_edge("h", "e", "m"), fn.sum("m", "h"))
feat = feat - dst_norm * graph.ndata.pop("h")
score = torch.sum(torch.abs(feat), dim=1)
return score
class HGPSLPool(nn.Module):
r"""
Description
-----------
The HGP-SL pooling layer from
`Hierarchical Graph Pooling with Structure Learning <https://arxiv.org/pdf/1911.05954.pdf>`
Parameters
----------
in_feat : int
The number of input node feature's channels
ratio : float, optional
Pooling ratio. Default: 0.8
sample : bool, optional
Whether use k-hop union graph to increase efficiency.
Currently we only support full graph. Default: :obj:`False`
sym_score_norm : bool, optional
Use symmetric norm for adjacency or not. Default: :obj:`True`
sparse : bool, optional
Use edge sparsemax instead of edge softmax. Default: :obj:`True`
sl : bool, optional
Use structure learining module or not. Default: :obj:`True`
lamb : float, optional
The lambda parameter as weight of raw adjacency as described in the
HGP-SL paper. Default: 1.0
negative_slop : float, optional
Negative slop for leaky_relu. Default: 0.2
Returns
-------
DGLGraph
The pooled graph.
torch.Tensor
Node features
torch.Tensor
Edge features
torch.Tensor
Permutation index
"""
def __init__(self, in_feat:int, ratio=0.8, sample=True,
sym_score_norm=True, sparse=True, sl=True,
lamb=1.0, negative_slop=0.2, k_hop=3):
super(HGPSLPool, self).__init__()
self.in_feat = in_feat
self.ratio = ratio
self.sample = sample
self.sparse = sparse
self.sl = sl
self.lamb = lamb
self.negative_slop = negative_slop
self.k_hop = k_hop
self.att = Parameter(torch.Tensor(1, self.in_feat * 2))
self.calc_info_score = NodeInfoScoreLayer(sym_norm=sym_score_norm)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.att.data)
def forward(self, graph:DGLGraph, feat:Tensor, e_feat=None):
# top-k pool first
if e_feat is None:
e_feat = torch.ones((graph.number_of_edges(),),
dtype=feat.dtype, device=feat.device)
batch_num_nodes = graph.batch_num_nodes()
x_score = self.calc_info_score(graph, feat, e_feat)
perm, next_batch_num_nodes = topk(x_score, self.ratio,
get_batch_id(batch_num_nodes),
batch_num_nodes)
feat = feat[perm]
pool_graph = None
if not self.sample or not self.sl:
# pool graph
graph.edata["e"] = e_feat
pool_graph = dgl.node_subgraph(graph, perm)
e_feat = pool_graph.edata.pop("e")
pool_graph.set_batch_num_nodes(next_batch_num_nodes)
# no structure learning layer, directly return.
if not self.sl:
return pool_graph, feat, e_feat, perm
# Structure Learning
if self.sample:
# A fast mode for large graphs.
# In large graphs, learning the possible edge weights between each
# pair of nodes is time consuming. To accelerate this process,
# we sample it's K-Hop neighbors for each node and then learn the
# edge weights between them.
# first build multi-hop graph
row, col = graph.all_edges()
num_nodes = graph.num_nodes()
scipy_adj = scipy.sparse.coo_matrix((e_feat.detach().cpu(), (row.detach().cpu(), col.detach().cpu())), shape=(num_nodes, num_nodes))
for _ in range(self.k_hop):
two_hop = scipy_adj ** 2
two_hop = two_hop * (1e-5 / two_hop.max())
scipy_adj = two_hop + scipy_adj
row, col = scipy_adj.nonzero()
row = torch.tensor(row, dtype=torch.long, device=graph.device)
col = torch.tensor(col, dtype=torch.long, device=graph.device)
e_feat = torch.tensor(scipy_adj.data, dtype=torch.float, device=feat.device)
# perform pooling on multi-hop graph
mask = perm.new_full((num_nodes, ), -1)
i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
mask[perm] = i
row, col = mask[row], mask[col]
mask = (row >=0 ) & (col >= 0)
row, col = row[mask], col[mask]
e_feat = e_feat[mask]
# add remaining self loops
mask = row != col
num_nodes = perm.size(0) # num nodes after pool
loop_index = torch.arange(0, num_nodes, dtype=row.dtype, device=row.device)
inv_mask = ~mask
loop_weight = torch.full((num_nodes, ), 0, dtype=e_feat.dtype, device=e_feat.device)
remaining_e_feat = e_feat[inv_mask]
if remaining_e_feat.numel() > 0:
loop_weight[row[inv_mask]] = remaining_e_feat
e_feat = torch.cat([e_feat[mask], loop_weight], dim=0)
row, col = row[mask], col[mask]
row = torch.cat([row, loop_index], dim=0)
col = torch.cat([col, loop_index], dim=0)
# attention scores
weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(dim=-1)
weights = F.leaky_relu(weights, self.negative_slop) + e_feat * self.lamb
# sl and normalization
sl_graph = dgl.graph((row, col))
if self.sparse:
weights = edge_sparsemax(sl_graph, weights)
else:
weights = edge_softmax(sl_graph, weights)
# get final graph
mask = torch.abs(weights) > 0
row, col, weights = row[mask], col[mask], weights[mask]
pool_graph = dgl.graph((row, col))
pool_graph.set_batch_num_nodes(next_batch_num_nodes)
e_feat = weights
else:
# Learning the possible edge weights between each pair of
# nodes in the pooled subgraph, relative slower.
# construct complete graphs for all graph in the batch
# use dense to build, then transform to sparse.
# maybe there's more efficient way?
batch_num_nodes = next_batch_num_nodes
block_begin_idx = torch.cat([batch_num_nodes.new_zeros(1),
batch_num_nodes.cumsum(dim=0)[:-1]], dim=0)
block_end_idx = batch_num_nodes.cumsum(dim=0)
dense_adj = torch.zeros((pool_graph.num_nodes(),
pool_graph.num_nodes()),
dtype=torch.float,
device=feat.device)
for idx_b, idx_e in zip(block_begin_idx, block_end_idx):
dense_adj[idx_b:idx_e, idx_b:idx_e] = 1.
row, col = torch.nonzero(dense_adj).t().contiguous()
# compute weights for node-pairs
weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(dim=-1)
weights = F.leaky_relu(weights, self.negative_slop)
dense_adj[row, col] = weights
# add pooled graph structure to weight matrix
pool_row, pool_col = pool_graph.all_edges()
dense_adj[pool_row, pool_col] += self.lamb * e_feat
weights = dense_adj[row, col]
del dense_adj
torch.cuda.empty_cache()
# edge softmax/sparsemax
complete_graph = dgl.graph((row, col))
if self.sparse:
weights = edge_sparsemax(complete_graph, weights)
else:
weights = edge_softmax(complete_graph, weights)
# get new e_feat and graph structure, clean up.
mask = torch.abs(weights) > 1e-9
row, col, weights = row[mask], col[mask], weights[mask]
e_feat = weights
pool_graph = dgl.graph((row, col))
pool_graph.set_batch_num_nodes(next_batch_num_nodes)
return pool_graph, feat, e_feat, perm
class ConvPoolReadout(torch.nn.Module):
"""A helper class. (GraphConv -> Pooling -> Readout)"""
def __init__(self, in_feat:int, out_feat:int, pool_ratio=0.8,
sample:bool=False, sparse:bool=True, sl:bool=True,
lamb:float=1., pool:bool=True):
super(ConvPoolReadout, self).__init__()
self.use_pool = pool
self.conv = WeightedGraphConv(in_feat, out_feat)
if pool:
self.pool = HGPSLPool(out_feat, ratio=pool_ratio, sparse=sparse,
sample=sample, sl=sl, lamb=lamb)
else:
self.pool = None
self.avgpool = AvgPooling()
self.maxpool = MaxPooling()
def forward(self, graph, feature, e_feat=None):
out = F.relu(self.conv(graph, feature, e_feat))
if self.use_pool:
graph, out, e_feat, _ = self.pool(graph, out, e_feat)
readout = torch.cat([self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1)
return graph, out, e_feat, readout
import argparse
import json
import logging
import os
from time import time
import dgl
import torch
import torch.nn
import torch.nn.functional as F
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from torch.utils.data import random_split
from networks import HGPSLModel
from utils import get_stats
def parse_args():
parser = argparse.ArgumentParser(description="HGP-SL-DGL")
parser.add_argument("--dataset", type=str, default="DD",
choices=["DD", "PROTEINS", "NCI1", "NCI109", "Mutagenicity", "ENZYMES"],
help="DD/PROTEINS/NCI1/NCI109/Mutagenicity/ENZYMES")
parser.add_argument("--batch_size", type=int, default=512,
help="batch size")
parser.add_argument("--sample", type=str, default="true",
help="use sample method")
parser.add_argument("--lr", type=float, default=1e-3,
help="learning rate")
parser.add_argument("--weight_decay", type=float, default=1e-3,
help="weight decay")
parser.add_argument("--pool_ratio", type=float, default=0.5,
help="pooling ratio")
parser.add_argument("--hid_dim", type=int, default=128,
help="hidden size")
parser.add_argument("--conv_layers", type=int, default=3,
help="number of conv layers")
parser.add_argument("--dropout", type=float, default=0.0,
help="dropout ratio")
parser.add_argument("--lamb", type=float, default=1.0,
help="trade-off parameter")
parser.add_argument("--epochs", type=int, default=1000,
help="max number of training epochs")
parser.add_argument("--patience", type=int, default=100,
help="patience for early stopping")
parser.add_argument("--device", type=int, default=-1,
help="device id, -1 for cpu")
parser.add_argument("--dataset_path", type=str, default="./dataset",
help="path to dataset")
parser.add_argument("--print_every", type=int, default=10,
help="print trainlog every k epochs, -1 for silent training")
parser.add_argument("--num_trials", type=int, default=1,
help="number of trials")
parser.add_argument("--output_path", type=str, default="./output")
args = parser.parse_args()
# device
args.device = "cpu" if args.device == -1 else "cuda:{}".format(args.device)
if not torch.cuda.is_available():
logging.warning("CUDA is not available, use CPU for training.")
args.device = "cpu"
# print every
if args.print_every == -1:
args.print_every = args.epochs + 1
# bool args
if args.sample.lower() == "true":
args.sample = True
else:
args.sample = False
# paths
if not os.path.exists(args.dataset_path):
os.makedirs(args.dataset_path)
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
name = "Data={}_Hidden={}_Pool={}_WeightDecay={}_Lr={}_Sample={}.log".format(
args.dataset, args.hid_dim, args.pool_ratio, args.weight_decay, args.lr, args.sample)
args.output_path = os.path.join(args.output_path, name)
return args
def train(model:torch.nn.Module, optimizer, trainloader, device):
model.train()
total_loss = 0.
num_batches = len(trainloader)
for batch in trainloader:
optimizer.zero_grad()
batch_graphs, batch_labels = batch
for (key, value) in batch_graphs.ndata.items():
batch_graphs.ndata[key] = value.float()
batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.long().to(device)
out = model(batch_graphs, batch_graphs.ndata["feat"])
loss = F.nll_loss(out, batch_labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / num_batches
@torch.no_grad()
def test(model:torch.nn.Module, loader, device):
model.eval()
correct = 0.
loss = 0.
num_graphs = 0
for batch in loader:
batch_graphs, batch_labels = batch
num_graphs += batch_labels.size(0)
for (key, value) in batch_graphs.ndata.items():
batch_graphs.ndata[key] = value.float()
batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.long().to(device)
out = model(batch_graphs, batch_graphs.ndata["feat"])
pred = out.argmax(dim=1)
loss += F.nll_loss(out, batch_labels, reduction="sum").item()
correct += pred.eq(batch_labels).sum().item()
return correct / num_graphs, loss / num_graphs
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
dataset = LegacyTUDataset(args.dataset, raw_dir=args.dataset_path)
# add self loop. We add self loop for each graph here since the function "add_self_loop" does not
# support batch graph.
for i in range(len(dataset)):
dataset.graph_lists[i] = dgl.add_self_loop(dataset.graph_lists[i])
num_training = int(len(dataset) * 0.8)
num_val = int(len(dataset) * 0.1)
num_test = len(dataset) - num_val - num_training
train_set, val_set, test_set = random_split(dataset, [num_training, num_val, num_test])
train_loader = GraphDataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=6)
val_loader = GraphDataLoader(val_set, batch_size=args.batch_size, num_workers=2)
test_loader = GraphDataLoader(test_set, batch_size=args.batch_size, num_workers=2)
device = torch.device(args.device)
# Step 2: Create model =================================================================== #
num_feature, num_classes, _ = dataset.statistics()
model = HGPSLModel(in_feat=num_feature, out_feat=num_classes, hid_feat=args.hid_dim,
conv_layers=args.conv_layers, dropout=args.dropout, pool_ratio=args.pool_ratio,
lamb=args.lamb, sample=args.sample).to(device)
args.num_feature = int(num_feature)
args.num_classes = int(num_classes)
# Step 3: Create training components ===================================================== #
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# Step 4: training epoches =============================================================== #
bad_cound = 0
best_val_loss = float("inf")
final_test_acc = 0.
best_epoch = 0
train_times = []
for e in range(args.epochs):
s_time = time()
train_loss = train(model, optimizer, train_loader, device)
train_times.append(time() - s_time)
val_acc, val_loss = test(model, val_loader, device)
test_acc, _ = test(model, test_loader, device)
if best_val_loss > val_loss:
best_val_loss = val_loss
final_test_acc = test_acc
bad_cound = 0
best_epoch = e + 1
else:
bad_cound += 1
if bad_cound >= args.patience:
break
if (e + 1) % args.print_every == 0:
log_format = "Epoch {}: loss={:.4f}, val_acc={:.4f}, final_test_acc={:.4f}"
print(log_format.format(e + 1, train_loss, val_acc, final_test_acc))
print("Best Epoch {}, final test acc {:.4f}".format(best_epoch, final_test_acc))
return final_test_acc, sum(train_times) / len(train_times)
if __name__ == "__main__":
args = parse_args()
res = []
train_times = []
for i in range(args.num_trials):
print("Trial {}/{}".format(i + 1, args.num_trials))
acc, train_time = main(args)
res.append(acc)
train_times.append(train_time)
mean, err_bd = get_stats(res, conf_interval=False)
print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd))
out_dict = {"hyper-parameters": vars(args),
"result": "{:.4f}(+-{:.4f})".format(mean, err_bd),
"train_time": "{:.4f}".format(sum(train_times) / len(train_times))}
with open(args.output_path, "w") as f:
json.dump(out_dict, f, sort_keys=True, indent=4)
import torch
from dgl.nn import AvgPooling, MaxPooling
import torch.nn.functional as F
import torch.nn
from layers import ConvPoolReadout
class HGPSLModel(torch.nn.Module):
r"""
Description
-----------
The graph classification model using HGP-SL pooling.
Parameters
----------
in_feat : int
The number of input node feature's channels.
out_feat : int
The number of output node feature's channels.
hid_feat : int
The number of hidden state's channels.
dropout : float, optional
The dropout rate. Default: 0
pool_ratio : float, optional
The pooling ratio for each pooling layer. Default: 0.5
conv_layers : int, optional
The number of graph convolution and pooling layers. Default: 3
sample : bool, optional
Whether use k-hop union graph to increase efficiency.
Currently we only support full graph. Default: :obj:`False`
sparse : bool, optional
Use edge sparsemax instead of edge softmax. Default: :obj:`True`
sl : bool, optional
Use structure learining module or not. Default: :obj:`True`
lamb : float, optional
The lambda parameter as weight of raw adjacency as described in the
HGP-SL paper. Default: 1.0
"""
def __init__(self, in_feat:int, out_feat:int, hid_feat:int,
dropout:float=0., pool_ratio:float=.5, conv_layers:int=3,
sample:bool=False, sparse:bool=True, sl:bool=True,
lamb:float=1.):
super(HGPSLModel, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.hid_feat = hid_feat
self.dropout = dropout
self.num_layers = conv_layers
self.pool_ratio = pool_ratio
convpools = []
for i in range(conv_layers):
c_in = in_feat if i == 0 else hid_feat
c_out = hid_feat
use_pool = (i != conv_layers - 1)
convpools.append(ConvPoolReadout(c_in, c_out, pool_ratio=pool_ratio,
sample=sample, sparse=sparse, sl=sl,
lamb=lamb, pool=use_pool))
self.convpool_layers = torch.nn.ModuleList(convpools)
self.lin1 = torch.nn.Linear(hid_feat * 2, hid_feat)
self.lin2 = torch.nn.Linear(hid_feat, hid_feat // 2)
self.lin3 = torch.nn.Linear(hid_feat // 2, self.out_feat)
def forward(self, graph, n_feat):
final_readout = None
e_feat = None
for i in range(self.num_layers):
graph, n_feat, e_feat, readout = self.convpool_layers[i](graph, n_feat, e_feat)
if final_readout is None:
final_readout = readout
else:
final_readout = final_readout + readout
n_feat = F.relu(self.lin1(final_readout))
n_feat = F.dropout(n_feat, p=self.dropout, training=self.training)
n_feat = F.relu(self.lin2(n_feat))
n_feat = F.dropout(n_feat, p=self.dropout, training=self.training)
n_feat = self.lin3(n_feat)
return F.log_softmax(n_feat, dim=-1)
import torch
import logging
from scipy.stats import t
import math
def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False):
"""Compute mean and standard deviation from an numerical array
Args:
array (array like obj): The numerical array, this array can be
convert to :obj:`torch.Tensor`.
conf_interval (bool, optional): If True, compute the confidence interval bound (95%)
instead of the std value. (default: :obj:`False`)
name (str, optional): The name of this numerical array, for log usage.
(default: :obj:`None`)
stdout (bool, optional): Whether to output result to the terminal.
(default: :obj:`False`)
logout (bool, optional): Whether to output result via logging module.
(default: :obj:`False`)
"""
eps = 1e-9
array = torch.Tensor(array)
std, mean = torch.std_mean(array)
std = std.item()
mean = mean.item()
center = mean
if conf_interval:
n = array.size(0)
se = std / (math.sqrt(n) + eps)
t_value = t.ppf(0.975, df=n-1)
err_bound = t_value * se
else:
err_bound = std
# log and print
if name is None:
name = "array {}".format(id(array))
log = "{}: {:.4f}(+-{:.4f})".format(name, center, err_bound)
if stdout:
print(log)
if logout:
logging.info(log)
return center, err_bound
def get_batch_id(num_nodes:torch.Tensor):
"""Convert the num_nodes array obtained from batch graph to batch_id array
for each node.
Args:
num_nodes (torch.Tensor): The tensor whose element is the number of nodes
in each graph in the batch graph.
"""
batch_size = num_nodes.size(0)
batch_ids = []
for i in range(batch_size):
item = torch.full((num_nodes[i],), i, dtype=torch.long, device=num_nodes.device)
batch_ids.append(item)
return torch.cat(batch_ids)
def topk(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Tensor):
"""The top-k pooling method. Given a graph batch, this method will pool out some
nodes from input node feature tensor for each graph according to the given ratio.
Args:
x (torch.Tensor): The input node feature batch-tensor to be pooled.
ratio (float): the pool ratio. For example if :obj:`ratio=0.5` then half of the input
tensor will be pooled out.
batch_id (torch.Tensor): The batch_id of each element in the input tensor.
num_nodes (torch.Tensor): The number of nodes of each graph in batch.
Returns:
perm (torch.Tensor): The index in batch to be kept.
k (torch.Tensor): The remaining number of nodes for each graph.
"""
batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
cum_num_nodes = torch.cat(
[num_nodes.new_zeros(1),
num_nodes.cumsum(dim=0)[:-1]], dim=0)
index = torch.arange(batch_id.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch_id]) + (batch_id * max_num_nodes)
dense_x = x.new_full((batch_size * max_num_nodes, ), torch.finfo(x.dtype).min)
dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes)
_, perm = dense_x.sort(dim=-1, descending=True)
perm = perm + cum_num_nodes.view(-1, 1)
perm = perm.view(-1)
k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)
mask = [
torch.arange(k[i], dtype=torch.long, device=x.device) +
i * max_num_nodes for i in range(batch_size)]
mask = torch.cat(mask, dim=0)
perm = perm[mask]
return perm, k
......@@ -101,9 +101,10 @@ def test(model:torch.nn.Module, loader, device):
model.eval()
correct = 0.
loss = 0.
num_graphs = len(loader)
num_graphs = 0
for batch in loader:
batch_graphs, batch_labels = batch
num_graphs += batch_labels.size(0)
for (key, value) in batch_graphs.ndata.items():
batch_graphs.ndata[key] = value.float()
batch_graphs = batch_graphs.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment