Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
......@@ -280,7 +280,6 @@ _ = model.to(opt.device)
# Place trainable parameter in list of parameters to train:
if "fc_lr_mul" in vars(opt).keys() and opt.fc_lr_mul != 0:
all_but_fc_params = list(
filter(lambda x: "last_linear" not in x[0], model.named_parameters())
)
......@@ -376,6 +375,8 @@ def same_model(model1, model2):
"""============================================================================"""
#################### TRAINER FUNCTION ############################
def train_one_epoch_finetune(
train_dataloader, model, optimizer, criterion, opt, epoch
......@@ -403,7 +404,6 @@ def train_one_epoch_finetune(
train_dataloader, desc="Epoch {} Training gt labels...".format(epoch)
)
for i, (class_labels, input) in enumerate(data_iterator):
# Compute embeddings for input batch
features = model(input.to(opt.device))
......
......@@ -263,7 +263,6 @@ _ = model.to(opt.device)
# Place trainable parameter in list of parameters to train:
if "fc_lr_mul" in vars(opt).keys() and opt.fc_lr_mul != 0:
all_but_fc_params = list(
filter(lambda x: "last_linear" not in x[0], model.named_parameters())
)
......
......@@ -11,6 +11,8 @@ import torch
from scipy import sparse
"""================================================================================================="""
############ LOSS SELECTION FUNCTION #####################
def loss_select(loss, opt, to_optim):
"""
......
......@@ -281,7 +281,6 @@ _ = model.to(opt.device)
# Place trainable parameter in list of parameters to train:
if "fc_lr_mul" in vars(opt).keys() and opt.fc_lr_mul != 0:
all_but_fc_params = list(
filter(lambda x: "last_linear" not in x[0], model.named_parameters())
)
......
import argparse, time, os, pickle
import argparse, os, pickle, time
import random
import sys
sys.path.append("..")
from utils.deduce import get_edge_dist
import numpy as np
import shutil
import dgl
import numpy as np
import seaborn
import torch
import torch.optim as optim
from models import LANDER
from dataset import LanderDataset
from utils import evaluation, decode, build_next_level, stop_iterating
from matplotlib import pyplot as plt
import seaborn
from models import LANDER
from utils import build_next_level, decode, evaluation, stop_iterating
from utils.deduce import get_edge_dist
STATISTIC = False
......@@ -25,43 +26,47 @@ STATISTIC = False
parser = argparse.ArgumentParser()
# Dataset
parser.add_argument('--data_path', type=str, required=True)
parser.add_argument('--model_filename', type=str, default='lander.pth')
parser.add_argument('--faiss_gpu', action='store_true')
parser.add_argument('--num_workers', type=int, default=0)
parser.add_argument('--output_filename', type=str, default='data/features.pkl')
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--model_filename", type=str, default="lander.pth")
parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--output_filename", type=str, default="data/features.pkl")
# HyperParam
parser.add_argument('--knn_k', type=int, default=10)
parser.add_argument('--levels', type=int, default=1)
parser.add_argument('--tau', type=float, default=0.5)
parser.add_argument('--threshold', type=str, default='prob')
parser.add_argument('--metrics', type=str, default='pairwise,bcubed,nmi')
parser.add_argument('--early_stop', action='store_true')
parser.add_argument("--knn_k", type=int, default=10)
parser.add_argument("--levels", type=int, default=1)
parser.add_argument("--tau", type=float, default=0.5)
parser.add_argument("--threshold", type=str, default="prob")
parser.add_argument("--metrics", type=str, default="pairwise,bcubed,nmi")
parser.add_argument("--early_stop", action="store_true")
# Model
parser.add_argument('--hidden', type=int, default=512)
parser.add_argument('--num_conv', type=int, default=4)
parser.add_argument('--dropout', type=float, default=0.)
parser.add_argument('--gat', action='store_true')
parser.add_argument('--gat_k', type=int, default=1)
parser.add_argument('--balance', action='store_true')
parser.add_argument('--use_cluster_feat', action='store_true')
parser.add_argument('--use_focal_loss', action='store_true')
parser.add_argument('--use_gt', action='store_true')
parser.add_argument("--hidden", type=int, default=512)
parser.add_argument("--num_conv", type=int, default=4)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--gat", action="store_true")
parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument("--balance", action="store_true")
parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument("--use_focal_loss", action="store_true")
parser.add_argument("--use_gt", action="store_true")
# Subgraph
parser.add_argument('--batch_size', type=int, default=4096)
parser.add_argument('--mode', type=str, default="1head")
parser.add_argument('--midpoint', type=str, default="false")
parser.add_argument('--linsize', type=int, default=29011)
parser.add_argument('--uinsize', type=int, default=18403)
parser.add_argument('--inclasses', type=int, default=948)
parser.add_argument('--thresh', type=float, default=1.0)
parser.add_argument('--draw', type=str, default='false')
parser.add_argument('--density_distance_pkl', type=str, default="density_distance.pkl")
parser.add_argument('--density_lindistance_jpg', type=str, default="density_lindistance.jpg")
parser.add_argument("--batch_size", type=int, default=4096)
parser.add_argument("--mode", type=str, default="1head")
parser.add_argument("--midpoint", type=str, default="false")
parser.add_argument("--linsize", type=int, default=29011)
parser.add_argument("--uinsize", type=int, default=18403)
parser.add_argument("--inclasses", type=int, default=948)
parser.add_argument("--thresh", type=float, default=1.0)
parser.add_argument("--draw", type=str, default="false")
parser.add_argument(
"--density_distance_pkl", type=str, default="density_distance.pkl"
)
parser.add_argument(
"--density_lindistance_jpg", type=str, default="density_lindistance.jpg"
)
args = parser.parse_args()
print(args)
......@@ -70,21 +75,21 @@ linsize = args.linsize
uinsize = args.uinsize
inclasses = args.inclasses
if args.draw == 'false':
if args.draw == "false":
args.draw = False
elif args.draw == 'true':
elif args.draw == "true":
args.draw = True
###########################
# Environment Configuration
if torch.cuda.is_available():
device = torch.device('cuda')
device = torch.device("cuda")
else:
device = torch.device('cpu')
device = torch.device("cpu")
##################
# Data Preparation
with open(args.data_path, 'rb') as f:
with open(args.data_path, "rb") as f:
loaded_data = pickle.load(f)
path2idx, features, pred_labels, labels, masks = loaded_data
......@@ -123,11 +128,12 @@ else:
print("filtered features:", len(features))
global_features = features.copy() # global features
dataset = LanderDataset(features=features, labels=labels, k=args.knn_k,
levels=1, faiss_gpu=False)
dataset = LanderDataset(
features=features, labels=labels, k=args.knn_k, levels=1, faiss_gpu=False
)
g = dataset.gs[0]
g.ndata['pred_den'] = torch.zeros((g.number_of_nodes()))
g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2))
g.ndata["pred_den"] = torch.zeros((g.number_of_nodes()))
g.edata["prob_conn"] = torch.zeros((g.number_of_edges(), 2))
global_labels = labels.copy()
ids = np.arange(g.number_of_nodes())
global_edges = ([], [])
......@@ -135,7 +141,7 @@ global_peaks = np.array([], dtype=np.long)
global_edges_len = len(global_edges[0])
global_num_nodes = g.number_of_nodes()
global_densities = g.ndata['density'][:linsize]
global_densities = g.ndata["density"][:linsize]
global_densities = np.sort(global_densities)
xs = np.arange(len(global_densities))
......@@ -143,23 +149,30 @@ fanouts = [args.knn_k - 1 for i in range(args.num_conv + 1)]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
# fix the number of edges
test_loader = dgl.dataloading.DataLoader(
g, torch.arange(g.number_of_nodes()), sampler,
g,
torch.arange(g.number_of_nodes()),
sampler,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
num_workers=args.num_workers
num_workers=args.num_workers,
)
##################
# Model Definition
if not args.use_gt:
feature_dim = g.ndata['features'].shape[1]
model = LANDER(feature_dim=feature_dim, nhid=args.hidden,
num_conv=args.num_conv, dropout=args.dropout,
use_GAT=args.gat, K=args.gat_k,
balance=args.balance,
use_cluster_feat=args.use_cluster_feat,
use_focal_loss=args.use_focal_loss)
feature_dim = g.ndata["features"].shape[1]
model = LANDER(
feature_dim=feature_dim,
nhid=args.hidden,
num_conv=args.num_conv,
dropout=args.dropout,
use_GAT=args.gat,
K=args.gat_k,
balance=args.balance,
use_cluster_feat=args.use_cluster_feat,
use_focal_loss=args.use_focal_loss,
)
model.load_state_dict(torch.load(args.model_filename))
model = model.to(device)
model.eval()
......@@ -179,46 +192,82 @@ for level in range(args.levels):
with torch.no_grad():
output_bipartite = model(bipartites)
global_nid = output_bipartite.dstdata[dgl.NID]
global_eid = output_bipartite.edata['global_eid']
g.ndata['pred_den'][global_nid] = output_bipartite.dstdata['pred_den'].to('cpu')
g.edata['prob_conn'][global_eid] = output_bipartite.edata['prob_conn'].to('cpu')
global_eid = output_bipartite.edata["global_eid"]
g.ndata["pred_den"][global_nid] = output_bipartite.dstdata[
"pred_den"
].to("cpu")
g.edata["prob_conn"][global_eid] = output_bipartite.edata[
"prob_conn"
].to("cpu")
torch.cuda.empty_cache()
if (batch + 1) % 10 == 0:
print('Batch %d / %d for inference' % (batch, total_batches))
new_pred_labels, peaks, \
global_edges, global_pred_labels, global_peaks = decode(g, args.tau, args.threshold, args.use_gt,
ids, global_edges, global_num_nodes,
global_peaks)
print("Batch %d / %d for inference" % (batch, total_batches))
(
new_pred_labels,
peaks,
global_edges,
global_pred_labels,
global_peaks,
) = decode(
g,
args.tau,
args.threshold,
args.use_gt,
ids,
global_edges,
global_num_nodes,
global_peaks,
)
if level == 0:
global_pred_densities = g.ndata['pred_den']
global_densities = g.ndata['density']
g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2))
global_pred_densities = g.ndata["pred_den"]
global_densities = g.ndata["density"]
g.edata["prob_conn"] = torch.zeros((g.number_of_edges(), 2))
ids = ids[peaks]
new_global_edges_len = len(global_edges[0])
num_edges_add_this_level = new_global_edges_len - global_edges_len
if stop_iterating(level, args.levels, args.early_stop, num_edges_add_this_level, num_edges_add_last_level,
args.knn_k):
if stop_iterating(
level,
args.levels,
args.early_stop,
num_edges_add_this_level,
num_edges_add_last_level,
args.knn_k,
):
break
global_edges_len = new_global_edges_len
num_edges_add_last_level = num_edges_add_this_level
# build new dataset
features, labels, cluster_features = build_next_level(features, labels, peaks,
global_features, global_pred_labels, global_peaks)
features, labels, cluster_features = build_next_level(
features,
labels,
peaks,
global_features,
global_pred_labels,
global_peaks,
)
# After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.
dataset = LanderDataset(features=features, labels=labels, k=args.knn_k,
levels=1, faiss_gpu=False, cluster_features=cluster_features)
dataset = LanderDataset(
features=features,
labels=labels,
k=args.knn_k,
levels=1,
faiss_gpu=False,
cluster_features=cluster_features,
)
g = dataset.gs[0]
g.ndata['pred_den'] = torch.zeros((g.number_of_nodes()))
g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2))
g.ndata["pred_den"] = torch.zeros((g.number_of_nodes()))
g.edata["prob_conn"] = torch.zeros((g.number_of_edges(), 2))
test_loader = dgl.dataloading.DataLoader(
g, torch.arange(g.number_of_nodes()), sampler,
g,
torch.arange(g.number_of_nodes()),
sampler,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
num_workers=args.num_workers
num_workers=args.num_workers,
)
if MODE == "selectbydensity":
......@@ -261,8 +310,10 @@ if MODE == "selectbydensity":
idx = np.where(l_in_gt_new == i)
prototypes[i] = np.mean(l_in_features[idx], axis=0)
similarity_matrix = torch.mm(torch.from_numpy(global_features.astype(np.float32)),
torch.from_numpy(prototypes.astype(np.float32)).t())
similarity_matrix = torch.mm(
torch.from_numpy(global_features.astype(np.float32)),
torch.from_numpy(prototypes.astype(np.float32)).t(),
)
similarity_matrix = (1 - similarity_matrix) / 2
minvalues, selected_pred_labels = torch.min(similarity_matrix, 1)
# far-close ratio
......@@ -274,7 +325,7 @@ if MODE == "selectbydensity":
cutidx = np.where(global_pred_densities >= 0.5)
draw_minvalues = minvalues[cutidx]
draw_densities = global_pred_densities[cutidx]
with open(args.density_distance_pkl, 'wb') as f:
with open(args.density_distance_pkl, "wb") as f:
pickle.dump((global_pred_densities, minvalues), f)
print("dumped.")
plt.clf()
......@@ -283,15 +334,29 @@ if MODE == "selectbydensity":
if len(draw_densities) > 10000:
samples_idx = random.sample(range(len(draw_minvalues)), 10000)
ax.plot(draw_densities[random], draw_minvalues[random], color='tab:blue', marker='*', linestyle="None",
markersize=1)
ax.plot(
draw_densities[random],
draw_minvalues[random],
color="tab:blue",
marker="*",
linestyle="None",
markersize=1,
)
else:
ax.plot(draw_densities[random], draw_minvalues[random], color='tab:blue', marker='*', linestyle="None",
markersize=1)
ax.plot(
draw_densities[random],
draw_minvalues[random],
color="tab:blue",
marker="*",
linestyle="None",
markersize=1,
)
plt.savefig(args.density_lindistance_jpg)
global_pred_labels_new[Tidx] = l_in_gt_new
global_pred_labels[selectidx] = global_pred_labels[selectidx] + len(l_in_unique)
global_pred_labels[selectidx] = global_pred_labels[selectidx] + len(
l_in_unique
)
global_pred_labels_new[selectedidx] = global_pred_labels
global_pred_labels = global_pred_labels_new
......@@ -332,7 +397,9 @@ if MODE == "recluster":
global_pred_labels_new[Tidx] = l_in_gt_new
print(len(global_pred_labels))
print(len(selectedidx[0]))
global_pred_labels_new[selectedidx[0]] = global_pred_labels + len(l_in_unique)
global_pred_labels_new[selectedidx[0]] = global_pred_labels + len(
l_in_unique
)
global_pred_labels = global_pred_labels_new
global_masks = masks
print("mask0", len(np.where(global_masks == 0)[0]))
......@@ -348,23 +415,29 @@ if MODE == "donothing":
print("##################### L_in ########################")
print(linsize)
if len(global_pred_labels) >= linsize:
evaluation(global_pred_labels[:linsize], global_gt_labels[:linsize], args.metrics)
evaluation(
global_pred_labels[:linsize], global_gt_labels[:linsize], args.metrics
)
else:
print("No samples in L_in!")
print("##################### U_in ########################")
uinidx = np.where(global_pred_labels[linsize:linsize + uinsize] != -1)[0]
uinidx = np.where(global_pred_labels[linsize : linsize + uinsize] != -1)[0]
uinidx = uinidx + linsize
print(len(uinidx))
if len(uinidx):
evaluation(global_pred_labels[uinidx], global_gt_labels[uinidx], args.metrics)
evaluation(
global_pred_labels[uinidx], global_gt_labels[uinidx], args.metrics
)
else:
print("No samples in U_in!")
print("##################### U_out ########################")
uoutidx = np.where(global_pred_labels[linsize + uinsize:] != -1)[0]
uoutidx = np.where(global_pred_labels[linsize + uinsize :] != -1)[0]
uoutidx = uoutidx + linsize + uinsize
print(len(uoutidx))
if len(uoutidx):
evaluation(global_pred_labels[uoutidx], global_gt_labels[uoutidx], args.metrics)
evaluation(
global_pred_labels[uoutidx], global_gt_labels[uoutidx], args.metrics
)
else:
print("No samples in U_out!")
print("##################### U ########################")
......@@ -390,9 +463,18 @@ print(len(nsidx))
if len(nsidx) != 0:
evaluation(global_pred_labels[nsidx], global_gt_labels[nsidx], args.metrics)
with open(args.output_filename, 'wb') as f:
with open(args.output_filename, "wb") as f:
print(orifeatures.shape)
print(global_pred_labels.shape)
print(global_gt_labels.shape)
print(global_masks.shape)
pickle.dump([path2idx, orifeatures, global_pred_labels, global_gt_labels, global_masks], f)
pickle.dump(
[
path2idx,
orifeatures,
global_pred_labels,
global_gt_labels,
global_masks,
],
f,
)
import argparse, time, os, pickle
import argparse, os, pickle, time
import random
import numpy as np
import sys
import dgl
import numpy as np
import torch
import torch.optim as optim
import sys
sys.path.append("..")
from models import LANDER
from dataset import LanderDataset
from models import LANDER
###########
# ArgParser
parser = argparse.ArgumentParser()
# Dataset
parser.add_argument('--data_path', type=str, required=True)
parser.add_argument('--levels', type=str, default='1')
parser.add_argument('--faiss_gpu', action='store_true')
parser.add_argument('--model_filename', type=str, default='lander.pth')
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--levels", type=str, default="1")
parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument("--model_filename", type=str, default="lander.pth")
# KNN
parser.add_argument('--knn_k', type=str, default='10')
parser.add_argument('--num_workers', type=int, default=0)
parser.add_argument("--knn_k", type=str, default="10")
parser.add_argument("--num_workers", type=int, default=0)
# Model
parser.add_argument('--hidden', type=int, default=512)
parser.add_argument('--num_conv', type=int, default=1)
parser.add_argument('--dropout', type=float, default=0.)
parser.add_argument('--gat', action='store_true')
parser.add_argument('--gat_k', type=int, default=1)
parser.add_argument('--balance', action='store_true')
parser.add_argument('--use_cluster_feat', action='store_true')
parser.add_argument('--use_focal_loss', action='store_true')
parser.add_argument("--hidden", type=int, default=512)
parser.add_argument("--num_conv", type=int, default=1)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--gat", action="store_true")
parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument("--balance", action="store_true")
parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument("--use_focal_loss", action="store_true")
# Training
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight_decay', type=float, default=1e-5)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=1024)
parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--weight_decay", type=float, default=1e-5)
args = parser.parse_args()
print(args)
......@@ -49,9 +50,9 @@ print(args)
###########################
# Environment Configuration
if torch.cuda.is_available():
device = torch.device('cuda')
device = torch.device("cuda")
else:
device = torch.device('cpu')
device = torch.device("cpu")
def setup_seed(seed):
......@@ -66,7 +67,7 @@ def setup_seed(seed):
##################
# Data Preparation
with open(args.data_path, 'rb') as f:
with open(args.data_path, "rb") as f:
path2idx, features, labels, _, masks = pickle.load(f)
# lidx = np.where(masks==0)
# features = features[lidx]
......@@ -75,8 +76,8 @@ with open(args.data_path, 'rb') as f:
print("labels.shape:", labels.shape)
k_list = [int(k) for k in args.knn_k.split(',')]
lvl_list = [int(l) for l in args.levels.split(',')]
k_list = [int(k) for k in args.knn_k.split(",")]
lvl_list = [int(l) for l in args.levels.split(",")]
gs = []
nbrs = []
ks = []
......@@ -84,8 +85,13 @@ datasets = []
for k, l in zip(k_list, lvl_list):
print("k:", k)
print("levels:", l)
dataset = LanderDataset(features=features, labels=labels, k=k,
levels=l, faiss_gpu=args.faiss_gpu)
dataset = LanderDataset(
features=features,
labels=labels,
k=k,
levels=l,
faiss_gpu=args.faiss_gpu,
)
gs += [g for g in dataset.gs]
ks += [k for g in dataset.gs]
nbrs += [nbr for nbr in dataset.nbrs]
......@@ -101,24 +107,28 @@ for k, l in zip(k_list, lvl_list):
# nbrs += [nbr for nbr in dataset.nbrs]
with open("./dataset.pkl", 'wb') as f:
with open("./dataset.pkl", "wb") as f:
pickle.dump(datasets, f)
print('Dataset Prepared.')
print("Dataset Prepared.")
def set_train_sampler_loader(g, k):
fanouts = [k-1 for i in range(args.num_conv + 1)]
fanouts = [k - 1 for i in range(args.num_conv + 1)]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
# fix the number of edges
train_dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.number_of_nodes()), sampler,
g,
torch.arange(g.number_of_nodes()),
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers
num_workers=args.num_workers,
)
return train_dataloader
train_loaders = []
for gidx, g in enumerate(gs):
train_dataloader = set_train_sampler_loader(gs[gidx], ks[gidx])
......@@ -126,31 +136,40 @@ for gidx, g in enumerate(gs):
##################
# Model Definition
feature_dim = gs[0].ndata['features'].shape[1]
feature_dim = gs[0].ndata["features"].shape[1]
print("feature dimension:", feature_dim)
model = LANDER(feature_dim=feature_dim, nhid=args.hidden,
num_conv=args.num_conv, dropout=args.dropout,
use_GAT=args.gat, K=args.gat_k,
balance=args.balance,
use_cluster_feat=args.use_cluster_feat,
use_focal_loss=args.use_focal_loss)
model = LANDER(
feature_dim=feature_dim,
nhid=args.hidden,
num_conv=args.num_conv,
dropout=args.dropout,
use_GAT=args.gat,
K=args.gat_k,
balance=args.balance,
use_cluster_feat=args.use_cluster_feat,
use_focal_loss=args.use_focal_loss,
)
model = model.to(device)
model.train()
#################
# Hyperparameters
opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay)
opt = optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
# keep num_batch_per_loader the same for every sub_dataloader
num_batch_per_loader = len(train_loaders[0])
train_loaders = [iter(train_loader) for train_loader in train_loaders]
num_loaders = len(train_loaders)
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt,
T_max=args.epochs * num_batch_per_loader * num_loaders,
eta_min=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=args.epochs * num_batch_per_loader * num_loaders, eta_min=1e-5
)
print('Start Training.')
print("Start Training.")
###############
# Training Loop
......@@ -163,7 +182,9 @@ for epoch in range(args.epochs):
try:
minibatch = next(train_loaders[loader_id])
except:
train_loaders[loader_id] = iter(set_train_sampler_loader(gs[loader_id], ks[loader_id]))
train_loaders[loader_id] = iter(
set_train_sampler_loader(gs[loader_id], ks[loader_id])
)
minibatch = next(train_loaders[loader_id])
input_nodes, sub_g, bipartites = minibatch
sub_g = sub_g.to(device)
......@@ -171,20 +192,38 @@ for epoch in range(args.epochs):
# get the feature for the input_nodes
opt.zero_grad()
output_bipartite = model(bipartites)
loss, loss_den_val, loss_conn_val = model.compute_loss(output_bipartite)
loss, loss_den_val, loss_conn_val = model.compute_loss(
output_bipartite
)
loss_den_val_total.append(loss_den_val)
loss_conn_val_total.append(loss_conn_val)
loss_val_total.append(loss.item())
loss.backward()
opt.step()
if (batch + 1) % 10 == 0:
print('epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f'%
(epoch, batch, num_batch_per_loader, loader_id, num_loaders,
loss.item(), loss_den_val, loss_conn_val))
print(
"epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
% (
epoch,
batch,
num_batch_per_loader,
loader_id,
num_loaders,
loss.item(),
loss_den_val,
loss_conn_val,
)
)
scheduler.step()
print('epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f'%
(epoch, np.array(loss_val_total).mean(),
np.array(loss_den_val_total).mean(), np.array(loss_conn_val_total).mean()))
print(
"epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
% (
epoch,
np.array(loss_val_total).mean(),
np.array(loss_den_val_total).mean(),
np.array(loss_conn_val_total).mean(),
)
)
torch.save(model.state_dict(), args.model_filename)
torch.save(model.state_dict(), args.model_filename)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import dgl.function as fn
from dgl.nn.pytorch import GATConv
from torch.nn import init
class GraphConvLayer(nn.Module):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import dgl
import dgl.function as fn
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from .focal_loss import FocalLoss
from .graphconv import GraphConv
......
......@@ -3,6 +3,8 @@ import os
import pickle
import time
import dgl
import numpy as np
import torch
import torch.optim as optim
......@@ -10,8 +12,6 @@ from dataset import LanderDataset
from models import LANDER
from utils import build_next_level, decode, evaluation, stop_iterating
import dgl
###########
# ArgParser
parser = argparse.ArgumentParser()
......
......@@ -3,6 +3,8 @@ import os
import pickle
import time
import dgl
import numpy as np
import torch
import torch.optim as optim
......@@ -10,8 +12,6 @@ from dataset import LanderDataset
from models import LANDER
from utils import build_next_level, decode, evaluation, stop_iterating
import dgl
###########
# ArgParser
parser = argparse.ArgumentParser()
......
......@@ -3,14 +3,14 @@ import os
import pickle
import time
import dgl
import numpy as np
import torch
import torch.optim as optim
from dataset import LanderDataset
from models import LANDER
import dgl
###########
# ArgParser
parser = argparse.ArgumentParser()
......@@ -50,6 +50,7 @@ if torch.cuda.is_available():
else:
device = torch.device("cpu")
##################
# Data Preparation
def prepare_dataset_graphs(data_path, k_list, lvl_list):
......
......@@ -3,14 +3,14 @@ import os
import pickle
import time
import dgl
import numpy as np
import torch
import torch.optim as optim
from dataset import LanderDataset
from models import LANDER
import dgl
###########
# ArgParser
parser = argparse.ArgumentParser()
......
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""
import dgl
import numpy as np
import torch
from sklearn import mixture
import dgl
from .density import density_to_peaks, density_to_peaks_vectorize
__all__ = [
......
......@@ -6,7 +6,7 @@ import inspect
import numpy as np
from clustering_benchmark import ClusteringBenchmark
from utils import TextColors, Timer, metrics
from utils import metrics, TextColors, Timer
def _read_meta(fn):
......
......@@ -101,7 +101,6 @@ def faiss_search_knn(
sort=True,
verbose=False,
):
dists, nbrs = faiss_search_approx_knn(
query=feat, target=feat, k=k, nprobe=nprobe, verbose=verbose
)
......
......@@ -70,7 +70,6 @@ def svc_classify(x, y, search):
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)
accuracies = []
for train_index, test_index in kf.split(x, y):
x_train, x_test = x[train_index], x[test_index]
y_train, y_test = y[train_index], y[test_index]
......
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import GRU, BatchNorm1d, Linear, ModuleList, ReLU, Sequential
from utils import global_global_loss_, local_global_loss_
from dgl.nn import GINConv, NNConv, Set2Set
from dgl.nn.pytorch.glob import SumPooling
from torch.nn import BatchNorm1d, GRU, Linear, ModuleList, ReLU, Sequential
from utils import global_global_loss_, local_global_loss_
""" Feedforward neural network"""
......@@ -102,7 +102,6 @@ class GINEncoder(nn.Module):
self.pool = SumPooling()
def forward(self, graph, feat):
xs = []
x = feat
for i in range(self.n_layer):
......@@ -163,7 +162,6 @@ class InfoGraph(nn.Module):
return global_emb
def forward(self, graph, feat, graph_id):
global_emb, local_emb = self.encoder(graph, feat)
global_h = self.global_d(global_emb) # global hidden representation
......@@ -221,7 +219,6 @@ class NNConvEncoder(nn.Module):
self.set2set = Set2Set(hid_dim, n_iters=3, n_layers=1)
def forward(self, graph, nfeat, efeat):
out = F.relu(self.lin0(nfeat))
h = out.unsqueeze(0)
......@@ -279,7 +276,6 @@ class InfoGraphS(nn.Module):
self.unsup_d = FeedforwardNetwork(2 * hid_dim, hid_dim)
def forward(self, graph, nfeat, efeat):
sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat)
sup_global_pred = self.fc2(F.relu(self.fc1(sup_global_emb)))
......@@ -288,7 +284,6 @@ class InfoGraphS(nn.Module):
return sup_global_pred
def unsup_forward(self, graph, nfeat, efeat, graph_id):
sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat)
unsup_global_emb, unsup_local_emb = self.unsup_encoder(
graph, nfeat, efeat
......
import argparse
import dgl
import numpy as np
import torch as th
import torch.nn.functional as F
from model import InfoGraphS
import dgl
from dgl.data import QM9EdgeDataset
from dgl.data.utils import Subset
from dgl.dataloading import GraphDataLoader
from model import InfoGraphS
def argument():
......@@ -160,7 +160,6 @@ def evaluate(model, loader, num, device):
if __name__ == "__main__":
# Step 1: Prepare graph data ===================================== #
args = argument()
label_keys = [args.target]
......
import argparse
import torch as th
from evaluate_embedding import evaluate_embedding
from model import InfoGraph
import dgl
import torch as th
from dgl.data import GINDataset
from dgl.dataloading import GraphDataLoader
from evaluate_embedding import evaluate_embedding
from model import InfoGraph
def argument():
......@@ -75,7 +75,6 @@ def collate(samples):
if __name__ == "__main__":
# Step 1: Prepare graph data ===================================== #
args = argument()
print(args)
......@@ -131,7 +130,6 @@ if __name__ == "__main__":
model.train()
for graph, label in dataloader:
graph = graph.to(args.device)
feat = graph.ndata["attr"]
graph_id = graph.ndata["graph_id"]
......@@ -147,7 +145,6 @@ if __name__ == "__main__":
print("Epoch {}, Loss {:.4f}".format(epoch, loss_all))
if epoch % args.log_interval == 0:
# evaluate embeddings
model.eval()
emb = model.get_embedding(wholegraph, wholefeat).cpu()
......
......@@ -41,7 +41,6 @@ def get_negative_expectation(q_samples, average=True):
def local_global_loss_(l_enc, g_enc, graph_id):
num_graphs = g_enc.shape[0]
num_nodes = l_enc.shape[0]
......@@ -51,7 +50,6 @@ def local_global_loss_(l_enc, g_enc, graph_id):
neg_mask = th.ones((num_nodes, num_graphs)).to(device)
for nodeidx, graphidx in enumerate(graph_id):
pos_mask[nodeidx][graphidx] = 1.0
neg_mask[nodeidx][graphidx] = 0.0
......@@ -66,7 +64,6 @@ def local_global_loss_(l_enc, g_enc, graph_id):
def global_global_loss_(sup_enc, unsup_enc):
num_graphs = sup_enc.shape[0]
device = sup_enc.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment