Unverified Commit 9c41c22d authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[Model] Official implementation for HiLANDER model. (#3087)



* add hilander model implementation draft

* use focal loss

* fix

* change data root

* add necessary scripts

* update download links

* update

* update example table

* fix

* update readme with numbers

* add empty folder

* only eval at the end

* set up hilander

* inform results may fluctuate

* address comments
Co-authored-by: default avatarsneakerkg <xiaotj1990327@gmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-19-212.us-east-2.compute.internal>
parent 66ad774f
......@@ -10,6 +10,7 @@ The folder contains example implementations of selected research papers related
| ------------------------------------------------------------ | ------------------- | -------------------------------- | ------------------------- | ------------------ | ------------------ |
| [Latent Dirichlet Allocation](#lda) | :heavy_check_mark: | :heavy_check_mark: | | | |
| [Network Embedding with Completely-imbalanced Labels](#rect) | :heavy_check_mark: | | | | |
| [Learning Hierarchical Graph Neural Networks for Image Clustering](#hilander) | | | | | |
| [Boost then Convolve: Gradient Boosting Meets Graph Neural Networks](#bgnn) | :heavy_check_mark: | | | | |
| [Contrastive Multi-View Representation Learning on Graphs](#mvgrl) | :heavy_check_mark: | | :heavy_check_mark: | | |
| [Deep Graph Contrastive Representation Learning](#grace) | :heavy_check_mark: | | | | |
......@@ -106,6 +107,9 @@ The folder contains example implementations of selected research papers related
## 2021
- <a name="hilander"></a> Xing et al. Learning Hierarchical Graph Neural Networks for Image Clustering.
- Example code: [PyTorch](../examples/pytorch/hilander)
- Tags: clustering
- <a name="bgnn"></a> Ivanov et al. Boost then Convolve: Gradient Boosting Meets Graph Neural Networks. [Paper link](https://openreview.net/forum?id=ebS5NUfoMKL).
- Example code: [PyTorch](../examples/pytorch/bgnn)
- Tags: semi-supervised node classification, tabular data, GBDT
......
Learning Hierarchical Graph Neural Networks for Image Clustering
================================================================
This folder contains the official code for "Learning Hierarchical Graph Neural Networks for Image Clustering"(link needed).
## Setup
We use python 3.7. The CUDA version needs to be 10.2. Besides DGL (>=0.5.2), we depend on several packages. To install dependencies using conda:
```bash
conda create -n Hilander # create env
conda activate Hilander # activate env
conda install pytorch==1.7.0 torchvision==0.8.0 cudatoolkit=10.2 -c pytorch # install pytorch 1.7 version
conda install -c pytorch faiss-gpu cudatoolkit=10.2 # install faiss gpu version matching cuda 10.2
pip install dgl-cu102 # install dgl for cuda 10.2
pip install tqdm # install tqdm
git clone https://github.com/yjxiong/clustering-benchmark.git # install clustering-benchmark for evaluation
cd clustering-benchmark
python setup.py install
cd ../
```
## Data
The datasets used for training and test are hosted by several services.
[AWS S3](https://dgl-data.s3.us-west-2.amazonaws.com/dataset/hilander/data.tar.gz) | [Google Drive](https://drive.google.com/file/d/1KLa3uu9ndaCc7YjnSVRLHpcJVMSz868v/view?usp=sharing) | [BaiduPan](https://pan.baidu.com/s/11iRcp84esfkkvdcw3kmPAw) (pwd: wbmh)
After download, unpack the pickled files into `data/`.
## Training
We provide training scripts for different datasets.
For training on DeepGlint, one can run
```bash
bash scripts/train_deepglint.sh
```
Deepglint is a large-scale dataset, we randomly select 10% of the classes to construct a subset to train.
For training on full iNatualist dataset, one can run
```bash
bash scripts/train_inat.sh
```
For training on re-sampled iNatualist dataset, one can run
```bash
bash scripts/train_inat_resampled_1_in_6_per_class.sh
```
We sample a subset of the full iNat2018-Train to attain a drastically different train-time cluster size distribution as iNat2018-Test, which is named as inat_resampled_1_in_6_per_class.
## Inference
In the paper, we have two experiment settings: Clustering with Seen Test Data Distribution and Clustering with Unseen Test Data Distribution.
For Clustering with Seen Test Data Distribution, one can run
```bash
bash scripts/test_deepglint_imbd_sampled_as_deepglint.sh
bash scripts/test_inat.sh
```
**Clustering with Seen Test Data Distribution Performance**
| | IMDB-Test-SameDist | iNat2018-Test |
| ------------------ | ------------------------------: | ------------------------------: |
| Fp | 0.793 | 0.330 |
| Fb | 0.795 | 0.350 |
| NMI | 0.947 | 0.774 |
* The results might fluctuate a little due to the randomness introduced by gpu knn building using faiss-gpu.
For Clustering with Unseen Test Data Distribution, one can run
```bash
bash scripts/test_deepglint_hannah.sh
bash scripts/test_deepglint_imdb.sh
bash scripts/test_inat_train_on_resampled_1_in_6_per_class.sh
```
**Clustering with Unseen Test Data Distribution Performance**
| | Hannah | IMDB | iNat2018-Test |
| ------------------ | ------------------------------: | ------------------------------: | ------------------------------: |
| Fp | 0.720 | 0.765 | 0.294 |
| Fb | 0.700 | 0.796 | 0.352 |
| NMI | 0.810 | 0.953 | 0.764 |
* The results might fluctuate a little due to the randomness introduced by gpu knn building using faiss-gpu.
from .dataset import LanderDataset
import numpy as np
import pickle
import dgl
import torch
from utils import (build_knns, fast_knns2spmat, row_normalize, knns2ordered_nbrs,
density_estimation, sparse_mx_to_indices_values, l2norm,
decode, build_next_level)
class LanderDataset(object):
def __init__(self, features, labels, cluster_features=None, k=10, levels=1, faiss_gpu=False):
self.k = k
self.gs = []
self.nbrs = []
self.dists = []
self.levels = levels
# Initialize features and labels
features = l2norm(features.astype('float32'))
global_features = features.copy()
if cluster_features is None:
cluster_features = features
global_num_nodes = features.shape[0]
global_edges = ([], [])
global_peaks = np.array([], dtype=np.long)
ids = np.arange(global_num_nodes)
# Recursive graph construction
for lvl in range(self.levels):
if features.shape[0] <= self.k:
self.levels = lvl
break
if faiss_gpu:
knns = build_knns(features, self.k, 'faiss_gpu')
else:
knns = build_knns(features, self.k, 'faiss')
dists, nbrs = knns2ordered_nbrs(knns)
self.nbrs.append(nbrs)
self.dists.append(dists)
density = density_estimation(dists, nbrs, labels)
g = self._build_graph(features, cluster_features, labels, density, knns)
self.gs.append(g)
if lvl >= self.levels - 1:
break
# Decode peak nodes
new_pred_labels, peaks,\
global_edges, global_pred_labels, global_peaks = decode(g, 0, 'sim', True,
ids, global_edges, global_num_nodes,
global_peaks)
ids = ids[peaks]
features, labels, cluster_features = build_next_level(features, labels, peaks,
global_features, global_pred_labels, global_peaks)
def _build_graph(self, features, cluster_features, labels, density, knns):
adj = fast_knns2spmat(knns, self.k)
adj, adj_row_sum = row_normalize(adj)
indices, values, shape = sparse_mx_to_indices_values(adj)
g = dgl.graph((indices[1], indices[0]))
g.ndata['features'] = torch.FloatTensor(features)
g.ndata['cluster_features'] = torch.FloatTensor(cluster_features)
g.ndata['labels'] = torch.LongTensor(labels)
g.ndata['density'] = torch.FloatTensor(density)
g.edata['affine'] = torch.FloatTensor(values)
# A Bipartite from DGL sampler will not store global eid, so we explicitly save it here
g.edata['global_eid'] = g.edges(form='eid')
g.ndata['norm'] = torch.FloatTensor(adj_row_sum)
g.apply_edges(lambda edges: {'raw_affine': edges.data['affine'] / edges.dst['norm']})
g.apply_edges(lambda edges: {'labels_conn': (edges.src['labels'] == edges.dst['labels']).long()})
g.apply_edges(lambda edges: {'mask_conn': (edges.src['density'] > edges.dst['density']).bool()})
return g
def __getitem__(self, index):
assert index < len(self.gs)
return self.gs[index]
def __len__(self):
return len(self.gs)
from .lander import LANDER
from .graphconv import GraphConv
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
# Below code are based on
# https://zhuanlan.zhihu.com/p/28527749
class FocalLoss(nn.Module):
r"""
This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are
instead summed for each minibatch.
"""
def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
super(FocalLoss, self).__init__()
if alpha is None:
self.alpha = Variable(torch.ones(class_num, 1))
else:
if isinstance(alpha, Variable):
self.alpha = alpha
else:
self.alpha = Variable(alpha)
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
def forward(self, inputs, targets):
N = inputs.size(0)
C = inputs.size(1)
P = F.softmax(inputs)
class_mask = inputs.data.new(N, C).fill_(0)
class_mask = Variable(class_mask)
ids = targets.view(-1, 1)
class_mask.scatter_(1, ids.data, 1.)
if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.cuda()
alpha = self.alpha[ids.data.view(-1)]
probs = (P*class_mask).sum(1).view(-1,1)
log_p = probs.log()
batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import dgl.function as fn
from dgl.nn.pytorch import GATConv
class GraphConvLayer(nn.Module):
def __init__(self, in_feats, out_feats, bias=True):
super(GraphConvLayer, self).__init__()
self.mlp = nn.Linear(in_feats * 2, out_feats, bias=bias)
def forward(self, bipartite, feat):
if isinstance(feat, tuple):
srcfeat, dstfeat = feat
else:
srcfeat = feat
dstfeat = feat[:graph.num_dst_nodes()]
graph = bipartite.local_var()
graph.srcdata['h'] = srcfeat
graph.update_all(fn.u_mul_e('h', 'affine', 'm'),
fn.sum(msg='m', out='h'))
gcn_feat = torch.cat([dstfeat, graph.dstdata['h']], dim=-1)
out = self.mlp(gcn_feat)
return out
class GraphConv(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0, use_GAT = False, K = 1):
super(GraphConv, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
if use_GAT:
self.gcn_layer = GATConv(in_dim, out_dim, K, allow_zero_in_degree = True)
self.bias = nn.Parameter(torch.Tensor(K, out_dim))
init.constant_(self.bias, 0)
else:
self.gcn_layer = GraphConvLayer(in_dim, out_dim, bias=True)
self.dropout = dropout
self.use_GAT = use_GAT
def forward(self, bipartite, features):
out = self.gcn_layer(bipartite, features)
if self.use_GAT:
out = torch.mean(out + self.bias, dim = 1)
out = out.reshape(out.shape[0], -1)
out = F.relu(out)
if self.dropout > 0:
out = F.dropout(out, self.dropout, training=self.training)
return out
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from .graphconv import GraphConv
from .focal_loss import FocalLoss
class LANDER(nn.Module):
def __init__(self, feature_dim, nhid, num_conv=4, dropout=0,
use_GAT=True, K=1, balance=False,
use_cluster_feat = True, use_focal_loss = True, **kwargs):
super(LANDER, self).__init__()
nhid_half = int(nhid / 2)
self.use_cluster_feat = use_cluster_feat
self.use_focal_loss = use_focal_loss
if self.use_cluster_feat:
self.feature_dim = feature_dim * 2
else:
self.feature_dim = feature_dim
input_dim = (feature_dim, nhid, nhid, nhid_half)
output_dim = (nhid, nhid, nhid_half, nhid_half)
self.conv = nn.ModuleList()
self.conv.append(GraphConv(self.feature_dim, nhid, dropout, use_GAT, K))
for i in range(1, num_conv):
self.conv.append(GraphConv(input_dim[i], output_dim[i], dropout, use_GAT, K))
self.src_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half)
self.dst_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half)
self.classifier_conn = nn.Sequential(nn.PReLU(nhid_half),
nn.Linear(nhid_half, nhid_half),
nn.PReLU(nhid_half),
nn.Linear(nhid_half, 2))
if self.use_focal_loss:
self.loss_conn = FocalLoss(2)
else:
self.loss_conn = nn.CrossEntropyLoss()
self.loss_den = nn.MSELoss()
self.balance = balance
def pred_conn(self, edges):
src_feat = self.src_mlp(edges.src['conv_features'])
dst_feat = self.dst_mlp(edges.dst['conv_features'])
pred_conn = self.classifier_conn(src_feat + dst_feat)
return {'pred_conn': pred_conn}
def pred_den_msg(self, edges):
prob = edges.data['prob_conn']
res = edges.data['raw_affine'] * (prob[:, 1] - prob[:, 0])
return {'pred_den_msg': res}
def forward(self, bipartites):
if isinstance(bipartites, dgl.DGLGraph):
bipartites = [bipartites] * len(self.conv)
if self.use_cluster_feat:
neighbor_x = torch.cat([bipartites[0].ndata['features'], bipartites[0].ndata['cluster_features']], axis=1)
else:
neighbor_x = bipartites[0].ndata['features']
for i in range(len(self.conv)):
neighbor_x = self.conv[i](bipartites[i], neighbor_x)
output_bipartite = bipartites[-1]
output_bipartite.ndata['conv_features'] = neighbor_x
else:
if self.use_cluster_feat:
neighbor_x_src = torch.cat([bipartites[0].srcdata['features'], bipartites[0].srcdata['cluster_features']], axis=1)
center_x_src = torch.cat([bipartites[1].srcdata['features'], bipartites[1].srcdata['cluster_features']], axis=1)
else:
neighbor_x_src = bipartites[0].srcdata['features']
center_x_src = bipartites[1].srcdata['features']
for i in range(len(self.conv)):
neighbor_x_dst = neighbor_x_src[:bipartites[i].num_dst_nodes()]
neighbor_x_src = self.conv[i](bipartites[i], (neighbor_x_src, neighbor_x_dst))
center_x_dst = center_x_src[:bipartites[i+1].num_dst_nodes()]
center_x_src = self.conv[i](bipartites[i+1], (center_x_src, center_x_dst))
output_bipartite = bipartites[-1]
output_bipartite.srcdata['conv_features'] = neighbor_x_src
output_bipartite.dstdata['conv_features'] = center_x_src
output_bipartite.apply_edges(self.pred_conn)
output_bipartite.edata['prob_conn'] = F.softmax(output_bipartite.edata['pred_conn'], dim=1)
output_bipartite.update_all(self.pred_den_msg, fn.mean('pred_den_msg', 'pred_den'))
return output_bipartite
def compute_loss(self, bipartite):
pred_den = bipartite.dstdata['pred_den']
loss_den = self.loss_den(pred_den, bipartite.dstdata['density'])
labels_conn = bipartite.edata['labels_conn']
mask_conn = bipartite.edata['mask_conn']
if self.balance:
labels_conn = bipartite.edata['labels_conn']
neg_check = torch.logical_and(bipartite.edata['labels_conn'] == 0, mask_conn)
num_neg = torch.sum(neg_check).item()
neg_indices = torch.where(neg_check)[0]
pos_check = torch.logical_and(bipartite.edata['labels_conn'] == 1, mask_conn)
num_pos = torch.sum(pos_check).item()
pos_indices = torch.where(pos_check)[0]
if num_pos > num_neg:
mask_conn[pos_indices[np.random.choice(num_pos, num_pos - num_neg, replace = False)]] = 0
elif num_pos < num_neg:
mask_conn[neg_indices[np.random.choice(num_neg, num_neg - num_pos, replace = False)]] = 0
# In subgraph training, it may happen that all edges are masked in a batch
if mask_conn.sum() > 0:
loss_conn = self.loss_conn(bipartite.edata['pred_conn'][mask_conn], labels_conn[mask_conn])
loss = loss_den + loss_conn
loss_den_val = loss_den.item()
loss_conn_val = loss_conn.item()
else:
loss = loss_den
loss_den_val = loss_den.item()
loss_conn_val = 0
return loss, loss_den_val, loss_conn_val
python test_subg.py --data_path data/subcenter_arcface_deepglint_hannah_features.pkl --model_filename checkpoint/deepglint_sampler.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop
python test_subg.py --data_path data/subcenter_arcface_deepglint_imdb_features.pkl --model_filename checkpoint/deepglint_sampler.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop
python test_subg.py --data_path data/subcenter_arcface_deepglint_imdb_features_sampled_as_deepglint_1_in_10.pkl --model_filename checkpoint/deepglint_sampler.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop
python test_subg.py --data_path data/inat2018_test.pkl --model_filename checkpoint/inat.ckpt --knn_k 10 --tau 0.1 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop
python test_subg.py --data_path data/inat2018_test.pkl --model_filename checkpoint/inat_resampled_1_in_6_per_class.ckpt --knn_k 10 --tau 0.1 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop
python train_subg.py --data_path data/subcenter_arcface_deepglint_train_1_in_10_recreated.pkl --model_filename checkpoint/deepglint_sampler.pth --knn_k 10,5,3 --levels 2,3,4 --faiss_gpu --hidden 512 --epochs 250 --lr 0.01 --batch_size 4096 --num_conv 1 --gat --balance
python train_subg.py --data_path data/inat2018_train_dedup_inter_intra.pkl --model_filename checkpoint/inat.ckpt --knn_k 10,5,3 --levels 2,3,4 --faiss_gpu --hidden 512 --epochs 250 --lr 0.01 --batch_size 4096 --num_conv 1 --gat --balance
python train_subg.py --data_path data/inat2018_train_dedup_inter_intra_1_in_6_per_class.pkl --model_filename checkpoint/inat_resampled_1_in_6_per_class.ckpt --knn_k 10,5,3 --levels 2,3,4 --faiss_gpu --hidden 512 --epochs 250 --lr 0.01 --batch_size 4096 --num_conv 1 --gat --balance
import argparse, time, os, pickle
import numpy as np
import dgl
import torch
import torch.optim as optim
from models import LANDER
from dataset import LanderDataset
from utils import evaluation, decode, build_next_level, stop_iterating
###########
# ArgParser
parser = argparse.ArgumentParser()
# Dataset
parser.add_argument('--data_path', type=str, required=True)
parser.add_argument('--model_filename', type=str, default='lander.pth')
parser.add_argument('--faiss_gpu', action='store_true')
parser.add_argument('--early_stop', action='store_true')
# HyperParam
parser.add_argument('--knn_k', type=int, default=10)
parser.add_argument('--levels', type=int, default=1)
parser.add_argument('--tau', type=float, default=0.5)
parser.add_argument('--threshold', type=str, default='prob')
parser.add_argument('--metrics', type=str, default='pairwise,bcubed,nmi')
# Model
parser.add_argument('--hidden', type=int, default=512)
parser.add_argument('--num_conv', type=int, default=4)
parser.add_argument('--dropout', type=float, default=0.)
parser.add_argument('--gat', action='store_true')
parser.add_argument('--gat_k', type=int, default=1)
parser.add_argument('--balance', action='store_true')
parser.add_argument('--use_cluster_feat', action='store_true')
parser.add_argument('--use_focal_loss', action='store_true')
parser.add_argument('--use_gt', action='store_true')
args = parser.parse_args()
###########################
# Environment Configuration
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
##################
# Data Preparation
with open(args.data_path, 'rb') as f:
features, labels = pickle.load(f)
global_features = features.copy()
dataset = LanderDataset(features=features, labels=labels, k=args.knn_k,
levels=1, faiss_gpu=args.faiss_gpu)
g = dataset.gs[0].to(device)
global_labels = labels.copy()
ids = np.arange(g.number_of_nodes())
global_edges = ([], [])
global_edges_len = len(global_edges[0])
global_num_nodes = g.number_of_nodes()
##################
# Model Definition
if not args.use_gt:
feature_dim = g.ndata['features'].shape[1]
model = LANDER(feature_dim=feature_dim, nhid=args.hidden,
num_conv=args.num_conv, dropout=args.dropout,
use_GAT=args.gat, K=args.gat_k,
balance=args.balance,
use_cluster_feat=args.use_cluster_feat,
use_focal_loss=args.use_focal_loss)
model.load_state_dict(torch.load(args.model_filename))
model = model.to(device)
model.eval()
# number of edges added is the indicator for early stopping
num_edges_add_last_level = np.Inf
##################################
# Predict connectivity and density
for level in range(args.levels):
if not args.use_gt:
with torch.no_grad():
g = model(g)
new_pred_labels, peaks,\
global_edges, global_pred_labels, global_peaks = decode(g, args.tau, args.threshold, args.use_gt,
ids, global_edges, global_num_nodes)
ids = ids[peaks]
new_global_edges_len = len(global_edges[0])
num_edges_add_this_level = new_global_edges_len - global_edges_len
if stop_iterating(level, args.levels, args.early_stop, num_edges_add_this_level, num_edges_add_last_level, args.knn_k):
break
global_edges_len = new_global_edges_len
num_edges_add_last_level = num_edges_add_this_level
# build new dataset
features, labels, cluster_features = build_next_level(features, labels, peaks,
global_features, global_pred_labels, global_peaks)
# After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.
dataset = LanderDataset(features=features, labels=labels, k=args.knn_k,
levels=1, faiss_gpu=False, cluster_features = cluster_features)
if len(dataset.gs) == 0:
break
g = dataset.gs[0].to(device)
evaluation(global_pred_labels, global_labels, args.metrics)
import argparse, time, os, pickle
import numpy as np
import dgl
import torch
import torch.optim as optim
from models import LANDER
from dataset import LanderDataset
from utils import evaluation, decode, build_next_level, stop_iterating
###########
# ArgParser
parser = argparse.ArgumentParser()
# Dataset
parser.add_argument('--data_path', type=str, required=True)
parser.add_argument('--model_filename', type=str, default='lander.pth')
parser.add_argument('--faiss_gpu', action='store_true')
parser.add_argument('--num_workers', type=int, default=0)
# HyperParam
parser.add_argument('--knn_k', type=int, default=10)
parser.add_argument('--levels', type=int, default=1)
parser.add_argument('--tau', type=float, default=0.5)
parser.add_argument('--threshold', type=str, default='prob')
parser.add_argument('--metrics', type=str, default='pairwise,bcubed,nmi')
parser.add_argument('--early_stop', action='store_true')
# Model
parser.add_argument('--hidden', type=int, default=512)
parser.add_argument('--num_conv', type=int, default=4)
parser.add_argument('--dropout', type=float, default=0.)
parser.add_argument('--gat', action='store_true')
parser.add_argument('--gat_k', type=int, default=1)
parser.add_argument('--balance', action='store_true')
parser.add_argument('--use_cluster_feat', action='store_true')
parser.add_argument('--use_focal_loss', action='store_true')
parser.add_argument('--use_gt', action='store_true')
# Subgraph
parser.add_argument('--batch_size', type=int, default=4096)
args = parser.parse_args()
print(args)
###########################
# Environment Configuration
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
##################
# Data Preparation
with open(args.data_path, 'rb') as f:
features, labels = pickle.load(f)
global_features = features.copy()
dataset = LanderDataset(features=features, labels=labels, k=args.knn_k,
levels=1, faiss_gpu=args.faiss_gpu)
g = dataset.gs[0]
g.ndata['pred_den'] = torch.zeros((g.number_of_nodes()))
g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2))
global_labels = labels.copy()
ids = np.arange(g.number_of_nodes())
global_edges = ([], [])
global_peaks = np.array([], dtype=np.long)
global_edges_len = len(global_edges[0])
global_num_nodes = g.number_of_nodes()
fanouts = [args.knn_k-1 for i in range(args.num_conv + 1)]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
# fix the number of edges
test_loader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.number_of_nodes()), sampler,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
num_workers=args.num_workers
)
##################
# Model Definition
if not args.use_gt:
feature_dim = g.ndata['features'].shape[1]
model = LANDER(feature_dim=feature_dim, nhid=args.hidden,
num_conv=args.num_conv, dropout=args.dropout,
use_GAT=args.gat, K=args.gat_k,
balance=args.balance,
use_cluster_feat=args.use_cluster_feat,
use_focal_loss=args.use_focal_loss)
model.load_state_dict(torch.load(args.model_filename))
model = model.to(device)
model.eval()
# number of edges added is the indicator for early stopping
num_edges_add_last_level = np.Inf
##################################
# Predict connectivity and density
for level in range(args.levels):
if not args.use_gt:
total_batches = len(test_loader)
for batch, minibatch in enumerate(test_loader):
input_nodes, sub_g, bipartites = minibatch
sub_g = sub_g.to(device)
bipartites = [b.to(device) for b in bipartites]
with torch.no_grad():
output_bipartite = model(bipartites)
global_nid = output_bipartite.dstdata[dgl.NID]
global_eid = output_bipartite.edata['global_eid']
g.ndata['pred_den'][global_nid] = output_bipartite.dstdata['pred_den'].to('cpu')
g.edata['prob_conn'][global_eid] = output_bipartite.edata['prob_conn'].to('cpu')
torch.cuda.empty_cache()
if (batch + 1) % 10 == 0:
print('Batch %d / %d for inference' % (batch, total_batches))
new_pred_labels, peaks,\
global_edges, global_pred_labels, global_peaks = decode(g, args.tau, args.threshold, args.use_gt,
ids, global_edges, global_num_nodes,
global_peaks)
ids = ids[peaks]
new_global_edges_len = len(global_edges[0])
num_edges_add_this_level = new_global_edges_len - global_edges_len
if stop_iterating(level, args.levels, args.early_stop, num_edges_add_this_level, num_edges_add_last_level, args.knn_k):
break
global_edges_len = new_global_edges_len
num_edges_add_last_level = num_edges_add_this_level
# build new dataset
features, labels, cluster_features = build_next_level(features, labels, peaks,
global_features, global_pred_labels, global_peaks)
# After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.
dataset = LanderDataset(features=features, labels=labels, k=args.knn_k,
levels=1, faiss_gpu=False, cluster_features = cluster_features)
g = dataset.gs[0]
g.ndata['pred_den'] = torch.zeros((g.number_of_nodes()))
g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2))
test_loader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.number_of_nodes()), sampler,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
num_workers=args.num_workers
)
evaluation(global_pred_labels, global_labels, args.metrics)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment