Unverified Commit 72ef642f authored by 张天启's avatar 张天启 Committed by GitHub
Browse files

[Example] Add sagpool example for pytorch backend (#2429)



* add sagpool example for pytorch backend

* polish sagpool example for pytorch backend

* [Example] SAGPool: use std variance

* [Example] SAGPool: change to std

* add sagpool example to index page

* add graph property prediction tag to sagpool
Co-authored-by: default avatarzhangtianqi <tianqizh@amazon.com>
parent 5d8330cc
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
| [Molecular Graph Convolutions: Moving Beyond Fingerprints](#weave) | | | :heavy_check_mark: | | | | [Molecular Graph Convolutions: Moving Beyond Fingerprints](#weave) | | | :heavy_check_mark: | | |
| [LINE: Large-scale Information Network Embedding](#line) | | :heavy_check_mark: | | | :heavy_check_mark: | | [LINE: Large-scale Information Network Embedding](#line) | | :heavy_check_mark: | | | :heavy_check_mark: |
| [DeepWalk: Online Learning of Social Representations](#deepwalk) | | :heavy_check_mark: | | | :heavy_check_mark: | | [DeepWalk: Online Learning of Social Representations](#deepwalk) | | :heavy_check_mark: | | | :heavy_check_mark: |
| [Self-Attention Graph Pooling](#sagpool) | | | :heavy_check_mark: | | |
| | | | | | | | | | | | | |
| | | | | | | | | | | | | |
...@@ -130,6 +131,10 @@ ...@@ -130,6 +131,10 @@
- Example code: [PyTorch](../examples/pytorch/mixhop) - Example code: [PyTorch](../examples/pytorch/mixhop)
- Tags: node classification - Tags: node classification
- <a name="sagpool"></a> Lee, Junhyun, et al. Self-Attention Graph Pooling. [Paper link](https://arxiv.org/abs/1904.08082).
- Example code: [PyTorch](../examples/pytorch/sagpool)
- Tags: graph classification, pooling
## 2018 ## 2018
- <a name="dgmg"></a> Li et al. Learning Deep Generative Models of Graphs. [Paper link](https://arxiv.org/abs/1803.03324). - <a name="dgmg"></a> Li et al. Learning Deep Generative Models of Graphs. [Paper link](https://arxiv.org/abs/1803.03324).
......
# DGL Implementation of the SAGPool Paper
This DGL example implements the GNN model proposed in the paper [Self Attention Graph Pooling](https://arxiv.org/pdf/1904.08082.pdf).
The author's codes of implementation is in [here](https://github.com/inyeoplee77/SAGPool)
The graph dataset used in this example
---------------------------------------
The DGL's built-in LegacyTUDataset. This is a serial of graph kernel datasets for graph classification. We use 'DD', 'PROTEINS', 'NCI1', 'NCI109' and 'Mutagenicity' in this SAGPool implementation. All these datasets are randomly splited to train, validation and test set with ratio 0.8, 0.1 and 0.1.
NOTE: Since there is no data attributes in some of these datasets, we use node_id (in one-hot vector whose length is the max number of nodes across all graphs) as the node feature. Also note that the node_id in some datasets is not unique (e.g. a graph may has two nodes with the same id).
DD
- NumGraphs: 1178
- AvgNodesPerGraph: 284.32
- AvgEdgesPerGraph: 715.66
- NumFeats: 89
- NumClasses: 2
PROTEINS
- NumGraphs: 1113
- AvgNodesPerGraph: 39.06
- AvgEdgesPerGraph: 72.82
- NumFeats: 1
- NumClasses: 2
NCI1
- NumGraphs: 4110
- AvgNodesPerGraph: 29.87
- AvgEdgesPerGraph: 32.30
- NumFeats: 37
- NumClasses: 2
NCI109
- NumGraphs: 4127
- AvgNodesPerGraph: 29.68
- AvgEdgesPerGraph: 32.13
- NumFeats: 38
- NumClasses: 2
Mutagenicity
- NumGraphs: 4337
- AvgNodesPerGraph: 30.32
- AvgEdgesPerGraph: 30.77
- NumFeats: 14
- NumClasses: 2
How to run example files
--------------------------------
The valid dataset names (you can find a full list [here](https://chrsmrrs.github.io/datasets/docs/datasets/)):
- 'DD' for D&D
- 'PROTEINS' for PROTEINS
- 'NCI1' for NCI1
- 'NCI109' for NCI109
- 'Mutagenicity' for Mutagenicity
In the sagpool folder, run
```bash
python main.py --dataset ${your_dataset_name_here}
```
If want to use a GPU, run
```bash
python main.py --device ${your_device_id_here} --dataset ${your_dataset_name_here}
```
If your want to perform a grid search, modify parameter settings in `grid_search_config.json` and run
```bash
python grid_search.py --device ${your_device_id_here} --num_trials ${num_of_trials_here}
```
Performance
-------------------------
NOTE: We do not perform grid search or finetune here, so there may be a gap between results in paper and our results. Also, we only perform 10 trials for each experiment, which is different from 200 trials per experiment in the paper.
**The global architecture result**
| Dataset | paper result (global) | ours (global) |
| ------------- | -------------------------------- | --------------------------- |
| D&D | 76.19 (0.94) | 74.79 (2.69) |
| PROTEINS | 70.04 (1.47) | 70.36 (5.90) |
| NCI1 | 74.18 (1.20) | 72.82 (2.36) |
| NCI109 | 74.06 (0.78) | 71.64 (2.65) |
| Mutagenicity | N/A | 76.55 (2.89) |
**The hierarchical architecture result**
| Dataset | paper result (hierarchical) | ours (hierarchical) |
| ------------- | -------------------------------- | --------------------------- |
| D&D | 76.45 (0.97) | 75.38 (4.17) |
| PROTEINS | 71.86 (0.97) | 70.36 (5.68) |
| NCI1 | 67.45 (1.11) | 70.61 (2.25) |
| NCI109 | 67.86 (1.41) | 69.13 (3.85) |
| Mutagenicity | N/A | 75.20 (1.95) |
import torch.utils.data
from torch.utils.data.dataloader import DataLoader
import dgl
import numpy as np
def collate_fn(batch):
"""
collate_fn for dataset batching
transform ndata to tensor (in gpu is available)
"""
graphs, labels = map(list, zip(*batch))
# batch graphs and cast to PyTorch tensor
for graph in graphs:
for (key, value) in graph.ndata.items():
graph.ndata[key] = value.float()
batched_graphs = dgl.batch(graphs)
# cast to PyTorch tensor
batched_labels = torch.LongTensor(np.array(labels))
return batched_graphs, batched_labels
class GraphDataLoader(DataLoader):
def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):
super(GraphDataLoader, self).__init__(dataset, batch_size, shuffle,
collate_fn=collate_fn, **kwargs)
import json
import os
from copy import deepcopy
from main import main, parse_args
from utils import get_stats
def load_config(path="./grid_search_config.json"):
with open(path, "r") as f:
return json.load(f)
def run_experiments(args):
res = []
for i in range(args.num_trials):
print("Trial {}/{}".format(i + 1, args.num_trials))
acc, _ = main(args)
res.append(acc)
mean, err_bd = get_stats(res, conf_interval=True)
return mean, err_bd
def grid_search(config:dict):
args = parse_args()
results = {}
for d in config["dataset"]:
args.dataset = d
best_acc, err_bd = 0., 0.
best_args = vars(args)
for arch in config["arch"]:
args.architecture = arch
for hidden in config["hidden"]:
args.hid_dim = hidden
for pool_ratio in config["pool_ratio"]:
args.pool_ratio = pool_ratio
for lr in config["lr"]:
args.lr = lr
for weight_decay in config["weight_decay"]:
args.weight_decay = weight_decay
acc, bd = run_experiments(args)
if acc > best_acc:
best_acc = acc
err_bd = bd
best_args = deepcopy(vars(args))
args.output_path = "./output"
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
args.output_path = "./output/{}.log".format(d)
result = {
"params": best_args,
"result": "{:.4f}({:.4f})".format(best_acc, err_bd)
}
with open(args.output_path, "w") as f:
json.dump(result, f, sort_keys=True, indent=4)
grid_search(load_config())
{
"arch": ["hierarchical", "global"],
"hidden": [16, 32, 64, 128],
"pool_ratio": [0.25, 0.5],
"lr": [1e-2, 5e-2, 1e-3, 5e-3, 1e-4, 5e-4],
"weight_decay": [1e-2, 1e-3, 1e-4, 1e-5],
"dataset": ["DD", "PROTEINS", "NCI1", "NCI109", "Mutagenicity"]
}
import torch
import torch.nn.functional as F
import dgl
from dgl.nn import GraphConv, AvgPooling, MaxPooling
from utils import topk, get_batch_id
class SAGPool(torch.nn.Module):
"""The Self-Attention Pooling layer in paper
`Self Attention Graph Pooling <https://arxiv.org/pdf/1904.08082.pdf>`
Args:
in_dim (int): The dimension of node feature.
ratio (float, optional): The pool ratio which determines the amount of nodes
remain after pooling. (default: :obj:`0.5`)
conv_op (torch.nn.Module, optional): The graph convolution layer in dgl used to
compute scale for each node. (default: :obj:`dgl.nn.GraphConv`)
non_linearity (Callable, optional): The non-linearity function, a pytorch function.
(default: :obj:`torch.tanh`)
"""
def __init__(self, in_dim:int, ratio=0.5, conv_op=GraphConv, non_linearity=torch.tanh):
super(SAGPool, self).__init__()
self.in_dim = in_dim
self.ratio = ratio
self.score_layer = conv_op(in_dim, 1)
self.non_linearity = non_linearity
def forward(self, graph:dgl.DGLGraph, feature:torch.Tensor):
score = self.score_layer(graph, feature).squeeze()
perm, next_batch_num_nodes = topk(score, self.ratio, get_batch_id(graph.batch_num_nodes()), graph.batch_num_nodes())
feature = feature[perm] * self.non_linearity(score[perm]).view(-1, 1)
graph = dgl.node_subgraph(graph, perm)
# node_subgraph currently does not support batch-graph,
# the 'batch_num_nodes' of the result subgraph is None.
# So we manually set the 'batch_num_nodes' here.
# Since global pooling has nothing to do with 'batch_num_edges',
# we can leave it to be None or unchanged.
graph.set_batch_num_nodes(next_batch_num_nodes)
return graph, feature, perm
class ConvPoolBlock(torch.nn.Module):
"""A combination of GCN layer and SAGPool layer,
followed by a concatenated (mean||sum) readout operation.
"""
def __init__(self, in_dim:int, out_dim:int, pool_ratio=0.8):
super(ConvPoolBlock, self).__init__()
self.conv = GraphConv(in_dim, out_dim)
self.pool = SAGPool(out_dim, ratio=pool_ratio)
self.avgpool = AvgPooling()
self.maxpool = MaxPooling()
def forward(self, graph, feature):
out = F.relu(self.conv(graph, feature))
graph, out, _ = self.pool(graph, out)
g_out = torch.cat([self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1)
return graph, out, g_out
import argparse
import json
import logging
import os
from time import time
import dgl
import torch
import torch.nn
import torch.nn.functional as F
from dgl.data import LegacyTUDataset
from torch.utils.data import random_split
from dataloader import GraphDataLoader
from network import get_sag_network
from utils import get_stats
def parse_args():
parser = argparse.ArgumentParser(description="Self-Attention Graph Pooling")
parser.add_argument("--dataset", type=str, default="DD",
choices=["DD", "PROTEINS", "NCI1", "NCI109", "Mutagenicity"],
help="DD/PROTEINS/NCI1/NCI109/Mutagenicity")
parser.add_argument("--batch_size", type=int, default=128,
help="batch size")
parser.add_argument("--lr", type=float, default=5e-4,
help="learning rate")
parser.add_argument("--weight_decay", type=float, default=1e-4,
help="weight decay")
parser.add_argument("--pool_ratio", type=float, default=0.5,
help="pooling ratio")
parser.add_argument("--hid_dim", type=int, default=128,
help="hidden size")
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout ratio")
parser.add_argument("--epochs", type=int, default=100000,
help="max number of training epochs")
parser.add_argument("--patience", type=int, default=50,
help="patience for early stopping")
parser.add_argument("--device", type=int, default=-1,
help="device id, -1 for cpu")
parser.add_argument("--architecture", type=str, default="hierarchical",
choices=["hierarchical", "global"],
help="model architecture")
parser.add_argument("--dataset_path", type=str, default="./dataset",
help="path to dataset")
parser.add_argument("--conv_layers", type=int, default=3,
help="number of conv layers")
parser.add_argument("--print_every", type=int, default=10,
help="print trainlog every k epochs, -1 for silent training")
parser.add_argument("--num_trials", type=int, default=1,
help="number of trials")
parser.add_argument("--output_path", type=str, default="./output")
args = parser.parse_args()
# device
args.device = "cpu" if args.device == -1 else "cuda:{}".format(args.device)
if not torch.cuda.is_available():
logging.warning("CUDA is not available, use CPU for training.")
args.device = "cpu"
# print every
if args.print_every == -1:
args.print_every = args.epochs + 1
# paths
if not os.path.exists(args.dataset_path):
os.makedirs(args.dataset_path)
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
name = "Data={}_Hidden={}_Arch={}_Pool={}_WeightDecay={}_Lr={}.log".format(
args.dataset, args.hid_dim, args.architecture, args.pool_ratio, args.weight_decay, args.lr)
args.output_path = os.path.join(args.output_path, name)
return args
def train(model:torch.nn.Module, optimizer, trainloader, device):
model.train()
total_loss = 0.
for batch in trainloader:
optimizer.zero_grad()
batch_graphs, batch_labels = batch
batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.to(device)
out = model(batch_graphs)
loss = F.nll_loss(out, batch_labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(trainloader.dataset)
@torch.no_grad()
def test(model:torch.nn.Module, loader, device):
model.eval()
correct = 0.
loss = 0.
num_graphs = len(loader.dataset)
for batch in loader:
batch_graphs, batch_labels = batch
batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.to(device)
out = model(batch_graphs)
pred = out.argmax(dim=1)
loss += F.nll_loss(out, batch_labels, reduction="sum").item()
correct += pred.eq(batch_labels).sum().item()
return correct / num_graphs, loss / num_graphs
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
dataset = LegacyTUDataset(args.dataset, raw_dir=args.dataset_path)
# add self loop. We add self loop for each graph here since the function "add_self_loop" does not
# support batch graph.
for i in range(len(dataset)):
dataset.graph_lists[i] = dgl.add_self_loop(dataset.graph_lists[i])
num_training = int(len(dataset) * 0.8)
num_val = int(len(dataset) * 0.1)
num_test = len(dataset) - num_val - num_training
train_set, val_set, test_set = random_split(dataset, [num_training, num_val, num_test])
train_loader = GraphDataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=6)
val_loader = GraphDataLoader(val_set, batch_size=args.batch_size, num_workers=2)
test_loader = GraphDataLoader(test_set, batch_size=args.batch_size, num_workers=2)
device = torch.device(args.device)
# Step 2: Create model =================================================================== #
num_feature, num_classes, _ = dataset.statistics()
model_op = get_sag_network(args.architecture)
model = model_op(in_dim=num_feature, hid_dim=args.hid_dim, out_dim=num_classes,
num_convs=args.conv_layers, pool_ratio=args.pool_ratio, dropout=args.dropout).to(device)
args.num_feature = int(num_feature)
args.num_classes = int(num_classes)
# Step 3: Create training components ===================================================== #
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# Step 4: training epoches =============================================================== #
bad_cound = 0
best_val_loss = float("inf")
final_test_acc = 0.
best_epoch = 0
train_times = []
for e in range(args.epochs):
s_time = time()
train_loss = train(model, optimizer, train_loader, device)
train_times.append(time() - s_time)
val_acc, val_loss = test(model, val_loader, device)
test_acc, _ = test(model, test_loader, device)
if best_val_loss > val_loss:
best_val_loss = val_loss
final_test_acc = test_acc
bad_cound = 0
best_epoch = e + 1
else:
bad_cound += 1
if bad_cound >= args.patience:
break
if (e + 1) % args.print_every == 0:
log_format = "Epoch {}: loss={:.4f}, val_acc={:.4f}, final_test_acc={:.4f}"
print(log_format.format(e + 1, train_loss, val_acc, final_test_acc))
print("Best Epoch {}, final test acc {:.4f}".format(best_epoch, final_test_acc))
return final_test_acc, sum(train_times) / len(train_times)
if __name__ == "__main__":
args = parse_args()
res = []
train_times = []
for i in range(args.num_trials):
print("Trial {}/{}".format(i + 1, args.num_trials))
acc, train_time = main(args)
res.append(acc)
train_times.append(train_time)
mean, err_bd = get_stats(res)
print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd))
out_dict = {"hyper-parameters": vars(args),
"result": "{:.4f}(+-{:.4f})".format(mean, err_bd),
"train_time": "{:.4f}".format(sum(train_times) / len(train_times))}
with open(args.output_path, "w") as f:
json.dump(out_dict, f, sort_keys=True, indent=4)
import torch
import torch.nn
import torch.nn.functional as F
import dgl
from dgl.nn import GraphConv, AvgPooling, MaxPooling
from layer import ConvPoolBlock, SAGPool
class SAGNetworkHierarchical(torch.nn.Module):
"""The Self-Attention Graph Pooling Network with hierarchical readout in paper
`Self Attention Graph Pooling <https://arxiv.org/pdf/1904.08082.pdf>`
Args:
in_dim (int): The input node feature dimension.
hid_dim (int): The hidden dimension for node feature.
out_dim (int): The output dimension.
num_convs (int, optional): The number of graph convolution layers.
(default: 3)
pool_ratio (float, optional): The pool ratio which determines the amount of nodes
remain after pooling. (default: :obj:`0.5`)
dropout (float, optional): The dropout ratio for each layer. (default: 0)
"""
def __init__(self, in_dim:int, hid_dim:int, out_dim:int, num_convs=3,
pool_ratio:float=0.5, dropout:float=0.0):
super(SAGNetworkHierarchical, self).__init__()
self.dropout = dropout
self.num_convpools = num_convs
convpools = []
for i in range(num_convs):
_i_dim = in_dim if i == 0 else hid_dim
_o_dim = hid_dim
convpools.append(ConvPoolBlock(_i_dim, _o_dim, pool_ratio=pool_ratio))
self.convpools = torch.nn.ModuleList(convpools)
self.lin1 = torch.nn.Linear(hid_dim * 2, hid_dim)
self.lin2 = torch.nn.Linear(hid_dim, hid_dim // 2)
self.lin3 = torch.nn.Linear(hid_dim // 2, out_dim)
def forward(self, graph:dgl.DGLGraph):
feat = graph.ndata["feat"]
final_readout = None
for i in range(self.num_convpools):
graph, feat, readout = self.convpools[i](graph, feat)
if final_readout is None:
final_readout = readout
else:
final_readout = final_readout + readout
feat = F.relu(self.lin1(final_readout))
feat = F.dropout(feat, p=self.dropout, training=self.training)
feat = F.relu(self.lin2(feat))
feat = F.log_softmax(self.lin3(feat), dim=-1)
return feat
class SAGNetworkGlobal(torch.nn.Module):
"""The Self-Attention Graph Pooling Network with global readout in paper
`Self Attention Graph Pooling <https://arxiv.org/pdf/1904.08082.pdf>`
Args:
in_dim (int): The input node feature dimension.
hid_dim (int): The hidden dimension for node feature.
out_dim (int): The output dimension.
num_convs (int, optional): The number of graph convolution layers.
(default: 3)
pool_ratio (float, optional): The pool ratio which determines the amount of nodes
remain after pooling. (default: :obj:`0.5`)
dropout (float, optional): The dropout ratio for each layer. (default: 0)
"""
def __init__(self, in_dim:int, hid_dim:int, out_dim:int, num_convs=3,
pool_ratio:float=0.5, dropout:float=0.0):
super(SAGNetworkGlobal, self).__init__()
self.dropout = dropout
self.num_convs = num_convs
convs = []
for i in range(num_convs):
_i_dim = in_dim if i == 0 else hid_dim
_o_dim = hid_dim
convs.append(GraphConv(_i_dim, _o_dim))
self.convs = torch.nn.ModuleList(convs)
concat_dim = num_convs * hid_dim
self.pool = SAGPool(concat_dim, ratio=pool_ratio)
self.avg_readout = AvgPooling()
self.max_readout = MaxPooling()
self.lin1 = torch.nn.Linear(concat_dim * 2, hid_dim)
self.lin2 = torch.nn.Linear(hid_dim, hid_dim // 2)
self.lin3 = torch.nn.Linear(hid_dim // 2, out_dim)
def forward(self, graph:dgl.DGLGraph):
feat = graph.ndata["feat"]
conv_res = []
for i in range(self.num_convs):
feat = self.convs[i](graph, feat)
conv_res.append(feat)
conv_res = torch.cat(conv_res, dim=-1)
graph, feat, _ = self.pool(graph, conv_res)
feat = torch.cat([self.avg_readout(graph, feat), self.max_readout(graph, feat)], dim=-1)
feat = F.relu(self.lin1(feat))
feat = F.dropout(feat, p=self.dropout, training=self.training)
feat = F.relu(self.lin2(feat))
feat = F.log_softmax(self.lin3(feat), dim=-1)
return feat
def get_sag_network(net_type:str="hierarchical"):
if net_type == "hierarchical":
return SAGNetworkHierarchical
elif net_type == "global":
return SAGNetworkGlobal
else:
raise ValueError("SAGNetwork type {} is not supported.".format(net_type))
import torch
import logging
from scipy.stats import t
import math
def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False):
"""Compute mean and standard deviation from an numerical array
Args:
array (array like obj): The numerical array, this array can be
convert to :obj:`torch.Tensor`.
conf_interval (bool, optional): If True, compute the confidence interval bound (95%)
instead of the std value. (default: :obj:`False`)
name (str, optional): The name of this numerical array, for log usage.
(default: :obj:`None`)
stdout (bool, optional): Whether to output result to the terminal.
(default: :obj:`False`)
logout (bool, optional): Whether to output result via logging module.
(default: :obj:`False`)
"""
eps = 1e-9
array = torch.Tensor(array)
std, mean = torch.std_mean(array)
std = std.item()
mean = mean.item()
center = mean
if conf_interval:
n = array.size(0)
se = std / (math.sqrt(n) + eps)
t_value = t.ppf(0.975, df=n-1)
err_bound = t_value * se
else:
err_bound = std
# log and print
if name is None:
name = "array {}".format(id(array))
log = "{}: {:.4f}(+-{:.4f})".format(name, center, err_bound)
if stdout:
print(log)
if logout:
logging.info(log)
return center, err_bound
def get_batch_id(num_nodes:torch.Tensor):
"""Convert the num_nodes array obtained from batch graph to batch_id array
for each node.
Args:
num_nodes (torch.Tensor): The tensor whose element is the number of nodes
in each graph in the batch graph.
"""
batch_size = num_nodes.size(0)
batch_ids = []
for i in range(batch_size):
item = torch.full((num_nodes[i],), i, dtype=torch.long, device=num_nodes.device)
batch_ids.append(item)
return torch.cat(batch_ids)
def topk(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Tensor):
"""The top-k pooling method. Given a graph batch, this method will pool out some
nodes from input node feature tensor for each graph according to the given ratio.
Args:
x (torch.Tensor): The input node feature batch-tensor to be pooled.
ratio (float): the pool ratio. For example if :obj:`ratio=0.5` then half of the input
tensor will be pooled out.
batch_id (torch.Tensor): The batch_id of each element in the input tensor.
num_nodes (torch.Tensor): The number of nodes of each graph in batch.
Returns:
perm (torch.Tensor): The index in batch to be kept.
k (torch.Tensor): The remaining number of nodes for each graph.
"""
batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
cum_num_nodes = torch.cat(
[num_nodes.new_zeros(1),
num_nodes.cumsum(dim=0)[:-1]], dim=0)
index = torch.arange(batch_id.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch_id]) + (batch_id * max_num_nodes)
dense_x = x.new_full((batch_size * max_num_nodes, ), torch.finfo(x.dtype).min)
dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes)
_, perm = dense_x.sort(dim=-1, descending=True)
perm = perm + cum_num_nodes.view(-1, 1)
perm = perm.view(-1)
k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)
mask = [
torch.arange(k[i], dtype=torch.long, device=x.device) +
i * max_num_nodes for i in range(batch_size)]
mask = torch.cat(mask, dim=0)
perm = perm[mask]
return perm, k
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment