Unverified Commit f19f05ce authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4651)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 977b1ba4
import argparse, time, os, pickle import argparse
import numpy as np import os
import pickle
import time
import dgl import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
from models import LANDER
from dataset import LanderDataset from dataset import LanderDataset
from utils import evaluation, decode, build_next_level, stop_iterating from models import LANDER
from utils import build_next_level, decode, evaluation, stop_iterating
import dgl
########### ###########
# ArgParser # ArgParser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Dataset # Dataset
parser.add_argument('--data_path', type=str, required=True) parser.add_argument("--data_path", type=str, required=True)
parser.add_argument('--model_filename', type=str, default='lander.pth') parser.add_argument("--model_filename", type=str, default="lander.pth")
parser.add_argument('--faiss_gpu', action='store_true') parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument('--num_workers', type=int, default=0) parser.add_argument("--num_workers", type=int, default=0)
# HyperParam # HyperParam
parser.add_argument('--knn_k', type=int, default=10) parser.add_argument("--knn_k", type=int, default=10)
parser.add_argument('--levels', type=int, default=1) parser.add_argument("--levels", type=int, default=1)
parser.add_argument('--tau', type=float, default=0.5) parser.add_argument("--tau", type=float, default=0.5)
parser.add_argument('--threshold', type=str, default='prob') parser.add_argument("--threshold", type=str, default="prob")
parser.add_argument('--metrics', type=str, default='pairwise,bcubed,nmi') parser.add_argument("--metrics", type=str, default="pairwise,bcubed,nmi")
parser.add_argument('--early_stop', action='store_true') parser.add_argument("--early_stop", action="store_true")
# Model # Model
parser.add_argument('--hidden', type=int, default=512) parser.add_argument("--hidden", type=int, default=512)
parser.add_argument('--num_conv', type=int, default=4) parser.add_argument("--num_conv", type=int, default=4)
parser.add_argument('--dropout', type=float, default=0.) parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument('--gat', action='store_true') parser.add_argument("--gat", action="store_true")
parser.add_argument('--gat_k', type=int, default=1) parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument('--balance', action='store_true') parser.add_argument("--balance", action="store_true")
parser.add_argument('--use_cluster_feat', action='store_true') parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument('--use_focal_loss', action='store_true') parser.add_argument("--use_focal_loss", action="store_true")
parser.add_argument('--use_gt', action='store_true') parser.add_argument("--use_gt", action="store_true")
# Subgraph # Subgraph
parser.add_argument('--batch_size', type=int, default=4096) parser.add_argument("--batch_size", type=int, default=4096)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
...@@ -47,20 +50,25 @@ print(args) ...@@ -47,20 +50,25 @@ print(args)
########################### ###########################
# Environment Configuration # Environment Configuration
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device('cuda') device = torch.device("cuda")
else: else:
device = torch.device('cpu') device = torch.device("cpu")
################## ##################
# Data Preparation # Data Preparation
with open(args.data_path, 'rb') as f: with open(args.data_path, "rb") as f:
features, labels = pickle.load(f) features, labels = pickle.load(f)
global_features = features.copy() global_features = features.copy()
dataset = LanderDataset(features=features, labels=labels, k=args.knn_k, dataset = LanderDataset(
levels=1, faiss_gpu=args.faiss_gpu) features=features,
labels=labels,
k=args.knn_k,
levels=1,
faiss_gpu=args.faiss_gpu,
)
g = dataset.gs[0] g = dataset.gs[0]
g.ndata['pred_den'] = torch.zeros((g.number_of_nodes())) g.ndata["pred_den"] = torch.zeros((g.number_of_nodes()))
g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2)) g.edata["prob_conn"] = torch.zeros((g.number_of_edges(), 2))
global_labels = labels.copy() global_labels = labels.copy()
ids = np.arange(g.number_of_nodes()) ids = np.arange(g.number_of_nodes())
global_edges = ([], []) global_edges = ([], [])
...@@ -68,27 +76,34 @@ global_peaks = np.array([], dtype=np.long) ...@@ -68,27 +76,34 @@ global_peaks = np.array([], dtype=np.long)
global_edges_len = len(global_edges[0]) global_edges_len = len(global_edges[0])
global_num_nodes = g.number_of_nodes() global_num_nodes = g.number_of_nodes()
fanouts = [args.knn_k-1 for i in range(args.num_conv + 1)] fanouts = [args.knn_k - 1 for i in range(args.num_conv + 1)]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts) sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
# fix the number of edges # fix the number of edges
test_loader = dgl.dataloading.DataLoader( test_loader = dgl.dataloading.DataLoader(
g, torch.arange(g.number_of_nodes()), sampler, g,
torch.arange(g.number_of_nodes()),
sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=args.num_workers num_workers=args.num_workers,
) )
################## ##################
# Model Definition # Model Definition
if not args.use_gt: if not args.use_gt:
feature_dim = g.ndata['features'].shape[1] feature_dim = g.ndata["features"].shape[1]
model = LANDER(feature_dim=feature_dim, nhid=args.hidden, model = LANDER(
num_conv=args.num_conv, dropout=args.dropout, feature_dim=feature_dim,
use_GAT=args.gat, K=args.gat_k, nhid=args.hidden,
balance=args.balance, num_conv=args.num_conv,
use_cluster_feat=args.use_cluster_feat, dropout=args.dropout,
use_focal_loss=args.use_focal_loss) use_GAT=args.gat,
K=args.gat_k,
balance=args.balance,
use_cluster_feat=args.use_cluster_feat,
use_focal_loss=args.use_focal_loss,
)
model.load_state_dict(torch.load(args.model_filename)) model.load_state_dict(torch.load(args.model_filename))
model = model.to(device) model = model.to(device)
model.eval() model.eval()
...@@ -107,39 +122,76 @@ for level in range(args.levels): ...@@ -107,39 +122,76 @@ for level in range(args.levels):
with torch.no_grad(): with torch.no_grad():
output_bipartite = model(bipartites) output_bipartite = model(bipartites)
global_nid = output_bipartite.dstdata[dgl.NID] global_nid = output_bipartite.dstdata[dgl.NID]
global_eid = output_bipartite.edata['global_eid'] global_eid = output_bipartite.edata["global_eid"]
g.ndata['pred_den'][global_nid] = output_bipartite.dstdata['pred_den'].to('cpu') g.ndata["pred_den"][global_nid] = output_bipartite.dstdata[
g.edata['prob_conn'][global_eid] = output_bipartite.edata['prob_conn'].to('cpu') "pred_den"
].to("cpu")
g.edata["prob_conn"][global_eid] = output_bipartite.edata[
"prob_conn"
].to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
if (batch + 1) % 10 == 0: if (batch + 1) % 10 == 0:
print('Batch %d / %d for inference' % (batch, total_batches)) print("Batch %d / %d for inference" % (batch, total_batches))
new_pred_labels, peaks,\ (
global_edges, global_pred_labels, global_peaks = decode(g, args.tau, args.threshold, args.use_gt, new_pred_labels,
ids, global_edges, global_num_nodes, peaks,
global_peaks) global_edges,
global_pred_labels,
global_peaks,
) = decode(
g,
args.tau,
args.threshold,
args.use_gt,
ids,
global_edges,
global_num_nodes,
global_peaks,
)
ids = ids[peaks] ids = ids[peaks]
new_global_edges_len = len(global_edges[0]) new_global_edges_len = len(global_edges[0])
num_edges_add_this_level = new_global_edges_len - global_edges_len num_edges_add_this_level = new_global_edges_len - global_edges_len
if stop_iterating(level, args.levels, args.early_stop, num_edges_add_this_level, num_edges_add_last_level, args.knn_k): if stop_iterating(
level,
args.levels,
args.early_stop,
num_edges_add_this_level,
num_edges_add_last_level,
args.knn_k,
):
break break
global_edges_len = new_global_edges_len global_edges_len = new_global_edges_len
num_edges_add_last_level = num_edges_add_this_level num_edges_add_last_level = num_edges_add_this_level
# build new dataset # build new dataset
features, labels, cluster_features = build_next_level(features, labels, peaks, features, labels, cluster_features = build_next_level(
global_features, global_pred_labels, global_peaks) features,
labels,
peaks,
global_features,
global_pred_labels,
global_peaks,
)
# After the first level, the number of nodes reduce a lot. Using cpu faiss is faster. # After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.
dataset = LanderDataset(features=features, labels=labels, k=args.knn_k, dataset = LanderDataset(
levels=1, faiss_gpu=False, cluster_features = cluster_features) features=features,
labels=labels,
k=args.knn_k,
levels=1,
faiss_gpu=False,
cluster_features=cluster_features,
)
g = dataset.gs[0] g = dataset.gs[0]
g.ndata['pred_den'] = torch.zeros((g.number_of_nodes())) g.ndata["pred_den"] = torch.zeros((g.number_of_nodes()))
g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2)) g.edata["prob_conn"] = torch.zeros((g.number_of_edges(), 2))
test_loader = dgl.dataloading.DataLoader( test_loader = dgl.dataloading.DataLoader(
g, torch.arange(g.number_of_nodes()), sampler, g,
torch.arange(g.number_of_nodes()),
sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=args.num_workers num_workers=args.num_workers,
) )
evaluation(global_pred_labels, global_labels, args.metrics) evaluation(global_pred_labels, global_labels, args.metrics)
import argparse, time, os, pickle import argparse
import numpy as np import os
import pickle
import time
import dgl import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
from models import LANDER
from dataset import LanderDataset from dataset import LanderDataset
from models import LANDER
import dgl
########### ###########
# ArgParser # ArgParser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Dataset # Dataset
parser.add_argument('--data_path', type=str, required=True) parser.add_argument("--data_path", type=str, required=True)
parser.add_argument('--test_data_path', type=str, required=True) parser.add_argument("--test_data_path", type=str, required=True)
parser.add_argument('--levels', type=str, default='1') parser.add_argument("--levels", type=str, default="1")
parser.add_argument('--faiss_gpu', action='store_true') parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument('--model_filename', type=str, default='lander.pth') parser.add_argument("--model_filename", type=str, default="lander.pth")
# KNN # KNN
parser.add_argument('--knn_k', type=str, default='10') parser.add_argument("--knn_k", type=str, default="10")
# Model # Model
parser.add_argument('--hidden', type=int, default=512) parser.add_argument("--hidden", type=int, default=512)
parser.add_argument('--num_conv', type=int, default=4) parser.add_argument("--num_conv", type=int, default=4)
parser.add_argument('--dropout', type=float, default=0.) parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument('--gat', action='store_true') parser.add_argument("--gat", action="store_true")
parser.add_argument('--gat_k', type=int, default=1) parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument('--balance', action='store_true') parser.add_argument("--balance", action="store_true")
parser.add_argument('--use_cluster_feat', action='store_true') parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument('--use_focal_loss', action='store_true') parser.add_argument("--use_focal_loss", action="store_true")
# Training # Training
parser.add_argument('--epochs', type=int, default=100) parser.add_argument("--epochs", type=int, default=100)
parser.add_argument('--lr', type=float, default=0.1) parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument('--momentum', type=float, default=0.9) parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument('--weight_decay', type=float, default=1e-5) parser.add_argument("--weight_decay", type=float, default=1e-5)
args = parser.parse_args() args = parser.parse_args()
########################### ###########################
# Environment Configuration # Environment Configuration
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device('cuda') device = torch.device("cuda")
else: else:
device = torch.device('cpu') device = torch.device("cpu")
################## ##################
# Data Preparation # Data Preparation
def prepare_dataset_graphs(data_path, k_list, lvl_list): def prepare_dataset_graphs(data_path, k_list, lvl_list):
with open(data_path, 'rb') as f: with open(data_path, "rb") as f:
features, labels = pickle.load(f) features, labels = pickle.load(f)
gs = [] gs = []
for k, l in zip(k_list, lvl_list): for k, l in zip(k_list, lvl_list):
dataset = LanderDataset(features=features, labels=labels, k=k, dataset = LanderDataset(
levels=l, faiss_gpu=args.faiss_gpu) features=features,
labels=labels,
k=k,
levels=l,
faiss_gpu=args.faiss_gpu,
)
gs += [g.to(device) for g in dataset.gs] gs += [g.to(device) for g in dataset.gs]
return gs return gs
k_list = [int(k) for k in args.knn_k.split(',')]
lvl_list = [int(l) for l in args.levels.split(',')] k_list = [int(k) for k in args.knn_k.split(",")]
lvl_list = [int(l) for l in args.levels.split(",")]
gs = prepare_dataset_graphs(args.data_path, k_list, lvl_list) gs = prepare_dataset_graphs(args.data_path, k_list, lvl_list)
test_gs = prepare_dataset_graphs(args.test_data_path, k_list, lvl_list) test_gs = prepare_dataset_graphs(args.test_data_path, k_list, lvl_list)
################## ##################
# Model Definition # Model Definition
feature_dim = gs[0].ndata['features'].shape[1] feature_dim = gs[0].ndata["features"].shape[1]
model = LANDER(feature_dim=feature_dim, nhid=args.hidden, model = LANDER(
num_conv=args.num_conv, dropout=args.dropout, feature_dim=feature_dim,
use_GAT=args.gat, K=args.gat_k, nhid=args.hidden,
balance=args.balance, num_conv=args.num_conv,
use_cluster_feat=args.use_cluster_feat, dropout=args.dropout,
use_focal_loss=args.use_focal_loss) use_GAT=args.gat,
K=args.gat_k,
balance=args.balance,
use_cluster_feat=args.use_cluster_feat,
use_focal_loss=args.use_focal_loss,
)
model = model.to(device) model = model.to(device)
model.train() model.train()
best_model = None best_model = None
...@@ -81,9 +94,15 @@ best_loss = np.Inf ...@@ -81,9 +94,15 @@ best_loss = np.Inf
################# #################
# Hyperparameters # Hyperparameters
opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, opt = optim.SGD(
weight_decay=args.weight_decay) model.parameters(),
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs, eta_min=1e-5) lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=args.epochs, eta_min=1e-5
)
############### ###############
# Training Loop # Training Loop
...@@ -99,8 +118,10 @@ for epoch in range(args.epochs): ...@@ -99,8 +118,10 @@ for epoch in range(args.epochs):
loss.backward() loss.backward()
opt.step() opt.step()
scheduler.step() scheduler.step()
print('Training, epoch: %d, loss_den: %.6f, loss_conn: %.6f'% print(
(epoch, all_loss_den_val, all_loss_conn_val)) "Training, epoch: %d, loss_den: %.6f, loss_conn: %.6f"
% (epoch, all_loss_den_val, all_loss_conn_val)
)
# Report test # Report test
all_test_loss_den_val = 0 all_test_loss_den_val = 0
all_test_loss_conn_val = 0 all_test_loss_conn_val = 0
...@@ -110,12 +131,14 @@ for epoch in range(args.epochs): ...@@ -110,12 +131,14 @@ for epoch in range(args.epochs):
loss, loss_den_val, loss_conn_val = model.compute_loss(g) loss, loss_den_val, loss_conn_val = model.compute_loss(g)
all_test_loss_den_val += loss_den_val all_test_loss_den_val += loss_den_val
all_test_loss_conn_val += loss_conn_val all_test_loss_conn_val += loss_conn_val
print('Testing, epoch: %d, loss_den: %.6f, loss_conn: %.6f'% print(
(epoch, all_test_loss_den_val, all_test_loss_conn_val)) "Testing, epoch: %d, loss_den: %.6f, loss_conn: %.6f"
% (epoch, all_test_loss_den_val, all_test_loss_conn_val)
)
if all_test_loss_conn_val + all_test_loss_den_val < best_loss: if all_test_loss_conn_val + all_test_loss_den_val < best_loss:
best_loss = all_test_loss_conn_val + all_test_loss_den_val best_loss = all_test_loss_conn_val + all_test_loss_den_val
print ('New best epoch', epoch) print("New best epoch", epoch)
torch.save(model.state_dict(), args.model_filename+'_best') torch.save(model.state_dict(), args.model_filename + "_best")
torch.save(model.state_dict(), args.model_filename) torch.save(model.state_dict(), args.model_filename)
torch.save(model.state_dict(), args.model_filename) torch.save(model.state_dict(), args.model_filename)
import argparse, time, os, pickle import argparse
import numpy as np import os
import pickle
import time
import dgl import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
from models import LANDER
from dataset import LanderDataset from dataset import LanderDataset
from models import LANDER
import dgl
########### ###########
# ArgParser # ArgParser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Dataset # Dataset
parser.add_argument('--data_path', type=str, required=True) parser.add_argument("--data_path", type=str, required=True)
parser.add_argument('--levels', type=str, default='1') parser.add_argument("--levels", type=str, default="1")
parser.add_argument('--faiss_gpu', action='store_true') parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument('--model_filename', type=str, default='lander.pth') parser.add_argument("--model_filename", type=str, default="lander.pth")
# KNN # KNN
parser.add_argument('--knn_k', type=str, default='10') parser.add_argument("--knn_k", type=str, default="10")
parser.add_argument('--num_workers', type=int, default=0) parser.add_argument("--num_workers", type=int, default=0)
# Model # Model
parser.add_argument('--hidden', type=int, default=512) parser.add_argument("--hidden", type=int, default=512)
parser.add_argument('--num_conv', type=int, default=1) parser.add_argument("--num_conv", type=int, default=1)
parser.add_argument('--dropout', type=float, default=0.) parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument('--gat', action='store_true') parser.add_argument("--gat", action="store_true")
parser.add_argument('--gat_k', type=int, default=1) parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument('--balance', action='store_true') parser.add_argument("--balance", action="store_true")
parser.add_argument('--use_cluster_feat', action='store_true') parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument('--use_focal_loss', action='store_true') parser.add_argument("--use_focal_loss", action="store_true")
# Training # Training
parser.add_argument('--epochs', type=int, default=100) parser.add_argument("--epochs", type=int, default=100)
parser.add_argument('--batch_size', type=int, default=1024) parser.add_argument("--batch_size", type=int, default=1024)
parser.add_argument('--lr', type=float, default=0.1) parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument('--momentum', type=float, default=0.9) parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument('--weight_decay', type=float, default=1e-5) parser.add_argument("--weight_decay", type=float, default=1e-5)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
...@@ -46,42 +48,51 @@ print(args) ...@@ -46,42 +48,51 @@ print(args)
########################### ###########################
# Environment Configuration # Environment Configuration
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device('cuda') device = torch.device("cuda")
else: else:
device = torch.device('cpu') device = torch.device("cpu")
################## ##################
# Data Preparation # Data Preparation
with open(args.data_path, 'rb') as f: with open(args.data_path, "rb") as f:
features, labels = pickle.load(f) features, labels = pickle.load(f)
k_list = [int(k) for k in args.knn_k.split(',')] k_list = [int(k) for k in args.knn_k.split(",")]
lvl_list = [int(l) for l in args.levels.split(',')] lvl_list = [int(l) for l in args.levels.split(",")]
gs = [] gs = []
nbrs = [] nbrs = []
ks = [] ks = []
for k, l in zip(k_list, lvl_list): for k, l in zip(k_list, lvl_list):
dataset = LanderDataset(features=features, labels=labels, k=k, dataset = LanderDataset(
levels=l, faiss_gpu=args.faiss_gpu) features=features,
labels=labels,
k=k,
levels=l,
faiss_gpu=args.faiss_gpu,
)
gs += [g for g in dataset.gs] gs += [g for g in dataset.gs]
ks += [k for g in dataset.gs] ks += [k for g in dataset.gs]
nbrs += [nbr for nbr in dataset.nbrs] nbrs += [nbr for nbr in dataset.nbrs]
print('Dataset Prepared.') print("Dataset Prepared.")
def set_train_sampler_loader(g, k): def set_train_sampler_loader(g, k):
fanouts = [k-1 for i in range(args.num_conv + 1)] fanouts = [k - 1 for i in range(args.num_conv + 1)]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts) sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
# fix the number of edges # fix the number of edges
train_dataloader = dgl.dataloading.DataLoader( train_dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.number_of_nodes()), sampler, g,
torch.arange(g.number_of_nodes()),
sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=args.num_workers num_workers=args.num_workers,
) )
return train_dataloader return train_dataloader
train_loaders = [] train_loaders = []
for gidx, g in enumerate(gs): for gidx, g in enumerate(gs):
train_dataloader = set_train_sampler_loader(gs[gidx], ks[gidx]) train_dataloader = set_train_sampler_loader(gs[gidx], ks[gidx])
...@@ -89,30 +100,39 @@ for gidx, g in enumerate(gs): ...@@ -89,30 +100,39 @@ for gidx, g in enumerate(gs):
################## ##################
# Model Definition # Model Definition
feature_dim = gs[0].ndata['features'].shape[1] feature_dim = gs[0].ndata["features"].shape[1]
model = LANDER(feature_dim=feature_dim, nhid=args.hidden, model = LANDER(
num_conv=args.num_conv, dropout=args.dropout, feature_dim=feature_dim,
use_GAT=args.gat, K=args.gat_k, nhid=args.hidden,
balance=args.balance, num_conv=args.num_conv,
use_cluster_feat=args.use_cluster_feat, dropout=args.dropout,
use_focal_loss=args.use_focal_loss) use_GAT=args.gat,
K=args.gat_k,
balance=args.balance,
use_cluster_feat=args.use_cluster_feat,
use_focal_loss=args.use_focal_loss,
)
model = model.to(device) model = model.to(device)
model.train() model.train()
################# #################
# Hyperparameters # Hyperparameters
opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, opt = optim.SGD(
weight_decay=args.weight_decay) model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
# keep num_batch_per_loader the same for every sub_dataloader # keep num_batch_per_loader the same for every sub_dataloader
num_batch_per_loader = len(train_loaders[0]) num_batch_per_loader = len(train_loaders[0])
train_loaders = [iter(train_loader) for train_loader in train_loaders] train_loaders = [iter(train_loader) for train_loader in train_loaders]
num_loaders = len(train_loaders) num_loaders = len(train_loaders)
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, scheduler = optim.lr_scheduler.CosineAnnealingLR(
T_max=args.epochs * num_batch_per_loader * num_loaders, opt, T_max=args.epochs * num_batch_per_loader * num_loaders, eta_min=1e-5
eta_min=1e-5) )
print('Start Training.') print("Start Training.")
############### ###############
# Training Loop # Training Loop
...@@ -125,7 +145,9 @@ for epoch in range(args.epochs): ...@@ -125,7 +145,9 @@ for epoch in range(args.epochs):
try: try:
minibatch = next(train_loaders[loader_id]) minibatch = next(train_loaders[loader_id])
except: except:
train_loaders[loader_id] = iter(set_train_sampler_loader(gs[loader_id], ks[loader_id])) train_loaders[loader_id] = iter(
set_train_sampler_loader(gs[loader_id], ks[loader_id])
)
minibatch = next(train_loaders[loader_id]) minibatch = next(train_loaders[loader_id])
input_nodes, sub_g, bipartites = minibatch input_nodes, sub_g, bipartites = minibatch
sub_g = sub_g.to(device) sub_g = sub_g.to(device)
...@@ -133,20 +155,38 @@ for epoch in range(args.epochs): ...@@ -133,20 +155,38 @@ for epoch in range(args.epochs):
# get the feature for the input_nodes # get the feature for the input_nodes
opt.zero_grad() opt.zero_grad()
output_bipartite = model(bipartites) output_bipartite = model(bipartites)
loss, loss_den_val, loss_conn_val = model.compute_loss(output_bipartite) loss, loss_den_val, loss_conn_val = model.compute_loss(
output_bipartite
)
loss_den_val_total.append(loss_den_val) loss_den_val_total.append(loss_den_val)
loss_conn_val_total.append(loss_conn_val) loss_conn_val_total.append(loss_conn_val)
loss_val_total.append(loss.item()) loss_val_total.append(loss.item())
loss.backward() loss.backward()
opt.step() opt.step()
if (batch + 1) % 10 == 0: if (batch + 1) % 10 == 0:
print('epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f'% print(
(epoch, batch, num_batch_per_loader, loader_id, num_loaders, "epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
loss.item(), loss_den_val, loss_conn_val)) % (
epoch,
batch,
num_batch_per_loader,
loader_id,
num_loaders,
loss.item(),
loss_den_val,
loss_conn_val,
)
)
scheduler.step() scheduler.step()
print('epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f'% print(
(epoch, np.array(loss_val_total).mean(), "epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
np.array(loss_den_val_total).mean(), np.array(loss_conn_val_total).mean())) % (
epoch,
np.array(loss_val_total).mean(),
np.array(loss_den_val_total).mean(),
np.array(loss_conn_val_total).mean(),
)
)
torch.save(model.state_dict(), args.model_filename) torch.save(model.state_dict(), args.model_filename)
torch.save(model.state_dict(), args.model_filename) torch.save(model.state_dict(), args.model_filename)
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from .misc import *
from .knn import *
from .adjacency import * from .adjacency import *
from .faiss_search import faiss_search_knn
from .faiss_gpu import faiss_search_approx_knn
from .evaluate import *
from .deduce import * from .deduce import *
from .density import * from .density import *
from .evaluate import *
from .faiss_gpu import faiss_search_approx_knn
from .faiss_search import faiss_search_knn
from .knn import *
from .metrics import * from .metrics import *
from .misc import *
...@@ -8,17 +8,19 @@ import numpy as np ...@@ -8,17 +8,19 @@ import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
from scipy.sparse import coo_matrix from scipy.sparse import coo_matrix
def row_normalize(mx): def row_normalize(mx):
"""Row-normalize sparse matrix""" """Row-normalize sparse matrix"""
rowsum = np.array(mx.sum(1)) rowsum = np.array(mx.sum(1))
# if rowsum <= 0, keep its previous value # if rowsum <= 0, keep its previous value
rowsum[rowsum <= 0] = 1 rowsum[rowsum <= 0] = 1
r_inv = np.power(rowsum, -1).flatten() r_inv = np.power(rowsum, -1).flatten()
r_inv[np.isinf(r_inv)] = 0. r_inv[np.isinf(r_inv)] = 0.0
r_mat_inv = sp.diags(r_inv) r_mat_inv = sp.diags(r_inv)
mx = r_mat_inv.dot(mx) mx = r_mat_inv.dot(mx)
return mx, r_inv return mx, r_inv
def sparse_mx_to_indices_values(sparse_mx): def sparse_mx_to_indices_values(sparse_mx):
sparse_mx = sparse_mx.tocoo().astype(np.float32) sparse_mx = sparse_mx.tocoo().astype(np.float32)
indices = np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64) indices = np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
......
...@@ -2,25 +2,32 @@ ...@@ -2,25 +2,32 @@
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
""" """
import numpy as np import numpy as np
from sklearn import mixture
import torch import torch
from sklearn import mixture
import dgl import dgl
from .density import density_to_peaks_vectorize, density_to_peaks from .density import density_to_peaks, density_to_peaks_vectorize
__all__ = ['peaks_to_labels', 'edge_to_connected_graph', 'decode', 'build_next_level'] __all__ = [
"peaks_to_labels",
"edge_to_connected_graph",
"decode",
"build_next_level",
]
def _find_parent(parent, u): def _find_parent(parent, u):
idx = [] idx = []
# parent is a fixed point # parent is a fixed point
while (u != parent[u]): while u != parent[u]:
idx.append(u) idx.append(u)
u = parent[u] u = parent[u]
for i in idx: for i in idx:
parent[i] = u parent[i] = u
return u return u
def edge_to_connected_graph(edges, num): def edge_to_connected_graph(edges, num):
parent = list(range(num)) parent = list(range(num))
for u, v in edges: for u, v in edges:
...@@ -37,6 +44,7 @@ def edge_to_connected_graph(edges, num): ...@@ -37,6 +44,7 @@ def edge_to_connected_graph(edges, num):
cluster_id = np.array([remap[f] for f in parent]) cluster_id = np.array([remap[f] for f in parent])
return cluster_id return cluster_id
def peaks_to_edges(peaks, dist2peak, tau): def peaks_to_edges(peaks, dist2peak, tau):
edges = [] edges = []
for src in peaks: for src in peaks:
...@@ -48,73 +56,105 @@ def peaks_to_edges(peaks, dist2peak, tau): ...@@ -48,73 +56,105 @@ def peaks_to_edges(peaks, dist2peak, tau):
edges.append([src, dst]) edges.append([src, dst])
return edges return edges
def peaks_to_labels(peaks, dist2peak, tau, inst_num): def peaks_to_labels(peaks, dist2peak, tau, inst_num):
edges = peaks_to_edges(peaks, dist2peak, tau) edges = peaks_to_edges(peaks, dist2peak, tau)
pred_labels = edge_to_connected_graph(edges, inst_num) pred_labels = edge_to_connected_graph(edges, inst_num)
return pred_labels, edges return pred_labels, edges
def get_dists(g, nbrs, use_gt): def get_dists(g, nbrs, use_gt):
k = nbrs.shape[1] k = nbrs.shape[1]
src_id = nbrs[:,1:].reshape(-1) src_id = nbrs[:, 1:].reshape(-1)
dst_id = nbrs[:,0].repeat(k - 1) dst_id = nbrs[:, 0].repeat(k - 1)
eids = g.edge_ids(src_id, dst_id) eids = g.edge_ids(src_id, dst_id)
if use_gt: if use_gt:
new_dists = (1 - g.edata['labels_edge'][eids]).reshape(-1, k - 1).float() new_dists = (
(1 - g.edata["labels_edge"][eids]).reshape(-1, k - 1).float()
)
else: else:
new_dists = g.edata['prob_conn'][eids, 0].reshape(-1, k - 1) new_dists = g.edata["prob_conn"][eids, 0].reshape(-1, k - 1)
ind = torch.argsort(new_dists, 1) ind = torch.argsort(new_dists, 1)
offset = torch.LongTensor((nbrs[:, 0] * (k - 1)).repeat(k - 1).reshape(-1, k - 1)).to(g.device) offset = torch.LongTensor(
(nbrs[:, 0] * (k - 1)).repeat(k - 1).reshape(-1, k - 1)
).to(g.device)
ind = ind + offset ind = ind + offset
nbrs = torch.LongTensor(nbrs).to(g.device) nbrs = torch.LongTensor(nbrs).to(g.device)
new_nbrs = torch.take(nbrs[:,1:], ind) new_nbrs = torch.take(nbrs[:, 1:], ind)
new_dists = torch.cat([torch.zeros((new_dists.shape[0], 1)).to(g.device), new_dists], dim=1) new_dists = torch.cat(
new_nbrs = torch.cat([torch.arange(new_nbrs.shape[0]).view(-1, 1).to(g.device), new_nbrs], dim=1) [torch.zeros((new_dists.shape[0], 1)).to(g.device), new_dists], dim=1
)
new_nbrs = torch.cat(
[torch.arange(new_nbrs.shape[0]).view(-1, 1).to(g.device), new_nbrs],
dim=1,
)
return new_nbrs.cpu().detach().numpy(), new_dists.cpu().detach().numpy() return new_nbrs.cpu().detach().numpy(), new_dists.cpu().detach().numpy()
def get_edge_dist(g, threshold): def get_edge_dist(g, threshold):
if threshold == 'prob': if threshold == "prob":
return g.edata['prob_conn'][:,0] return g.edata["prob_conn"][:, 0]
return 1 - g.edata['raw_affine'] return 1 - g.edata["raw_affine"]
def tree_generation(ng): def tree_generation(ng):
ng.ndata['keep_eid'] = torch.zeros(ng.number_of_nodes()).long() - 1 ng.ndata["keep_eid"] = torch.zeros(ng.number_of_nodes()).long() - 1
def message_func(edges): def message_func(edges):
return {'mval': edges.data['edge_dist'], return {"mval": edges.data["edge_dist"], "meid": edges.data[dgl.EID]}
'meid': edges.data[dgl.EID]}
def reduce_func(nodes): def reduce_func(nodes):
ind = torch.min(nodes.mailbox['mval'], dim=1)[1] ind = torch.min(nodes.mailbox["mval"], dim=1)[1]
keep_eid = nodes.mailbox['meid'].gather(1, ind.view(-1, 1)) keep_eid = nodes.mailbox["meid"].gather(1, ind.view(-1, 1))
return {'keep_eid': keep_eid[:, 0]} return {"keep_eid": keep_eid[:, 0]}
node_order = dgl.traversal.topological_nodes_generator(ng) node_order = dgl.traversal.topological_nodes_generator(ng)
ng.prop_nodes(node_order, message_func, reduce_func) ng.prop_nodes(node_order, message_func, reduce_func)
eids = ng.ndata['keep_eid'] eids = ng.ndata["keep_eid"]
eids = eids[eids > -1] eids = eids[eids > -1]
edges = ng.find_edges(eids) edges = ng.find_edges(eids)
treeg = dgl.graph(edges, num_nodes=ng.number_of_nodes()) treeg = dgl.graph(edges, num_nodes=ng.number_of_nodes())
return treeg return treeg
def peak_propogation(treeg): def peak_propogation(treeg):
treeg.ndata['pred_labels'] = torch.zeros(treeg.number_of_nodes()).long() - 1 treeg.ndata["pred_labels"] = torch.zeros(treeg.number_of_nodes()).long() - 1
peaks = torch.where(treeg.in_degrees() == 0)[0].cpu().numpy() peaks = torch.where(treeg.in_degrees() == 0)[0].cpu().numpy()
treeg.ndata['pred_labels'][peaks] = torch.arange(peaks.shape[0]) treeg.ndata["pred_labels"][peaks] = torch.arange(peaks.shape[0])
def message_func(edges): def message_func(edges):
return {'mlb': edges.src['pred_labels']} return {"mlb": edges.src["pred_labels"]}
def reduce_func(nodes): def reduce_func(nodes):
return {'pred_labels': nodes.mailbox['mlb'][:, 0]} return {"pred_labels": nodes.mailbox["mlb"][:, 0]}
node_order = dgl.traversal.topological_nodes_generator(treeg) node_order = dgl.traversal.topological_nodes_generator(treeg)
treeg.prop_nodes(node_order, message_func, reduce_func) treeg.prop_nodes(node_order, message_func, reduce_func)
pred_labels = treeg.ndata['pred_labels'].cpu().numpy() pred_labels = treeg.ndata["pred_labels"].cpu().numpy()
return peaks, pred_labels return peaks, pred_labels
def decode(g, tau, threshold, use_gt,
ids=None, global_edges=None, global_num_nodes=None, global_peaks=None): def decode(
g,
tau,
threshold,
use_gt,
ids=None,
global_edges=None,
global_num_nodes=None,
global_peaks=None,
):
# Edge filtering with tau and density # Edge filtering with tau and density
den_key = 'density' if use_gt else 'pred_den' den_key = "density" if use_gt else "pred_den"
g = g.local_var() g = g.local_var()
g.edata['edge_dist'] = get_edge_dist(g, threshold) g.edata["edge_dist"] = get_edge_dist(g, threshold)
g.apply_edges(lambda edges: {'keep': (edges.src[den_key] > edges.dst[den_key]).long() * \ g.apply_edges(
(edges.data['edge_dist'] < 1 - tau).long()}) lambda edges: {
eids = torch.where(g.edata['keep'] == 0)[0] "keep": (edges.src[den_key] > edges.dst[den_key]).long()
* (edges.data["edge_dist"] < 1 - tau).long()
}
)
eids = torch.where(g.edata["keep"] == 0)[0]
ng = dgl.remove_edges(g, eids) ng = dgl.remove_edges(g, eids)
# Tree generation # Tree generation
...@@ -128,23 +168,37 @@ def decode(g, tau, threshold, use_gt, ...@@ -128,23 +168,37 @@ def decode(g, tau, threshold, use_gt,
# Merge with previous layers # Merge with previous layers
src, dst = treeg.edges() src, dst = treeg.edges()
new_global_edges = (global_edges[0] + ids[src.numpy()].tolist(), new_global_edges = (
global_edges[1] + ids[dst.numpy()].tolist()) global_edges[0] + ids[src.numpy()].tolist(),
global_edges[1] + ids[dst.numpy()].tolist(),
)
global_treeg = dgl.graph(new_global_edges, num_nodes=global_num_nodes) global_treeg = dgl.graph(new_global_edges, num_nodes=global_num_nodes)
global_peaks, global_pred_labels = peak_propogation(global_treeg) global_peaks, global_pred_labels = peak_propogation(global_treeg)
return pred_labels, peaks, new_global_edges, global_pred_labels, global_peaks return (
pred_labels,
def build_next_level(features, labels, peaks, peaks,
global_features, global_pred_labels, global_peaks): new_global_edges,
global_pred_labels,
global_peaks,
)
def build_next_level(
features, labels, peaks, global_features, global_pred_labels, global_peaks
):
global_peak_to_label = global_pred_labels[global_peaks] global_peak_to_label = global_pred_labels[global_peaks]
global_label_to_peak = np.zeros_like(global_peak_to_label) global_label_to_peak = np.zeros_like(global_peak_to_label)
for i, pl in enumerate(global_peak_to_label): for i, pl in enumerate(global_peak_to_label):
global_label_to_peak[pl] = i global_label_to_peak[pl] = i
cluster_ind = np.split(np.argsort(global_pred_labels), cluster_ind = np.split(
np.unique(np.sort(global_pred_labels), return_index=True)[1][1:]) np.argsort(global_pred_labels),
np.unique(np.sort(global_pred_labels), return_index=True)[1][1:],
)
cluster_features = np.zeros((len(peaks), global_features.shape[1])) cluster_features = np.zeros((len(peaks), global_features.shape[1]))
for pi in range(len(peaks)): for pi in range(len(peaks)):
cluster_features[global_label_to_peak[pi],:] = np.mean(global_features[cluster_ind[pi],:], axis=0) cluster_features[global_label_to_peak[pi], :] = np.mean(
global_features[cluster_ind[pi], :], axis=0
)
features = features[peaks] features = features[peaks]
labels = labels[peaks] labels = labels[peaks]
return features, labels, cluster_features return features, labels, cluster_features
...@@ -4,53 +4,72 @@ ...@@ -4,53 +4,72 @@
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
""" """
import numpy as np
from tqdm import tqdm
from itertools import groupby from itertools import groupby
import numpy as np
import torch import torch
from tqdm import tqdm
__all__ = [
"density_estimation",
"density_to_peaks",
"density_to_peaks_vectorize",
]
__all__ = ['density_estimation', 'density_to_peaks', 'density_to_peaks_vectorize']
def density_estimation(dists, nbrs, labels, **kwargs): def density_estimation(dists, nbrs, labels, **kwargs):
''' use supervised density defined on neigborhood """use supervised density defined on neigborhood"""
'''
num, k_knn = dists.shape num, k_knn = dists.shape
conf = np.ones((num, ), dtype=np.float32) conf = np.ones((num,), dtype=np.float32)
ind_array = labels[nbrs] == np.expand_dims(labels, 1).repeat(k_knn, 1) ind_array = labels[nbrs] == np.expand_dims(labels, 1).repeat(k_knn, 1)
pos = ((1-dists[:,1:]) * ind_array[:,1:]).sum(1) pos = ((1 - dists[:, 1:]) * ind_array[:, 1:]).sum(1)
neg = ((1-dists[:,1:]) * (1-ind_array[:,1:])).sum(1) neg = ((1 - dists[:, 1:]) * (1 - ind_array[:, 1:])).sum(1)
conf = (pos - neg) * conf conf = (pos - neg) * conf
conf /= (k_knn - 1) conf /= k_knn - 1
return conf return conf
def density_to_peaks_vectorize(dists, nbrs, density, max_conn=1, name = ''):
def density_to_peaks_vectorize(dists, nbrs, density, max_conn=1, name=""):
# just calculate 1 connectivity # just calculate 1 connectivity
assert dists.shape[0] == density.shape[0] assert dists.shape[0] == density.shape[0]
assert dists.shape == nbrs.shape assert dists.shape == nbrs.shape
num, k = dists.shape num, k = dists.shape
if name == 'gcn_feat': if name == "gcn_feat":
include_mask = nbrs != np.arange(0, num).reshape(-1, 1) include_mask = nbrs != np.arange(0, num).reshape(-1, 1)
secondary_mask = np.sum(include_mask, axis = 1) == k # TODO: the condition == k should not happen as distance to the node self should be smallest, check for numerical stability; TODO: make top M instead of only supporting top 1 secondary_mask = (
np.sum(include_mask, axis=1) == k
) # TODO: the condition == k should not happen as distance to the node self should be smallest, check for numerical stability; TODO: make top M instead of only supporting top 1
include_mask[secondary_mask, -1] = False include_mask[secondary_mask, -1] = False
nbrs_exclude_self = nbrs[include_mask].reshape(-1, k-1) # (V, 79) nbrs_exclude_self = nbrs[include_mask].reshape(-1, k - 1) # (V, 79)
dists_exclude_self = dists[include_mask].reshape(-1, k-1) # (V, 79) dists_exclude_self = dists[include_mask].reshape(-1, k - 1) # (V, 79)
else: else:
include_mask = nbrs != np.arange(0, num).reshape(-1, 1) include_mask = nbrs != np.arange(0, num).reshape(-1, 1)
nbrs_exclude_self = nbrs[include_mask].reshape(-1, k-1) # (V, 79) nbrs_exclude_self = nbrs[include_mask].reshape(-1, k - 1) # (V, 79)
dists_exclude_self = dists[include_mask].reshape(-1, k-1) # (V, 79) dists_exclude_self = dists[include_mask].reshape(-1, k - 1) # (V, 79)
compare_map = density[nbrs_exclude_self] > density.reshape(-1, 1) compare_map = density[nbrs_exclude_self] > density.reshape(-1, 1)
peak_index = np.argmax(np.where(compare_map, 1, 0), axis = 1) # (V,) peak_index = np.argmax(np.where(compare_map, 1, 0), axis=1) # (V,)
compare_map_sum = np.sum(compare_map.cpu().data.numpy(), axis=1) # (V,) compare_map_sum = np.sum(compare_map.cpu().data.numpy(), axis=1) # (V,)
dist2peak = {i: [] if compare_map_sum[i] == 0 else [dists_exclude_self[i, peak_index[i]]] for i in range(num)} dist2peak = {
peaks = {i: [] if compare_map_sum[i] == 0 else [nbrs_exclude_self[i, peak_index[i]]] for i in range(num)} i: []
if compare_map_sum[i] == 0
else [dists_exclude_self[i, peak_index[i]]]
for i in range(num)
}
peaks = {
i: []
if compare_map_sum[i] == 0
else [nbrs_exclude_self[i, peak_index[i]]]
for i in range(num)
}
return dist2peak, peaks return dist2peak, peaks
def density_to_peaks(dists, nbrs, density, max_conn=1, sort='dist'):
def density_to_peaks(dists, nbrs, density, max_conn=1, sort="dist"):
# Note that dists has been sorted in ascending order # Note that dists has been sorted in ascending order
assert dists.shape[0] == density.shape[0] assert dists.shape[0] == density.shape[0]
assert dists.shape == nbrs.shape assert dists.shape == nbrs.shape
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import inspect
import argparse import argparse
import numpy as np import inspect
from utils import Timer, TextColors, metrics import numpy as np
from clustering_benchmark import ClusteringBenchmark from clustering_benchmark import ClusteringBenchmark
from utils import TextColors, Timer, metrics
def _read_meta(fn): def _read_meta(fn):
labels = list() labels = list()
...@@ -19,39 +20,49 @@ def _read_meta(fn): ...@@ -19,39 +20,49 @@ def _read_meta(fn):
return np.array(labels), lb_set return np.array(labels), lb_set
def evaluate(gt_labels, pred_labels, metric='pairwise'): def evaluate(gt_labels, pred_labels, metric="pairwise"):
if isinstance(gt_labels, str) and isinstance(pred_labels, str): if isinstance(gt_labels, str) and isinstance(pred_labels, str):
print('[gt_labels] {}'.format(gt_labels)) print("[gt_labels] {}".format(gt_labels))
print('[pred_labels] {}'.format(pred_labels)) print("[pred_labels] {}".format(pred_labels))
gt_labels, gt_lb_set = _read_meta(gt_labels) gt_labels, gt_lb_set = _read_meta(gt_labels)
pred_labels, pred_lb_set = _read_meta(pred_labels) pred_labels, pred_lb_set = _read_meta(pred_labels)
print('#inst: gt({}) vs pred({})'.format(len(gt_labels), print(
len(pred_labels))) "#inst: gt({}) vs pred({})".format(len(gt_labels), len(pred_labels))
print('#cls: gt({}) vs pred({})'.format(len(gt_lb_set), )
len(pred_lb_set))) print(
"#cls: gt({}) vs pred({})".format(len(gt_lb_set), len(pred_lb_set))
)
metric_func = metrics.__dict__[metric] metric_func = metrics.__dict__[metric]
with Timer('evaluate with {}{}{}'.format(TextColors.FATAL, metric, with Timer(
TextColors.ENDC)): "evaluate with {}{}{}".format(TextColors.FATAL, metric, TextColors.ENDC)
):
result = metric_func(gt_labels, pred_labels) result = metric_func(gt_labels, pred_labels)
if isinstance(result, np.float): if isinstance(result, np.float):
print('{}{}: {:.4f}{}'.format(TextColors.OKGREEN, metric, result, print(
TextColors.ENDC)) "{}{}: {:.4f}{}".format(
TextColors.OKGREEN, metric, result, TextColors.ENDC
)
)
else: else:
ave_pre, ave_rec, fscore = result ave_pre, ave_rec, fscore = result
print('{}ave_pre: {:.4f}, ave_rec: {:.4f}, fscore: {:.4f}{}'.format( print(
TextColors.OKGREEN, ave_pre, ave_rec, fscore, TextColors.ENDC)) "{}ave_pre: {:.4f}, ave_rec: {:.4f}, fscore: {:.4f}{}".format(
TextColors.OKGREEN, ave_pre, ave_rec, fscore, TextColors.ENDC
)
)
def evaluation(pred_labels, labels, metrics): def evaluation(pred_labels, labels, metrics):
print('==> evaluation') print("==> evaluation")
#pred_labels = g.ndata['pred_labels'].cpu().numpy() # pred_labels = g.ndata['pred_labels'].cpu().numpy()
max_cluster = np.max(pred_labels) max_cluster = np.max(pred_labels)
#gt_labels_all = g.ndata['labels'].cpu().numpy() # gt_labels_all = g.ndata['labels'].cpu().numpy()
gt_labels_all = labels gt_labels_all = labels
pred_labels_all = pred_labels pred_labels_all = pred_labels
metric_list = metrics.split(',') metric_list = metrics.split(",")
for metric in metric_list: for metric in metric_list:
evaluate(gt_labels_all, pred_labels_all, metric) evaluate(gt_labels_all, pred_labels_all, metric)
# H and C-scores # H and C-scores
......
""" """
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
""" """
import os
import gc import gc
import os
import faiss
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
import faiss __all__ = ["faiss_search_approx_knn"]
__all__ = ['faiss_search_approx_knn']
class faiss_index_wrapper(): class faiss_index_wrapper:
def __init__(self, def __init__(
target, self,
nprobe=128, target,
index_factory_str=None, nprobe=128,
verbose=False, index_factory_str=None,
mode='proxy', verbose=False,
using_gpu=True): mode="proxy",
using_gpu=True,
):
self._res_list = [] self._res_list = []
num_gpu = faiss.get_num_gpus() num_gpu = faiss.get_num_gpus()
print('[faiss gpu] #GPU: {}'.format(num_gpu)) print("[faiss gpu] #GPU: {}".format(num_gpu))
size, dim = target.shape size, dim = target.shape
assert size > 0, "size: {}".format(size) assert size > 0, "size: {}".format(size)
index_factory_str = "IVF{},PQ{}".format( index_factory_str = (
min(8192, 16 * round(np.sqrt(size))), "IVF{},PQ{}".format(min(8192, 16 * round(np.sqrt(size))), 32)
32) if index_factory_str is None else index_factory_str if index_factory_str is None
else index_factory_str
)
cpu_index = faiss.index_factory(dim, index_factory_str) cpu_index = faiss.index_factory(dim, index_factory_str)
cpu_index.nprobe = nprobe cpu_index.nprobe = nprobe
if mode == 'proxy': if mode == "proxy":
co = faiss.GpuClonerOptions() co = faiss.GpuClonerOptions()
co.useFloat16 = True co.useFloat16 = True
co.usePrecomputed = False co.usePrecomputed = False
...@@ -40,17 +45,18 @@ class faiss_index_wrapper(): ...@@ -40,17 +45,18 @@ class faiss_index_wrapper():
for i in range(num_gpu): for i in range(num_gpu):
res = faiss.StandardGpuResources() res = faiss.StandardGpuResources()
self._res_list.append(res) self._res_list.append(res)
sub_index = faiss.index_cpu_to_gpu( sub_index = (
res, i, cpu_index, co) if using_gpu else cpu_index faiss.index_cpu_to_gpu(res, i, cpu_index, co)
if using_gpu
else cpu_index
)
index.addIndex(sub_index) index.addIndex(sub_index)
elif mode == 'shard': elif mode == "shard":
co = faiss.GpuMultipleClonerOptions() co = faiss.GpuMultipleClonerOptions()
co.useFloat16 = True co.useFloat16 = True
co.usePrecomputed = False co.usePrecomputed = False
co.shard = True co.shard = True
index = faiss.index_cpu_to_all_gpus(cpu_index, index = faiss.index_cpu_to_all_gpus(cpu_index, co, ngpu=num_gpu)
co,
ngpu=num_gpu)
else: else:
raise KeyError("Unknown index mode") raise KeyError("Unknown index mode")
...@@ -58,14 +64,19 @@ class faiss_index_wrapper(): ...@@ -58,14 +64,19 @@ class faiss_index_wrapper():
index.verbose = verbose index.verbose = verbose
# get nlist to decide how many samples used for training # get nlist to decide how many samples used for training
nlist = int(float([ nlist = int(
item for item in index_factory_str.split(",") if 'IVF' in item float(
][0].replace("IVF", ""))) [
item
for item in index_factory_str.split(",")
if "IVF" in item
][0].replace("IVF", "")
)
)
# training # training
if not index.is_trained: if not index.is_trained:
indexes_sample_for_train = np.random.randint( indexes_sample_for_train = np.random.randint(0, size, nlist * 256)
0, size, nlist * 256)
index.train(target[indexes_sample_for_train]) index.train(target[indexes_sample_for_train])
# add with ids # add with ids
...@@ -88,25 +99,29 @@ def batch_search(index, query, k, bs, verbose=False): ...@@ -88,25 +99,29 @@ def batch_search(index, query, k, bs, verbose=False):
dists = np.zeros((n, k), dtype=np.float32) dists = np.zeros((n, k), dtype=np.float32)
nbrs = np.zeros((n, k), dtype=np.int64) nbrs = np.zeros((n, k), dtype=np.int64)
for sid in tqdm(range(0, n, bs), for sid in tqdm(
desc="faiss searching...", range(0, n, bs), desc="faiss searching...", disable=not verbose
disable=not verbose): ):
eid = min(n, sid + bs) eid = min(n, sid + bs)
dists[sid:eid], nbrs[sid:eid] = index.search(query[sid:eid], k) dists[sid:eid], nbrs[sid:eid] = index.search(query[sid:eid], k)
return dists, nbrs return dists, nbrs
def faiss_search_approx_knn(query, def faiss_search_approx_knn(
target, query,
k, target,
nprobe=128, k,
bs=int(1e6), nprobe=128,
index_factory_str=None, bs=int(1e6),
verbose=False): index_factory_str=None,
index = faiss_index_wrapper(target, verbose=False,
nprobe=nprobe, ):
index_factory_str=index_factory_str, index = faiss_index_wrapper(
verbose=verbose) target,
nprobe=nprobe,
index_factory_str=index_factory_str,
verbose=verbose,
)
dists, nbrs = batch_search(index, query, k=k, bs=bs, verbose=verbose) dists, nbrs = batch_search(index, query, k=k, bs=bs, verbose=verbose)
del index del index
......
...@@ -2,105 +2,114 @@ ...@@ -2,105 +2,114 @@
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
""" """
import gc import gc
from tqdm import tqdm from tqdm import tqdm
from .faiss_gpu import faiss_search_approx_knn from .faiss_gpu import faiss_search_approx_knn
__all__ = ['faiss_search_knn'] __all__ = ["faiss_search_knn"]
def precise_dist(feat, nbrs, num_process=4, sort=True, verbose=False): def precise_dist(feat, nbrs, num_process=4, sort=True, verbose=False):
import torch import torch
feat_share = torch.from_numpy(feat).share_memory_() feat_share = torch.from_numpy(feat).share_memory_()
nbrs_share = torch.from_numpy(nbrs).share_memory_() nbrs_share = torch.from_numpy(nbrs).share_memory_()
dist_share = torch.zeros_like(nbrs_share).float().share_memory_() dist_share = torch.zeros_like(nbrs_share).float().share_memory_()
precise_dist_share_mem(feat_share, precise_dist_share_mem(
nbrs_share, feat_share,
dist_share, nbrs_share,
num_process=num_process, dist_share,
sort=sort, num_process=num_process,
verbose=verbose) sort=sort,
verbose=verbose,
)
del feat_share del feat_share
gc.collect() gc.collect()
return dist_share.numpy(), nbrs_share.numpy() return dist_share.numpy(), nbrs_share.numpy()
def precise_dist_share_mem(feat,
nbrs, def precise_dist_share_mem(
dist, feat,
num_process=16, nbrs,
sort=True, dist,
process_unit=4000, num_process=16,
verbose=False): sort=True,
process_unit=4000,
verbose=False,
):
from torch import multiprocessing as mp from torch import multiprocessing as mp
num, _ = feat.shape num, _ = feat.shape
num_per_proc = int(num / num_process) + 1 num_per_proc = int(num / num_process) + 1
for pi in range(num_process): for pi in range(num_process):
sid = pi * num_per_proc sid = pi * num_per_proc
eid = min(sid + num_per_proc, num) eid = min(sid + num_per_proc, num)
kwargs={'feat': feat, kwargs = {
'nbrs': nbrs, "feat": feat,
'dist': dist, "nbrs": nbrs,
'sid': sid, "dist": dist,
'eid': eid, "sid": sid,
'sort': sort, "eid": eid,
'process_unit': process_unit, "sort": sort,
'verbose': verbose, "process_unit": process_unit,
} "verbose": verbose,
}
bmm(**kwargs) bmm(**kwargs)
def bmm(feat,
nbrs, def bmm(
dist, feat, nbrs, dist, sid, eid, sort=True, process_unit=4000, verbose=False
sid, ):
eid,
sort=True,
process_unit=4000,
verbose=False):
import torch import torch
_, cols = dist.shape _, cols = dist.shape
batch_sim = torch.zeros((eid - sid, cols), dtype=torch.float32) batch_sim = torch.zeros((eid - sid, cols), dtype=torch.float32)
for s in tqdm(range(sid, eid, process_unit), for s in tqdm(
desc='bmm', range(sid, eid, process_unit), desc="bmm", disable=not verbose
disable=not verbose): ):
e = min(eid, s + process_unit) e = min(eid, s + process_unit)
query = feat[s:e].unsqueeze(1) query = feat[s:e].unsqueeze(1)
gallery = feat[nbrs[s:e]].permute(0, 2, 1) gallery = feat[nbrs[s:e]].permute(0, 2, 1)
batch_sim[s - sid:e - sid] = torch.clamp(torch.bmm(query, gallery).view(-1, cols), 0.0, 1.0) batch_sim[s - sid : e - sid] = torch.clamp(
torch.bmm(query, gallery).view(-1, cols), 0.0, 1.0
)
if sort: if sort:
sort_unit = int(1e6) sort_unit = int(1e6)
batch_nbr = nbrs[sid:eid] batch_nbr = nbrs[sid:eid]
for s in range(0, batch_sim.shape[0], sort_unit): for s in range(0, batch_sim.shape[0], sort_unit):
e = min(s + sort_unit, eid) e = min(s + sort_unit, eid)
batch_sim[s:e], indices = torch.sort(batch_sim[s:e], batch_sim[s:e], indices = torch.sort(
descending=True) batch_sim[s:e], descending=True
)
batch_nbr[s:e] = torch.gather(batch_nbr[s:e], 1, indices) batch_nbr[s:e] = torch.gather(batch_nbr[s:e], 1, indices)
nbrs[sid:eid] = batch_nbr nbrs[sid:eid] = batch_nbr
dist[sid:eid] = 1. - batch_sim dist[sid:eid] = 1.0 - batch_sim
def faiss_search_knn(feat,
k, def faiss_search_knn(
nprobe=128, feat,
num_process=4, k,
is_precise=True, nprobe=128,
sort=True, num_process=4,
verbose=False): is_precise=True,
sort=True,
dists, nbrs = faiss_search_approx_knn(query=feat, verbose=False,
target=feat, ):
k=k,
nprobe=nprobe, dists, nbrs = faiss_search_approx_knn(
verbose=verbose) query=feat, target=feat, k=k, nprobe=nprobe, verbose=verbose
)
if is_precise: if is_precise:
print('compute precise dist among k={} nearest neighbors'.format(k)) print("compute precise dist among k={} nearest neighbors".format(k))
dists, nbrs = precise_dist(feat, dists, nbrs = precise_dist(
nbrs, feat, nbrs, num_process=num_process, sort=sort, verbose=verbose
num_process=num_process, )
sort=sort,
verbose=verbose)
return dists, nbrs return dists, nbrs
...@@ -4,21 +4,25 @@ ...@@ -4,21 +4,25 @@
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
""" """
import os
import math import math
import numpy as np
import multiprocessing as mp import multiprocessing as mp
from tqdm import tqdm import os
import numpy as np
from tqdm import tqdm
from utils import Timer from utils import Timer
from .faiss_search import faiss_search_knn from .faiss_search import faiss_search_knn
__all__ = [ __all__ = [
'knn_faiss', 'knn_faiss_gpu', "knn_faiss",
'fast_knns2spmat', 'build_knns', "knn_faiss_gpu",
'knns2ordered_nbrs' "fast_knns2spmat",
"build_knns",
"knns2ordered_nbrs",
] ]
def knns2ordered_nbrs(knns, sort=True): def knns2ordered_nbrs(knns, sort=True):
if isinstance(knns, list): if isinstance(knns, list):
knns = np.array(knns) knns = np.array(knns)
...@@ -32,9 +36,11 @@ def knns2ordered_nbrs(knns, sort=True): ...@@ -32,9 +36,11 @@ def knns2ordered_nbrs(knns, sort=True):
nbrs = nbrs[idxs, nb_idx] nbrs = nbrs[idxs, nb_idx]
return dists, nbrs return dists, nbrs
def fast_knns2spmat(knns, k, th_sim=0, use_sim=True, fill_value=None): def fast_knns2spmat(knns, k, th_sim=0, use_sim=True, fill_value=None):
# convert knns to symmetric sparse matrix # convert knns to symmetric sparse matrix
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
eps = 1e-5 eps = 1e-5
n = len(knns) n = len(knns)
if isinstance(knns, list): if isinstance(knns, list):
...@@ -52,14 +58,15 @@ def fast_knns2spmat(knns, k, th_sim=0, use_sim=True, fill_value=None): ...@@ -52,14 +58,15 @@ def fast_knns2spmat(knns, k, th_sim=0, use_sim=True, fill_value=None):
knns = ndarr knns = ndarr
nbrs = knns[:, 0, :] nbrs = knns[:, 0, :]
dists = knns[:, 1, :] dists = knns[:, 1, :]
assert -eps <= dists.min() <= dists.max( assert (
) <= 1 + eps, "min: {}, max: {}".format(dists.min(), dists.max()) -eps <= dists.min() <= dists.max() <= 1 + eps
), "min: {}, max: {}".format(dists.min(), dists.max())
if use_sim: if use_sim:
sims = 1. - dists sims = 1.0 - dists
else: else:
sims = dists sims = dists
if fill_value is not None: if fill_value is not None:
print('[fast_knns2spmat] edge fill value:', fill_value) print("[fast_knns2spmat] edge fill value:", fill_value)
sims.fill(fill_value) sims.fill(fill_value)
row, col = np.where(sims >= th_sim) row, col = np.where(sims >= th_sim)
# remove the self-loop # remove the self-loop
...@@ -72,24 +79,25 @@ def fast_knns2spmat(knns, k, th_sim=0, use_sim=True, fill_value=None): ...@@ -72,24 +79,25 @@ def fast_knns2spmat(knns, k, th_sim=0, use_sim=True, fill_value=None):
spmat = csr_matrix((data, (row, col)), shape=(n, n)) spmat = csr_matrix((data, (row, col)), shape=(n, n))
return spmat return spmat
def build_knns(feats,
k, def build_knns(feats, k, knn_method, dump=True):
knn_method, with Timer("build index"):
dump=True): if knn_method == "faiss":
with Timer('build index'):
if knn_method == 'faiss':
index = knn_faiss(feats, k, omp_num_threads=None) index = knn_faiss(feats, k, omp_num_threads=None)
elif knn_method == 'faiss_gpu': elif knn_method == "faiss_gpu":
index = knn_faiss_gpu(feats, k) index = knn_faiss_gpu(feats, k)
else: else:
raise KeyError( raise KeyError(
'Only support faiss and faiss_gpu currently ({}).'.format(knn_method)) "Only support faiss and faiss_gpu currently ({}).".format(
knn_method
)
)
knns = index.get_knns() knns = index.get_knns()
return knns return knns
class knn(): class knn:
def __init__(self, feats, k, index_path='', verbose=True): def __init__(self, feats, k, index_path="", verbose=True):
pass pass
def filter_by_th(self, i): def filter_by_th(self, i):
...@@ -106,68 +114,87 @@ class knn(): ...@@ -106,68 +114,87 @@ class knn():
return (th_nbrs, th_dists) return (th_nbrs, th_dists)
def get_knns(self, th=None): def get_knns(self, th=None):
if th is None or th <= 0.: if th is None or th <= 0.0:
return self.knns return self.knns
# TODO: optimize the filtering process by numpy # TODO: optimize the filtering process by numpy
# nproc = mp.cpu_count() # nproc = mp.cpu_count()
nproc = 1 nproc = 1
with Timer('filter edges by th {} (CPU={})'.format(th, nproc), with Timer(
self.verbose): "filter edges by th {} (CPU={})".format(th, nproc), self.verbose
):
self.th = th self.th = th
self.th_knns = [] self.th_knns = []
tot = len(self.knns) tot = len(self.knns)
if nproc > 1: if nproc > 1:
pool = mp.Pool(nproc) pool = mp.Pool(nproc)
th_knns = list( th_knns = list(
tqdm(pool.imap(self.filter_by_th, range(tot)), total=tot)) tqdm(pool.imap(self.filter_by_th, range(tot)), total=tot)
)
pool.close() pool.close()
else: else:
th_knns = [self.filter_by_th(i) for i in range(tot)] th_knns = [self.filter_by_th(i) for i in range(tot)]
return th_knns return th_knns
class knn_faiss(knn): class knn_faiss(knn):
def __init__(self, def __init__(
feats, self,
k, feats,
nprobe=128, k,
omp_num_threads=None, nprobe=128,
rebuild_index=True, omp_num_threads=None,
verbose=True, rebuild_index=True,
**kwargs): verbose=True,
**kwargs
):
import faiss import faiss
if omp_num_threads is not None: if omp_num_threads is not None:
faiss.omp_set_num_threads(omp_num_threads) faiss.omp_set_num_threads(omp_num_threads)
self.verbose = verbose self.verbose = verbose
with Timer('[faiss] build index', verbose): with Timer("[faiss] build index", verbose):
feats = feats.astype('float32') feats = feats.astype("float32")
size, dim = feats.shape size, dim = feats.shape
index = faiss.IndexFlatIP(dim) index = faiss.IndexFlatIP(dim)
index.add(feats) index.add(feats)
with Timer('[faiss] query topk {}'.format(k), verbose): with Timer("[faiss] query topk {}".format(k), verbose):
sims, nbrs = index.search(feats, k=k) sims, nbrs = index.search(feats, k=k)
self.knns = [(np.array(nbr, dtype=np.int32), self.knns = [
1 - np.array(sim, dtype=np.float32)) (
for nbr, sim in zip(nbrs, sims)] np.array(nbr, dtype=np.int32),
1 - np.array(sim, dtype=np.float32),
)
for nbr, sim in zip(nbrs, sims)
]
class knn_faiss_gpu(knn): class knn_faiss_gpu(knn):
def __init__(self, def __init__(
feats, self,
k, feats,
nprobe=128, k,
num_process=4, nprobe=128,
is_precise=True, num_process=4,
sort=True, is_precise=True,
verbose=True, sort=True,
**kwargs): verbose=True,
with Timer('[faiss_gpu] query topk {}'.format(k), verbose): **kwargs
dists, nbrs = faiss_search_knn(feats, ):
k=k, with Timer("[faiss_gpu] query topk {}".format(k), verbose):
nprobe=nprobe, dists, nbrs = faiss_search_knn(
num_process=num_process, feats,
is_precise=is_precise, k=k,
sort=sort, nprobe=nprobe,
verbose=verbose) num_process=num_process,
is_precise=is_precise,
self.knns = [(np.array(nbr, dtype=np.int32), sort=sort,
np.array(dist, dtype=np.float32)) verbose=verbose,
for nbr, dist in zip(nbrs, dists)] )
self.knns = [
(
np.array(nbr, dtype=np.int32),
np.array(dist, dtype=np.float32),
)
for nbr, dist in zip(nbrs, dists)
]
...@@ -7,25 +7,32 @@ This file re-uses implementation from https://github.com/yl-1993/learn-to-cluste ...@@ -7,25 +7,32 @@ This file re-uses implementation from https://github.com/yl-1993/learn-to-cluste
from __future__ import division from __future__ import division
import numpy as np import numpy as np
from sklearn.metrics.cluster import (contingency_matrix, from sklearn.metrics import precision_score, recall_score
normalized_mutual_info_score) from sklearn.metrics.cluster import (
from sklearn.metrics import (precision_score, recall_score) contingency_matrix,
normalized_mutual_info_score,
)
__all__ = ["pairwise", "bcubed", "nmi", "precision", "recall", "accuracy"]
__all__ = ['pairwise', 'bcubed', 'nmi', 'precision', 'recall', 'accuracy']
def _check(gt_labels, pred_labels): def _check(gt_labels, pred_labels):
if gt_labels.ndim != 1: if gt_labels.ndim != 1:
raise ValueError("gt_labels must be 1D: shape is %r" % raise ValueError(
(gt_labels.shape, )) "gt_labels must be 1D: shape is %r" % (gt_labels.shape,)
)
if pred_labels.ndim != 1: if pred_labels.ndim != 1:
raise ValueError("pred_labels must be 1D: shape is %r" % raise ValueError(
(pred_labels.shape, )) "pred_labels must be 1D: shape is %r" % (pred_labels.shape,)
)
if gt_labels.shape != pred_labels.shape: if gt_labels.shape != pred_labels.shape:
raise ValueError( raise ValueError(
"gt_labels and pred_labels must have same size, got %d and %d" % "gt_labels and pred_labels must have same size, got %d and %d"
(gt_labels.shape[0], pred_labels.shape[0])) % (gt_labels.shape[0], pred_labels.shape[0])
)
return gt_labels, pred_labels return gt_labels, pred_labels
def _get_lb2idxs(labels): def _get_lb2idxs(labels):
lb2idxs = {} lb2idxs = {}
for idx, lb in enumerate(labels): for idx, lb in enumerate(labels):
...@@ -34,20 +41,22 @@ def _get_lb2idxs(labels): ...@@ -34,20 +41,22 @@ def _get_lb2idxs(labels):
lb2idxs[lb].append(idx) lb2idxs[lb].append(idx)
return lb2idxs return lb2idxs
def _compute_fscore(pre, rec): def _compute_fscore(pre, rec):
return 2. * pre * rec / (pre + rec) return 2.0 * pre * rec / (pre + rec)
def fowlkes_mallows_score(gt_labels, pred_labels, sparse=True): def fowlkes_mallows_score(gt_labels, pred_labels, sparse=True):
''' The original function is from `sklearn.metrics.fowlkes_mallows_score`. """The original function is from `sklearn.metrics.fowlkes_mallows_score`.
We output the pairwise precision, pairwise recall and F-measure, We output the pairwise precision, pairwise recall and F-measure,
instead of calculating the geometry mean of precision and recall. instead of calculating the geometry mean of precision and recall.
''' """
n_samples, = gt_labels.shape (n_samples,) = gt_labels.shape
c = contingency_matrix(gt_labels, pred_labels, sparse=sparse) c = contingency_matrix(gt_labels, pred_labels, sparse=sparse)
tk = np.dot(c.data, c.data) - n_samples tk = np.dot(c.data, c.data) - n_samples
pk = np.sum(np.asarray(c.sum(axis=0)).ravel()**2) - n_samples pk = np.sum(np.asarray(c.sum(axis=0)).ravel() ** 2) - n_samples
qk = np.sum(np.asarray(c.sum(axis=1)).ravel()**2) - n_samples qk = np.sum(np.asarray(c.sum(axis=1)).ravel() ** 2) - n_samples
avg_pre = tk / pk avg_pre = tk / pk
avg_rec = tk / qk avg_rec = tk / qk
...@@ -55,10 +64,12 @@ def fowlkes_mallows_score(gt_labels, pred_labels, sparse=True): ...@@ -55,10 +64,12 @@ def fowlkes_mallows_score(gt_labels, pred_labels, sparse=True):
return avg_pre, avg_rec, fscore return avg_pre, avg_rec, fscore
def pairwise(gt_labels, pred_labels, sparse=True): def pairwise(gt_labels, pred_labels, sparse=True):
_check(gt_labels, pred_labels) _check(gt_labels, pred_labels)
return fowlkes_mallows_score(gt_labels, pred_labels, sparse) return fowlkes_mallows_score(gt_labels, pred_labels, sparse)
def bcubed(gt_labels, pred_labels): def bcubed(gt_labels, pred_labels):
_check(gt_labels, pred_labels) _check(gt_labels, pred_labels)
...@@ -75,7 +86,7 @@ def bcubed(gt_labels, pred_labels): ...@@ -75,7 +86,7 @@ def bcubed(gt_labels, pred_labels):
gt_num[i] = len(gt_idxs) gt_num[i] = len(gt_idxs)
for pred_lb in all_pred_lbs: for pred_lb in all_pred_lbs:
pred_idxs = pred_lb2idxs[pred_lb] pred_idxs = pred_lb2idxs[pred_lb]
n = 1. * np.intersect1d(gt_idxs, pred_idxs).size n = 1.0 * np.intersect1d(gt_idxs, pred_idxs).size
pre[i] += n**2 / len(pred_idxs) pre[i] += n**2 / len(pred_idxs)
rec[i] += n**2 / gt_num[i] rec[i] += n**2 / gt_num[i]
...@@ -86,14 +97,18 @@ def bcubed(gt_labels, pred_labels): ...@@ -86,14 +97,18 @@ def bcubed(gt_labels, pred_labels):
return avg_pre, avg_rec, fscore return avg_pre, avg_rec, fscore
def nmi(gt_labels, pred_labels): def nmi(gt_labels, pred_labels):
return normalized_mutual_info_score(pred_labels, gt_labels) return normalized_mutual_info_score(pred_labels, gt_labels)
def precision(gt_labels, pred_labels): def precision(gt_labels, pred_labels):
return precision_score(gt_labels, pred_labels) return precision_score(gt_labels, pred_labels)
def recall(gt_labels, pred_labels): def recall(gt_labels, pred_labels):
return recall_score(gt_labels, pred_labels) return recall_score(gt_labels, pred_labels)
def accuracy(gt_labels, pred_labels): def accuracy(gt_labels, pred_labels):
return np.mean(gt_labels == pred_labels) return np.mean(gt_labels == pred_labels)
...@@ -4,25 +4,28 @@ ...@@ -4,25 +4,28 @@
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
""" """
import os
import time
import json import json
import os
import pickle import pickle
import random import random
import time
import numpy as np import numpy as np
class TextColors: class TextColors:
HEADER = '\033[35m' HEADER = "\033[35m"
OKBLUE = '\033[34m' OKBLUE = "\033[34m"
OKGREEN = '\033[32m' OKGREEN = "\033[32m"
WARNING = '\033[33m' WARNING = "\033[33m"
FATAL = '\033[31m' FATAL = "\033[31m"
ENDC = '\033[0m' ENDC = "\033[0m"
BOLD = '\033[1m' BOLD = "\033[1m"
UNDERLINE = '\033[4m' UNDERLINE = "\033[4m"
class Timer():
def __init__(self, name='task', verbose=True): class Timer:
def __init__(self, name="task", verbose=True):
self.name = name self.name = name
self.verbose = verbose self.verbose = verbose
...@@ -32,49 +35,66 @@ class Timer(): ...@@ -32,49 +35,66 @@ class Timer():
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if self.verbose: if self.verbose:
print('[Time] {} consumes {:.4f} s'.format( print(
self.name, "[Time] {} consumes {:.4f} s".format(
time.time() - self.start)) self.name, time.time() - self.start
)
)
return exc_type is None return exc_type is None
def set_random_seed(seed, cuda=False): def set_random_seed(seed, cuda=False):
import torch import torch
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if cuda: if cuda:
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def l2norm(vec): def l2norm(vec):
vec /= np.linalg.norm(vec, axis=1).reshape(-1, 1) vec /= np.linalg.norm(vec, axis=1).reshape(-1, 1)
return vec return vec
def is_l2norm(features, size): def is_l2norm(features, size):
rand_i = random.choice(range(size)) rand_i = random.choice(range(size))
norm_ = np.dot(features[rand_i, :], features[rand_i, :]) norm_ = np.dot(features[rand_i, :], features[rand_i, :])
return abs(norm_ - 1) < 1e-6 return abs(norm_ - 1) < 1e-6
def is_spmat_eq(a, b): def is_spmat_eq(a, b):
return (a != b).nnz == 0 return (a != b).nnz == 0
def aggregate(features, adj, times): def aggregate(features, adj, times):
dtype = features.dtype dtype = features.dtype
for i in range(times): for i in range(times):
features = adj * features features = adj * features
return features.astype(dtype) return features.astype(dtype)
def mkdir_if_no_exists(path, subdirs=[''], is_folder=False):
if path == '': def mkdir_if_no_exists(path, subdirs=[""], is_folder=False):
if path == "":
return return
for sd in subdirs: for sd in subdirs:
if sd != '' or is_folder: if sd != "" or is_folder:
d = os.path.dirname(os.path.join(path, sd)) d = os.path.dirname(os.path.join(path, sd))
else: else:
d = os.path.dirname(path) d = os.path.dirname(path)
if not os.path.exists(d): if not os.path.exists(d):
os.makedirs(d) os.makedirs(d)
def stop_iterating(current_l, total_l, early_stop, num_edges_add_this_level, num_edges_add_last_level, knn_k):
def stop_iterating(
current_l,
total_l,
early_stop,
num_edges_add_this_level,
num_edges_add_last_level,
knn_k,
):
# Stopping rule 1: run all levels # Stopping rule 1: run all levels
if current_l == total_l - 1: if current_l == total_l - 1:
return True return True
...@@ -82,6 +102,10 @@ def stop_iterating(current_l, total_l, early_stop, num_edges_add_this_level, num ...@@ -82,6 +102,10 @@ def stop_iterating(current_l, total_l, early_stop, num_edges_add_this_level, num
if num_edges_add_this_level == 0: if num_edges_add_this_level == 0:
return True return True
# Stopping rule 3: early stopping, two levels start to produce similar numbers of edges # Stopping rule 3: early stopping, two levels start to produce similar numbers of edges
if early_stop and float(num_edges_add_last_level) / num_edges_add_this_level < knn_k - 1: if (
early_stop
and float(num_edges_add_last_level) / num_edges_add_this_level
< knn_k - 1
):
return True return True
return False return False
''' Evaluate unsupervised embedding using a variety of basic classifiers. ''' """ Evaluate unsupervised embedding using a variety of basic classifiers. """
''' Credit: https://github.com/fanyun-sun/InfoGraph ''' """ Credit: https://github.com/fanyun-sun/InfoGraph """
import numpy as np
import torch
import torch.nn as nn
from sklearn import preprocessing from sklearn import preprocessing
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV, StratifiedKFold from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sklearn.svm import SVC from sklearn.svm import SVC
import numpy as np
import torch
import torch.nn as nn
class LogReg(nn.Module): class LogReg(nn.Module):
def __init__(self, ft_in, nb_classes): def __init__(self, ft_in, nb_classes):
...@@ -26,7 +25,8 @@ class LogReg(nn.Module): ...@@ -26,7 +25,8 @@ class LogReg(nn.Module):
ret = self.fc(seq) ret = self.fc(seq)
return ret return ret
def logistic_classify(x, y, device = 'cpu'):
def logistic_classify(x, y, device="cpu"):
nb_classes = np.unique(y).shape[0] nb_classes = np.unique(y).shape[0]
xent = nn.CrossEntropyLoss() xent = nn.CrossEntropyLoss()
hid_units = x.shape[1] hid_units = x.shape[1]
...@@ -35,11 +35,14 @@ def logistic_classify(x, y, device = 'cpu'): ...@@ -35,11 +35,14 @@ def logistic_classify(x, y, device = 'cpu'):
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None) kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)
for train_index, test_index in kf.split(x, y): for train_index, test_index in kf.split(x, y):
train_embs, test_embs = x[train_index], x[test_index] train_embs, test_embs = x[train_index], x[test_index]
train_lbls, test_lbls= y[train_index], y[test_index] train_lbls, test_lbls = y[train_index], y[test_index]
train_embs, train_lbls = torch.from_numpy(train_embs).to(device), torch.from_numpy(train_lbls).to(device) train_embs, train_lbls = torch.from_numpy(train_embs).to(
test_embs, test_lbls = torch.from_numpy(test_embs).to(device), torch.from_numpy(test_lbls).to(device) device
), torch.from_numpy(train_lbls).to(device)
test_embs, test_lbls = torch.from_numpy(test_embs).to(
device
), torch.from_numpy(test_lbls).to(device)
log = LogReg(hid_units, nb_classes) log = LogReg(hid_units, nb_classes)
log = log.to(device) log = log.to(device)
...@@ -62,6 +65,7 @@ def logistic_classify(x, y, device = 'cpu'): ...@@ -62,6 +65,7 @@ def logistic_classify(x, y, device = 'cpu'):
accs.append(acc.item()) accs.append(acc.item())
return np.mean(accs) return np.mean(accs)
def svc_classify(x, y, search): def svc_classify(x, y, search):
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None) kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)
accuracies = [] accuracies = []
...@@ -71,21 +75,24 @@ def svc_classify(x, y, search): ...@@ -71,21 +75,24 @@ def svc_classify(x, y, search):
y_train, y_test = y[train_index], y[test_index] y_train, y_test = y[train_index], y[test_index]
if search: if search:
params = {'C':[0.001, 0.01, 0.1, 1, 10, 100, 1000]} params = {"C": [0.001, 0.01, 0.1, 1, 10, 100, 1000]}
classifier = GridSearchCV(SVC(), params, cv=5, scoring='accuracy', verbose=0) classifier = GridSearchCV(
SVC(), params, cv=5, scoring="accuracy", verbose=0
)
else: else:
classifier = SVC(C=10) classifier = SVC(C=10)
classifier.fit(x_train, y_train) classifier.fit(x_train, y_train)
accuracies.append(accuracy_score(y_test, classifier.predict(x_test))) accuracies.append(accuracy_score(y_test, classifier.predict(x_test)))
return np.mean(accuracies) return np.mean(accuracies)
def evaluate_embedding(embeddings, labels, search=True, device = 'cpu'):
def evaluate_embedding(embeddings, labels, search=True, device="cpu"):
labels = preprocessing.LabelEncoder().fit_transform(labels) labels = preprocessing.LabelEncoder().fit_transform(labels)
x, y = np.array(embeddings), np.array(labels) x, y = np.array(embeddings), np.array(labels)
logreg_accuracy = logistic_classify(x, y, device) logreg_accuracy = logistic_classify(x, y, device)
print('LogReg', logreg_accuracy) print("LogReg", logreg_accuracy)
svc_accuracy = svc_classify(x, y, search) svc_accuracy = svc_classify(x, y, search)
print('svc', svc_accuracy) print("svc", svc_accuracy)
return logreg_accuracy, svc_accuracy return logreg_accuracy, svc_accuracy
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Sequential, ModuleList, Linear, GRU, ReLU, BatchNorm1d from torch.nn import GRU, BatchNorm1d, Linear, ModuleList, ReLU, Sequential
from utils import global_global_loss_, local_global_loss_
from dgl.nn import GINConv, NNConv, Set2Set from dgl.nn import GINConv, NNConv, Set2Set
from dgl.nn.pytorch.glob import SumPooling from dgl.nn.pytorch.glob import SumPooling
from utils import global_global_loss_, local_global_loss_ """ Feedforward neural network"""
''' Feedforward neural network'''
class FeedforwardNetwork(nn.Module): class FeedforwardNetwork(nn.Module):
''' """
3-layer feed-forward neural networks with jumping connections 3-layer feed-forward neural networks with jumping connections
Parameters Parameters
----------- -----------
...@@ -26,18 +26,19 @@ class FeedforwardNetwork(nn.Module): ...@@ -26,18 +26,19 @@ class FeedforwardNetwork(nn.Module):
forward(feat): forward(feat):
feat: Tensor feat: Tensor
[N * D], input features [N * D], input features
''' """
def __init__(self, in_dim, hid_dim): def __init__(self, in_dim, hid_dim):
super(FeedforwardNetwork, self).__init__() super(FeedforwardNetwork, self).__init__()
self.block = Sequential(Linear(in_dim, hid_dim), self.block = Sequential(
ReLU(), Linear(in_dim, hid_dim),
Linear(hid_dim, hid_dim), ReLU(),
ReLU(), Linear(hid_dim, hid_dim),
Linear(hid_dim, hid_dim), ReLU(),
ReLU() Linear(hid_dim, hid_dim),
) ReLU(),
)
self.jump_con = Linear(in_dim, hid_dim) self.jump_con = Linear(in_dim, hid_dim)
...@@ -50,10 +51,11 @@ class FeedforwardNetwork(nn.Module): ...@@ -50,10 +51,11 @@ class FeedforwardNetwork(nn.Module):
return out return out
''' Unsupervised Setting ''' """ Unsupervised Setting """
class GINEncoder(nn.Module): class GINEncoder(nn.Module):
''' """
Encoder based on dgl.nn.GINConv & dgl.nn.SumPooling Encoder based on dgl.nn.GINConv & dgl.nn.SumPooling
Parameters Parameters
----------- -----------
...@@ -61,7 +63,7 @@ class GINEncoder(nn.Module): ...@@ -61,7 +63,7 @@ class GINEncoder(nn.Module):
Input feature size. Input feature size.
hid_dim: int hid_dim: int
Hidden feature size. Hidden feature size.
n_layer: n_layer:
Number of GIN layers. Number of GIN layers.
Functions Functions
...@@ -70,7 +72,7 @@ class GINEncoder(nn.Module): ...@@ -70,7 +72,7 @@ class GINEncoder(nn.Module):
graph: DGLGraph graph: DGLGraph
feat: Tensor feat: Tensor
[N * D], node features [N * D], node features
''' """
def __init__(self, in_dim, hid_dim, n_layer): def __init__(self, in_dim, hid_dim, n_layer):
super(GINEncoder, self).__init__() super(GINEncoder, self).__init__()
...@@ -86,12 +88,11 @@ class GINEncoder(nn.Module): ...@@ -86,12 +88,11 @@ class GINEncoder(nn.Module):
else: else:
n_in = hid_dim n_in = hid_dim
n_out = hid_dim n_out = hid_dim
block = Sequential(Linear(n_in, n_out), block = Sequential(
ReLU(), Linear(n_in, n_out), ReLU(), Linear(hid_dim, hid_dim)
Linear(hid_dim, hid_dim) )
)
conv = GINConv(apply_func = block, aggregator_type = 'sum') conv = GINConv(apply_func=block, aggregator_type="sum")
bn = BatchNorm1d(hid_dim) bn = BatchNorm1d(hid_dim)
self.convs.append(conv) self.convs.append(conv)
...@@ -109,8 +110,8 @@ class GINEncoder(nn.Module): ...@@ -109,8 +110,8 @@ class GINEncoder(nn.Module):
x = self.bns[i](x) x = self.bns[i](x)
xs.append(x) xs.append(x)
local_emb = th.cat(xs, 1) # patch-level embedding local_emb = th.cat(xs, 1) # patch-level embedding
global_emb = self.pool(graph, local_emb) # graph-level embedding global_emb = self.pool(graph, local_emb) # graph-level embedding
return global_emb, local_emb return global_emb, local_emb
...@@ -125,7 +126,7 @@ class InfoGraph(nn.Module): ...@@ -125,7 +126,7 @@ class InfoGraph(nn.Module):
Input feature size. Input feature size.
hid_dim: int hid_dim: int
Hidden feature size. Hidden feature size.
n_layer: int n_layer: int
Number of the GNN encoder layers. Number of the GNN encoder layers.
Functions Functions
...@@ -146,8 +147,12 @@ class InfoGraph(nn.Module): ...@@ -146,8 +147,12 @@ class InfoGraph(nn.Module):
self.encoder = GINEncoder(in_dim, hid_dim, n_layer) self.encoder = GINEncoder(in_dim, hid_dim, n_layer)
self.local_d = FeedforwardNetwork(embedding_dim, embedding_dim) # local discriminator (node-level) self.local_d = FeedforwardNetwork(
self.global_d = FeedforwardNetwork(embedding_dim, embedding_dim) # global discriminator (graph-level) embedding_dim, embedding_dim
) # local discriminator (node-level)
self.global_d = FeedforwardNetwork(
embedding_dim, embedding_dim
) # global discriminator (graph-level)
def get_embedding(self, graph, feat): def get_embedding(self, graph, feat):
# get_embedding function for evaluation the learned embeddings # get_embedding function for evaluation the learned embeddings
...@@ -161,19 +166,20 @@ class InfoGraph(nn.Module): ...@@ -161,19 +166,20 @@ class InfoGraph(nn.Module):
global_emb, local_emb = self.encoder(graph, feat) global_emb, local_emb = self.encoder(graph, feat)
global_h = self.global_d(global_emb) # global hidden representation global_h = self.global_d(global_emb) # global hidden representation
local_h = self.local_d(local_emb) # local hidden representation local_h = self.local_d(local_emb) # local hidden representation
loss = local_global_loss_(local_h, global_h, graph_id) loss = local_global_loss_(local_h, global_h, graph_id)
return loss return loss
''' Semisupervised Setting ''' """ Semisupervised Setting """
class NNConvEncoder(nn.Module): class NNConvEncoder(nn.Module):
''' """
Encoder based on dgl.nn.NNConv & GRU & dgl.nn.set2set pooling Encoder based on dgl.nn.NNConv & GRU & dgl.nn.set2set pooling
Parameters Parameters
----------- -----------
...@@ -190,7 +196,7 @@ class NNConvEncoder(nn.Module): ...@@ -190,7 +196,7 @@ class NNConvEncoder(nn.Module):
[N * D1], node features [N * D1], node features
efeat: Tensor efeat: Tensor
[E * D2], edge features [E * D2], edge features
''' """
def __init__(self, in_dim, hid_dim): def __init__(self, in_dim, hid_dim):
super(NNConvEncoder, self).__init__() super(NNConvEncoder, self).__init__()
...@@ -198,9 +204,17 @@ class NNConvEncoder(nn.Module): ...@@ -198,9 +204,17 @@ class NNConvEncoder(nn.Module):
self.lin0 = Linear(in_dim, hid_dim) self.lin0 = Linear(in_dim, hid_dim)
# mlp for edge convolution in NNConv # mlp for edge convolution in NNConv
block = Sequential(Linear(5, 128), ReLU(), Linear(128, hid_dim * hid_dim)) block = Sequential(
Linear(5, 128), ReLU(), Linear(128, hid_dim * hid_dim)
self.conv = NNConv(hid_dim, hid_dim, edge_func = block, aggregator_type = 'mean', residual = False) )
self.conv = NNConv(
hid_dim,
hid_dim,
edge_func=block,
aggregator_type="mean",
residual=False,
)
self.gru = GRU(hid_dim, hid_dim) self.gru = GRU(hid_dim, hid_dim)
# set2set pooling # set2set pooling
...@@ -228,7 +242,7 @@ class NNConvEncoder(nn.Module): ...@@ -228,7 +242,7 @@ class NNConvEncoder(nn.Module):
class InfoGraphS(nn.Module): class InfoGraphS(nn.Module):
''' """
InfoGraph* model for semi-supervised setting InfoGraph* model for semi-supervised setting
Parameters Parameters
----------- -----------
...@@ -244,8 +258,8 @@ class InfoGraphS(nn.Module): ...@@ -244,8 +258,8 @@ class InfoGraphS(nn.Module):
unsupforward(graph): unsupforward(graph):
graph: DGLGraph graph: DGLGraph
''' """
def __init__(self, in_dim, hid_dim): def __init__(self, in_dim, hid_dim):
super(InfoGraphS, self).__init__() super(InfoGraphS, self).__init__()
...@@ -265,19 +279,21 @@ class InfoGraphS(nn.Module): ...@@ -265,19 +279,21 @@ class InfoGraphS(nn.Module):
self.unsup_d = FeedforwardNetwork(2 * hid_dim, hid_dim) self.unsup_d = FeedforwardNetwork(2 * hid_dim, hid_dim)
def forward(self, graph, nfeat, efeat): def forward(self, graph, nfeat, efeat):
sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat) sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat)
sup_global_pred = self.fc2(F.relu(self.fc1(sup_global_emb))) sup_global_pred = self.fc2(F.relu(self.fc1(sup_global_emb)))
sup_global_pred = sup_global_pred.view(-1) sup_global_pred = sup_global_pred.view(-1)
return sup_global_pred return sup_global_pred
def unsup_forward(self, graph, nfeat, efeat, graph_id): def unsup_forward(self, graph, nfeat, efeat, graph_id):
sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat) sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat)
unsup_global_emb, unsup_local_emb = self.unsup_encoder(graph, nfeat, efeat) unsup_global_emb, unsup_local_emb = self.unsup_encoder(
graph, nfeat, efeat
)
g_enc = self.unsup_global_d(unsup_global_emb) g_enc = self.unsup_global_d(unsup_global_emb)
l_enc = self.unsup_local_d(unsup_local_emb) l_enc = self.unsup_local_d(unsup_local_emb)
...@@ -287,5 +303,5 @@ class InfoGraphS(nn.Module): ...@@ -287,5 +303,5 @@ class InfoGraphS(nn.Module):
# Calculate loss # Calculate loss
unsup_loss = local_global_loss_(l_enc, g_enc, graph_id) unsup_loss = local_global_loss_(l_enc, g_enc, graph_id)
con_loss = global_global_loss_(sup_g_enc, unsup_g_enc) con_loss = global_global_loss_(sup_g_enc, unsup_g_enc)
return unsup_loss, con_loss return unsup_loss, con_loss
import argparse
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
from model import InfoGraphS
import dgl import dgl
from dgl.dataloading import GraphDataLoader
from dgl.data.utils import Subset
from dgl.data import QM9EdgeDataset from dgl.data import QM9EdgeDataset
from model import InfoGraphS from dgl.data.utils import Subset
import argparse from dgl.dataloading import GraphDataLoader
def argument(): def argument():
parser = argparse.ArgumentParser(description='InfoGraphS') parser = argparse.ArgumentParser(description="InfoGraphS")
# data source params # data source params
parser.add_argument('--target', type=str, default='mu', help='Choose regression task') parser.add_argument(
parser.add_argument('--train_num', type=int, default=5000, help='Size of training set') "--target", type=str, default="mu", help="Choose regression task"
)
parser.add_argument(
"--train_num", type=int, default=5000, help="Size of training set"
)
# training params # training params
parser.add_argument('--gpu', type=int, default=-1, help='GPU index, default:-1, using CPU.') parser.add_argument(
parser.add_argument('--epochs', type=int, default=200, help='Training epochs.') "--gpu", type=int, default=-1, help="GPU index, default:-1, using CPU."
parser.add_argument('--batch_size', type=int, default=20, help='Training batch size.') )
parser.add_argument('--val_batch_size', type=int, default=100, help='Validation batch size.') parser.add_argument(
"--epochs", type=int, default=200, help="Training epochs."
)
parser.add_argument(
"--batch_size", type=int, default=20, help="Training batch size."
)
parser.add_argument(
"--val_batch_size", type=int, default=100, help="Validation batch size."
)
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.') parser.add_argument(
parser.add_argument('--wd', type=float, default=0, help='Weight decay.') "--lr", type=float, default=0.001, help="Learning rate."
)
parser.add_argument("--wd", type=float, default=0, help="Weight decay.")
# model params # model params
parser.add_argument('--hid_dim', type=int, default=64, help='Hidden layer dimensionality') parser.add_argument(
parser.add_argument('--reg', type=float, default=0.001, help='Regularization coefficient') "--hid_dim", type=int, default=64, help="Hidden layer dimensionality"
)
parser.add_argument(
"--reg", type=float, default=0.001, help="Regularization coefficient"
)
args = parser.parse_args() args = parser.parse_args()
# check cuda # check cuda
if args.gpu != -1 and th.cuda.is_available(): if args.gpu != -1 and th.cuda.is_available():
args.device = 'cuda:{}'.format(args.gpu) args.device = "cuda:{}".format(args.gpu)
else: else:
args.device = 'cpu' args.device = "cpu"
return args return args
class DenseQM9EdgeDataset(QM9EdgeDataset): class DenseQM9EdgeDataset(QM9EdgeDataset):
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph and label by index r"""Get graph and label by index
Parameters Parameters
---------- ----------
idx : int idx : int
Item index Item index
Returns Returns
------- -------
dgl.DGLGraph dgl.DGLGraph
The graph contains: The graph contains:
- ``ndata['pos']``: the coordinates of each atom - ``ndata['pos']``: the coordinates of each atom
- ``ndata['attr']``: the features of each atom - ``ndata['attr']``: the features of each atom
- ``edata['edge_attr']``: the features of each bond - ``edata['edge_attr']``: the features of each bond
Tensor Tensor
Property values of molecular graphs Property values of molecular graphs
""" """
pos = self.node_pos[self.n_cumsum[idx]:self.n_cumsum[idx+1]] pos = self.node_pos[self.n_cumsum[idx] : self.n_cumsum[idx + 1]]
src = self.src[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]] src = self.src[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]
dst = self.dst[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]] dst = self.dst[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]
g = dgl.graph((src, dst)) g = dgl.graph((src, dst))
g.ndata['pos'] = th.tensor(pos).float() g.ndata["pos"] = th.tensor(pos).float()
g.ndata['attr'] = th.tensor(self.node_attr[self.n_cumsum[idx]:self.n_cumsum[idx+1]]).float() g.ndata["attr"] = th.tensor(
g.edata['edge_attr'] = th.tensor(self.edge_attr[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]).float() self.node_attr[self.n_cumsum[idx] : self.n_cumsum[idx + 1]]
).float()
g.edata["edge_attr"] = th.tensor(
self.edge_attr[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]
).float()
label = th.tensor(self.targets[idx][self.label_keys]).float() label = th.tensor(self.targets[idx][self.label_keys]).float()
n_nodes = g.num_nodes() n_nodes = g.num_nodes()
row = th.arange(n_nodes) row = th.arange(n_nodes)
col = th.arange(n_nodes) col = th.arange(n_nodes)
row = row.view(-1,1).repeat(1, n_nodes).view(-1) row = row.view(-1, 1).repeat(1, n_nodes).view(-1)
col = col.repeat(n_nodes) col = col.repeat(n_nodes)
src = g.edges()[0] src = g.edges()[0]
dst = g.edges()[1] dst = g.edges()[1]
idx = src * n_nodes + dst idx = src * n_nodes + dst
size = list(g.edata['edge_attr'].size()) size = list(g.edata["edge_attr"].size())
size[0] = n_nodes * n_nodes size[0] = n_nodes * n_nodes
edge_attr = g.edata['edge_attr'].new_zeros(size) edge_attr = g.edata["edge_attr"].new_zeros(size)
edge_attr[idx] = g.edata['edge_attr'] edge_attr[idx] = g.edata["edge_attr"]
pos = g.ndata['pos'] pos = g.ndata["pos"]
dist = th.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1) dist = th.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1)
new_edge_attr = th.cat([edge_attr, dist.type_as(edge_attr)], dim = -1) new_edge_attr = th.cat([edge_attr, dist.type_as(edge_attr)], dim=-1)
graph = dgl.graph((row,col)) graph = dgl.graph((row, col))
graph.ndata['attr'] = g.ndata['attr'] graph.ndata["attr"] = g.ndata["attr"]
graph.edata['edge_attr'] = new_edge_attr graph.edata["edge_attr"] = new_edge_attr
graph = graph.remove_self_loop() graph = graph.remove_self_loop()
return graph, label return graph, label
def collate(samples): def collate(samples):
''' collate function for building graph dataloader ''' """collate function for building graph dataloader"""
# generate batched graphs and labels # generate batched graphs and labels
graphs, targets = map(list, zip(*samples)) graphs, targets = map(list, zip(*samples))
batched_graph = dgl.batch(graphs) batched_graph = dgl.batch(graphs)
batched_targets = th.Tensor(targets) batched_targets = th.Tensor(targets)
n_graphs = len(graphs) n_graphs = len(graphs)
graph_id = th.arange(n_graphs) graph_id = th.arange(n_graphs)
graph_id = dgl.broadcast_nodes(batched_graph, graph_id) graph_id = dgl.broadcast_nodes(batched_graph, graph_id)
batched_graph.ndata['graph_id'] = graph_id batched_graph.ndata["graph_id"] = graph_id
return batched_graph, batched_targets return batched_graph, batched_targets
def evaluate(model, loader, num, device): def evaluate(model, loader, num, device):
error = 0 error = 0
for graphs, targets in loader: for graphs, targets in loader:
graphs = graphs.to(device) graphs = graphs.to(device)
nfeat, efeat = graphs.ndata['attr'], graphs.edata['edge_attr'] nfeat, efeat = graphs.ndata["attr"], graphs.edata["edge_attr"]
targets = targets.to(device) targets = targets.to(device)
error += (model(graphs, nfeat, efeat) - targets).abs().sum().item() error += (model(graphs, nfeat, efeat) - targets).abs().sum().item()
error = error / num error = error / num
return error return error
if __name__ == '__main__':
if __name__ == "__main__":
# Step 1: Prepare graph data ===================================== # # Step 1: Prepare graph data ===================================== #
args = argument() args = argument()
label_keys = [args.target] label_keys = [args.target]
print(args) print(args)
dataset = DenseQM9EdgeDataset(label_keys = label_keys) dataset = DenseQM9EdgeDataset(label_keys=label_keys)
# Train/Val/Test Splitting # Train/Val/Test Splitting
N = dataset.targets.shape[0] N = dataset.targets.shape[0]
all_idx = np.arange(N) all_idx = np.arange(N)
...@@ -151,63 +178,75 @@ if __name__ == '__main__': ...@@ -151,63 +178,75 @@ if __name__ == '__main__':
val_idx = all_idx[:val_num] val_idx = all_idx[:val_num]
test_idx = all_idx[val_num : val_num + test_num] test_idx = all_idx[val_num : val_num + test_num]
train_idx = all_idx[val_num + test_num : val_num + test_num + args.train_num] train_idx = all_idx[
val_num + test_num : val_num + test_num + args.train_num
]
train_data = Subset(dataset, train_idx) train_data = Subset(dataset, train_idx)
val_data = Subset(dataset, val_idx) val_data = Subset(dataset, val_idx)
test_data = Subset(dataset, test_idx) test_data = Subset(dataset, test_idx)
unsup_idx = all_idx[val_num + test_num:] unsup_idx = all_idx[val_num + test_num :]
unsup_data = Subset(dataset, unsup_idx) unsup_data = Subset(dataset, unsup_idx)
# generate supervised training dataloader and unsupervised training dataloader # generate supervised training dataloader and unsupervised training dataloader
train_loader = GraphDataLoader(train_data, train_loader = GraphDataLoader(
batch_size=args.batch_size, train_data,
collate_fn=collate, batch_size=args.batch_size,
drop_last=False, collate_fn=collate,
shuffle=True) drop_last=False,
shuffle=True,
unsup_loader = GraphDataLoader(unsup_data, )
batch_size=args.batch_size,
collate_fn=collate, unsup_loader = GraphDataLoader(
drop_last=False, unsup_data,
shuffle=True) batch_size=args.batch_size,
collate_fn=collate,
drop_last=False,
shuffle=True,
)
# generate validation & testing dataloader # generate validation & testing dataloader
val_loader = GraphDataLoader(val_data, val_loader = GraphDataLoader(
batch_size=args.val_batch_size, val_data,
collate_fn=collate, batch_size=args.val_batch_size,
drop_last=False, collate_fn=collate,
shuffle=True) drop_last=False,
shuffle=True,
)
test_loader = GraphDataLoader(test_data, test_loader = GraphDataLoader(
batch_size=args.val_batch_size, test_data,
collate_fn=collate, batch_size=args.val_batch_size,
drop_last=False, collate_fn=collate,
shuffle=True) drop_last=False,
shuffle=True,
)
print('======== target = {} ========'.format(args.target)) print("======== target = {} ========".format(args.target))
in_dim = dataset[0][0].ndata['attr'].shape[1] in_dim = dataset[0][0].ndata["attr"].shape[1]
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
model = InfoGraphS(in_dim, args.hid_dim) model = InfoGraphS(in_dim, args.hid_dim)
model = model.to(args.device) model = model.to(args.device)
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
optimizer = th.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) optimizer = th.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.wd
)
scheduler = th.optim.lr_scheduler.ReduceLROnPlateau( scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.7, patience=5, min_lr=0.000001 optimizer, mode="min", factor=0.7, patience=5, min_lr=0.000001
) )
# Step 4: training epochs =============================================================== # # Step 4: training epochs =============================================================== #
best_val_error = float('inf') best_val_error = float("inf")
test_error = float('inf') test_error = float("inf")
for epoch in range(args.epochs): for epoch in range(args.epochs):
''' Training ''' """Training"""
model.train() model.train()
lr = scheduler.optimizer.param_groups[0]['lr'] lr = scheduler.optimizer.param_groups[0]["lr"]
iteration = 0 iteration = 0
sup_loss_all = 0 sup_loss_all = 0
...@@ -220,18 +259,28 @@ if __name__ == '__main__': ...@@ -220,18 +259,28 @@ if __name__ == '__main__':
sup_graph = sup_graph.to(args.device) sup_graph = sup_graph.to(args.device)
unsup_graph = unsup_graph.to(args.device) unsup_graph = unsup_graph.to(args.device)
sup_nfeat, sup_efeat = sup_graph.ndata['attr'], sup_graph.edata['edge_attr'] sup_nfeat, sup_efeat = (
unsup_nfeat, unsup_efeat, unsup_graph_id = unsup_graph.ndata['attr'],\ sup_graph.ndata["attr"],
unsup_graph.edata['edge_attr'], unsup_graph.ndata['graph_id'] sup_graph.edata["edge_attr"],
)
unsup_nfeat, unsup_efeat, unsup_graph_id = (
unsup_graph.ndata["attr"],
unsup_graph.edata["edge_attr"],
unsup_graph.ndata["graph_id"],
)
sup_target = sup_target sup_target = sup_target
sup_target = sup_target.to(args.device) sup_target = sup_target.to(args.device)
optimizer.zero_grad() optimizer.zero_grad()
sup_loss = F.mse_loss(model(sup_graph, sup_nfeat, sup_efeat), sup_target) sup_loss = F.mse_loss(
unsup_loss, consis_loss = model.unsup_forward(unsup_graph, unsup_nfeat, unsup_efeat, unsup_graph_id) model(sup_graph, sup_nfeat, sup_efeat), sup_target
)
unsup_loss, consis_loss = model.unsup_forward(
unsup_graph, unsup_nfeat, unsup_efeat, unsup_graph_id
)
loss = sup_loss + unsup_loss + args.reg * consis_loss loss = sup_loss + unsup_loss + args.reg * consis_loss
...@@ -243,17 +292,23 @@ if __name__ == '__main__': ...@@ -243,17 +292,23 @@ if __name__ == '__main__':
optimizer.step() optimizer.step()
print('Epoch: {}, Sup_Loss: {:4f}, Unsup_loss: {:.4f}, Consis_loss: {:.4f}' \ print(
.format(epoch, sup_loss_all, unsup_loss_all, consis_loss_all)) "Epoch: {}, Sup_Loss: {:4f}, Unsup_loss: {:.4f}, Consis_loss: {:.4f}".format(
epoch, sup_loss_all, unsup_loss_all, consis_loss_all
)
)
model.eval() model.eval()
val_error = evaluate(model, val_loader, val_num, args.device) val_error = evaluate(model, val_loader, val_num, args.device)
scheduler.step(val_error) scheduler.step(val_error)
if val_error < best_val_error: if val_error < best_val_error:
best_val_error = val_error best_val_error = val_error
test_error = evaluate(model, test_loader, test_num, args.device) test_error = evaluate(model, test_loader, test_num, args.device)
print('Epoch: {}, LR: {}, val_error: {:.4f}, best_test_error: {:.4f}' \ print(
.format(epoch, lr, val_error, test_error)) "Epoch: {}, LR: {}, val_error: {:.4f}, best_test_error: {:.4f}".format(
epoch, lr, val_error, test_error
)
)
import argparse
import torch as th import torch as th
from evaluate_embedding import evaluate_embedding
from model import InfoGraph
import dgl import dgl
from dgl.data import GINDataset from dgl.data import GINDataset
from dgl.dataloading import GraphDataLoader from dgl.dataloading import GraphDataLoader
from model import InfoGraph
from evaluate_embedding import evaluate_embedding
import argparse
def argument(): def argument():
parser = argparse.ArgumentParser(description='InfoGraph') parser = argparse.ArgumentParser(description="InfoGraph")
# data source params # data source params
parser.add_argument('--dataname', type=str, default='MUTAG', help='Name of dataset.') parser.add_argument(
"--dataname", type=str, default="MUTAG", help="Name of dataset."
)
# training params # training params
parser.add_argument('--gpu', type=int, default=-1, help='GPU index, default:-1, using CPU.') parser.add_argument(
parser.add_argument('--epochs', type=int, default=20, help='Training epochs.') "--gpu", type=int, default=-1, help="GPU index, default:-1, using CPU."
parser.add_argument('--batch_size', type=int, default=128, help='Training batch size.') )
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate.') parser.add_argument(
parser.add_argument('--log_interval', type=int, default=1, help='Interval between two evaluations.') "--epochs", type=int, default=20, help="Training epochs."
)
parser.add_argument(
"--batch_size", type=int, default=128, help="Training batch size."
)
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
parser.add_argument(
"--log_interval",
type=int,
default=1,
help="Interval between two evaluations.",
)
# model params # model params
parser.add_argument('--n_layers', type=int, default=3, help='Number of graph convolution layers before each pooling.') parser.add_argument(
parser.add_argument('--hid_dim', type=int, default=32, help='Hidden layer dimensionalities.') "--n_layers",
type=int,
default=3,
help="Number of graph convolution layers before each pooling.",
)
parser.add_argument(
"--hid_dim", type=int, default=32, help="Hidden layer dimensionalities."
)
args = parser.parse_args() args = parser.parse_args()
# check cuda # check cuda
if args.gpu != -1 and th.cuda.is_available(): if args.gpu != -1 and th.cuda.is_available():
args.device = 'cuda:{}'.format(args.gpu) args.device = "cuda:{}".format(args.gpu)
else: else:
args.device = 'cpu' args.device = "cpu"
return args return args
def collate(samples): def collate(samples):
''' collate function for building graph dataloader''' """collate function for building graph dataloader"""
graphs, labels = map(list, zip(*samples)) graphs, labels = map(list, zip(*samples))
# generate batched graphs and labels # generate batched graphs and labels
...@@ -49,35 +69,37 @@ def collate(samples): ...@@ -49,35 +69,37 @@ def collate(samples):
graph_id = th.arange(n_graphs) graph_id = th.arange(n_graphs)
graph_id = dgl.broadcast_nodes(batched_graph, graph_id) graph_id = dgl.broadcast_nodes(batched_graph, graph_id)
batched_graph.ndata['graph_id'] = graph_id batched_graph.ndata["graph_id"] = graph_id
return batched_graph, batched_labels return batched_graph, batched_labels
if __name__ == '__main__': if __name__ == "__main__":
# Step 1: Prepare graph data ===================================== # # Step 1: Prepare graph data ===================================== #
args = argument() args = argument()
print(args) print(args)
# load dataset from dgl.data.GINDataset # load dataset from dgl.data.GINDataset
dataset = GINDataset(args.dataname, self_loop = False) dataset = GINDataset(args.dataname, self_loop=False)
# get graphs and labels # get graphs and labels
graphs, labels = map(list, zip(*dataset)) graphs, labels = map(list, zip(*dataset))
# generate a full-graph with all examples for evaluation # generate a full-graph with all examples for evaluation
wholegraph = dgl.batch(graphs) wholegraph = dgl.batch(graphs)
wholegraph.ndata['attr'] = wholegraph.ndata['attr'].to(th.float32) wholegraph.ndata["attr"] = wholegraph.ndata["attr"].to(th.float32)
# create dataloader for batch training # create dataloader for batch training
dataloader = GraphDataLoader(dataset, dataloader = GraphDataLoader(
batch_size=args.batch_size, dataset,
collate_fn=collate, batch_size=args.batch_size,
drop_last=False, collate_fn=collate,
shuffle=True) drop_last=False,
shuffle=True,
)
in_dim = wholegraph.ndata['attr'].shape[1] in_dim = wholegraph.ndata["attr"].shape[1]
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
model = InfoGraph(in_dim, args.hid_dim, args.n_layers) model = InfoGraph(in_dim, args.hid_dim, args.n_layers)
...@@ -85,19 +107,19 @@ if __name__ == '__main__': ...@@ -85,19 +107,19 @@ if __name__ == '__main__':
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
optimizer = th.optim.Adam(model.parameters(), lr=args.lr) optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
print('===== Before training ======') print("===== Before training ======")
wholegraph = wholegraph.to(args.device) wholegraph = wholegraph.to(args.device)
wholefeat = wholegraph.ndata['attr'] wholefeat = wholegraph.ndata["attr"]
emb = model.get_embedding(wholegraph, wholefeat).cpu() emb = model.get_embedding(wholegraph, wholefeat).cpu()
res = evaluate_embedding(emb, labels, args.device) res = evaluate_embedding(emb, labels, args.device)
''' Evaluate the initialized embeddings ''' """ Evaluate the initialized embeddings """
''' using logistic regression and SVM(non-linear) ''' """ using logistic regression and SVM(non-linear) """
print('logreg {:4f}, svc {:4f}'.format(res[0], res[1])) print("logreg {:4f}, svc {:4f}".format(res[0], res[1]))
best_logreg = 0 best_logreg = 0
best_logreg_epoch = 0 best_logreg_epoch = 0
best_svc = 0 best_svc = 0
...@@ -107,30 +129,30 @@ if __name__ == '__main__': ...@@ -107,30 +129,30 @@ if __name__ == '__main__':
for epoch in range(args.epochs): for epoch in range(args.epochs):
loss_all = 0 loss_all = 0
model.train() model.train()
for graph, label in dataloader: for graph, label in dataloader:
graph = graph.to(args.device) graph = graph.to(args.device)
feat = graph.ndata['attr'] feat = graph.ndata["attr"]
graph_id = graph.ndata['graph_id'] graph_id = graph.ndata["graph_id"]
n_graph = label.shape[0] n_graph = label.shape[0]
optimizer.zero_grad() optimizer.zero_grad()
loss = model(graph, feat, graph_id) loss = model(graph, feat, graph_id)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
loss_all += loss.item() loss_all += loss.item()
print('Epoch {}, Loss {:.4f}'.format(epoch, loss_all)) print("Epoch {}, Loss {:.4f}".format(epoch, loss_all))
if epoch % args.log_interval == 0: if epoch % args.log_interval == 0:
# evaluate embeddings # evaluate embeddings
model.eval() model.eval()
emb = model.get_embedding(wholegraph, wholefeat).cpu() emb = model.get_embedding(wholegraph, wholefeat).cpu()
res = evaluate_embedding(emb, labels, args.device) res = evaluate_embedding(emb, labels, args.device)
if res[0] > best_logreg: if res[0] > best_logreg:
best_logreg = res[0] best_logreg = res[0]
best_logreg_epoch = epoch best_logreg_epoch = epoch
...@@ -139,7 +161,11 @@ if __name__ == '__main__': ...@@ -139,7 +161,11 @@ if __name__ == '__main__':
best_svc = res[1] best_svc = res[1]
best_svc_epoch = epoch best_svc_epoch = epoch
print('best logreg {:4f}, epoch {} | best svc: {:4f}, epoch {}'.format(best_logreg, best_logreg_epoch, best_svc, best_svc_epoch)) print(
"best logreg {:4f}, epoch {} | best svc: {:4f}, epoch {}".format(
best_logreg, best_logreg_epoch, best_svc, best_svc_epoch
)
)
print('Training End') print("Training End")
print('best logreg {:4f} ,best svc {:4f}'.format(best_logreg, best_svc)) print("best logreg {:4f} ,best svc {:4f}".format(best_logreg, best_svc))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment