Unverified Commit f19f05ce authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4651)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 977b1ba4
...@@ -4,55 +4,73 @@ import logging ...@@ -4,55 +4,73 @@ import logging
import os import os
from time import time from time import time
import dgl
import torch import torch
import torch.nn import torch.nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from torch.utils.data import random_split
from networks import HGPSLModel from networks import HGPSLModel
from torch.utils.data import random_split
from utils import get_stats from utils import get_stats
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="HGP-SL-DGL") parser = argparse.ArgumentParser(description="HGP-SL-DGL")
parser.add_argument("--dataset", type=str, default="DD", parser.add_argument(
choices=["DD", "PROTEINS", "NCI1", "NCI109", "Mutagenicity", "ENZYMES"], "--dataset",
help="DD/PROTEINS/NCI1/NCI109/Mutagenicity/ENZYMES") type=str,
parser.add_argument("--batch_size", type=int, default=512, default="DD",
help="batch size") choices=["DD", "PROTEINS", "NCI1", "NCI109", "Mutagenicity", "ENZYMES"],
parser.add_argument("--sample", type=str, default="true", help="DD/PROTEINS/NCI1/NCI109/Mutagenicity/ENZYMES",
help="use sample method") )
parser.add_argument("--lr", type=float, default=1e-3, parser.add_argument(
help="learning rate") "--batch_size", type=int, default=512, help="batch size"
parser.add_argument("--weight_decay", type=float, default=1e-3, )
help="weight decay") parser.add_argument(
parser.add_argument("--pool_ratio", type=float, default=0.5, "--sample", type=str, default="true", help="use sample method"
help="pooling ratio") )
parser.add_argument("--hid_dim", type=int, default=128, parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
help="hidden size") parser.add_argument(
parser.add_argument("--conv_layers", type=int, default=3, "--weight_decay", type=float, default=1e-3, help="weight decay"
help="number of conv layers") )
parser.add_argument("--dropout", type=float, default=0.0, parser.add_argument(
help="dropout ratio") "--pool_ratio", type=float, default=0.5, help="pooling ratio"
parser.add_argument("--lamb", type=float, default=1.0, )
help="trade-off parameter") parser.add_argument("--hid_dim", type=int, default=128, help="hidden size")
parser.add_argument("--epochs", type=int, default=1000, parser.add_argument(
help="max number of training epochs") "--conv_layers", type=int, default=3, help="number of conv layers"
parser.add_argument("--patience", type=int, default=100, )
help="patience for early stopping") parser.add_argument(
parser.add_argument("--device", type=int, default=-1, "--dropout", type=float, default=0.0, help="dropout ratio"
help="device id, -1 for cpu") )
parser.add_argument("--dataset_path", type=str, default="./dataset", parser.add_argument(
help="path to dataset") "--lamb", type=float, default=1.0, help="trade-off parameter"
parser.add_argument("--print_every", type=int, default=10, )
help="print trainlog every k epochs, -1 for silent training") parser.add_argument(
parser.add_argument("--num_trials", type=int, default=1, "--epochs", type=int, default=1000, help="max number of training epochs"
help="number of trials") )
parser.add_argument(
"--patience", type=int, default=100, help="patience for early stopping"
)
parser.add_argument(
"--device", type=int, default=-1, help="device id, -1 for cpu"
)
parser.add_argument(
"--dataset_path", type=str, default="./dataset", help="path to dataset"
)
parser.add_argument(
"--print_every",
type=int,
default=10,
help="print trainlog every k epochs, -1 for silent training",
)
parser.add_argument(
"--num_trials", type=int, default=1, help="number of trials"
)
parser.add_argument("--output_path", type=str, default="./output") parser.add_argument("--output_path", type=str, default="./output")
args = parser.parse_args() args = parser.parse_args()
# device # device
...@@ -76,16 +94,24 @@ def parse_args(): ...@@ -76,16 +94,24 @@ def parse_args():
os.makedirs(args.dataset_path) os.makedirs(args.dataset_path)
if not os.path.exists(args.output_path): if not os.path.exists(args.output_path):
os.makedirs(args.output_path) os.makedirs(args.output_path)
name = "Data={}_Hidden={}_Pool={}_WeightDecay={}_Lr={}_Sample={}.log".format( name = (
args.dataset, args.hid_dim, args.pool_ratio, args.weight_decay, args.lr, args.sample) "Data={}_Hidden={}_Pool={}_WeightDecay={}_Lr={}_Sample={}.log".format(
args.dataset,
args.hid_dim,
args.pool_ratio,
args.weight_decay,
args.lr,
args.sample,
)
)
args.output_path = os.path.join(args.output_path, name) args.output_path = os.path.join(args.output_path, name)
return args return args
def train(model:torch.nn.Module, optimizer, trainloader, device): def train(model: torch.nn.Module, optimizer, trainloader, device):
model.train() model.train()
total_loss = 0. total_loss = 0.0
num_batches = len(trainloader) num_batches = len(trainloader)
for batch in trainloader: for batch in trainloader:
optimizer.zero_grad() optimizer.zero_grad()
...@@ -98,15 +124,15 @@ def train(model:torch.nn.Module, optimizer, trainloader, device): ...@@ -98,15 +124,15 @@ def train(model:torch.nn.Module, optimizer, trainloader, device):
optimizer.step() optimizer.step()
total_loss += loss.item() total_loss += loss.item()
return total_loss / num_batches return total_loss / num_batches
@torch.no_grad() @torch.no_grad()
def test(model:torch.nn.Module, loader, device): def test(model: torch.nn.Module, loader, device):
model.eval() model.eval()
correct = 0. correct = 0.0
loss = 0. loss = 0.0
num_graphs = 0 num_graphs = 0
for batch in loader: for batch in loader:
batch_graphs, batch_labels = batch batch_graphs, batch_labels = batch
...@@ -132,30 +158,47 @@ def main(args): ...@@ -132,30 +158,47 @@ def main(args):
num_training = int(len(dataset) * 0.8) num_training = int(len(dataset) * 0.8)
num_val = int(len(dataset) * 0.1) num_val = int(len(dataset) * 0.1)
num_test = len(dataset) - num_val - num_training num_test = len(dataset) - num_val - num_training
train_set, val_set, test_set = random_split(dataset, [num_training, num_val, num_test]) train_set, val_set, test_set = random_split(
dataset, [num_training, num_val, num_test]
train_loader = GraphDataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=6) )
val_loader = GraphDataLoader(val_set, batch_size=args.batch_size, num_workers=2)
test_loader = GraphDataLoader(test_set, batch_size=args.batch_size, num_workers=2) train_loader = GraphDataLoader(
train_set, batch_size=args.batch_size, shuffle=True, num_workers=6
)
val_loader = GraphDataLoader(
val_set, batch_size=args.batch_size, num_workers=2
)
test_loader = GraphDataLoader(
test_set, batch_size=args.batch_size, num_workers=2
)
device = torch.device(args.device) device = torch.device(args.device)
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
num_feature, num_classes, _ = dataset.statistics() num_feature, num_classes, _ = dataset.statistics()
model = HGPSLModel(in_feat=num_feature, out_feat=num_classes, hid_feat=args.hid_dim, model = HGPSLModel(
conv_layers=args.conv_layers, dropout=args.dropout, pool_ratio=args.pool_ratio, in_feat=num_feature,
lamb=args.lamb, sample=args.sample).to(device) out_feat=num_classes,
hid_feat=args.hid_dim,
conv_layers=args.conv_layers,
dropout=args.dropout,
pool_ratio=args.pool_ratio,
lamb=args.lamb,
sample=args.sample,
).to(device)
args.num_feature = int(num_feature) args.num_feature = int(num_feature)
args.num_classes = int(num_classes) args.num_classes = int(num_classes)
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
# Step 4: training epoches =============================================================== # # Step 4: training epoches =============================================================== #
bad_cound = 0 bad_cound = 0
best_val_loss = float("inf") best_val_loss = float("inf")
final_test_acc = 0. final_test_acc = 0.0
best_epoch = 0 best_epoch = 0
train_times = [] train_times = []
for e in range(args.epochs): for e in range(args.epochs):
...@@ -173,11 +216,17 @@ def main(args): ...@@ -173,11 +216,17 @@ def main(args):
bad_cound += 1 bad_cound += 1
if bad_cound >= args.patience: if bad_cound >= args.patience:
break break
if (e + 1) % args.print_every == 0: if (e + 1) % args.print_every == 0:
log_format = "Epoch {}: loss={:.4f}, val_acc={:.4f}, final_test_acc={:.4f}" log_format = (
"Epoch {}: loss={:.4f}, val_acc={:.4f}, final_test_acc={:.4f}"
)
print(log_format.format(e + 1, train_loss, val_acc, final_test_acc)) print(log_format.format(e + 1, train_loss, val_acc, final_test_acc))
print("Best Epoch {}, final test acc {:.4f}".format(best_epoch, final_test_acc)) print(
"Best Epoch {}, final test acc {:.4f}".format(
best_epoch, final_test_acc
)
)
return final_test_acc, sum(train_times) / len(train_times) return final_test_acc, sum(train_times) / len(train_times)
...@@ -194,9 +243,11 @@ if __name__ == "__main__": ...@@ -194,9 +243,11 @@ if __name__ == "__main__":
mean, err_bd = get_stats(res, conf_interval=False) mean, err_bd = get_stats(res, conf_interval=False)
print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd)) print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd))
out_dict = {"hyper-parameters": vars(args), out_dict = {
"result": "{:.4f}(+-{:.4f})".format(mean, err_bd), "hyper-parameters": vars(args),
"train_time": "{:.4f}".format(sum(train_times) / len(train_times))} "result": "{:.4f}(+-{:.4f})".format(mean, err_bd),
"train_time": "{:.4f}".format(sum(train_times) / len(train_times)),
}
with open(args.output_path, "w") as f: with open(args.output_path, "w") as f:
json.dump(out_dict, f, sort_keys=True, indent=4) json.dump(out_dict, f, sort_keys=True, indent=4)
import torch import torch
from dgl.nn import AvgPooling, MaxPooling
import torch.nn.functional as F
import torch.nn import torch.nn
import torch.nn.functional as F
from layers import ConvPoolReadout from layers import ConvPoolReadout
from dgl.nn import AvgPooling, MaxPooling
class HGPSLModel(torch.nn.Module): class HGPSLModel(torch.nn.Module):
r""" r"""
...@@ -27,7 +28,7 @@ class HGPSLModel(torch.nn.Module): ...@@ -27,7 +28,7 @@ class HGPSLModel(torch.nn.Module):
conv_layers : int, optional conv_layers : int, optional
The number of graph convolution and pooling layers. Default: 3 The number of graph convolution and pooling layers. Default: 3
sample : bool, optional sample : bool, optional
Whether use k-hop union graph to increase efficiency. Whether use k-hop union graph to increase efficiency.
Currently we only support full graph. Default: :obj:`False` Currently we only support full graph. Default: :obj:`False`
sparse : bool, optional sparse : bool, optional
Use edge sparsemax instead of edge softmax. Default: :obj:`True` Use edge sparsemax instead of edge softmax. Default: :obj:`True`
...@@ -37,10 +38,20 @@ class HGPSLModel(torch.nn.Module): ...@@ -37,10 +38,20 @@ class HGPSLModel(torch.nn.Module):
The lambda parameter as weight of raw adjacency as described in the The lambda parameter as weight of raw adjacency as described in the
HGP-SL paper. Default: 1.0 HGP-SL paper. Default: 1.0
""" """
def __init__(self, in_feat:int, out_feat:int, hid_feat:int,
dropout:float=0., pool_ratio:float=.5, conv_layers:int=3, def __init__(
sample:bool=False, sparse:bool=True, sl:bool=True, self,
lamb:float=1.): in_feat: int,
out_feat: int,
hid_feat: int,
dropout: float = 0.0,
pool_ratio: float = 0.5,
conv_layers: int = 3,
sample: bool = False,
sparse: bool = True,
sl: bool = True,
lamb: float = 1.0,
):
super(HGPSLModel, self).__init__() super(HGPSLModel, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
...@@ -53,10 +64,19 @@ class HGPSLModel(torch.nn.Module): ...@@ -53,10 +64,19 @@ class HGPSLModel(torch.nn.Module):
for i in range(conv_layers): for i in range(conv_layers):
c_in = in_feat if i == 0 else hid_feat c_in = in_feat if i == 0 else hid_feat
c_out = hid_feat c_out = hid_feat
use_pool = (i != conv_layers - 1) use_pool = i != conv_layers - 1
convpools.append(ConvPoolReadout(c_in, c_out, pool_ratio=pool_ratio, convpools.append(
sample=sample, sparse=sparse, sl=sl, ConvPoolReadout(
lamb=lamb, pool=use_pool)) c_in,
c_out,
pool_ratio=pool_ratio,
sample=sample,
sparse=sparse,
sl=sl,
lamb=lamb,
pool=use_pool,
)
)
self.convpool_layers = torch.nn.ModuleList(convpools) self.convpool_layers = torch.nn.ModuleList(convpools)
self.lin1 = torch.nn.Linear(hid_feat * 2, hid_feat) self.lin1 = torch.nn.Linear(hid_feat * 2, hid_feat)
...@@ -68,12 +88,14 @@ class HGPSLModel(torch.nn.Module): ...@@ -68,12 +88,14 @@ class HGPSLModel(torch.nn.Module):
e_feat = None e_feat = None
for i in range(self.num_layers): for i in range(self.num_layers):
graph, n_feat, e_feat, readout = self.convpool_layers[i](graph, n_feat, e_feat) graph, n_feat, e_feat, readout = self.convpool_layers[i](
graph, n_feat, e_feat
)
if final_readout is None: if final_readout is None:
final_readout = readout final_readout = readout
else: else:
final_readout = final_readout + readout final_readout = final_readout + readout
n_feat = F.relu(self.lin1(final_readout)) n_feat = F.relu(self.lin1(final_readout))
n_feat = F.dropout(n_feat, p=self.dropout, training=self.training) n_feat = F.dropout(n_feat, p=self.dropout, training=self.training)
n_feat = F.relu(self.lin2(n_feat)) n_feat = F.relu(self.lin2(n_feat))
......
import torch
import logging import logging
from scipy.stats import t
import math import math
import torch
from scipy.stats import t
def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False): def get_stats(
array, conf_interval=False, name=None, stdout=False, logout=False
):
"""Compute mean and standard deviation from an numerical array """Compute mean and standard deviation from an numerical array
Args: Args:
array (array like obj): The numerical array, this array can be array (array like obj): The numerical array, this array can be
convert to :obj:`torch.Tensor`. convert to :obj:`torch.Tensor`.
conf_interval (bool, optional): If True, compute the confidence interval bound (95%) conf_interval (bool, optional): If True, compute the confidence interval bound (95%)
instead of the std value. (default: :obj:`False`) instead of the std value. (default: :obj:`False`)
name (str, optional): The name of this numerical array, for log usage. name (str, optional): The name of this numerical array, for log usage.
(default: :obj:`None`) (default: :obj:`None`)
stdout (bool, optional): Whether to output result to the terminal. stdout (bool, optional): Whether to output result to the terminal.
(default: :obj:`False`) (default: :obj:`False`)
logout (bool, optional): Whether to output result via logging module. logout (bool, optional): Whether to output result via logging module.
(default: :obj:`False`) (default: :obj:`False`)
...@@ -29,7 +32,7 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False) ...@@ -29,7 +32,7 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False)
if conf_interval: if conf_interval:
n = array.size(0) n = array.size(0)
se = std / (math.sqrt(n) + eps) se = std / (math.sqrt(n) + eps)
t_value = t.ppf(0.975, df=n-1) t_value = t.ppf(0.975, df=n - 1)
err_bound = t_value * se err_bound = t_value * se
else: else:
err_bound = std err_bound = std
...@@ -46,7 +49,7 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False) ...@@ -46,7 +49,7 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False)
return center, err_bound return center, err_bound
def get_batch_id(num_nodes:torch.Tensor): def get_batch_id(num_nodes: torch.Tensor):
"""Convert the num_nodes array obtained from batch graph to batch_id array """Convert the num_nodes array obtained from batch graph to batch_id array
for each node. for each node.
...@@ -57,12 +60,19 @@ def get_batch_id(num_nodes:torch.Tensor): ...@@ -57,12 +60,19 @@ def get_batch_id(num_nodes:torch.Tensor):
batch_size = num_nodes.size(0) batch_size = num_nodes.size(0)
batch_ids = [] batch_ids = []
for i in range(batch_size): for i in range(batch_size):
item = torch.full((num_nodes[i],), i, dtype=torch.long, device=num_nodes.device) item = torch.full(
(num_nodes[i],), i, dtype=torch.long, device=num_nodes.device
)
batch_ids.append(item) batch_ids.append(item)
return torch.cat(batch_ids) return torch.cat(batch_ids)
def topk(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Tensor): def topk(
x: torch.Tensor,
ratio: float,
batch_id: torch.Tensor,
num_nodes: torch.Tensor,
):
"""The top-k pooling method. Given a graph batch, this method will pool out some """The top-k pooling method. Given a graph batch, this method will pool out some
nodes from input node feature tensor for each graph according to the given ratio. nodes from input node feature tensor for each graph according to the given ratio.
...@@ -72,21 +82,23 @@ def topk(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Ten ...@@ -72,21 +82,23 @@ def topk(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Ten
tensor will be pooled out. tensor will be pooled out.
batch_id (torch.Tensor): The batch_id of each element in the input tensor. batch_id (torch.Tensor): The batch_id of each element in the input tensor.
num_nodes (torch.Tensor): The number of nodes of each graph in batch. num_nodes (torch.Tensor): The number of nodes of each graph in batch.
Returns: Returns:
perm (torch.Tensor): The index in batch to be kept. perm (torch.Tensor): The index in batch to be kept.
k (torch.Tensor): The remaining number of nodes for each graph. k (torch.Tensor): The remaining number of nodes for each graph.
""" """
batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item() batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
cum_num_nodes = torch.cat( cum_num_nodes = torch.cat(
[num_nodes.new_zeros(1), [num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0
num_nodes.cumsum(dim=0)[:-1]], dim=0) )
index = torch.arange(batch_id.size(0), dtype=torch.long, device=x.device) index = torch.arange(batch_id.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch_id]) + (batch_id * max_num_nodes) index = (index - cum_num_nodes[batch_id]) + (batch_id * max_num_nodes)
dense_x = x.new_full((batch_size * max_num_nodes, ), torch.finfo(x.dtype).min) dense_x = x.new_full(
(batch_size * max_num_nodes,), torch.finfo(x.dtype).min
)
dense_x[index] = x dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes) dense_x = dense_x.view(batch_size, max_num_nodes)
...@@ -96,8 +108,10 @@ def topk(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Ten ...@@ -96,8 +108,10 @@ def topk(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Ten
k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long) k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)
mask = [ mask = [
torch.arange(k[i], dtype=torch.long, device=x.device) + torch.arange(k[i], dtype=torch.long, device=x.device)
i * max_num_nodes for i in range(batch_size)] + i * max_num_nodes
for i in range(batch_size)
]
mask = torch.cat(mask, dim=0) mask = torch.cat(mask, dim=0)
perm = perm[mask] perm = perm[mask]
......
import dgl
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.function as fn import dgl.function as fn
from dgl.nn.functional import edge_softmax from dgl.nn.functional import edge_softmax
class HGTLayer(nn.Module): class HGTLayer(nn.Module):
def __init__(self, def __init__(
in_dim, self,
out_dim, in_dim,
node_dict, out_dim,
edge_dict, node_dict,
n_heads, edge_dict,
dropout = 0.2, n_heads,
use_norm = False): dropout=0.2,
use_norm=False,
):
super(HGTLayer, self).__init__() super(HGTLayer, self).__init__()
self.in_dim = in_dim self.in_dim = in_dim
self.out_dim = out_dim self.out_dim = out_dim
self.node_dict = node_dict self.node_dict = node_dict
self.edge_dict = edge_dict self.edge_dict = edge_dict
self.num_types = len(node_dict) self.num_types = len(node_dict)
self.num_relations = len(edge_dict) self.num_relations = len(edge_dict)
self.total_rel = self.num_types * self.num_relations * self.num_types self.total_rel = self.num_types * self.num_relations * self.num_types
self.n_heads = n_heads self.n_heads = n_heads
self.d_k = out_dim // n_heads self.d_k = out_dim // n_heads
self.sqrt_dk = math.sqrt(self.d_k) self.sqrt_dk = math.sqrt(self.d_k)
self.att = None self.att = None
self.k_linears = nn.ModuleList() self.k_linears = nn.ModuleList()
self.q_linears = nn.ModuleList() self.q_linears = nn.ModuleList()
self.v_linears = nn.ModuleList() self.v_linears = nn.ModuleList()
self.a_linears = nn.ModuleList() self.a_linears = nn.ModuleList()
self.norms = nn.ModuleList() self.norms = nn.ModuleList()
self.use_norm = use_norm self.use_norm = use_norm
for t in range(self.num_types): for t in range(self.num_types):
self.k_linears.append(nn.Linear(in_dim, out_dim)) self.k_linears.append(nn.Linear(in_dim, out_dim))
self.q_linears.append(nn.Linear(in_dim, out_dim)) self.q_linears.append(nn.Linear(in_dim, out_dim))
self.v_linears.append(nn.Linear(in_dim, out_dim)) self.v_linears.append(nn.Linear(in_dim, out_dim))
self.a_linears.append(nn.Linear(out_dim, out_dim)) self.a_linears.append(nn.Linear(out_dim, out_dim))
if use_norm: if use_norm:
self.norms.append(nn.LayerNorm(out_dim)) self.norms.append(nn.LayerNorm(out_dim))
self.relation_pri = nn.Parameter(torch.ones(self.num_relations, self.n_heads)) self.relation_pri = nn.Parameter(
self.relation_att = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)) torch.ones(self.num_relations, self.n_heads)
self.relation_msg = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)) )
self.skip = nn.Parameter(torch.ones(self.num_types)) self.relation_att = nn.Parameter(
self.drop = nn.Dropout(dropout) torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)
)
self.relation_msg = nn.Parameter(
torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)
)
self.skip = nn.Parameter(torch.ones(self.num_types))
self.drop = nn.Dropout(dropout)
nn.init.xavier_uniform_(self.relation_att) nn.init.xavier_uniform_(self.relation_att)
nn.init.xavier_uniform_(self.relation_msg) nn.init.xavier_uniform_(self.relation_msg)
...@@ -76,38 +87,62 @@ class HGTLayer(nn.Module): ...@@ -76,38 +87,62 @@ class HGTLayer(nn.Module):
k = torch.einsum("bij,ijk->bik", k, relation_att) k = torch.einsum("bij,ijk->bik", k, relation_att)
v = torch.einsum("bij,ijk->bik", v, relation_msg) v = torch.einsum("bij,ijk->bik", v, relation_msg)
sub_graph.srcdata['k'] = k sub_graph.srcdata["k"] = k
sub_graph.dstdata['q'] = q sub_graph.dstdata["q"] = q
sub_graph.srcdata['v_%d' % e_id] = v sub_graph.srcdata["v_%d" % e_id] = v
sub_graph.apply_edges(fn.v_dot_u('q', 'k', 't')) sub_graph.apply_edges(fn.v_dot_u("q", "k", "t"))
attn_score = sub_graph.edata.pop('t').sum(-1) * relation_pri / self.sqrt_dk attn_score = (
attn_score = edge_softmax(sub_graph, attn_score, norm_by='dst') sub_graph.edata.pop("t").sum(-1)
* relation_pri
sub_graph.edata['t'] = attn_score.unsqueeze(-1) / self.sqrt_dk
)
G.multi_update_all({etype : (fn.u_mul_e('v_%d' % e_id, 't', 'm'), fn.sum('m', 't')) \ attn_score = edge_softmax(sub_graph, attn_score, norm_by="dst")
for etype, e_id in edge_dict.items()}, cross_reducer = 'mean')
sub_graph.edata["t"] = attn_score.unsqueeze(-1)
G.multi_update_all(
{
etype: (
fn.u_mul_e("v_%d" % e_id, "t", "m"),
fn.sum("m", "t"),
)
for etype, e_id in edge_dict.items()
},
cross_reducer="mean",
)
new_h = {} new_h = {}
for ntype in G.ntypes: for ntype in G.ntypes:
''' """
Step 3: Target-specific Aggregation Step 3: Target-specific Aggregation
x = norm( W[node_type] * gelu( Agg(x) ) + x ) x = norm( W[node_type] * gelu( Agg(x) ) + x )
''' """
n_id = node_dict[ntype] n_id = node_dict[ntype]
alpha = torch.sigmoid(self.skip[n_id]) alpha = torch.sigmoid(self.skip[n_id])
t = G.nodes[ntype].data['t'].view(-1, self.out_dim) t = G.nodes[ntype].data["t"].view(-1, self.out_dim)
trans_out = self.drop(self.a_linears[n_id](t)) trans_out = self.drop(self.a_linears[n_id](t))
trans_out = trans_out * alpha + h[ntype] * (1-alpha) trans_out = trans_out * alpha + h[ntype] * (1 - alpha)
if self.use_norm: if self.use_norm:
new_h[ntype] = self.norms[n_id](trans_out) new_h[ntype] = self.norms[n_id](trans_out)
else: else:
new_h[ntype] = trans_out new_h[ntype] = trans_out
return new_h return new_h
class HGT(nn.Module): class HGT(nn.Module):
def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads, use_norm = True): def __init__(
self,
G,
node_dict,
edge_dict,
n_inp,
n_hid,
n_out,
n_layers,
n_heads,
use_norm=True,
):
super(HGT, self).__init__() super(HGT, self).__init__()
self.node_dict = node_dict self.node_dict = node_dict
self.edge_dict = edge_dict self.edge_dict = edge_dict
...@@ -116,29 +151,39 @@ class HGT(nn.Module): ...@@ -116,29 +151,39 @@ class HGT(nn.Module):
self.n_hid = n_hid self.n_hid = n_hid
self.n_out = n_out self.n_out = n_out
self.n_layers = n_layers self.n_layers = n_layers
self.adapt_ws = nn.ModuleList() self.adapt_ws = nn.ModuleList()
for t in range(len(node_dict)): for t in range(len(node_dict)):
self.adapt_ws.append(nn.Linear(n_inp, n_hid)) self.adapt_ws.append(nn.Linear(n_inp, n_hid))
for _ in range(n_layers): for _ in range(n_layers):
self.gcs.append(HGTLayer(n_hid, n_hid, node_dict, edge_dict, n_heads, use_norm = use_norm)) self.gcs.append(
HGTLayer(
n_hid,
n_hid,
node_dict,
edge_dict,
n_heads,
use_norm=use_norm,
)
)
self.out = nn.Linear(n_hid, n_out) self.out = nn.Linear(n_hid, n_out)
def forward(self, G, out_key): def forward(self, G, out_key):
h = {} h = {}
for ntype in G.ntypes: for ntype in G.ntypes:
n_id = self.node_dict[ntype] n_id = self.node_dict[ntype]
h[ntype] = F.gelu(self.adapt_ws[n_id](G.nodes[ntype].data['inp'])) h[ntype] = F.gelu(self.adapt_ws[n_id](G.nodes[ntype].data["inp"]))
for i in range(self.n_layers): for i in range(self.n_layers):
h = self.gcs[i](G, h) h = self.gcs[i](G, h)
return self.out(h[out_key]) return self.out(h[out_key])
class HeteroRGCNLayer(nn.Module): class HeteroRGCNLayer(nn.Module):
def __init__(self, in_size, out_size, etypes): def __init__(self, in_size, out_size, etypes):
super(HeteroRGCNLayer, self).__init__() super(HeteroRGCNLayer, self).__init__()
# W_r for each relation # W_r for each relation
self.weight = nn.ModuleDict({ self.weight = nn.ModuleDict(
name : nn.Linear(in_size, out_size) for name in etypes {name: nn.Linear(in_size, out_size) for name in etypes}
}) )
def forward(self, G, feat_dict): def forward(self, G, feat_dict):
# The input is a dictionary of node features for each type # The input is a dictionary of node features for each type
...@@ -147,18 +192,18 @@ class HeteroRGCNLayer(nn.Module): ...@@ -147,18 +192,18 @@ class HeteroRGCNLayer(nn.Module):
# Compute W_r * h # Compute W_r * h
Wh = self.weight[etype](feat_dict[srctype]) Wh = self.weight[etype](feat_dict[srctype])
# Save it in graph for message passing # Save it in graph for message passing
G.nodes[srctype].data['Wh_%s' % etype] = Wh G.nodes[srctype].data["Wh_%s" % etype] = Wh
# Specify per-relation message passing functions: (message_func, reduce_func). # Specify per-relation message passing functions: (message_func, reduce_func).
# Note that the results are saved to the same destination feature 'h', which # Note that the results are saved to the same destination feature 'h', which
# hints the type wise reducer for aggregation. # hints the type wise reducer for aggregation.
funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h')) funcs[etype] = (fn.copy_u("Wh_%s" % etype, "m"), fn.mean("m", "h"))
# Trigger message passing of multiple types. # Trigger message passing of multiple types.
# The first argument is the message passing functions for each relation. # The first argument is the message passing functions for each relation.
# The second one is the type wise reducer, could be "sum", "max", # The second one is the type wise reducer, could be "sum", "max",
# "min", "mean", "stack" # "min", "mean", "stack"
G.multi_update_all(funcs, 'sum') G.multi_update_all(funcs, "sum")
# return the updated node feature dictionary # return the updated node feature dictionary
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes} return {ntype: G.nodes[ntype].data["h"] for ntype in G.ntypes}
class HeteroRGCN(nn.Module): class HeteroRGCN(nn.Module):
...@@ -169,9 +214,9 @@ class HeteroRGCN(nn.Module): ...@@ -169,9 +214,9 @@ class HeteroRGCN(nn.Module):
self.layer2 = HeteroRGCNLayer(hidden_size, out_size, G.etypes) self.layer2 = HeteroRGCNLayer(hidden_size, out_size, G.etypes)
def forward(self, G, out_key): def forward(self, G, out_key):
input_dict = {ntype : G.nodes[ntype].data['inp'] for ntype in G.ntypes} input_dict = {ntype: G.nodes[ntype].data["inp"] for ntype in G.ntypes}
h_dict = self.layer1(G, input_dict) h_dict = self.layer1(G, input_dict)
h_dict = {k : F.leaky_relu(h) for k, h in h_dict.items()} h_dict = {k: F.leaky_relu(h) for k, h in h_dict.items()}
h_dict = self.layer2(G, h_dict) h_dict = self.layer2(G, h_dict)
# get appropriate logits # get appropriate logits
return h_dict[out_key] return h_dict[out_key]
...@@ -4,50 +4,55 @@ ...@@ -4,50 +4,55 @@
# In[1]: # In[1]:
import scipy.io import argparse
import urllib.request
import dgl
import math import math
import urllib.request
import numpy as np import numpy as np
import scipy.io
from model import * from model import *
import argparse
import dgl
torch.manual_seed(0) torch.manual_seed(0)
data_url = 'https://data.dgl.ai/dataset/ACM.mat' data_url = "https://data.dgl.ai/dataset/ACM.mat"
data_file_path = '/tmp/ACM.mat' data_file_path = "/tmp/ACM.mat"
urllib.request.urlretrieve(data_url, data_file_path) urllib.request.urlretrieve(data_url, data_file_path)
data = scipy.io.loadmat(data_file_path) data = scipy.io.loadmat(data_file_path)
parser = argparse.ArgumentParser(description='Training GNN on ogbn-products benchmark') parser = argparse.ArgumentParser(
description="Training GNN on ogbn-products benchmark"
)
parser.add_argument('--n_epoch', type=int, default=200) parser.add_argument("--n_epoch", type=int, default=200)
parser.add_argument('--n_hid', type=int, default=256) parser.add_argument("--n_hid", type=int, default=256)
parser.add_argument('--n_inp', type=int, default=256) parser.add_argument("--n_inp", type=int, default=256)
parser.add_argument('--clip', type=int, default=1.0) parser.add_argument("--clip", type=int, default=1.0)
parser.add_argument('--max_lr', type=float, default=1e-3) parser.add_argument("--max_lr", type=float, default=1e-3)
args = parser.parse_args() args = parser.parse_args()
def get_n_params(model): def get_n_params(model):
pp=0 pp = 0
for p in list(model.parameters()): for p in list(model.parameters()):
nn=1 nn = 1
for s in list(p.size()): for s in list(p.size()):
nn = nn*s nn = nn * s
pp += nn pp += nn
return pp return pp
def train(model, G): def train(model, G):
best_val_acc = torch.tensor(0) best_val_acc = torch.tensor(0)
best_test_acc = torch.tensor(0) best_test_acc = torch.tensor(0)
train_step = torch.tensor(0) train_step = torch.tensor(0)
for epoch in np.arange(args.n_epoch) + 1: for epoch in np.arange(args.n_epoch) + 1:
model.train() model.train()
logits = model(G, 'paper') logits = model(G, "paper")
# The loss is computed only for labeled nodes. # The loss is computed only for labeled nodes.
loss = F.cross_entropy(logits[train_idx], labels[train_idx].to(device)) loss = F.cross_entropy(logits[train_idx], labels[train_idx].to(device))
optimizer.zero_grad() optimizer.zero_grad()
...@@ -58,38 +63,44 @@ def train(model, G): ...@@ -58,38 +63,44 @@ def train(model, G):
scheduler.step(train_step) scheduler.step(train_step)
if epoch % 5 == 0: if epoch % 5 == 0:
model.eval() model.eval()
logits = model(G, 'paper') logits = model(G, "paper")
pred = logits.argmax(1).cpu() pred = logits.argmax(1).cpu()
train_acc = (pred[train_idx] == labels[train_idx]).float().mean() train_acc = (pred[train_idx] == labels[train_idx]).float().mean()
val_acc = (pred[val_idx] == labels[val_idx]).float().mean() val_acc = (pred[val_idx] == labels[val_idx]).float().mean()
test_acc = (pred[test_idx] == labels[test_idx]).float().mean() test_acc = (pred[test_idx] == labels[test_idx]).float().mean()
if best_val_acc < val_acc: if best_val_acc < val_acc:
best_val_acc = val_acc best_val_acc = val_acc
best_test_acc = test_acc best_test_acc = test_acc
print('Epoch: %d LR: %.5f Loss %.4f, Train Acc %.4f, Val Acc %.4f (Best %.4f), Test Acc %.4f (Best %.4f)' % ( print(
epoch, "Epoch: %d LR: %.5f Loss %.4f, Train Acc %.4f, Val Acc %.4f (Best %.4f), Test Acc %.4f (Best %.4f)"
optimizer.param_groups[0]['lr'], % (
loss.item(), epoch,
train_acc.item(), optimizer.param_groups[0]["lr"],
val_acc.item(), loss.item(),
best_val_acc.item(), train_acc.item(),
test_acc.item(), val_acc.item(),
best_test_acc.item(), best_val_acc.item(),
)) test_acc.item(),
best_test_acc.item(),
)
)
device = torch.device("cuda:0") device = torch.device("cuda:0")
G = dgl.heterograph({ G = dgl.heterograph(
('paper', 'written-by', 'author') : data['PvsA'].nonzero(), {
('author', 'writing', 'paper') : data['PvsA'].transpose().nonzero(), ("paper", "written-by", "author"): data["PvsA"].nonzero(),
('paper', 'citing', 'paper') : data['PvsP'].nonzero(), ("author", "writing", "paper"): data["PvsA"].transpose().nonzero(),
('paper', 'cited', 'paper') : data['PvsP'].transpose().nonzero(), ("paper", "citing", "paper"): data["PvsP"].nonzero(),
('paper', 'is-about', 'subject') : data['PvsL'].nonzero(), ("paper", "cited", "paper"): data["PvsP"].transpose().nonzero(),
('subject', 'has', 'paper') : data['PvsL'].transpose().nonzero(), ("paper", "is-about", "subject"): data["PvsL"].nonzero(),
}) ("subject", "has", "paper"): data["PvsL"].transpose().nonzero(),
}
)
print(G) print(G)
pvc = data['PvsC'].tocsr() pvc = data["PvsC"].tocsr()
p_selected = pvc.tocoo() p_selected = pvc.tocoo()
# generate labels # generate labels
labels = pvc.indices labels = pvc.indices
...@@ -108,51 +119,67 @@ for ntype in G.ntypes: ...@@ -108,51 +119,67 @@ for ntype in G.ntypes:
node_dict[ntype] = len(node_dict) node_dict[ntype] = len(node_dict)
for etype in G.etypes: for etype in G.etypes:
edge_dict[etype] = len(edge_dict) edge_dict[etype] = len(edge_dict)
G.edges[etype].data['id'] = torch.ones(G.number_of_edges(etype), dtype=torch.long) * edge_dict[etype] G.edges[etype].data["id"] = (
torch.ones(G.number_of_edges(etype), dtype=torch.long)
* edge_dict[etype]
)
# Random initialize input feature # Random initialize input feature
for ntype in G.ntypes: for ntype in G.ntypes:
emb = nn.Parameter(torch.Tensor(G.number_of_nodes(ntype), 256), requires_grad = False) emb = nn.Parameter(
torch.Tensor(G.number_of_nodes(ntype), 256), requires_grad=False
)
nn.init.xavier_uniform_(emb) nn.init.xavier_uniform_(emb)
G.nodes[ntype].data['inp'] = emb G.nodes[ntype].data["inp"] = emb
G = G.to(device) G = G.to(device)
model = HGT(G, model = HGT(
node_dict, edge_dict, G,
n_inp=args.n_inp, node_dict,
n_hid=args.n_hid, edge_dict,
n_out=labels.max().item()+1, n_inp=args.n_inp,
n_layers=2, n_hid=args.n_hid,
n_heads=4, n_out=labels.max().item() + 1,
use_norm = True).to(device) n_layers=2,
n_heads=4,
use_norm=True,
).to(device)
optimizer = torch.optim.AdamW(model.parameters()) optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr) scheduler = torch.optim.lr_scheduler.OneCycleLR(
print('Training HGT with #param: %d' % (get_n_params(model))) optimizer, total_steps=args.n_epoch, max_lr=args.max_lr
)
print("Training HGT with #param: %d" % (get_n_params(model)))
train(model, G) train(model, G)
model = HeteroRGCN(
G,
model = HeteroRGCN(G, in_size=args.n_inp,
in_size=args.n_inp, hidden_size=args.n_hid,
hidden_size=args.n_hid, out_size=labels.max().item() + 1,
out_size=labels.max().item()+1).to(device) ).to(device)
optimizer = torch.optim.AdamW(model.parameters()) optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr) scheduler = torch.optim.lr_scheduler.OneCycleLR(
print('Training RGCN with #param: %d' % (get_n_params(model))) optimizer, total_steps=args.n_epoch, max_lr=args.max_lr
)
print("Training RGCN with #param: %d" % (get_n_params(model)))
train(model, G) train(model, G)
model = HGT(
model = HGT(G, G,
node_dict, edge_dict, node_dict,
n_inp=args.n_inp, edge_dict,
n_hid=args.n_hid, n_inp=args.n_inp,
n_out=labels.max().item()+1, n_hid=args.n_hid,
n_layers=0, n_out=labels.max().item() + 1,
n_heads=4).to(device) n_layers=0,
n_heads=4,
).to(device)
optimizer = torch.optim.AdamW(model.parameters()) optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr) scheduler = torch.optim.lr_scheduler.OneCycleLR(
print('Training MLP with #param: %d' % (get_n_params(model))) optimizer, total_steps=args.n_epoch, max_lr=args.max_lr
)
print("Training MLP with #param: %d" % (get_n_params(model)))
train(model, G) train(model, G)
...@@ -2,21 +2,29 @@ ...@@ -2,21 +2,29 @@
################## LIBRARIES ############################## ################## LIBRARIES ##############################
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import numpy as np, os, csv, datetime, torch, faiss import csv
from PIL import Image import datetime
import os
import pickle as pkl
import faiss
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from tqdm import tqdm import numpy as np
import torch
from PIL import Image
from sklearn import metrics from sklearn import metrics
import pickle as pkl
from torch import nn from torch import nn
from tqdm import tqdm
"""=============================================================================================================""" """============================================================================================================="""
################### TensorBoard Settings ################### ################### TensorBoard Settings ###################
def args2exp_name(args): def args2exp_name(args):
exp_name = f"{args.dataset}_{args.loss}_{args.lr}_bs{args.bs}_spc{args.samples_per_class}_embed{args.embed_dim}_arch{args.arch}_decay{args.decay}_fclr{args.fc_lr_mul}_anneal{args.sigmoid_temperature}" exp_name = f"{args.dataset}_{args.loss}_{args.lr}_bs{args.bs}_spc{args.samples_per_class}_embed{args.embed_dim}_arch{args.arch}_decay{args.decay}_fclr{args.fc_lr_mul}_anneal{args.sigmoid_temperature}"
return exp_name return exp_name
################# ACQUIRE NUMBER OF WEIGHTS ################# ################# ACQUIRE NUMBER OF WEIGHTS #################
def gimme_params(model): def gimme_params(model):
...@@ -32,6 +40,7 @@ def gimme_params(model): ...@@ -32,6 +40,7 @@ def gimme_params(model):
params = sum([np.prod(p.size()) for p in model_parameters]) params = sum([np.prod(p.size()) for p in model_parameters])
return params return params
################# SAVE TRAINING PARAMETERS IN NICE STRING ################# ################# SAVE TRAINING PARAMETERS IN NICE STRING #################
def gimme_save_string(opt): def gimme_save_string(opt):
""" """
...@@ -43,20 +52,24 @@ def gimme_save_string(opt): ...@@ -43,20 +52,24 @@ def gimme_save_string(opt):
string, returns string summary of parameters. string, returns string summary of parameters.
""" """
varx = vars(opt) varx = vars(opt)
base_str = '' base_str = ""
for key in varx: for key in varx:
base_str += str(key) base_str += str(key)
if isinstance(varx[key],dict): if isinstance(varx[key], dict):
for sub_key, sub_item in varx[key].items(): for sub_key, sub_item in varx[key].items():
base_str += '\n\t'+str(sub_key)+': '+str(sub_item) base_str += "\n\t" + str(sub_key) + ": " + str(sub_item)
else: else:
base_str += '\n\t'+str(varx[key]) base_str += "\n\t" + str(varx[key])
base_str+='\n\n' base_str += "\n\n"
return base_str return base_str
def f1_score(
def f1_score(model_generated_cluster_labels, target_labels, feature_coll, computed_centroids): model_generated_cluster_labels,
target_labels,
feature_coll,
computed_centroids,
):
""" """
NOTE: MOSTLY ADAPTED FROM https://github.com/wzzheng/HDML on Hardness-Aware Deep Metric Learning. NOTE: MOSTLY ADAPTED FROM https://github.com/wzzheng/HDML on Hardness-Aware Deep Metric Learning.
...@@ -72,7 +85,10 @@ def f1_score(model_generated_cluster_labels, target_labels, feature_coll, comput ...@@ -72,7 +85,10 @@ def f1_score(model_generated_cluster_labels, target_labels, feature_coll, comput
d = np.zeros(len(feature_coll)) d = np.zeros(len(feature_coll))
for i in range(len(feature_coll)): for i in range(len(feature_coll)):
d[i] = np.linalg.norm(feature_coll[i,:] - computed_centroids[model_generated_cluster_labels[i],:]) d[i] = np.linalg.norm(
feature_coll[i, :]
- computed_centroids[model_generated_cluster_labels[i], :]
)
labels_pred = np.zeros(len(feature_coll)) labels_pred = np.zeros(len(feature_coll))
for i in np.unique(model_generated_cluster_labels): for i in np.unique(model_generated_cluster_labels):
...@@ -81,40 +97,38 @@ def f1_score(model_generated_cluster_labels, target_labels, feature_coll, comput ...@@ -81,40 +97,38 @@ def f1_score(model_generated_cluster_labels, target_labels, feature_coll, comput
cid = index[ind] cid = index[ind]
labels_pred[index] = cid labels_pred[index] = cid
N = len(target_labels) N = len(target_labels)
#Cluster n_labels # Cluster n_labels
avail_labels = np.unique(target_labels) avail_labels = np.unique(target_labels)
n_labels = len(avail_labels) n_labels = len(avail_labels)
#Count the number of objects in each cluster # Count the number of objects in each cluster
count_cluster = np.zeros(n_labels) count_cluster = np.zeros(n_labels)
for i in range(n_labels): for i in range(n_labels):
count_cluster[i] = len(np.where(target_labels == avail_labels[i])[0]) count_cluster[i] = len(np.where(target_labels == avail_labels[i])[0])
#Build a mapping from item_id to item index # Build a mapping from item_id to item index
keys = np.unique(labels_pred) keys = np.unique(labels_pred)
num_item = len(keys) num_item = len(keys)
values = range(num_item) values = range(num_item)
item_map = dict() item_map = dict()
for i in range(len(keys)): for i in range(len(keys)):
item_map.update([(keys[i], values[i])]) item_map.update([(keys[i], values[i])])
# Count the number of objects of each item
#Count the number of objects of each item
count_item = np.zeros(num_item) count_item = np.zeros(num_item)
for i in range(N): for i in range(N):
index = item_map[labels_pred[i]] index = item_map[labels_pred[i]]
count_item[index] = count_item[index] + 1 count_item[index] = count_item[index] + 1
#Compute True Positive (TP) plus False Positive (FP) count # Compute True Positive (TP) plus False Positive (FP) count
tp_fp = 0 tp_fp = 0
for k in range(n_labels): for k in range(n_labels):
if count_cluster[k] > 1: if count_cluster[k] > 1:
tp_fp = tp_fp + comb(count_cluster[k], 2) tp_fp = tp_fp + comb(count_cluster[k], 2)
#Compute True Positive (TP) count # Compute True Positive (TP) count
tp = 0 tp = 0
for k in range(n_labels): for k in range(n_labels):
member = np.where(target_labels == avail_labels[k])[0] member = np.where(target_labels == avail_labels[k])[0]
...@@ -129,10 +143,10 @@ def f1_score(model_generated_cluster_labels, target_labels, feature_coll, comput ...@@ -129,10 +143,10 @@ def f1_score(model_generated_cluster_labels, target_labels, feature_coll, comput
if count[i] > 1: if count[i] > 1:
tp = tp + comb(count[i], 2) tp = tp + comb(count[i], 2)
#Compute False Positive (FP) count # Compute False Positive (FP) count
fp = tp_fp - tp fp = tp_fp - tp
#Compute False Negative (FN) count # Compute False Negative (FN) count
count = 0 count = 0
for j in range(num_item): for j in range(num_item):
if count_item[j] > 1: if count_item[j] > 1:
...@@ -141,16 +155,16 @@ def f1_score(model_generated_cluster_labels, target_labels, feature_coll, comput ...@@ -141,16 +155,16 @@ def f1_score(model_generated_cluster_labels, target_labels, feature_coll, comput
# compute F measure # compute F measure
beta = 1 beta = 1
P = tp / (tp + fp) P = tp / (tp + fp)
R = tp / (tp + fn) R = tp / (tp + fn)
F1 = (beta*beta + 1) * P * R / (beta*beta * P + R) F1 = (beta * beta + 1) * P * R / (beta * beta * P + R)
return F1 return F1
"""============================================================================================================="""
"""============================================================================================================="""
def eval_metrics_one_dataset(model, test_dataloader, device, k_vals, opt): def eval_metrics_one_dataset(model, test_dataloader, device, k_vals, opt):
""" """
Compute evaluation metrics on test-dataset, e.g. NMI, F1 and Recall @ k. Compute evaluation metrics on test-dataset, e.g. NMI, F1 and Recall @ k.
...@@ -170,63 +184,82 @@ def eval_metrics_one_dataset(model, test_dataloader, device, k_vals, opt): ...@@ -170,63 +184,82 @@ def eval_metrics_one_dataset(model, test_dataloader, device, k_vals, opt):
n_classes = len(test_dataloader.dataset.avail_classes) n_classes = len(test_dataloader.dataset.avail_classes)
with torch.no_grad(): with torch.no_grad():
### For all test images, extract features ### For all test images, extract features
target_labels, feature_coll = [],[] target_labels, feature_coll = [], []
final_iter = tqdm(test_dataloader, desc='Computing Evaluation Metrics...') final_iter = tqdm(
image_paths= [x[0] for x in test_dataloader.dataset.image_list] test_dataloader, desc="Computing Evaluation Metrics..."
)
image_paths = [x[0] for x in test_dataloader.dataset.image_list]
for idx, inp in enumerate(final_iter): for idx, inp in enumerate(final_iter):
input_img, target = inp[-1], inp[0] input_img, target = inp[-1], inp[0]
target_labels.extend(target.numpy().tolist()) target_labels.extend(target.numpy().tolist())
out = model(input_img.to(device), feature=True) out = model(input_img.to(device), feature=True)
feature_coll.extend(out.cpu().detach().numpy().tolist()) feature_coll.extend(out.cpu().detach().numpy().tolist())
#pdb.set_trace() # pdb.set_trace()
target_labels = np.hstack(target_labels).reshape(-1,1) target_labels = np.hstack(target_labels).reshape(-1, 1)
feature_coll = np.vstack(feature_coll).astype('float32') feature_coll = np.vstack(feature_coll).astype("float32")
torch.cuda.empty_cache() torch.cuda.empty_cache()
### Set Faiss CPU Cluster index ### Set Faiss CPU Cluster index
cpu_cluster_index = faiss.IndexFlatL2(feature_coll.shape[-1]) cpu_cluster_index = faiss.IndexFlatL2(feature_coll.shape[-1])
kmeans = faiss.Clustering(feature_coll.shape[-1], n_classes) kmeans = faiss.Clustering(feature_coll.shape[-1], n_classes)
kmeans.niter = 20 kmeans.niter = 20
kmeans.min_points_per_centroid = 1 kmeans.min_points_per_centroid = 1
kmeans.max_points_per_centroid = 1000000000 kmeans.max_points_per_centroid = 1000000000
### Train Kmeans ### Train Kmeans
kmeans.train(feature_coll, cpu_cluster_index) kmeans.train(feature_coll, cpu_cluster_index)
computed_centroids = faiss.vector_float_to_array(kmeans.centroids).reshape(n_classes, feature_coll.shape[-1]) computed_centroids = faiss.vector_float_to_array(
kmeans.centroids
).reshape(n_classes, feature_coll.shape[-1])
### Assign feature points to clusters ### Assign feature points to clusters
faiss_search_index = faiss.IndexFlatL2(computed_centroids.shape[-1]) faiss_search_index = faiss.IndexFlatL2(computed_centroids.shape[-1])
faiss_search_index.add(computed_centroids) faiss_search_index.add(computed_centroids)
_, model_generated_cluster_labels = faiss_search_index.search(feature_coll, 1) _, model_generated_cluster_labels = faiss_search_index.search(
feature_coll, 1
)
### Compute NMI ### Compute NMI
NMI = metrics.cluster.normalized_mutual_info_score(model_generated_cluster_labels.reshape(-1), target_labels.reshape(-1)) NMI = metrics.cluster.normalized_mutual_info_score(
model_generated_cluster_labels.reshape(-1),
target_labels.reshape(-1),
)
### Recover max(k_vals) nehbours to use for recall computation ### Recover max(k_vals) nehbours to use for recall computation
faiss_search_index = faiss.IndexFlatL2(feature_coll.shape[-1]) faiss_search_index = faiss.IndexFlatL2(feature_coll.shape[-1])
faiss_search_index.add(feature_coll) faiss_search_index.add(feature_coll)
_, k_closest_points = faiss_search_index.search(feature_coll, int(np.max(k_vals)+1)) _, k_closest_points = faiss_search_index.search(
k_closest_classes = target_labels.reshape(-1)[k_closest_points[:,1:]] feature_coll, int(np.max(k_vals) + 1)
print('computing recalls') )
k_closest_classes = target_labels.reshape(-1)[k_closest_points[:, 1:]]
print("computing recalls")
### Compute Recall ### Compute Recall
recall_all_k = [] recall_all_k = []
for k in k_vals: for k in k_vals:
recall_at_k = np.sum([1 for target, recalled_predictions in zip(target_labels, k_closest_classes) if target in recalled_predictions[:k]])/len(target_labels) recall_at_k = np.sum(
[
1
for target, recalled_predictions in zip(
target_labels, k_closest_classes
)
if target in recalled_predictions[:k]
]
) / len(target_labels)
recall_all_k.append(recall_at_k) recall_all_k.append(recall_at_k)
print('finished recalls') print("finished recalls")
print('computing F1') print("computing F1")
### Compute F1 Score ### Compute F1 Score
F1 = 0 F1 = 0
# F1 = f1_score(model_generated_cluster_labels, target_labels, feature_coll, computed_centroids) # F1 = f1_score(model_generated_cluster_labels, target_labels, feature_coll, computed_centroids)
print('finished computing f1') print("finished computing f1")
return F1, NMI, recall_all_k, feature_coll return F1, NMI, recall_all_k, feature_coll
def eval_metrics_query_and_gallery_dataset(
def eval_metrics_query_and_gallery_dataset(model, query_dataloader, gallery_dataloader, device, k_vals, opt): model, query_dataloader, gallery_dataloader, device, k_vals, opt
):
""" """
Compute evaluation metrics on test-dataset, e.g. NMI, F1 and Recall @ k. Compute evaluation metrics on test-dataset, e.g. NMI, F1 and Recall @ k.
...@@ -247,75 +280,111 @@ def eval_metrics_query_and_gallery_dataset(model, query_dataloader, gallery_data ...@@ -247,75 +280,111 @@ def eval_metrics_query_and_gallery_dataset(model, query_dataloader, gallery_data
with torch.no_grad(): with torch.no_grad():
### For all query test images, extract features ### For all query test images, extract features
query_target_labels, query_feature_coll = [],[] query_target_labels, query_feature_coll = [], []
query_image_paths = [x[0] for x in query_dataloader.dataset.image_list] query_image_paths = [x[0] for x in query_dataloader.dataset.image_list]
query_iter = tqdm(query_dataloader, desc='Extraction Query Features') query_iter = tqdm(query_dataloader, desc="Extraction Query Features")
for idx,inp in enumerate(query_iter): for idx, inp in enumerate(query_iter):
input_img,target = inp[-1], inp[0] input_img, target = inp[-1], inp[0]
query_target_labels.extend(target.numpy().tolist()) query_target_labels.extend(target.numpy().tolist())
out = model(input_img.to(device), feature=True) out = model(input_img.to(device), feature=True)
query_feature_coll.extend(out.cpu().detach().numpy().tolist()) query_feature_coll.extend(out.cpu().detach().numpy().tolist())
### For all gallery test images, extract features ### For all gallery test images, extract features
gallery_target_labels, gallery_feature_coll = [],[] gallery_target_labels, gallery_feature_coll = [], []
gallery_image_paths = [x[0] for x in gallery_dataloader.dataset.image_list] gallery_image_paths = [
gallery_iter = tqdm(gallery_dataloader, desc='Extraction Gallery Features') x[0] for x in gallery_dataloader.dataset.image_list
for idx,inp in enumerate(gallery_iter): ]
input_img,target = inp[-1], inp[0] gallery_iter = tqdm(
gallery_dataloader, desc="Extraction Gallery Features"
)
for idx, inp in enumerate(gallery_iter):
input_img, target = inp[-1], inp[0]
gallery_target_labels.extend(target.numpy().tolist()) gallery_target_labels.extend(target.numpy().tolist())
out = model(input_img.to(device), feature=True) out = model(input_img.to(device), feature=True)
gallery_feature_coll.extend(out.cpu().detach().numpy().tolist()) gallery_feature_coll.extend(out.cpu().detach().numpy().tolist())
query_target_labels, query_feature_coll = np.hstack(
query_target_labels, query_feature_coll = np.hstack(query_target_labels).reshape(-1,1), np.vstack(query_feature_coll).astype('float32') query_target_labels
gallery_target_labels, gallery_feature_coll = np.hstack(gallery_target_labels).reshape(-1,1), np.vstack(gallery_feature_coll).astype('float32') ).reshape(-1, 1), np.vstack(query_feature_coll).astype("float32")
gallery_target_labels, gallery_feature_coll = np.hstack(
gallery_target_labels
).reshape(-1, 1), np.vstack(gallery_feature_coll).astype("float32")
torch.cuda.empty_cache() torch.cuda.empty_cache()
### Set CPU Cluster index ### Set CPU Cluster index
stackset = np.concatenate([query_feature_coll, gallery_feature_coll],axis=0) stackset = np.concatenate(
stacklabels = np.concatenate([query_target_labels, gallery_target_labels],axis=0) [query_feature_coll, gallery_feature_coll], axis=0
)
stacklabels = np.concatenate(
[query_target_labels, gallery_target_labels], axis=0
)
cpu_cluster_index = faiss.IndexFlatL2(stackset.shape[-1]) cpu_cluster_index = faiss.IndexFlatL2(stackset.shape[-1])
kmeans = faiss.Clustering(stackset.shape[-1], n_classes) kmeans = faiss.Clustering(stackset.shape[-1], n_classes)
kmeans.niter = 20 kmeans.niter = 20
kmeans.min_points_per_centroid = 1 kmeans.min_points_per_centroid = 1
kmeans.max_points_per_centroid = 1000000000 kmeans.max_points_per_centroid = 1000000000
### Train Kmeans ### Train Kmeans
kmeans.train(stackset, cpu_cluster_index) kmeans.train(stackset, cpu_cluster_index)
computed_centroids = faiss.vector_float_to_array(kmeans.centroids).reshape(n_classes, stackset.shape[-1]) computed_centroids = faiss.vector_float_to_array(
kmeans.centroids
).reshape(n_classes, stackset.shape[-1])
### Assign feature points to clusters ### Assign feature points to clusters
faiss_search_index = faiss.IndexFlatL2(computed_centroids.shape[-1]) faiss_search_index = faiss.IndexFlatL2(computed_centroids.shape[-1])
faiss_search_index.add(computed_centroids) faiss_search_index.add(computed_centroids)
_, model_generated_cluster_labels = faiss_search_index.search(stackset, 1) _, model_generated_cluster_labels = faiss_search_index.search(
stackset, 1
)
### Compute NMI ### Compute NMI
NMI = metrics.cluster.normalized_mutual_info_score(model_generated_cluster_labels.reshape(-1), stacklabels.reshape(-1)) NMI = metrics.cluster.normalized_mutual_info_score(
model_generated_cluster_labels.reshape(-1), stacklabels.reshape(-1)
)
### Recover max(k_vals) nearest neighbours to use for recall computation ### Recover max(k_vals) nearest neighbours to use for recall computation
faiss_search_index = faiss.IndexFlatL2(gallery_feature_coll.shape[-1]) faiss_search_index = faiss.IndexFlatL2(gallery_feature_coll.shape[-1])
faiss_search_index.add(gallery_feature_coll) faiss_search_index.add(gallery_feature_coll)
_, k_closest_points = faiss_search_index.search(query_feature_coll, int(np.max(k_vals))) _, k_closest_points = faiss_search_index.search(
k_closest_classes = gallery_target_labels.reshape(-1)[k_closest_points] query_feature_coll, int(np.max(k_vals))
)
k_closest_classes = gallery_target_labels.reshape(-1)[k_closest_points]
### Compute Recall ### Compute Recall
recall_all_k = [] recall_all_k = []
for k in k_vals: for k in k_vals:
recall_at_k = np.sum([1 for target, recalled_predictions in zip(query_target_labels, k_closest_classes) if target in recalled_predictions[:k]])/len(query_target_labels) recall_at_k = np.sum(
[
1
for target, recalled_predictions in zip(
query_target_labels, k_closest_classes
)
if target in recalled_predictions[:k]
]
) / len(query_target_labels)
recall_all_k.append(recall_at_k) recall_all_k.append(recall_at_k)
recall_str = ', '.join('@{0}: {1:.4f}'.format(k,rec) for k,rec in zip(k_vals, recall_all_k)) recall_str = ", ".join(
"@{0}: {1:.4f}".format(k, rec)
for k, rec in zip(k_vals, recall_all_k)
)
### Compute F1 score ### Compute F1 score
F1 = f1_score(model_generated_cluster_labels, stacklabels, stackset, computed_centroids) F1 = f1_score(
model_generated_cluster_labels,
stacklabels,
stackset,
computed_centroids,
)
return F1, NMI, recall_all_k, query_feature_coll, gallery_feature_coll return F1, NMI, recall_all_k, query_feature_coll, gallery_feature_coll
"""=============================================================================================================""" """============================================================================================================="""
####### RECOVER CLOSEST EXAMPLE IMAGES ####### ####### RECOVER CLOSEST EXAMPLE IMAGES #######
def recover_closest_one_dataset(feature_matrix_all, image_paths, save_path, n_image_samples=10, n_closest=3): def recover_closest_one_dataset(
feature_matrix_all, image_paths, save_path, n_image_samples=10, n_closest=3
):
""" """
Provide sample recoveries. Provide sample recoveries.
...@@ -329,31 +398,45 @@ def recover_closest_one_dataset(feature_matrix_all, image_paths, save_path, n_im ...@@ -329,31 +398,45 @@ def recover_closest_one_dataset(feature_matrix_all, image_paths, save_path, n_im
Nothing! Nothing!
""" """
image_paths = np.array([x[0] for x in image_paths]) image_paths = np.array([x[0] for x in image_paths])
sample_idxs = np.random.choice(np.arange(len(feature_matrix_all)), n_image_samples) sample_idxs = np.random.choice(
np.arange(len(feature_matrix_all)), n_image_samples
)
faiss_search_index = faiss.IndexFlatL2(feature_matrix_all.shape[-1]) faiss_search_index = faiss.IndexFlatL2(feature_matrix_all.shape[-1])
faiss_search_index.add(feature_matrix_all) faiss_search_index.add(feature_matrix_all)
_, closest_feature_idxs = faiss_search_index.search(feature_matrix_all, n_closest+1) _, closest_feature_idxs = faiss_search_index.search(
feature_matrix_all, n_closest + 1
)
sample_paths = image_paths[closest_feature_idxs][sample_idxs] sample_paths = image_paths[closest_feature_idxs][sample_idxs]
f,axes = plt.subplots(n_image_samples, n_closest+1) f, axes = plt.subplots(n_image_samples, n_closest + 1)
for i,(ax,plot_path) in enumerate(zip(axes.reshape(-1), sample_paths.reshape(-1))): for i, (ax, plot_path) in enumerate(
zip(axes.reshape(-1), sample_paths.reshape(-1))
):
ax.imshow(np.array(Image.open(plot_path))) ax.imshow(np.array(Image.open(plot_path)))
ax.set_xticks([]) ax.set_xticks([])
ax.set_yticks([]) ax.set_yticks([])
if i%(n_closest+1): if i % (n_closest + 1):
ax.axvline(x=0, color='g', linewidth=13) ax.axvline(x=0, color="g", linewidth=13)
else: else:
ax.axvline(x=0, color='r', linewidth=13) ax.axvline(x=0, color="r", linewidth=13)
f.set_size_inches(10,20) f.set_size_inches(10, 20)
f.tight_layout() f.tight_layout()
f.savefig(save_path) f.savefig(save_path)
plt.close() plt.close()
####### RECOVER CLOSEST EXAMPLE IMAGES ####### ####### RECOVER CLOSEST EXAMPLE IMAGES #######
def recover_closest_inshop(query_feature_matrix_all, gallery_feature_matrix_all, query_image_paths, gallery_image_paths, save_path, n_image_samples=10, n_closest=3): def recover_closest_inshop(
query_feature_matrix_all,
gallery_feature_matrix_all,
query_image_paths,
gallery_image_paths,
save_path,
n_image_samples=10,
n_closest=3,
):
""" """
Provide sample recoveries. Provide sample recoveries.
...@@ -368,34 +451,43 @@ def recover_closest_inshop(query_feature_matrix_all, gallery_feature_matrix_all, ...@@ -368,34 +451,43 @@ def recover_closest_inshop(query_feature_matrix_all, gallery_feature_matrix_all,
Returns: Returns:
Nothing! Nothing!
""" """
query_image_paths, gallery_image_paths = np.array(query_image_paths), np.array(gallery_image_paths) query_image_paths, gallery_image_paths = np.array(
sample_idxs = np.random.choice(np.arange(len(query_feature_matrix_all)), n_image_samples) query_image_paths
), np.array(gallery_image_paths)
sample_idxs = np.random.choice(
np.arange(len(query_feature_matrix_all)), n_image_samples
)
faiss_search_index = faiss.IndexFlatL2(gallery_feature_matrix_all.shape[-1]) faiss_search_index = faiss.IndexFlatL2(gallery_feature_matrix_all.shape[-1])
faiss_search_index.add(gallery_feature_matrix_all) faiss_search_index.add(gallery_feature_matrix_all)
_, closest_feature_idxs = faiss_search_index.search(query_feature_matrix_all, n_closest) _, closest_feature_idxs = faiss_search_index.search(
query_feature_matrix_all, n_closest
)
image_paths = gallery_image_paths[closest_feature_idxs] image_paths = gallery_image_paths[closest_feature_idxs]
image_paths = np.concatenate([query_image_paths.reshape(-1,1), image_paths],axis=-1) image_paths = np.concatenate(
[query_image_paths.reshape(-1, 1), image_paths], axis=-1
)
sample_paths = image_paths[closest_feature_idxs][sample_idxs] sample_paths = image_paths[closest_feature_idxs][sample_idxs]
f,axes = plt.subplots(n_image_samples, n_closest+1) f, axes = plt.subplots(n_image_samples, n_closest + 1)
for i,(ax,plot_path) in enumerate(zip(axes.reshape(-1), sample_paths.reshape(-1))): for i, (ax, plot_path) in enumerate(
zip(axes.reshape(-1), sample_paths.reshape(-1))
):
ax.imshow(np.array(Image.open(plot_path))) ax.imshow(np.array(Image.open(plot_path)))
ax.set_xticks([]) ax.set_xticks([])
ax.set_yticks([]) ax.set_yticks([])
if i%(n_closest+1): if i % (n_closest + 1):
ax.axvline(x=0, color='g', linewidth=13) ax.axvline(x=0, color="g", linewidth=13)
else: else:
ax.axvline(x=0, color='r', linewidth=13) ax.axvline(x=0, color="r", linewidth=13)
f.set_size_inches(10,20) f.set_size_inches(10, 20)
f.tight_layout() f.tight_layout()
f.savefig(save_path) f.savefig(save_path)
plt.close() plt.close()
"""=============================================================================================================""" """============================================================================================================="""
################## SET NETWORK TRAINING CHECKPOINT ##################### ################## SET NETWORK TRAINING CHECKPOINT #####################
def set_checkpoint(model, opt, progress_saver, savepath): def set_checkpoint(model, opt, progress_saver, savepath):
...@@ -411,20 +503,25 @@ def set_checkpoint(model, opt, progress_saver, savepath): ...@@ -411,20 +503,25 @@ def set_checkpoint(model, opt, progress_saver, savepath):
Returns: Returns:
Nothing! Nothing!
""" """
torch.save({'state_dict':model.state_dict(), 'opt':opt, torch.save(
'progress':progress_saver}, savepath) {
"state_dict": model.state_dict(),
"opt": opt,
"progress": progress_saver,
},
savepath,
)
"""=============================================================================================================""" """============================================================================================================="""
################## WRITE TO CSV FILE ##################### ################## WRITE TO CSV FILE #####################
class CSV_Writer(): class CSV_Writer:
""" """
Class to append newly compute training metrics to a csv file Class to append newly compute training metrics to a csv file
for data logging. for data logging.
Is used together with the LOGGER class. Is used together with the LOGGER class.
""" """
def __init__(self, save_path, columns): def __init__(self, save_path, columns):
""" """
Args: Args:
...@@ -434,7 +531,7 @@ class CSV_Writer(): ...@@ -434,7 +531,7 @@ class CSV_Writer():
Nothing! Nothing!
""" """
self.save_path = save_path self.save_path = save_path
self.columns = columns self.columns = columns
with open(self.save_path, "a") as csv_file: with open(self.save_path, "a") as csv_file:
writer = csv.writer(csv_file, delimiter=",") writer = csv.writer(csv_file, delimiter=",")
...@@ -450,7 +547,7 @@ class CSV_Writer(): ...@@ -450,7 +547,7 @@ class CSV_Writer():
Nothing! Nothing!
""" """
with open(self.save_path, "a") as csv_file: with open(self.save_path, "a") as csv_file:
writer = csv.writer(csv_file, delimiter=',') writer = csv.writer(csv_file, delimiter=",")
writer.writerow(inputs) writer.writerow(inputs)
...@@ -469,37 +566,41 @@ def set_logging(opt): ...@@ -469,37 +566,41 @@ def set_logging(opt):
Returns: Returns:
Nothing! Nothing!
""" """
checkfolder = opt.save_path+'/'+str(opt.iter) checkfolder = opt.save_path + "/" + str(opt.iter)
#Create start-time-based name if opt.savename is not give. # Create start-time-based name if opt.savename is not give.
if opt.savename == '': if opt.savename == "":
date = datetime.datetime.now() date = datetime.datetime.now()
checkfolder = opt.save_path+'/'+str(opt.iter) checkfolder = opt.save_path + "/" + str(opt.iter)
#If folder already exists, iterate over it until is doesn't. # If folder already exists, iterate over it until is doesn't.
# counter = 1 # counter = 1
# while os.path.exists(checkfolder): # while os.path.exists(checkfolder):
# checkfolder = opt.save_path+'/'+opt.savename+'_'+str(counter) # checkfolder = opt.save_path+'/'+opt.savename+'_'+str(counter)
# counter += 1 # counter += 1
#Create Folder # Create Folder
if not os.path.exists(checkfolder): if not os.path.exists(checkfolder):
os.makedirs(checkfolder) os.makedirs(checkfolder)
opt.save_path = checkfolder opt.save_path = checkfolder
#Store training parameters as text and pickle in said folder. # Store training parameters as text and pickle in said folder.
with open(opt.save_path+'/Parameter_Info.txt','w') as f: with open(opt.save_path + "/Parameter_Info.txt", "w") as f:
f.write(gimme_save_string(opt)) f.write(gimme_save_string(opt))
pkl.dump(opt,open(opt.save_path+"/hypa.pkl","wb")) pkl.dump(opt, open(opt.save_path + "/hypa.pkl", "wb"))
import pdb import pdb
class LOGGER():
class LOGGER:
""" """
This class provides a collection of logging properties that are useful for training. This class provides a collection of logging properties that are useful for training.
These include setting the save folder, in which progression of training/testing metrics is visualized, These include setting the save folder, in which progression of training/testing metrics is visualized,
csv log-files are stored, sample recoveries are plotted and an internal data saver. csv log-files are stored, sample recoveries are plotted and an internal data saver.
""" """
def __init__(self, opt, metrics_to_log, name='Basic', start_new=True):
def __init__(self, opt, metrics_to_log, name="Basic", start_new=True):
""" """
Args: Args:
opt: argparse.Namespace, contains all training-specific parameters. opt: argparse.Namespace, contains all training-specific parameters.
...@@ -512,18 +613,23 @@ class LOGGER(): ...@@ -512,18 +613,23 @@ class LOGGER():
Returns: Returns:
Nothing! Nothing!
""" """
self.prop = opt self.prop = opt
self.metrics_to_log = metrics_to_log self.metrics_to_log = metrics_to_log
### Make Logging Directories ### Make Logging Directories
if start_new: set_logging(opt) if start_new:
set_logging(opt)
### Set Progress Saver Dict ### Set Progress Saver Dict
self.progress_saver = self.provide_progress_saver(metrics_to_log) self.progress_saver = self.provide_progress_saver(metrics_to_log)
### Set CSV Writters ### Set CSV Writters
self.csv_loggers= {mode:CSV_Writer(opt.save_path+'/log_'+mode+'_'+name+'.csv', lognames) for mode, lognames in metrics_to_log.items()} self.csv_loggers = {
mode: CSV_Writer(
opt.save_path + "/log_" + mode + "_" + name + ".csv", lognames
)
for mode, lognames in metrics_to_log.items()
}
def provide_progress_saver(self, metrics_to_log): def provide_progress_saver(self, metrics_to_log):
""" """
...@@ -532,7 +638,10 @@ class LOGGER(): ...@@ -532,7 +638,10 @@ class LOGGER():
Args: Args:
metrics_to_log: see __init__(). Describes the structure of Progress_Saver. metrics_to_log: see __init__(). Describes the structure of Progress_Saver.
""" """
Progress_Saver = {key:{sub_key:[] for sub_key in metrics_to_log[key]} for key in metrics_to_log.keys()} Progress_Saver = {
key: {sub_key: [] for sub_key in metrics_to_log[key]}
for key in metrics_to_log.keys()
}
return Progress_Saver return Progress_Saver
def log(self, main_keys, metric_keys, values): def log(self, main_keys, metric_keys, values):
...@@ -543,16 +652,19 @@ class LOGGER(): ...@@ -543,16 +652,19 @@ class LOGGER():
metric_keys: Needs to follow the list length of self.progress_saver[main_key(s)]. List of metric keys that are extended with new values. metric_keys: Needs to follow the list length of self.progress_saver[main_key(s)]. List of metric keys that are extended with new values.
values: Needs to be a list of the same structure as metric_keys. Actual values that are appended. values: Needs to be a list of the same structure as metric_keys. Actual values that are appended.
""" """
if not isinstance(main_keys, list): main_keys = [main_keys] if not isinstance(main_keys, list):
if not isinstance(metric_keys, list): metric_keys = [metric_keys] main_keys = [main_keys]
if not isinstance(values, list): values = [values] if not isinstance(metric_keys, list):
metric_keys = [metric_keys]
#Log data to progress saver dict. if not isinstance(values, list):
values = [values]
# Log data to progress saver dict.
for main_key in main_keys: for main_key in main_keys:
for value, metric_key in zip(values, metric_keys): for value, metric_key in zip(values, metric_keys):
self.progress_saver[main_key][metric_key].append(value) self.progress_saver[main_key][metric_key].append(value)
#Append data to csv. # Append data to csv.
self.csv_loggers[main_key].log(values) self.csv_loggers[main_key].log(values)
def update_info_plot(self): def update_info_plot(self):
...@@ -564,27 +676,70 @@ class LOGGER(): ...@@ -564,27 +676,70 @@ class LOGGER():
Returns: Returns:
Nothing! Nothing!
""" """
t_epochs = self.progress_saver['val']['Epochs'] t_epochs = self.progress_saver["val"]["Epochs"]
t_loss_list = [self.progress_saver['train']['Train Loss']] t_loss_list = [self.progress_saver["train"]["Train Loss"]]
t_legend_handles = ['Train Loss'] t_legend_handles = ["Train Loss"]
v_epochs = self.progress_saver['val']['Epochs'] v_epochs = self.progress_saver["val"]["Epochs"]
#Because Vehicle-ID normally uses three different test sets, a distinction has to be made. # Because Vehicle-ID normally uses three different test sets, a distinction has to be made.
if self.prop.dataset != 'vehicle_id': if self.prop.dataset != "vehicle_id":
title = ' | '.join(key+': {0:3.3f}'.format(np.max(item)) for key,item in self.progress_saver['val'].items() if key not in ['Time', 'Epochs']) title = " | ".join(
key + ": {0:3.3f}".format(np.max(item))
for key, item in self.progress_saver["val"].items()
if key not in ["Time", "Epochs"]
)
self.info_plot.title = title self.info_plot.title = title
v_metric_list = [self.progress_saver['val'][key] for key in self.progress_saver['val'].keys() if key not in ['Time', 'Epochs']] v_metric_list = [
v_legend_handles = [key for key in self.progress_saver['val'].keys() if key not in ['Time', 'Epochs']] self.progress_saver["val"][key]
for key in self.progress_saver["val"].keys()
self.info_plot.make_plot(t_epochs, v_epochs, t_loss_list, v_metric_list, t_legend_handles, v_legend_handles) if key not in ["Time", "Epochs"]
]
v_legend_handles = [
key
for key in self.progress_saver["val"].keys()
if key not in ["Time", "Epochs"]
]
self.info_plot.make_plot(
t_epochs,
v_epochs,
t_loss_list,
v_metric_list,
t_legend_handles,
v_legend_handles,
)
else: else:
#Iterate over all test sets. # Iterate over all test sets.
for i in range(3): for i in range(3):
title = ' | '.join(key+': {0:3.3f}'.format(np.max(item)) for key,item in self.progress_saver['val'].items() if key not in ['Time', 'Epochs'] and 'Set {}'.format(i) in key) title = " | ".join(
self.info_plot['Set {}'.format(i)].title = title key + ": {0:3.3f}".format(np.max(item))
v_metric_list = [self.progress_saver['val'][key] for key in self.progress_saver['val'].keys() if key not in ['Time', 'Epochs'] and 'Set {}'.format(i) in key] for key, item in self.progress_saver["val"].items()
v_legend_handles = [key for key in self.progress_saver['val'].keys() if key not in ['Time', 'Epochs'] and 'Set {}'.format(i) in key] if key not in ["Time", "Epochs"]
self.info_plot['Set {}'.format(i)].make_plot(t_epochs, v_epochs, t_loss_list, v_metric_list, t_legend_handles, v_legend_handles, appendix='set_{}'.format(i)) and "Set {}".format(i) in key
)
self.info_plot["Set {}".format(i)].title = title
v_metric_list = [
self.progress_saver["val"][key]
for key in self.progress_saver["val"].keys()
if key not in ["Time", "Epochs"]
and "Set {}".format(i) in key
]
v_legend_handles = [
key
for key in self.progress_saver["val"].keys()
if key not in ["Time", "Epochs"]
and "Set {}".format(i) in key
]
self.info_plot["Set {}".format(i)].make_plot(
t_epochs,
v_epochs,
t_loss_list,
v_metric_list,
t_legend_handles,
v_legend_handles,
appendix="set_{}".format(i),
)
def metrics_to_examine(dataset, k_vals): def metrics_to_examine(dataset, k_vals):
""" """
...@@ -598,18 +753,21 @@ def metrics_to_examine(dataset, k_vals): ...@@ -598,18 +753,21 @@ def metrics_to_examine(dataset, k_vals):
Returns: Returns:
metric_dict: Dictionary representing the storing structure for LOGGER.progress_saver. See LOGGER.__init__() for an example. metric_dict: Dictionary representing the storing structure for LOGGER.progress_saver. See LOGGER.__init__() for an example.
""" """
metric_dict = {'train':['Epochs','Time','Train Loss']} metric_dict = {"train": ["Epochs", "Time", "Train Loss"]}
if dataset=='vehicle_id': if dataset == "vehicle_id":
metric_dict['val'] = ['Epochs','Time'] metric_dict["val"] = ["Epochs", "Time"]
#Vehicle_ID uses three test sets # Vehicle_ID uses three test sets
for i in range(3): for i in range(3):
metric_dict['val'] += ['Set {} NMI'.format(i), 'Set {} F1'.format(i)] metric_dict["val"] += [
"Set {} NMI".format(i),
"Set {} F1".format(i),
]
for k in k_vals: for k in k_vals:
metric_dict['val'] += ['Set {} Recall @ {}'.format(i,k)] metric_dict["val"] += ["Set {} Recall @ {}".format(i, k)]
else: else:
metric_dict['val'] = ['Epochs','Time','NMI', 'F1'] metric_dict["val"] = ["Epochs", "Time", "NMI", "F1"]
metric_dict['val'] += ['Recall @ {}'.format(k) for k in k_vals] metric_dict["val"] += ["Recall @ {}".format(k) for k in k_vals]
return metric_dict return metric_dict
...@@ -633,8 +791,14 @@ def vis(model, test_dataloader, device, split, opt): ...@@ -633,8 +791,14 @@ def vis(model, test_dataloader, device, split, opt):
torch.cuda.empty_cache() torch.cuda.empty_cache()
if opt.dataset == "Inaturalist": if opt.dataset == "Inaturalist":
if opt.iter > 0: if opt.iter > 0:
with open(opt.cluster_path, 'rb') as clusterf: with open(opt.cluster_path, "rb") as clusterf:
path2idx, global_features, global_pred_labels, gt_labels, masks = pkl.load(clusterf) (
path2idx,
global_features,
global_pred_labels,
gt_labels,
masks,
) = pkl.load(clusterf)
gt_labels = gt_labels + len(np.unique(global_pred_labels)) gt_labels = gt_labels + len(np.unique(global_pred_labels))
idx2path = {v: k for k, v in path2idx.items()} idx2path = {v: k for k, v in path2idx.items()}
else: else:
...@@ -643,16 +807,18 @@ def vis(model, test_dataloader, device, split, opt): ...@@ -643,16 +807,18 @@ def vis(model, test_dataloader, device, split, opt):
paths = [x.strip() for x in filelines] paths = [x.strip() for x in filelines]
Lin_paths = paths[:linsize] Lin_paths = paths[:linsize]
masks = np.zeros(len(paths)) masks = np.zeros(len(paths))
masks[:len(Lin_paths)] = 0 masks[: len(Lin_paths)] = 0
masks[len(Lin_paths):] = 2 masks[len(Lin_paths) :] = 2
_ = model.eval() _ = model.eval()
path2ids = {} path2ids = {}
with torch.no_grad(): with torch.no_grad():
### For all test images, extract features ### For all test images, extract features
target_labels, feature_coll = [],[] target_labels, feature_coll = [], []
final_iter = tqdm(test_dataloader, desc='Computing Evaluation Metrics...') final_iter = tqdm(
test_dataloader, desc="Computing Evaluation Metrics..."
)
image_paths = [x[0] for x in test_dataloader.dataset.image_list] image_paths = [x[0] for x in test_dataloader.dataset.image_list]
for i in range(len(image_paths)): for i in range(len(image_paths)):
path2ids[image_paths[i]] = i path2ids[image_paths[i]] = i
...@@ -661,9 +827,9 @@ def vis(model, test_dataloader, device, split, opt): ...@@ -661,9 +827,9 @@ def vis(model, test_dataloader, device, split, opt):
target_labels.extend(target.numpy().tolist()) target_labels.extend(target.numpy().tolist())
out = model(input_img.to(device), feature=True) out = model(input_img.to(device), feature=True)
feature_coll.extend(out.cpu().detach().numpy().tolist()) feature_coll.extend(out.cpu().detach().numpy().tolist())
#pdb.set_trace() # pdb.set_trace()
target_labels = np.hstack(target_labels).reshape(-1) target_labels = np.hstack(target_labels).reshape(-1)
feature_coll = np.vstack(feature_coll).astype('float32') feature_coll = np.vstack(feature_coll).astype("float32")
if (opt.dataset == "Inaturalist") and "all_train" in split: if (opt.dataset == "Inaturalist") and "all_train" in split:
if opt.iter > 0: if opt.iter > 0:
...@@ -690,8 +856,8 @@ def vis(model, test_dataloader, device, split, opt): ...@@ -690,8 +856,8 @@ def vis(model, test_dataloader, device, split, opt):
target_labels_new = np.zeros_like(target_labels) target_labels_new = np.zeros_like(target_labels)
for i in range(len(paths)): for i in range(len(paths)):
path = paths[i] path = paths[i]
idxx = path2ids[opt.source_path+'/'+path] idxx = path2ids[opt.source_path + "/" + path]
path2ids_new[opt.source_path+'/'+path] = i path2ids_new[opt.source_path + "/" + path] = i
predicted_features[i] = feature_coll[idxx] predicted_features[i] = feature_coll[idxx]
target_labels_new[i] = target_labels[idxx] target_labels_new[i] = target_labels[idxx]
...@@ -704,11 +870,13 @@ def vis(model, test_dataloader, device, split, opt): ...@@ -704,11 +870,13 @@ def vis(model, test_dataloader, device, split, opt):
print("all_train not in split.") print("all_train not in split.")
gtlabels = target_labels gtlabels = target_labels
output_feature_path = os.path.join(opt.source_path,split+"_inat_features.pkl") output_feature_path = os.path.join(
opt.source_path, split + "_inat_features.pkl"
)
print("Dump features into {}.".format(output_feature_path)) print("Dump features into {}.".format(output_feature_path))
with open(output_feature_path, "wb") as f: with open(output_feature_path, "wb") as f:
pkl.dump([path2ids, feature_coll, target_labels, gtlabels, masks], f) pkl.dump([path2ids, feature_coll, target_labels, gtlabels, masks], f)
print(target_labels.max()) print(target_labels.max())
print("target_labels:", target_labels.shape) print("target_labels:", target_labels.shape)
print("feature_coll:", feature_coll.shape) print("feature_coll:", feature_coll.shape)
\ No newline at end of file
...@@ -5,12 +5,18 @@ import pickle ...@@ -5,12 +5,18 @@ import pickle
import warnings import warnings
from numpy.core.arrayprint import IntegerFormat from numpy.core.arrayprint import IntegerFormat
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import numpy as np, pandas as pd, copy, torch, random, os import copy
import os
import random
from torch.utils.data import Dataset import numpy as np
import pandas as pd
import torch
from PIL import Image from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
"""============================================================================""" """============================================================================"""
...@@ -23,35 +29,56 @@ def give_dataloaders(dataset, trainset, testset, opt, cluster_path=""): ...@@ -23,35 +29,56 @@ def give_dataloaders(dataset, trainset, testset, opt, cluster_path=""):
Returns: Returns:
dataloaders: dict of dataloaders for training, testing and evaluation on training. dataloaders: dict of dataloaders for training, testing and evaluation on training.
""" """
#Dataset selection # Dataset selection
if opt.dataset=='Inaturalist': if opt.dataset == "Inaturalist":
if opt.finetune: if opt.finetune:
datasets = give_inat_datasets_finetune_1head(testset, cluster_path, opt) datasets = give_inat_datasets_finetune_1head(
testset, cluster_path, opt
)
else: else:
if opt.get_features: if opt.get_features:
datasets = give_inaturalist_datasets_for_features(opt) datasets = give_inaturalist_datasets_for_features(opt)
else: else:
datasets = give_inaturalist_datasets(opt) datasets = give_inaturalist_datasets(opt)
else: else:
raise Exception('No Dataset >{}< available!'.format(dataset)) raise Exception("No Dataset >{}< available!".format(dataset))
#Move datasets to dataloaders. # Move datasets to dataloaders.
dataloaders = {} dataloaders = {}
for key, dataset in datasets.items(): for key, dataset in datasets.items():
if (isinstance(dataset, TrainDatasetsmoothap) or isinstance(dataset, TrainDatasetsmoothap1Head))\ if (
and key in ['training', 'clustering']: isinstance(dataset, TrainDatasetsmoothap)
dataloaders[key] = torch.utils.data.DataLoader(dataset, batch_size=opt.bs, or isinstance(dataset, TrainDatasetsmoothap1Head)
num_workers=opt.kernels, sampler=torch.utils.data.SequentialSampler(dataset), ) and key in ["training", "clustering"]:
pin_memory=True, drop_last=True) dataloaders[key] = torch.utils.data.DataLoader(
dataset,
batch_size=opt.bs,
num_workers=opt.kernels,
sampler=torch.utils.data.SequentialSampler(dataset),
pin_memory=True,
drop_last=True,
)
else: else:
is_val = dataset.is_validation is_val = dataset.is_validation
if key == 'training' or key == 'clustering': if key == "training" or key == "clustering":
dataloaders[key] = torch.utils.data.DataLoader(dataset, batch_size=opt.bs, dataloaders[key] = torch.utils.data.DataLoader(
num_workers=opt.kernels, shuffle=not is_val, pin_memory=True, drop_last=not is_val) dataset,
else: batch_size=opt.bs,
dataloaders[key] = torch.utils.data.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.kernels,
num_workers=6, shuffle=not is_val, pin_memory=True, drop_last=not is_val) shuffle=not is_val,
pin_memory=True,
drop_last=not is_val,
)
else:
dataloaders[key] = torch.utils.data.DataLoader(
dataset,
batch_size=opt.bs,
num_workers=6,
shuffle=not is_val,
pin_memory=True,
drop_last=not is_val,
)
return dataloaders return dataloaders
...@@ -66,58 +93,66 @@ def give_inaturalist_datasets(opt): ...@@ -66,58 +93,66 @@ def give_inaturalist_datasets(opt):
Returns: Returns:
dict of PyTorch datasets for training, testing and evaluation. dict of PyTorch datasets for training, testing and evaluation.
""" """
#Load text-files containing classes and imagepaths. # Load text-files containing classes and imagepaths.
#Generate image_dicts of shape {class_idx:[list of paths to images belong to this class] ...} # Generate image_dicts of shape {class_idx:[list of paths to images belong to this class] ...}
train_image_dict, val_image_dict, test_image_dict = {},{},{} train_image_dict, val_image_dict, test_image_dict = {}, {}, {}
with open(os.path.join(opt.source_path, opt.trainset)) as f: with open(os.path.join(opt.source_path, opt.trainset)) as f:
FileLines = f.readlines() FileLines = f.readlines()
FileLines = [x.strip() for x in FileLines] FileLines = [x.strip() for x in FileLines]
for entry in FileLines: for entry in FileLines:
info = entry.split('/') info = entry.split("/")
if '/'.join([info[-3],info[-2]]) not in train_image_dict: if "/".join([info[-3], info[-2]]) not in train_image_dict:
train_image_dict['/'.join([info[-3],info[-2]])] = [] train_image_dict["/".join([info[-3], info[-2]])] = []
train_image_dict['/'.join([info[-3],info[-2]])].append(os.path.join(opt.source_path,entry)) train_image_dict["/".join([info[-3], info[-2]])].append(
os.path.join(opt.source_path, entry)
)
with open(os.path.join(opt.source_path, opt.testset)) as f: with open(os.path.join(opt.source_path, opt.testset)) as f:
FileLines = f.readlines() FileLines = f.readlines()
FileLines = [x.strip() for x in FileLines] FileLines = [x.strip() for x in FileLines]
for entry in FileLines: for entry in FileLines:
info = entry.split('/') info = entry.split("/")
if '/'.join([info[-3],info[-2]]) not in val_image_dict: if "/".join([info[-3], info[-2]]) not in val_image_dict:
val_image_dict['/'.join([info[-3], info[-2]])] = [] val_image_dict["/".join([info[-3], info[-2]])] = []
val_image_dict['/'.join([info[-3], info[-2]])].append(os.path.join(opt.source_path,entry)) val_image_dict["/".join([info[-3], info[-2]])].append(
os.path.join(opt.source_path, entry)
)
with open(os.path.join(opt.source_path, opt.testset)) as f: with open(os.path.join(opt.source_path, opt.testset)) as f:
FileLines = f.readlines() FileLines = f.readlines()
FileLines = [x.strip() for x in FileLines] FileLines = [x.strip() for x in FileLines]
for entry in FileLines: for entry in FileLines:
info = entry.split('/') info = entry.split("/")
if '/'.join([info[-3],info[-2]]) not in test_image_dict: if "/".join([info[-3], info[-2]]) not in test_image_dict:
test_image_dict['/'.join([info[-3],info[-2]])] = [] test_image_dict["/".join([info[-3], info[-2]])] = []
test_image_dict['/'.join([info[-3],info[-2]])].append(os.path.join(opt.source_path,entry)) test_image_dict["/".join([info[-3], info[-2]])].append(
os.path.join(opt.source_path, entry)
)
new_train_dict = {} new_train_dict = {}
class_ind_ind = 0 class_ind_ind = 0
for cate in train_image_dict: for cate in train_image_dict:
new_train_dict["te/%d"%class_ind_ind] = train_image_dict[cate] new_train_dict["te/%d" % class_ind_ind] = train_image_dict[cate]
class_ind_ind += 1 class_ind_ind += 1
train_image_dict = new_train_dict train_image_dict = new_train_dict
train_dataset = TrainDatasetsmoothap(train_image_dict, opt) train_dataset = TrainDatasetsmoothap(train_image_dict, opt)
val_dataset = BaseTripletDataset(val_image_dict, opt, is_validation=True) val_dataset = BaseTripletDataset(val_image_dict, opt, is_validation=True)
eval_dataset = BaseTripletDataset(test_image_dict, opt, is_validation=True) eval_dataset = BaseTripletDataset(test_image_dict, opt, is_validation=True)
# train_dataset.conversion = conversion
# val_dataset.conversion = conversion
# eval_dataset.conversion = conversion
#train_dataset.conversion = conversion return {
#val_dataset.conversion = conversion "training": train_dataset,
#eval_dataset.conversion = conversion "testing": val_dataset,
"evaluation": eval_dataset,
return {'training':train_dataset, 'testing':val_dataset, 'evaluation':eval_dataset} }
def give_inaturalist_datasets_for_features(opt): def give_inaturalist_datasets_for_features(opt):
...@@ -136,8 +171,14 @@ def give_inaturalist_datasets_for_features(opt): ...@@ -136,8 +171,14 @@ def give_inaturalist_datasets_for_features(opt):
train_image_dict, test_image_dict, eval_image_dict = {}, {}, {} train_image_dict, test_image_dict, eval_image_dict = {}, {}, {}
if opt.iter > 0: if opt.iter > 0:
with open(os.path.join(opt.cluster_path), 'rb') as clusterf: with open(os.path.join(opt.cluster_path), "rb") as clusterf:
path2idx, global_features, global_pred_labels, gt_labels, masks = pickle.load(clusterf) (
path2idx,
global_features,
global_pred_labels,
gt_labels,
masks,
) = pickle.load(clusterf)
gt_labels = gt_labels + len(np.unique(global_pred_labels)) gt_labels = gt_labels + len(np.unique(global_pred_labels))
for path, idx in path2idx.items(): for path, idx in path2idx.items():
...@@ -146,41 +187,54 @@ def give_inaturalist_datasets_for_features(opt): ...@@ -146,41 +187,54 @@ def give_inaturalist_datasets_for_features(opt):
test_image_dict["te/%d" % gt_labels[idx]] = [] test_image_dict["te/%d" % gt_labels[idx]] = []
test_image_dict["te/%d" % gt_labels[idx]].append(path) test_image_dict["te/%d" % gt_labels[idx]].append(path)
else: else:
if "te/%d" % global_pred_labels[idx] not in train_image_dict: if (
"te/%d" % global_pred_labels[idx]
not in train_image_dict
):
train_image_dict["te/%d" % global_pred_labels[idx]] = [] train_image_dict["te/%d" % global_pred_labels[idx]] = []
train_image_dict["te/%d" % global_pred_labels[idx]].append(path) train_image_dict["te/%d" % global_pred_labels[idx]].append(
path
)
if "te/%d" % global_pred_labels[idx] not in test_image_dict: if "te/%d" % global_pred_labels[idx] not in test_image_dict:
test_image_dict["te/%d" % global_pred_labels[idx]] = [] test_image_dict["te/%d" % global_pred_labels[idx]] = []
test_image_dict["te/%d" % global_pred_labels[idx]].append(path) test_image_dict["te/%d" % global_pred_labels[idx]].append(
path
)
else: else:
with open(os.path.join(opt.source_path, opt.trainset)) as f: with open(os.path.join(opt.source_path, opt.trainset)) as f:
FileLines = f.readlines() FileLines = f.readlines()
FileLines = [x.strip() for x in FileLines] FileLines = [x.strip() for x in FileLines]
for entry in FileLines: for entry in FileLines:
info = entry.split('/') info = entry.split("/")
if '/'.join([info[-3], info[-2]]) not in train_image_dict: if "/".join([info[-3], info[-2]]) not in train_image_dict:
train_image_dict['/'.join([info[-3], info[-2]])] = [] train_image_dict["/".join([info[-3], info[-2]])] = []
train_image_dict['/'.join([info[-3], info[-2]])].append(os.path.join(opt.source_path, entry)) train_image_dict["/".join([info[-3], info[-2]])].append(
os.path.join(opt.source_path, entry)
)
with open(os.path.join(opt.source_path, opt.all_trainset)) as f: with open(os.path.join(opt.source_path, opt.all_trainset)) as f:
FileLines = f.readlines() FileLines = f.readlines()
FileLines = [x.strip() for x in FileLines] FileLines = [x.strip() for x in FileLines]
for entry in FileLines: for entry in FileLines:
info = entry.split('/') info = entry.split("/")
if '/'.join([info[-3], info[-2]]) not in test_image_dict: if "/".join([info[-3], info[-2]]) not in test_image_dict:
test_image_dict['/'.join([info[-3], info[-2]])] = [] test_image_dict["/".join([info[-3], info[-2]])] = []
test_image_dict['/'.join([info[-3], info[-2]])].append(os.path.join(opt.source_path, entry)) test_image_dict["/".join([info[-3], info[-2]])].append(
os.path.join(opt.source_path, entry)
)
with open(os.path.join(opt.source_path, opt.testset)) as f: with open(os.path.join(opt.source_path, opt.testset)) as f:
FileLines = f.readlines() FileLines = f.readlines()
FileLines = [x.strip() for x in FileLines] FileLines = [x.strip() for x in FileLines]
for entry in FileLines: for entry in FileLines:
info = entry.split('/') info = entry.split("/")
if '/'.join([info[-3], info[-2]]) not in eval_image_dict: if "/".join([info[-3], info[-2]]) not in eval_image_dict:
eval_image_dict['/'.join([info[-3], info[-2]])] = [] eval_image_dict["/".join([info[-3], info[-2]])] = []
eval_image_dict['/'.join([info[-3], info[-2]])].append(os.path.join(opt.source_path, entry)) eval_image_dict["/".join([info[-3], info[-2]])].append(
os.path.join(opt.source_path, entry)
)
new_train_dict = {} new_train_dict = {}
class_ind_ind = 0 class_ind_ind = 0
...@@ -203,7 +257,9 @@ def give_inaturalist_datasets_for_features(opt): ...@@ -203,7 +257,9 @@ def give_inaturalist_datasets_for_features(opt):
class_ind_ind += 1 class_ind_ind += 1
eval_image_dict = new_eval_dict eval_image_dict = new_eval_dict
train_dataset = BaseTripletDataset(train_image_dict, opt, is_validation=True) train_dataset = BaseTripletDataset(
train_image_dict, opt, is_validation=True
)
test_dataset = BaseTripletDataset(test_image_dict, opt, is_validation=True) test_dataset = BaseTripletDataset(test_image_dict, opt, is_validation=True)
eval_dataset = BaseTripletDataset(eval_image_dict, opt, is_validation=True) eval_dataset = BaseTripletDataset(eval_image_dict, opt, is_validation=True)
...@@ -211,7 +267,12 @@ def give_inaturalist_datasets_for_features(opt): ...@@ -211,7 +267,12 @@ def give_inaturalist_datasets_for_features(opt):
# val_dataset.conversion = conversion # val_dataset.conversion = conversion
# eval_dataset.conversion = conversion # eval_dataset.conversion = conversion
return {'training': train_dataset, 'testing': test_dataset, 'eval': eval_dataset} return {
"training": train_dataset,
"testing": test_dataset,
"eval": eval_dataset,
}
def give_inat_datasets_finetune_1head(testset, cluster_label_path, opt): def give_inat_datasets_finetune_1head(testset, cluster_label_path, opt):
""" """
...@@ -226,9 +287,16 @@ def give_inat_datasets_finetune_1head(testset, cluster_label_path, opt): ...@@ -226,9 +287,16 @@ def give_inat_datasets_finetune_1head(testset, cluster_label_path, opt):
""" """
# Load cluster labels from hilander results. # Load cluster labels from hilander results.
import pickle import pickle
train_image_dict, val_image_dict, cluster_image_dict = {}, {}, {} train_image_dict, val_image_dict, cluster_image_dict = {}, {}, {}
with open(cluster_label_path, 'rb') as clusterf: with open(cluster_label_path, "rb") as clusterf:
path2idx, global_features, global_pred_labels, gt_labels, masks = pickle.load(clusterf) (
path2idx,
global_features,
global_pred_labels,
gt_labels,
masks,
) = pickle.load(clusterf)
for path, idx in path2idx.items(): for path, idx in path2idx.items():
if global_pred_labels[idx] == -1: if global_pred_labels[idx] == -1:
...@@ -243,10 +311,12 @@ def give_inat_datasets_finetune_1head(testset, cluster_label_path, opt): ...@@ -243,10 +311,12 @@ def give_inat_datasets_finetune_1head(testset, cluster_label_path, opt):
FileLines = [x.strip() for x in FileLines] FileLines = [x.strip() for x in FileLines]
for entry in FileLines: for entry in FileLines:
info = entry.split('/') info = entry.split("/")
if '/'.join([info[-3], info[-2]]) not in val_image_dict: if "/".join([info[-3], info[-2]]) not in val_image_dict:
val_image_dict['/'.join([info[-3], info[-2]])] = [] val_image_dict["/".join([info[-3], info[-2]])] = []
val_image_dict['/'.join([info[-3], info[-2]])].append(os.path.join(opt.source_path, entry)) val_image_dict["/".join([info[-3], info[-2]])].append(
os.path.join(opt.source_path, entry)
)
train_dataset = TrainDatasetsmoothap(train_image_dict, opt) train_dataset = TrainDatasetsmoothap(train_image_dict, opt)
...@@ -256,7 +326,11 @@ def give_inat_datasets_finetune_1head(testset, cluster_label_path, opt): ...@@ -256,7 +326,11 @@ def give_inat_datasets_finetune_1head(testset, cluster_label_path, opt):
# val_dataset.conversion = conversion # val_dataset.conversion = conversion
# eval_dataset.conversion = conversion # eval_dataset.conversion = conversion
return {'training': train_dataset, 'testing': val_dataset, 'evaluation': val_dataset} return {
"training": train_dataset,
"testing": val_dataset,
"evaluation": val_dataset,
}
################## BASIC PYTORCH DATASET USED FOR ALL DATASETS ################################## ################## BASIC PYTORCH DATASET USED FOR ALL DATASETS ##################################
...@@ -266,7 +340,10 @@ class BaseTripletDataset(Dataset): ...@@ -266,7 +340,10 @@ class BaseTripletDataset(Dataset):
This includes normalizing to ImageNet-standards, and Random & Resized cropping of shapes 224 for ResNet50 and 227 for This includes normalizing to ImageNet-standards, and Random & Resized cropping of shapes 224 for ResNet50 and 227 for
GoogLeNet during Training. During validation, only resizing to 256 or center cropping to 224/227 is performed. GoogLeNet during Training. During validation, only resizing to 256 or center cropping to 224/227 is performed.
""" """
def __init__(self, image_dict, opt, samples_per_class=8, is_validation=False):
def __init__(
self, image_dict, opt, samples_per_class=8, is_validation=False
):
""" """
Dataset Init-Function. Dataset Init-Function.
...@@ -278,49 +355,69 @@ class BaseTripletDataset(Dataset): ...@@ -278,49 +355,69 @@ class BaseTripletDataset(Dataset):
Returns: Returns:
Nothing! Nothing!
""" """
#Define length of dataset # Define length of dataset
self.n_files = np.sum([len(image_dict[key]) for key in image_dict.keys()]) self.n_files = np.sum(
[len(image_dict[key]) for key in image_dict.keys()]
)
self.is_validation = is_validation self.is_validation = is_validation
self.pars = opt self.pars = opt
self.image_dict = image_dict self.image_dict = image_dict
self.avail_classes = sorted(list(self.image_dict.keys())) self.avail_classes = sorted(list(self.image_dict.keys()))
#Convert image dictionary from classname:content to class_idx:content, because the initial indices are not necessarily from 0 - <n_classes>. # Convert image dictionary from classname:content to class_idx:content, because the initial indices are not necessarily from 0 - <n_classes>.
self.image_dict = {i:self.image_dict[key] for i,key in enumerate(self.avail_classes)} self.image_dict = {
i: self.image_dict[key] for i, key in enumerate(self.avail_classes)
}
self.avail_classes = sorted(list(self.image_dict.keys())) self.avail_classes = sorted(list(self.image_dict.keys()))
#Init. properties that are used when filling up batches. # Init. properties that are used when filling up batches.
if not self.is_validation: if not self.is_validation:
self.samples_per_class = samples_per_class self.samples_per_class = samples_per_class
#Select current class to sample images from up to <samples_per_class> # Select current class to sample images from up to <samples_per_class>
self.current_class = np.random.randint(len(self.avail_classes)) self.current_class = np.random.randint(len(self.avail_classes))
self.classes_visited = [self.current_class, self.current_class] self.classes_visited = [self.current_class, self.current_class]
self.n_samples_drawn = 0 self.n_samples_drawn = 0
#Data augmentation/processing methods. # Data augmentation/processing methods.
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
transf_list = [] transf_list = []
if not self.is_validation: if not self.is_validation:
transf_list.extend([transforms.RandomResizedCrop(size=224) if opt.arch=='resnet50' else transforms.RandomResizedCrop(size=227), transf_list.extend(
transforms.RandomHorizontalFlip(0.5)]) [
transforms.RandomResizedCrop(size=224)
if opt.arch == "resnet50"
else transforms.RandomResizedCrop(size=227),
transforms.RandomHorizontalFlip(0.5),
]
)
else: else:
transf_list.extend([transforms.Resize(256), transf_list.extend(
transforms.CenterCrop(224) if opt.arch=='resnet50' else transforms.CenterCrop(227)]) [
transforms.Resize(256),
transforms.CenterCrop(224)
if opt.arch == "resnet50"
else transforms.CenterCrop(227),
]
)
transf_list.extend([transforms.ToTensor(), normalize]) transf_list.extend([transforms.ToTensor(), normalize])
self.transform = transforms.Compose(transf_list) self.transform = transforms.Compose(transf_list)
#Convert Image-Dict to list of (image_path, image_class). Allows for easier direct sampling. # Convert Image-Dict to list of (image_path, image_class). Allows for easier direct sampling.
self.image_list = [[(x,key) for x in self.image_dict[key]] for key in self.image_dict.keys()] self.image_list = [
[(x, key) for x in self.image_dict[key]]
for key in self.image_dict.keys()
]
self.image_list = [x for y in self.image_list for x in y] self.image_list = [x for y in self.image_list for x in y]
#Flag that denotes if dataset is called for the first time. # Flag that denotes if dataset is called for the first time.
self.is_init = True self.is_init = True
def ensure_3dim(self, img): def ensure_3dim(self, img):
""" """
Function that ensures that the input img is three-dimensional. Function that ensures that the input img is three-dimensional.
...@@ -330,11 +427,10 @@ class BaseTripletDataset(Dataset): ...@@ -330,11 +427,10 @@ class BaseTripletDataset(Dataset):
Returns: Returns:
Checked PIL.Image img. Checked PIL.Image img.
""" """
if len(img.size)==2: if len(img.size) == 2:
img = img.convert('RGB') img = img.convert("RGB")
return img return img
def __getitem__(self, idx): def __getitem__(self, idx):
""" """
Args: Args:
...@@ -342,63 +438,99 @@ class BaseTripletDataset(Dataset): ...@@ -342,63 +438,99 @@ class BaseTripletDataset(Dataset):
Returns: Returns:
tuple of form (sample_class, torch.Tensor() of input image) tuple of form (sample_class, torch.Tensor() of input image)
""" """
if self.pars.loss == 'smoothap' or self.pars.loss == 'smoothap_element': if self.pars.loss == "smoothap" or self.pars.loss == "smoothap_element":
if self.is_init: if self.is_init:
#self.current_class = self.avail_classes[idx%len(self.avail_classes)] # self.current_class = self.avail_classes[idx%len(self.avail_classes)]
self.is_init = False self.is_init = False
if not self.is_validation:
if self.samples_per_class==1:
return self.image_list[idx][-1], self.transform(self.ensure_3dim(Image.open(self.image_list[idx][0])))
if self.n_samples_drawn==self.samples_per_class: if not self.is_validation:
#Once enough samples per class have been drawn, we choose another class to draw samples from. if self.samples_per_class == 1:
#Note that we ensure with self.classes_visited that no class is chosen if it had been chosen return self.image_list[idx][-1], self.transform(
#previously or one before that. self.ensure_3dim(Image.open(self.image_list[idx][0]))
)
if self.n_samples_drawn == self.samples_per_class:
# Once enough samples per class have been drawn, we choose another class to draw samples from.
# Note that we ensure with self.classes_visited that no class is chosen if it had been chosen
# previously or one before that.
counter = copy.deepcopy(self.avail_classes) counter = copy.deepcopy(self.avail_classes)
for prev_class in self.classes_visited: for prev_class in self.classes_visited:
if prev_class in counter: counter.remove(prev_class) if prev_class in counter:
counter.remove(prev_class)
self.current_class = counter[idx%len(counter)] self.current_class = counter[idx % len(counter)]
#self.classes_visited = self.classes_visited[1:]+[self.current_class] # self.classes_visited = self.classes_visited[1:]+[self.current_class]
# EDIT -> there can be no class repeats # EDIT -> there can be no class repeats
self.classes_visited = self.classes_visited+[self.current_class] self.classes_visited = self.classes_visited + [
self.current_class
]
self.n_samples_drawn = 0 self.n_samples_drawn = 0
class_sample_idx = idx%len(self.image_dict[self.current_class]) class_sample_idx = idx % len(
self.image_dict[self.current_class]
)
self.n_samples_drawn += 1 self.n_samples_drawn += 1
out_img = self.transform(self.ensure_3dim(Image.open(self.image_dict[self.current_class][class_sample_idx]))) out_img = self.transform(
return self.current_class,out_img self.ensure_3dim(
Image.open(
self.image_dict[self.current_class][
class_sample_idx
]
)
)
)
return self.current_class, out_img
else: else:
return self.image_list[idx][-1], self.transform(self.ensure_3dim(Image.open(self.image_list[idx][0]))) return self.image_list[idx][-1], self.transform(
self.ensure_3dim(Image.open(self.image_list[idx][0]))
)
else: else:
if self.is_init: if self.is_init:
self.current_class = self.avail_classes[idx%len(self.avail_classes)] self.current_class = self.avail_classes[
idx % len(self.avail_classes)
]
self.is_init = False self.is_init = False
if not self.is_validation: if not self.is_validation:
if self.samples_per_class==1: if self.samples_per_class == 1:
return self.image_list[idx][-1], self.transform(self.ensure_3dim(Image.open(self.image_list[idx][0]))) return self.image_list[idx][-1], self.transform(
self.ensure_3dim(Image.open(self.image_list[idx][0]))
if self.n_samples_drawn==self.samples_per_class: )
#Once enough samples per class have been drawn, we choose another class to draw samples from.
#Note that we ensure with self.classes_visited that no class is chosen if it had been chosen if self.n_samples_drawn == self.samples_per_class:
#previously or one before that. # Once enough samples per class have been drawn, we choose another class to draw samples from.
# Note that we ensure with self.classes_visited that no class is chosen if it had been chosen
# previously or one before that.
counter = copy.deepcopy(self.avail_classes) counter = copy.deepcopy(self.avail_classes)
for prev_class in self.classes_visited: for prev_class in self.classes_visited:
if prev_class in counter: counter.remove(prev_class) if prev_class in counter:
counter.remove(prev_class)
self.current_class = counter[idx%len(counter)] self.current_class = counter[idx % len(counter)]
self.classes_visited = self.classes_visited[1:]+[self.current_class] self.classes_visited = self.classes_visited[1:] + [
self.current_class
]
self.n_samples_drawn = 0 self.n_samples_drawn = 0
class_sample_idx = idx%len(self.image_dict[self.current_class]) class_sample_idx = idx % len(
self.image_dict[self.current_class]
)
self.n_samples_drawn += 1 self.n_samples_drawn += 1
out_img = self.transform(self.ensure_3dim(Image.open(self.image_dict[self.current_class][class_sample_idx]))) out_img = self.transform(
self.ensure_3dim(
Image.open(
self.image_dict[self.current_class][
class_sample_idx
]
)
)
)
return self.current_class, out_img return self.current_class, out_img
else: else:
return self.image_list[idx][-1], self.transform(self.ensure_3dim(Image.open(self.image_list[idx][0]))) return self.image_list[idx][-1], self.transform(
self.ensure_3dim(Image.open(self.image_list[idx][0]))
)
def __len__(self): def __len__(self):
return self.n_files return self.n_files
...@@ -408,11 +540,13 @@ flatten = lambda l: [item for sublist in l for item in sublist] ...@@ -408,11 +540,13 @@ flatten = lambda l: [item for sublist in l for item in sublist]
######################## dataset for SmoothAP regular training ################################## ######################## dataset for SmoothAP regular training ##################################
class TrainDatasetsmoothap(Dataset): class TrainDatasetsmoothap(Dataset):
""" """
This dataset class allows mini-batch formation pre-epoch, for greater speed This dataset class allows mini-batch formation pre-epoch, for greater speed
""" """
def __init__(self, image_dict, opt): def __init__(self, image_dict, opt):
""" """
Args: Args:
...@@ -428,33 +562,37 @@ class TrainDatasetsmoothap(Dataset): ...@@ -428,33 +562,37 @@ class TrainDatasetsmoothap(Dataset):
for instance in self.image_dict[sub]: for instance in self.image_dict[sub]:
newsub.append((sub, instance)) newsub.append((sub, instance))
self.image_dict[sub] = newsub self.image_dict[sub] = newsub
# checks # checks
# provide avail_classes # provide avail_classes
self.avail_classes = [*self.image_dict] self.avail_classes = [*self.image_dict]
# Data augmentation/processing methods. # Data augmentation/processing methods.
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
transf_list = [] transf_list = []
transf_list.extend(
transf_list.extend([ [
transforms.RandomResizedCrop(size=224) if opt.arch in ['resnet50', 'resnet50_mcn'] else transforms.RandomResizedCrop(size=227), transforms.RandomResizedCrop(size=224)
transforms.RandomHorizontalFlip(0.5)]) if opt.arch in ["resnet50", "resnet50_mcn"]
else transforms.RandomResizedCrop(size=227),
transforms.RandomHorizontalFlip(0.5),
]
)
transf_list.extend([transforms.ToTensor(), normalize]) transf_list.extend([transforms.ToTensor(), normalize])
self.transform = transforms.Compose(transf_list) self.transform = transforms.Compose(transf_list)
self.reshuffle() self.reshuffle()
def ensure_3dim(self, img): def ensure_3dim(self, img):
if len(img.size) == 2: if len(img.size) == 2:
img = img.convert('RGB') img = img.convert("RGB")
return img return img
def reshuffle(self): def reshuffle(self):
image_dict = copy.deepcopy(self.image_dict) image_dict = copy.deepcopy(self.image_dict)
print('shuffling data') print("shuffling data")
for sub in image_dict: for sub in image_dict:
random.shuffle(image_dict[sub]) random.shuffle(image_dict[sub])
...@@ -465,17 +603,22 @@ class TrainDatasetsmoothap(Dataset): ...@@ -465,17 +603,22 @@ class TrainDatasetsmoothap(Dataset):
finished = 0 finished = 0
while finished == 0: while finished == 0:
for sub_class in classes: for sub_class in classes:
if (len(image_dict[sub_class]) >=self.samples_per_class) and (len(batch) < self.batch_size/self.samples_per_class) : if (len(image_dict[sub_class]) >= self.samples_per_class) and (
batch.append(image_dict[sub_class][:self.samples_per_class]) len(batch) < self.batch_size / self.samples_per_class
image_dict[sub_class] = image_dict[sub_class][self.samples_per_class:] ):
batch.append(
if len(batch) == self.batch_size/self.samples_per_class: image_dict[sub_class][: self.samples_per_class]
)
image_dict[sub_class] = image_dict[sub_class][
self.samples_per_class :
]
if len(batch) == self.batch_size / self.samples_per_class:
total_batches.append(batch) total_batches.append(batch)
batch = [] batch = []
else: else:
finished = 1 finished = 1
random.shuffle(total_batches) random.shuffle(total_batches)
self.dataset = flatten(flatten(total_batches)) self.dataset = flatten(flatten(total_batches))
...@@ -483,16 +626,15 @@ class TrainDatasetsmoothap(Dataset): ...@@ -483,16 +626,15 @@ class TrainDatasetsmoothap(Dataset):
# we use SequentialSampler together with SuperLabelTrainDataset, # we use SequentialSampler together with SuperLabelTrainDataset,
# so idx==0 indicates the start of a new epoch # so idx==0 indicates the start of a new epoch
batch_item = self.dataset[idx] batch_item = self.dataset[idx]
if self.dataset_name == 'Inaturalist': if self.dataset_name == "Inaturalist":
cls = int(batch_item[0].split('/')[1]) cls = int(batch_item[0].split("/")[1])
else: else:
cls = batch_item[0] cls = batch_item[0]
img = Image.open(batch_item[1]) img = Image.open(batch_item[1])
return cls, self.transform(self.ensure_3dim(img)) return cls, self.transform(self.ensure_3dim(img))
def __len__(self): def __len__(self):
return len(self.dataset) return len(self.dataset)
...@@ -502,6 +644,7 @@ class TrainDatasetsmoothap1Head(Dataset): ...@@ -502,6 +644,7 @@ class TrainDatasetsmoothap1Head(Dataset):
This dataset class allows mini-batch formation pre-epoch, for greater speed This dataset class allows mini-batch formation pre-epoch, for greater speed
""" """
def __init__(self, image_dict_L, image_dict_U, opt): def __init__(self, image_dict_L, image_dict_U, opt):
""" """
Args: Args:
...@@ -518,29 +661,34 @@ class TrainDatasetsmoothap1Head(Dataset): ...@@ -518,29 +661,34 @@ class TrainDatasetsmoothap1Head(Dataset):
for instance in self.image_dict_L[sub_L]: for instance in self.image_dict_L[sub_L]:
newsub_L.append((sub_L, instance)) newsub_L.append((sub_L, instance))
self.image_dict_L[sub_L] = newsub_L self.image_dict_L[sub_L] = newsub_L
for sub_U in self.image_dict_U: for sub_U in self.image_dict_U:
newsub_U = [] newsub_U = []
for instance in self.image_dict_U[sub_U]: for instance in self.image_dict_U[sub_U]:
newsub_U.append((sub_U, instance)) newsub_U.append((sub_U, instance))
self.image_dict_U[sub_U] = newsub_U self.image_dict_U[sub_U] = newsub_U
# checks # checks
# provide avail_classes # provide avail_classes
self.avail_classes = [*self.image_dict_L] + [*self.image_dict_U] self.avail_classes = [*self.image_dict_L] + [*self.image_dict_U]
# Data augmentation/processing methods. # Data augmentation/processing methods.
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
transf_list = [] transf_list = []
transf_list.extend(
transf_list.extend([ [
transforms.RandomResizedCrop(size=224) if opt.arch in ['resnet50', 'resnet50_mcn'] else transforms.RandomResizedCrop(size=227), transforms.RandomResizedCrop(size=224)
transforms.RandomHorizontalFlip(0.5)]) if opt.arch in ["resnet50", "resnet50_mcn"]
else transforms.RandomResizedCrop(size=227),
transforms.RandomHorizontalFlip(0.5),
]
)
transf_list.extend([transforms.ToTensor(), normalize]) transf_list.extend([transforms.ToTensor(), normalize])
self.transform = transforms.Compose(transf_list) self.transform = transforms.Compose(transf_list)
self.reshuffle() self.reshuffle()
def sample_same_size(self): def sample_same_size(self):
image_dict = copy.deepcopy(self.image_dict_L) image_dict = copy.deepcopy(self.image_dict_L)
...@@ -548,7 +696,7 @@ class TrainDatasetsmoothap1Head(Dataset): ...@@ -548,7 +696,7 @@ class TrainDatasetsmoothap1Head(Dataset):
L_size = 0 L_size = 0
for sub_L in self.image_dict_L: for sub_L in self.image_dict_L:
L_size += len(self.image_dict_L[sub_L]) L_size += len(self.image_dict_L[sub_L])
U_size = 0 U_size = 0
classes_U = [*self.image_dict_U] classes_U = [*self.image_dict_U]
# while U_size < len(list(self.image_dict_U)) and U_size < L_size: # while U_size < len(list(self.image_dict_U)) and U_size < L_size:
...@@ -562,17 +710,15 @@ class TrainDatasetsmoothap1Head(Dataset): ...@@ -562,17 +710,15 @@ class TrainDatasetsmoothap1Head(Dataset):
image_dict[sub_U] = self.image_dict_U[sub_U] image_dict[sub_U] = self.image_dict_U[sub_U]
U_size += sub_U_size U_size += sub_U_size
return image_dict return image_dict
def ensure_3dim(self, img): def ensure_3dim(self, img):
if len(img.size) == 2: if len(img.size) == 2:
img = img.convert('RGB') img = img.convert("RGB")
return img return img
def reshuffle(self): def reshuffle(self):
image_dict = self.sample_same_size() image_dict = self.sample_same_size()
print('shuffling data') print("shuffling data")
for sub in image_dict: for sub in image_dict:
random.shuffle(image_dict[sub]) random.shuffle(image_dict[sub])
...@@ -583,33 +729,36 @@ class TrainDatasetsmoothap1Head(Dataset): ...@@ -583,33 +729,36 @@ class TrainDatasetsmoothap1Head(Dataset):
finished = 0 finished = 0
while finished == 0: while finished == 0:
for sub_class in classes: for sub_class in classes:
if (len(image_dict[sub_class]) >=self.samples_per_class) and (len(batch) < self.batch_size/self.samples_per_class) : if (len(image_dict[sub_class]) >= self.samples_per_class) and (
batch.append(image_dict[sub_class][:self.samples_per_class]) len(batch) < self.batch_size / self.samples_per_class
image_dict[sub_class] = image_dict[sub_class][self.samples_per_class:] ):
batch.append(
if len(batch) == self.batch_size/self.samples_per_class: image_dict[sub_class][: self.samples_per_class]
)
image_dict[sub_class] = image_dict[sub_class][
self.samples_per_class :
]
if len(batch) == self.batch_size / self.samples_per_class:
total_batches.append(batch) total_batches.append(batch)
batch = [] batch = []
else: else:
finished = 1 finished = 1
random.shuffle(total_batches) random.shuffle(total_batches)
self.dataset = flatten(flatten(total_batches)) self.dataset = flatten(flatten(total_batches))
def __getitem__(self, idx): def __getitem__(self, idx):
# we use SequentialSampler together with SuperLabelTrainDataset, # we use SequentialSampler together with SuperLabelTrainDataset,
# so idx==0 indicates the start of a new epoch # so idx==0 indicates the start of a new epoch
batch_item = self.dataset[idx] batch_item = self.dataset[idx]
if self.dataset_name == 'Inaturalist': if self.dataset_name == "Inaturalist":
cls = int(batch_item[0].split('/')[1]) cls = int(batch_item[0].split("/")[1])
else: else:
cls = batch_item[0] cls = batch_item[0]
img = Image.open(str(batch_item[1])) img = Image.open(str(batch_item[1]))
return cls, self.transform(self.ensure_3dim(img)) return cls, self.transform(self.ensure_3dim(img))
def __len__(self): def __len__(self):
return len(self.dataset) return len(self.dataset)
...@@ -3,26 +3,31 @@ ...@@ -3,26 +3,31 @@
##################################### LIBRARIES ########################################### ##################################### LIBRARIES ###########################################
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import numpy as np, time, pickle as pkl, csv import csv
import matplotlib.pyplot as plt import pickle as pkl
import time
import auxiliaries as aux
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.multiprocessing
import torch.nn as nn
from scipy.spatial import distance from scipy.spatial import distance
from sklearn.preprocessing import normalize from sklearn.preprocessing import normalize
from tqdm import tqdm from tqdm import tqdm
import torch, torch.nn as nn torch.multiprocessing.set_sharing_strategy("file_system")
import auxiliaries as aux
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
"""==================================================================================================================""" """=================================================================================================================="""
"""==================================================================================================================""" """=================================================================================================================="""
"""=========================================================""" """========================================================="""
def evaluate(dataset, LOG, **kwargs): def evaluate(dataset, LOG, **kwargs):
""" """
Given a dataset name, applies the correct evaluation function. Given a dataset name, applies the correct evaluation function.
...@@ -34,23 +39,26 @@ def evaluate(dataset, LOG, **kwargs): ...@@ -34,23 +39,26 @@ def evaluate(dataset, LOG, **kwargs):
Returns: Returns:
(optional) Computed metrics. Are normally written directly to LOG and printed. (optional) Computed metrics. Are normally written directly to LOG and printed.
""" """
if dataset in ['Inaturalist', 'semi_fungi']: if dataset in ["Inaturalist", "semi_fungi"]:
ret = evaluate_one_dataset(LOG, **kwargs) ret = evaluate_one_dataset(LOG, **kwargs)
elif dataset in ['vehicle_id']: elif dataset in ["vehicle_id"]:
ret = evaluate_multiple_datasets(LOG, **kwargs) ret = evaluate_multiple_datasets(LOG, **kwargs)
else: else:
raise Exception('No implementation for dataset {} available!') raise Exception("No implementation for dataset {} available!")
return ret return ret
"""=========================================================""" """========================================================="""
class DistanceMeasure():
class DistanceMeasure:
""" """
Container class to run and log the change of distance ratios Container class to run and log the change of distance ratios
between intra-class distances and inter-class distances. between intra-class distances and inter-class distances.
""" """
def __init__(self, checkdata, opt, name='Train', update_epochs=1):
def __init__(self, checkdata, opt, name="Train", update_epochs=1):
""" """
Args: Args:
checkdata: PyTorch DataLoader, data to check distance progression. checkdata: PyTorch DataLoader, data to check distance progression.
...@@ -61,20 +69,21 @@ class DistanceMeasure(): ...@@ -61,20 +69,21 @@ class DistanceMeasure():
Nothing! Nothing!
""" """
self.update_epochs = update_epochs self.update_epochs = update_epochs
self.pars = opt self.pars = opt
self.save_path = opt.save_path self.save_path = opt.save_path
self.name = name self.name = name
self.csv_file = opt.save_path+'/distance_measures_{}.csv'.format(self.name) self.csv_file = opt.save_path + "/distance_measures_{}.csv".format(
with open(self.csv_file,'a') as csv_file: self.name
writer = csv.writer(csv_file, delimiter=',') )
writer.writerow(['Rel. Intra/Inter Distance']) with open(self.csv_file, "a") as csv_file:
writer = csv.writer(csv_file, delimiter=",")
writer.writerow(["Rel. Intra/Inter Distance"])
self.checkdata = checkdata self.checkdata = checkdata
self.mean_class_dists = [] self.mean_class_dists = []
self.epochs = [] self.epochs = []
def measure(self, model, epoch): def measure(self, model, epoch):
""" """
...@@ -86,7 +95,8 @@ class DistanceMeasure(): ...@@ -86,7 +95,8 @@ class DistanceMeasure():
Returns: Returns:
Nothing! Nothing!
""" """
if epoch%self.update_epochs: return if epoch % self.update_epochs:
return
self.epochs.append(epoch) self.epochs.append(epoch)
...@@ -94,46 +104,55 @@ class DistanceMeasure(): ...@@ -94,46 +104,55 @@ class DistanceMeasure():
_ = model.eval() _ = model.eval()
#Compute Embeddings # Compute Embeddings
with torch.no_grad(): with torch.no_grad():
feature_coll, target_coll = [],[] feature_coll, target_coll = [], []
data_iter = tqdm(self.checkdata, desc='Estimating Data Distances...') data_iter = tqdm(
self.checkdata, desc="Estimating Data Distances..."
)
for idx, data in enumerate(data_iter): for idx, data in enumerate(data_iter):
input_img, target = data[1], data[0] input_img, target = data[1], data[0]
features = model(input_img.to(self.pars.device)) features = model(input_img.to(self.pars.device))
feature_coll.extend(features.cpu().detach().numpy().tolist()) feature_coll.extend(features.cpu().detach().numpy().tolist())
target_coll.extend(target.numpy().tolist()) target_coll.extend(target.numpy().tolist())
feature_coll = np.vstack(feature_coll).astype('float32') feature_coll = np.vstack(feature_coll).astype("float32")
target_coll = np.hstack(target_coll).reshape(-1) target_coll = np.hstack(target_coll).reshape(-1)
avail_labels = np.unique(target_coll) avail_labels = np.unique(target_coll)
#Compute indixes of embeddings for each class. # Compute indixes of embeddings for each class.
class_positions = [] class_positions = []
for lab in avail_labels: for lab in avail_labels:
class_positions.append(np.where(target_coll==lab)[0]) class_positions.append(np.where(target_coll == lab)[0])
#Compute average intra-class distance and center of mass. # Compute average intra-class distance and center of mass.
com_class, dists_class = [],[] com_class, dists_class = [], []
for class_pos in class_positions: for class_pos in class_positions:
dists = distance.cdist(feature_coll[class_pos],feature_coll[class_pos],'cosine') dists = distance.cdist(
dists = np.sum(dists)/(len(dists)**2-len(dists)) feature_coll[class_pos], feature_coll[class_pos], "cosine"
)
dists = np.sum(dists) / (len(dists) ** 2 - len(dists))
# dists = np.linalg.norm(np.std(feature_coll_aux[class_pos],axis=0).reshape(1,-1)).reshape(-1) # dists = np.linalg.norm(np.std(feature_coll_aux[class_pos],axis=0).reshape(1,-1)).reshape(-1)
com = normalize(np.mean(feature_coll[class_pos],axis=0).reshape(1,-1)).reshape(-1) com = normalize(
np.mean(feature_coll[class_pos], axis=0).reshape(1, -1)
).reshape(-1)
dists_class.append(dists) dists_class.append(dists)
com_class.append(com) com_class.append(com)
#Compute mean inter-class distances by the class-coms. # Compute mean inter-class distances by the class-coms.
mean_inter_dist = distance.cdist(np.array(com_class), np.array(com_class), 'cosine') mean_inter_dist = distance.cdist(
mean_inter_dist = np.sum(mean_inter_dist)/(len(mean_inter_dist)**2-len(mean_inter_dist)) np.array(com_class), np.array(com_class), "cosine"
)
mean_inter_dist = np.sum(mean_inter_dist) / (
len(mean_inter_dist) ** 2 - len(mean_inter_dist)
)
#Compute distance ratio # Compute distance ratio
mean_class_dist = np.mean(np.array(dists_class)/mean_inter_dist) mean_class_dist = np.mean(np.array(dists_class) / mean_inter_dist)
self.mean_class_dists.append(mean_class_dist) self.mean_class_dists.append(mean_class_dist)
self.update(mean_class_dist) self.update(mean_class_dist)
def update(self, mean_class_dist): def update(self, mean_class_dist):
""" """
Update Loggers. Update Loggers.
...@@ -146,7 +165,6 @@ class DistanceMeasure(): ...@@ -146,7 +165,6 @@ class DistanceMeasure():
self.update_csv(mean_class_dist) self.update_csv(mean_class_dist)
self.update_plot() self.update_plot()
def update_csv(self, mean_class_dist): def update_csv(self, mean_class_dist):
""" """
Update CSV. Update CSV.
...@@ -156,11 +174,10 @@ class DistanceMeasure(): ...@@ -156,11 +174,10 @@ class DistanceMeasure():
Returns: Returns:
Nothing! Nothing!
""" """
with open(self.csv_file, 'a') as csv_file: with open(self.csv_file, "a") as csv_file:
writer = csv.writer(csv_file, delimiter=',') writer = csv.writer(csv_file, delimiter=",")
writer.writerow([mean_class_dist]) writer.writerow([mean_class_dist])
def update_plot(self): def update_plot(self):
""" """
Update progression plot. Update progression plot.
...@@ -170,24 +187,25 @@ class DistanceMeasure(): ...@@ -170,24 +187,25 @@ class DistanceMeasure():
Returns: Returns:
Nothing! Nothing!
""" """
plt.style.use('ggplot') plt.style.use("ggplot")
f,ax = plt.subplots(1) f, ax = plt.subplots(1)
ax.set_title('Mean Intra- over Interclassdistances') ax.set_title("Mean Intra- over Interclassdistances")
ax.plot(self.epochs, self.mean_class_dists, label='Class') ax.plot(self.epochs, self.mean_class_dists, label="Class")
f.legend() f.legend()
f.set_size_inches(15,8) f.set_size_inches(15, 8)
f.savefig(self.save_path+'/distance_measures_{}.svg'.format(self.name)) f.savefig(
self.save_path + "/distance_measures_{}.svg".format(self.name)
)
class GradientMeasure:
class GradientMeasure():
""" """
Container for gradient measure functionalities. Container for gradient measure functionalities.
Measure the gradients coming from the embedding layer to the final conv. layer Measure the gradients coming from the embedding layer to the final conv. layer
to examine learning signal. to examine learning signal.
""" """
def __init__(self, opt, name='class-it'):
def __init__(self, opt, name="class-it"):
""" """
Args: Args:
opt: argparse.Namespace, contains all training-specific parameters. opt: argparse.Namespace, contains all training-specific parameters.
...@@ -195,10 +213,14 @@ class GradientMeasure(): ...@@ -195,10 +213,14 @@ class GradientMeasure():
Returns: Returns:
Nothing! Nothing!
""" """
self.pars = opt self.pars = opt
self.name = name self.name = name
self.saver = {'grad_normal_mean':[], 'grad_normal_std':[], 'grad_abs_mean':[], 'grad_abs_std':[]} self.saver = {
"grad_normal_mean": [],
"grad_normal_std": [],
"grad_abs_mean": [],
"grad_abs_std": [],
}
def include(self, params): def include(self, params):
""" """
...@@ -213,10 +235,10 @@ class GradientMeasure(): ...@@ -213,10 +235,10 @@ class GradientMeasure():
for grad in gradients: for grad in gradients:
### Shape: 128 x 2048 ### Shape: 128 x 2048
self.saver['grad_normal_mean'].append(np.mean(grad,axis=0)) self.saver["grad_normal_mean"].append(np.mean(grad, axis=0))
self.saver['grad_normal_std'].append(np.std(grad,axis=0)) self.saver["grad_normal_std"].append(np.std(grad, axis=0))
self.saver['grad_abs_mean'].append(np.mean(np.abs(grad),axis=0)) self.saver["grad_abs_mean"].append(np.mean(np.abs(grad), axis=0))
self.saver['grad_abs_std'].append(np.std(np.abs(grad),axis=0)) self.saver["grad_abs_std"].append(np.std(np.abs(grad), axis=0))
def dump(self, epoch): def dump(self, epoch):
""" """
...@@ -227,15 +249,24 @@ class GradientMeasure(): ...@@ -227,15 +249,24 @@ class GradientMeasure():
Returns: Returns:
Nothing! Nothing!
""" """
with open(self.pars.save_path+'/grad_dict_{}.pkl'.format(self.name),'ab') as f: with open(
self.pars.save_path + "/grad_dict_{}.pkl".format(self.name), "ab"
) as f:
pkl.dump([self.saver], f) pkl.dump([self.saver], f)
self.saver = {'grad_normal_mean':[], 'grad_normal_std':[], 'grad_abs_mean':[], 'grad_abs_std':[]} self.saver = {
"grad_normal_mean": [],
"grad_normal_std": [],
"grad_abs_mean": [],
"grad_abs_std": [],
}
"""========================================================="""
"""=========================================================""" def evaluate_one_dataset(
def evaluate_one_dataset(LOG, dataloader, model, opt, save=True, give_return=True, epoch=0): LOG, dataloader, model, opt, save=True, give_return=True, epoch=0
):
""" """
Compute evaluation metrics, update LOGGER and print results. Compute evaluation metrics, update LOGGER and print results.
...@@ -254,23 +285,54 @@ def evaluate_one_dataset(LOG, dataloader, model, opt, save=True, give_return=Tru ...@@ -254,23 +285,54 @@ def evaluate_one_dataset(LOG, dataloader, model, opt, save=True, give_return=Tru
image_paths = np.array(dataloader.dataset.image_list) image_paths = np.array(dataloader.dataset.image_list)
with torch.no_grad(): with torch.no_grad():
#Compute Metrics # Compute Metrics
F1, NMI, recall_at_ks, feature_matrix_all = aux.eval_metrics_one_dataset(model, dataloader, device=opt.device, k_vals=opt.k_vals, opt=opt) (
#Make printable summary string. F1,
NMI,
result_str = ', '.join('@{0}: {1:.4f}'.format(k,rec) for k,rec in zip(opt.k_vals, recall_at_ks)) recall_at_ks,
result_str = 'Epoch (Test) {0}: NMI [{1:.4f}] | F1 [{2:.4f}] | Recall [{3}]'.format(epoch, NMI, F1, result_str) feature_matrix_all,
) = aux.eval_metrics_one_dataset(
model, dataloader, device=opt.device, k_vals=opt.k_vals, opt=opt
)
# Make printable summary string.
result_str = ", ".join(
"@{0}: {1:.4f}".format(k, rec)
for k, rec in zip(opt.k_vals, recall_at_ks)
)
result_str = "Epoch (Test) {0}: NMI [{1:.4f}] | F1 [{2:.4f}] | Recall [{3}]".format(
epoch, NMI, F1, result_str
)
if LOG is not None: if LOG is not None:
if save: if save:
if not len(LOG.progress_saver['val']['Recall @ 1']) or recall_at_ks[0]>np.max(LOG.progress_saver['val']['Recall @ 1']): if not len(
#Save Checkpoint LOG.progress_saver["val"]["Recall @ 1"]
print("Set checkpoint at {}.".format(LOG.prop.save_path+'/checkpoint_{}.pth.tar'.format(opt.iter))) ) or recall_at_ks[0] > np.max(
aux.set_checkpoint(model, opt, LOG.progress_saver, LOG.prop.save_path+'/checkpoint_{}.pth.tar'.format(opt.iter)) LOG.progress_saver["val"]["Recall @ 1"]
):
# Save Checkpoint
print(
"Set checkpoint at {}.".format(
LOG.prop.save_path
+ "/checkpoint_{}.pth.tar".format(opt.iter)
)
)
aux.set_checkpoint(
model,
opt,
LOG.progress_saver,
LOG.prop.save_path
+ "/checkpoint_{}.pth.tar".format(opt.iter),
)
# aux.recover_closest_one_dataset(feature_matrix_all, image_paths, LOG.prop.save_path+'/sample_recoveries.png') # aux.recover_closest_one_dataset(feature_matrix_all, image_paths, LOG.prop.save_path+'/sample_recoveries.png')
#Update logs. # Update logs.
LOG.log('val', LOG.metrics_to_log['val'], [epoch, np.round(time.time()-start), NMI, F1]+recall_at_ks) LOG.log(
"val",
LOG.metrics_to_log["val"],
[epoch, np.round(time.time() - start), NMI, F1] + recall_at_ks,
)
print(result_str) print(result_str)
if give_return: if give_return:
return recall_at_ks, NMI, F1 return recall_at_ks, NMI, F1
...@@ -278,10 +340,19 @@ def evaluate_one_dataset(LOG, dataloader, model, opt, save=True, give_return=Tru ...@@ -278,10 +340,19 @@ def evaluate_one_dataset(LOG, dataloader, model, opt, save=True, give_return=Tru
None None
"""=========================================================""" """========================================================="""
def evaluate_query_and_gallery_dataset(LOG, query_dataloader, gallery_dataloader, model, opt, save=True, give_return=True, epoch=0):
def evaluate_query_and_gallery_dataset(
LOG,
query_dataloader,
gallery_dataloader,
model,
opt,
save=True,
give_return=True,
epoch=0,
):
""" """
Compute evaluation metrics, update LOGGER and print results, specifically for In-Shop Clothes. Compute evaluation metrics, update LOGGER and print results, specifically for In-Shop Clothes.
...@@ -298,24 +369,65 @@ def evaluate_query_and_gallery_dataset(LOG, query_dataloader, gallery_dataloade ...@@ -298,24 +369,65 @@ def evaluate_query_and_gallery_dataset(LOG, query_dataloader, gallery_dataloade
(optional) Computed metrics. Are normally written directly to LOG and printed. (optional) Computed metrics. Are normally written directly to LOG and printed.
""" """
start = time.time() start = time.time()
query_image_paths = np.array([x[0] for x in query_dataloader.dataset.image_list]) query_image_paths = np.array(
gallery_image_paths = np.array([x[0] for x in gallery_dataloader.dataset.image_list]) [x[0] for x in query_dataloader.dataset.image_list]
)
gallery_image_paths = np.array(
[x[0] for x in gallery_dataloader.dataset.image_list]
)
with torch.no_grad(): with torch.no_grad():
#Compute Metri cs. # Compute Metri cs.
F1, NMI, recall_at_ks, query_feature_matrix_all, gallery_feature_matrix_all = aux.eval_metrics_query_and_gallery_dataset(model, query_dataloader, gallery_dataloader, device=opt.device, k_vals = opt.k_vals, opt=opt) (
#Generate printable summary string. F1,
result_str = ', '.join('@{0}: {1:.4f}'.format(k,rec) for k,rec in zip(opt.k_vals, recall_at_ks)) NMI,
result_str = 'Epoch (Test) {0}: NMI [{1:.4f}] | F1 [{2:.4f}] | Recall [{3}]'.format(epoch, NMI, F1, result_str) recall_at_ks,
query_feature_matrix_all,
gallery_feature_matrix_all,
) = aux.eval_metrics_query_and_gallery_dataset(
model,
query_dataloader,
gallery_dataloader,
device=opt.device,
k_vals=opt.k_vals,
opt=opt,
)
# Generate printable summary string.
result_str = ", ".join(
"@{0}: {1:.4f}".format(k, rec)
for k, rec in zip(opt.k_vals, recall_at_ks)
)
result_str = "Epoch (Test) {0}: NMI [{1:.4f}] | F1 [{2:.4f}] | Recall [{3}]".format(
epoch, NMI, F1, result_str
)
if LOG is not None: if LOG is not None:
if save: if save:
if not len(LOG.progress_saver['val']['Recall @ 1']) or recall_at_ks[0]>np.max(LOG.progress_saver['val']['Recall @ 1']): if not len(
#Save Checkpoint LOG.progress_saver["val"]["Recall @ 1"]
aux.set_checkpoint(model, opt, LOG.progress_saver, LOG.prop.save_path+'/checkpoint.pth.tar') ) or recall_at_ks[0] > np.max(
aux.recover_closest_inshop(query_feature_matrix_all, gallery_feature_matrix_all, query_image_paths, gallery_image_paths, LOG.prop.save_path+'/sample_recoveries.png') LOG.progress_saver["val"]["Recall @ 1"]
#Update logs. ):
LOG.log('val', LOG.metrics_to_log['val'], [epoch, np.round(time.time()-start), NMI, F1]+recall_at_ks) # Save Checkpoint
aux.set_checkpoint(
model,
opt,
LOG.progress_saver,
LOG.prop.save_path + "/checkpoint.pth.tar",
)
aux.recover_closest_inshop(
query_feature_matrix_all,
gallery_feature_matrix_all,
query_image_paths,
gallery_image_paths,
LOG.prop.save_path + "/sample_recoveries.png",
)
# Update logs.
LOG.log(
"val",
LOG.metrics_to_log["val"],
[epoch, np.round(time.time() - start), NMI, F1] + recall_at_ks,
)
print(result_str) print(result_str)
if give_return: if give_return:
...@@ -324,14 +436,16 @@ def evaluate_query_and_gallery_dataset(LOG, query_dataloader, gallery_dataloade ...@@ -324,14 +436,16 @@ def evaluate_query_and_gallery_dataset(LOG, query_dataloader, gallery_dataloade
None None
"""========================================================="""
"""=========================================================""" def evaluate_multiple_datasets(
def evaluate_multiple_datasets(LOG, dataloaders, model, opt, save=True, give_return=True, epoch=0): LOG, dataloaders, model, opt, save=True, give_return=True, epoch=0
):
""" """
Compute evaluation metrics, update LOGGER and print results, specifically for Multi-test datasets s.a. PKU Vehicle ID. Compute evaluation metrics, update LOGGER and print results, specifically for Multi-test datasets s.a. PKU Vehicle ID.
Args: Args:
LOG: aux.LOGGER-instance. Main Logging Functionality. LOG: aux.LOGGER-instance. Main Logging Functionality.
dataloaders: List of PyTorch Dataloaders, test-dataloaders to evaluate. dataloaders: List of PyTorch Dataloaders, test-dataloaders to evaluate.
model: PyTorch Network, Network to evaluate. model: PyTorch Network, Network to evaluate.
...@@ -342,36 +456,62 @@ def evaluate_multiple_datasets(LOG, dataloaders, model, opt, save=True, give_ret ...@@ -342,36 +456,62 @@ def evaluate_multiple_datasets(LOG, dataloaders, model, opt, save=True, give_ret
Returns : Returns :
(optional) Computed metrics. Are normally written directly to LOG and printed. (optional) Computed metrics. Are normally written directly to LOG and printed.
""" """
start = time.time() start = time.time()
csv_data = [epoch] csv_data = [epoch]
with torch.no_grad(): with torch.no_grad():
for i,dataloader in enumerate(dataloaders): for i, dataloader in enumerate(dataloaders):
print('Working on Set {}/{}'.format(i+1, len(dataloaders))) print("Working on Set {}/{}".format(i + 1, len(dataloaders)))
image_paths = np.array(dataloader.dataset.image_list) image_paths = np.array(dataloader.dataset.image_list)
#Compute Metrics for specific testset. # Compute Metrics for specific testset.
F1, NMI, recall_at_ks, feature_matrix_all = aux.eval_metrics_one_dataset(model, dataloader, device=opt.device, k_vals=opt.k_vals, opt=opt) (
#Generate printable summary string. F1,
result_str = ', '.join('@{0}: {1:.4f}'.format(k,rec) for k,rec in zip(opt.k_vals, recall_at_ks)) NMI,
result_str = 'SET {0}: Epoch (Test) {1}: NMI [{2:.4f}] | F1 {3:.4f}| Recall [{4}]'.format(i+1, epoch, NMI, F1, result_str) recall_at_ks,
feature_matrix_all,
) = aux.eval_metrics_one_dataset(
model, dataloader, device=opt.device, k_vals=opt.k_vals, opt=opt
)
# Generate printable summary string.
result_str = ", ".join(
"@{0}: {1:.4f}".format(k, rec)
for k, rec in zip(opt.k_vals, recall_at_ks)
)
result_str = "SET {0}: Epoch (Test) {1}: NMI [{2:.4f}] | F1 {3:.4f}| Recall [{4}]".format(
i + 1, epoch, NMI, F1, result_str
)
if LOG is not None: if LOG is not None:
if save: if save:
if not len(LOG.progress_saver['val']['Set {} Recall @ 1'.format(i)]) or recall_at_ks[0]>np.max(LOG.progress_saver['val']['Set {} Recall @ 1'.format(i)]): if not len(
#Save Checkpoint for specific test set. LOG.progress_saver["val"]["Set {} Recall @ 1".format(i)]
aux.set_checkpoint(model, opt, LOG.progress_saver, LOG.prop.save_path+'/checkpoint_set{}.pth.tar'.format(i+1)) ) or recall_at_ks[0] > np.max(
aux.recover_closest_one_dataset(feature_matrix_all, image_paths, LOG.prop.save_path+'/sample_recoveries_set{}.png'.format(i+1)) LOG.progress_saver["val"]["Set {} Recall @ 1".format(i)]
):
csv_data += [NMI, F1]+recall_at_ks # Save Checkpoint for specific test set.
aux.set_checkpoint(
model,
opt,
LOG.progress_saver,
LOG.prop.save_path
+ "/checkpoint_set{}.pth.tar".format(i + 1),
)
aux.recover_closest_one_dataset(
feature_matrix_all,
image_paths,
LOG.prop.save_path
+ "/sample_recoveries_set{}.png".format(i + 1),
)
csv_data += [NMI, F1] + recall_at_ks
print(result_str) print(result_str)
csv_data.insert(0, np.round(time.time()-start))
#Update logs.
LOG.log('val', LOG.metrics_to_log['val'], csv_data)
csv_data.insert(0, np.round(time.time() - start))
#if give_return: # Update logs.
LOG.log("val", LOG.metrics_to_log["val"], csv_data)
# if give_return:
return csv_data[2:] return csv_data[2:]
#else: # else:
# None # None
\ No newline at end of file
import os, torch, argparse import argparse
import netlib as netlib import os
import auxiliaries as aux import auxiliaries as aux
import datasets as data import datasets as data
import evaluate as eval import evaluate as eval
import netlib as netlib
import torch
if __name__ == '__main__': if __name__ == "__main__":
################## INPUT ARGUMENTS ################### ################## INPUT ARGUMENTS ###################
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
####### Main Parameter: Dataset to use for Training ####### Main Parameter: Dataset to use for Training
parser.add_argument('--dataset', default='vehicle_id', type=str, help='Dataset to use.', parser.add_argument(
choices=['Inaturalist', 'vehicle_id']) "--dataset",
parser.add_argument('--source_path', default='/scratch/shared/beegfs/abrown/datasets', type=str, default="vehicle_id",
help='Path to training data.') type=str,
parser.add_argument('--save_path', default=os.getcwd() + '/Training_Results', type=str, help="Dataset to use.",
help='Where to save everything.') choices=["Inaturalist", "vehicle_id"],
parser.add_argument('--savename', default='', type=str, )
help='Save folder name if any special information is to be included.') parser.add_argument(
"--source_path",
default="/scratch/shared/beegfs/abrown/datasets",
type=str,
help="Path to training data.",
)
parser.add_argument(
"--save_path",
default=os.getcwd() + "/Training_Results",
type=str,
help="Where to save everything.",
)
parser.add_argument(
"--savename",
default="",
type=str,
help="Save folder name if any special information is to be included.",
)
### General Training Parameters ### General Training Parameters
parser.add_argument('--kernels', default=8, type=int, help='Number of workers for pytorch dataloader.') parser.add_argument(
parser.add_argument('--bs', default=112, type=int, help='Mini-Batchsize to use.') "--kernels",
parser.add_argument('--samples_per_class', default=4, type=int,help='Number of samples in one class drawn before choosing the next class. Set to >1 for losses other than ProxyNCA.') default=8,
parser.add_argument('--loss', default='smoothap', type=str) type=int,
help="Number of workers for pytorch dataloader.",
)
parser.add_argument(
"--bs", default=112, type=int, help="Mini-Batchsize to use."
)
parser.add_argument(
"--samples_per_class",
default=4,
type=int,
help="Number of samples in one class drawn before choosing the next class. Set to >1 for losses other than ProxyNCA.",
)
parser.add_argument("--loss", default="smoothap", type=str)
##### Evaluation Settings ##### Evaluation Settings
parser.add_argument('--k_vals', nargs='+', default=[1, 2, 4, 8], type=int, help='Recall @ Values.') parser.add_argument(
"--k_vals",
nargs="+",
default=[1, 2, 4, 8],
type=int,
help="Recall @ Values.",
)
##### Network parameters ##### Network parameters
parser.add_argument('--embed_dim', default=512, type=int, parser.add_argument(
help='Embedding dimensionality of the network. Note: in literature, dim=128 is used for ResNet50 and dim=512 for GoogLeNet.') "--embed_dim",
parser.add_argument('--arch', default='resnet50', type=str, default=512,
help='Network backend choice: resnet50, googlenet, BNinception') type=int,
parser.add_argument('--gpu', default=0, type=int, help='GPU-id for GPU to use.') help="Embedding dimensionality of the network. Note: in literature, dim=128 is used for ResNet50 and dim=512 for GoogLeNet.",
parser.add_argument('--resume', default='', type=str, help='path to where weights to be evaluated are saved.') )
parser.add_argument('--not_pretrained', action='store_true', parser.add_argument(
help='If added, the network will be trained WITHOUT ImageNet-pretrained weights.') "--arch",
default="resnet50",
parser.add_argument('--trainset', default="lin_train_set1.txt", type=str) type=str,
parser.add_argument('--testset', default="Inaturalist_test_set1.txt", type=str) help="Network backend choice: resnet50, googlenet, BNinception",
parser.add_argument('--cluster_path', default="", type=str) )
parser.add_argument('--finetune', default="false", type=str) parser.add_argument(
parser.add_argument('--class_num', default=948, type=int) "--gpu", default=0, type=int, help="GPU-id for GPU to use."
parser.add_argument('--get_features', default="false", type=str) )
parser.add_argument('--patch_size', default=16, type=int, help='vit patch size') parser.add_argument(
parser.add_argument('--pretrained_weights', default="", type=str, help='pretrained weight path') "--resume",
parser.add_argument('--use_bn_in_head', default=False, type=aux.bool_flag, default="",
help="Whether to use batch normalizations in projection head (Default: False)") type=str,
parser.add_argument("--checkpoint_key", default="teacher", type=str, help="path to where weights to be evaluated are saved.",
help='Key to use in the checkpoint (example: "teacher")') )
parser.add_argument('--drop_path_rate', default=0.1, type=float, help="stochastic depth rate") parser.add_argument(
parser.add_argument('--norm_last_layer', default=True, type=aux.bool_flag, "--not_pretrained",
help="""Whether or not to weight normalize the last layer of the DINO head. action="store_true",
help="If added, the network will be trained WITHOUT ImageNet-pretrained weights.",
)
parser.add_argument("--trainset", default="lin_train_set1.txt", type=str)
parser.add_argument(
"--testset", default="Inaturalist_test_set1.txt", type=str
)
parser.add_argument("--cluster_path", default="", type=str)
parser.add_argument("--finetune", default="false", type=str)
parser.add_argument("--class_num", default=948, type=int)
parser.add_argument("--get_features", default="false", type=str)
parser.add_argument(
"--patch_size", default=16, type=int, help="vit patch size"
)
parser.add_argument(
"--pretrained_weights",
default="",
type=str,
help="pretrained weight path",
)
parser.add_argument(
"--use_bn_in_head",
default=False,
type=aux.bool_flag,
help="Whether to use batch normalizations in projection head (Default: False)",
)
parser.add_argument(
"--checkpoint_key",
default="teacher",
type=str,
help='Key to use in the checkpoint (example: "teacher")',
)
parser.add_argument(
"--drop_path_rate",
default=0.1,
type=float,
help="stochastic depth rate",
)
parser.add_argument(
"--norm_last_layer",
default=True,
type=aux.bool_flag,
help="""Whether or not to weight normalize the last layer of the DINO head.
Not normalizing leads to better performance but can make the training unstable. Not normalizing leads to better performance but can make the training unstable.
In our experiments, we typically set this paramater to False with vit_small and True with vit_base.""") In our experiments, we typically set this paramater to False with vit_small and True with vit_base.""",
parser.add_argument('--linsize', default=29011, type=int, help="Lin data size.") )
parser.add_argument('--uinsize', default=18403, type=int, help="Uin data size.") parser.add_argument(
"--linsize", default=29011, type=int, help="Lin data size."
)
parser.add_argument(
"--uinsize", default=18403, type=int, help="Uin data size."
)
opt = parser.parse_args() opt = parser.parse_args()
"""============================================================================""" """============================================================================"""
opt.source_path += '/' + opt.dataset opt.source_path += "/" + opt.dataset
if opt.dataset == 'Inaturalist': if opt.dataset == "Inaturalist":
opt.n_epochs = 90 opt.n_epochs = 90
opt.tau = [40, 70] opt.tau = [40, 70]
opt.k_vals = [1, 4, 16, 32] opt.k_vals = [1, 4, 16, 32]
if opt.dataset == 'vehicle_id': if opt.dataset == "vehicle_id":
opt.k_vals = [1, 5] opt.k_vals = [1, 5]
if opt.finetune == 'true': if opt.finetune == "true":
opt.finetune = True opt.finetune = True
elif opt.finetune == 'false': elif opt.finetune == "false":
opt.finetune = False opt.finetune = False
if opt.get_features == 'true': if opt.get_features == "true":
opt.get_features = True opt.get_features = True
elif opt.get_features == 'false': elif opt.get_features == "false":
opt.get_features = False opt.get_features = False
metrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals) metrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals)
LOG = aux.LOGGER(opt, metrics_to_log, name='Base', start_new=True) LOG = aux.LOGGER(opt, metrics_to_log, name="Base", start_new=True)
"""============================================================================""" """============================================================================"""
##################### NETWORK SETUP ################## ##################### NETWORK SETUP ##################
opt.device = torch.device('cuda') opt.device = torch.device("cuda")
model = netlib.networkselect(opt) model = netlib.networkselect(opt)
# Push to Device # Push to Device
...@@ -96,18 +182,32 @@ if __name__ == '__main__': ...@@ -96,18 +182,32 @@ if __name__ == '__main__':
# The 'testing'-dataloader corresponds to the validation set, and the 'evaluation'-dataloader # The 'testing'-dataloader corresponds to the validation set, and the 'evaluation'-dataloader
# Is simply using the training set, however running under the same rules as 'testing' dataloader, # Is simply using the training set, however running under the same rules as 'testing' dataloader,
# i.e. no shuffling and no random cropping. # i.e. no shuffling and no random cropping.
dataloaders = data.give_dataloaders(opt.dataset, opt.trainset, opt.testset, opt) dataloaders = data.give_dataloaders(
opt.dataset, opt.trainset, opt.testset, opt
)
# Because the number of supervised classes is dataset dependent, we store them after # Because the number of supervised classes is dataset dependent, we store them after
# initializing the dataloader # initializing the dataloader
opt.num_classes = len(dataloaders['training'].dataset.avail_classes) opt.num_classes = len(dataloaders["training"].dataset.avail_classes)
if opt.dataset == 'Inaturalist': if opt.dataset == "Inaturalist":
eval_params = {'dataloader': dataloaders['testing'], 'model': model, 'opt': opt, 'epoch': 0} eval_params = {
"dataloader": dataloaders["testing"],
"model": model,
"opt": opt,
"epoch": 0,
}
elif opt.dataset == 'vehicle_id': elif opt.dataset == "vehicle_id":
eval_params = { eval_params = {
'dataloaders': [dataloaders['testing_set1'], dataloaders['testing_set2'], dataloaders['testing_set3']], "dataloaders": [
'model': model, 'opt': opt, 'epoch': 0} dataloaders["testing_set1"],
dataloaders["testing_set2"],
dataloaders["testing_set3"],
],
"model": model,
"opt": opt,
"epoch": 0,
}
"""============================================================================""" """============================================================================"""
####################evaluation ################## ####################evaluation ##################
......
...@@ -12,213 +12,363 @@ need to change all of the copyrights at the top of all of the files ...@@ -12,213 +12,363 @@ need to change all of the copyrights at the top of all of the files
#################### LIBRARIES ######################## #################### LIBRARIES ########################
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import os, numpy as np, argparse, random, matplotlib, datetime import argparse
import datetime
import os
import random
import matplotlib
import numpy as np
os.chdir(os.path.dirname(os.path.realpath(__file__))) os.chdir(os.path.dirname(os.path.realpath(__file__)))
from pathlib import Path from pathlib import Path
matplotlib.use('agg')
from tqdm import tqdm
matplotlib.use("agg")
import auxiliaries as aux import auxiliaries as aux
import datasets as data import datasets as data
import netlib as netlib
import losses as losses
import evaluate as eval import evaluate as eval
from tensorboardX import SummaryWriter import losses as losses
import netlib as netlib
import torch.multiprocessing import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system') from tensorboardX import SummaryWriter
from tqdm import tqdm
torch.multiprocessing.set_sharing_strategy("file_system")
import time import time
start = time.time() start = time.time()
################### INPUT ARGUMENTS ################### ################### INPUT ARGUMENTS ###################
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
####### Main Parameter: Dataset to use for Training ####### Main Parameter: Dataset to use for Training
parser.add_argument('--dataset', default='Inaturalist', type=str, help='Dataset to use.', choices=['Inaturalist','semi_fungi']) parser.add_argument(
"--dataset",
default="Inaturalist",
type=str,
help="Dataset to use.",
choices=["Inaturalist", "semi_fungi"],
)
### General Training Parameters ### General Training Parameters
parser.add_argument('--lr', default=0.00001, type=float, help='Learning Rate for network parameters.') parser.add_argument(
parser.add_argument('--fc_lr_mul', default=5, type=float, help='OPTIONAL: Multiply the embedding layer learning rate by this value. If set to 0, the embedding layer shares the same learning rate.') "--lr",
parser.add_argument('--n_epochs', default=400, type=int, help='Number of training epochs.') default=0.00001,
parser.add_argument('--kernels', default=8, type=int, help='Number of workers for pytorch dataloader.') type=float,
parser.add_argument('--bs', default=112 , type=int, help='Mini-Batchsize to use.') help="Learning Rate for network parameters.",
parser.add_argument('--samples_per_class', default=4, type=int, help='Number of samples in one class drawn before choosing the next class') )
parser.add_argument('--seed', default=1, type=int, help='Random seed for reproducibility.') parser.add_argument(
parser.add_argument('--scheduler', default='step', type=str, help='Type of learning rate scheduling. Currently: step & exp.') "--fc_lr_mul",
parser.add_argument('--gamma', default=0.3, type=float, help='Learning rate reduction after tau epochs.') default=5,
parser.add_argument('--decay', default=0.001, type=float, help='Weight decay for optimizer.') type=float,
parser.add_argument('--tau', default= [200,300],nargs='+',type=int,help='Stepsize(s) before reducing learning rate.') help="OPTIONAL: Multiply the embedding layer learning rate by this value. If set to 0, the embedding layer shares the same learning rate.",
parser.add_argument('--infrequent_eval', default=0,type=int, help='only compute evaluation metrics every 10 epochs') )
parser.add_argument('--opt', default = 'adam',help='adam or sgd') parser.add_argument(
"--n_epochs", default=400, type=int, help="Number of training epochs."
)
parser.add_argument(
"--kernels",
default=8,
type=int,
help="Number of workers for pytorch dataloader.",
)
parser.add_argument(
"--bs", default=112, type=int, help="Mini-Batchsize to use."
)
parser.add_argument(
"--samples_per_class",
default=4,
type=int,
help="Number of samples in one class drawn before choosing the next class",
)
parser.add_argument(
"--seed", default=1, type=int, help="Random seed for reproducibility."
)
parser.add_argument(
"--scheduler",
default="step",
type=str,
help="Type of learning rate scheduling. Currently: step & exp.",
)
parser.add_argument(
"--gamma",
default=0.3,
type=float,
help="Learning rate reduction after tau epochs.",
)
parser.add_argument(
"--decay", default=0.001, type=float, help="Weight decay for optimizer."
)
parser.add_argument(
"--tau",
default=[200, 300],
nargs="+",
type=int,
help="Stepsize(s) before reducing learning rate.",
)
parser.add_argument(
"--infrequent_eval",
default=0,
type=int,
help="only compute evaluation metrics every 10 epochs",
)
parser.add_argument("--opt", default="adam", help="adam or sgd")
##### Loss-specific Settings ##### Loss-specific Settings
parser.add_argument('--loss', default='smoothap', type=str) parser.add_argument("--loss", default="smoothap", type=str)
parser.add_argument('--sigmoid_temperature', default=0.01, type=float, help='SmoothAP: the temperature of the sigmoid used in SmoothAP loss') parser.add_argument(
"--sigmoid_temperature",
default=0.01,
type=float,
help="SmoothAP: the temperature of the sigmoid used in SmoothAP loss",
)
##### Evaluation Settings ##### Evaluation Settings
parser.add_argument('--k_vals', nargs='+', default=[1,2,4,8], type=int, help='Recall @ Values.') parser.add_argument(
parser.add_argument('--resume', default='', type=str, help='path to checkpoint to load weights from (if empty then ImageNet pre-trained weights are loaded') "--k_vals",
nargs="+",
default=[1, 2, 4, 8],
type=int,
help="Recall @ Values.",
)
parser.add_argument(
"--resume",
default="",
type=str,
help="path to checkpoint to load weights from (if empty then ImageNet pre-trained weights are loaded",
)
##### Network parameters ##### Network parameters
parser.add_argument('--embed_dim', default=512, type=int, help='Embedding dimensionality of the network') parser.add_argument(
parser.add_argument('--arch', default='resnet50', type=str, help='Network backend choice: resnet50, googlenet, BNinception') "--embed_dim",
parser.add_argument('--grad_measure', action='store_true', help='If added, gradients passed from embedding layer to the last conv-layer are stored in each iteration.') default=512,
parser.add_argument('--dist_measure', action='store_true', help='If added, the ratio between intra- and interclass distances is stored after each epoch.') type=int,
parser.add_argument('--not_pretrained', action='store_true', help='If added, the network will be trained WITHOUT ImageNet-pretrained weights.') help="Embedding dimensionality of the network",
)
parser.add_argument(
"--arch",
default="resnet50",
type=str,
help="Network backend choice: resnet50, googlenet, BNinception",
)
parser.add_argument(
"--grad_measure",
action="store_true",
help="If added, gradients passed from embedding layer to the last conv-layer are stored in each iteration.",
)
parser.add_argument(
"--dist_measure",
action="store_true",
help="If added, the ratio between intra- and interclass distances is stored after each epoch.",
)
parser.add_argument(
"--not_pretrained",
action="store_true",
help="If added, the network will be trained WITHOUT ImageNet-pretrained weights.",
)
##### Setup Parameters ##### Setup Parameters
parser.add_argument('--gpu', default=0, type=int, help='GPU-id for GPU to use.') parser.add_argument("--gpu", default=0, type=int, help="GPU-id for GPU to use.")
parser.add_argument('--savename', default='', type=str, help='Save folder name if any special information is to be included.') parser.add_argument(
"--savename",
default="",
type=str,
help="Save folder name if any special information is to be included.",
)
### Paths to datasets and storage folder ### Paths to datasets and storage folder
parser.add_argument('--source_path', default='/scratch/shared/beegfs/abrown/datasets', type=str, help='Path to data') parser.add_argument(
parser.add_argument('--save_path', default=os.getcwd()+'/Training_Results', type=str, help='Where to save the checkpoints') "--source_path",
default="/scratch/shared/beegfs/abrown/datasets",
type=str,
help="Path to data",
)
parser.add_argument(
"--save_path",
default=os.getcwd() + "/Training_Results",
type=str,
help="Where to save the checkpoints",
)
### additional parameters ### additional parameters
parser.add_argument('--trainset', default="lin_train_set1.txt", type=str) parser.add_argument("--trainset", default="lin_train_set1.txt", type=str)
parser.add_argument('--testset', default="Inaturalist_test_set1.txt", type=str) parser.add_argument("--testset", default="Inaturalist_test_set1.txt", type=str)
parser.add_argument('--cluster_path', default="", type=str) parser.add_argument("--cluster_path", default="", type=str)
parser.add_argument('--finetune', default='true', type=str) parser.add_argument("--finetune", default="true", type=str)
parser.add_argument('--class_num', default=948, type=int) parser.add_argument("--class_num", default=948, type=int)
parser.add_argument('--pretrained_weights', default="", type=str, help='pretrained weight path') parser.add_argument(
parser.add_argument('--use_bn_in_head', default=False, type=aux.bool_flag, "--pretrained_weights", default="", type=str, help="pretrained weight path"
help="Whether to use batch normalizations in projection head (Default: False)") )
parser.add_argument("--checkpoint_key", default="teacher", type=str, parser.add_argument(
help='Key to use in the checkpoint (example: "teacher")') "--use_bn_in_head",
parser.add_argument('--drop_path_rate', default=0.1, type=float, help="stochastic depth rate") default=False,
parser.add_argument('--iter', default=1, type=int) type=aux.bool_flag,
help="Whether to use batch normalizations in projection head (Default: False)",
)
parser.add_argument(
"--checkpoint_key",
default="teacher",
type=str,
help='Key to use in the checkpoint (example: "teacher")',
)
parser.add_argument(
"--drop_path_rate", default=0.1, type=float, help="stochastic depth rate"
)
parser.add_argument("--iter", default=1, type=int)
opt = parser.parse_args() opt = parser.parse_args()
"""============================================================================""" """============================================================================"""
opt.source_path += '/' + opt.dataset opt.source_path += "/" + opt.dataset
opt.save_path += '/' + opt.dataset + "_" + str(opt.embed_dim) opt.save_path += "/" + opt.dataset + "_" + str(opt.embed_dim)
if opt.dataset== 'Inaturalist': if opt.dataset == "Inaturalist":
# opt.n_epochs = 90 # opt.n_epochs = 90
opt.tau = [40, 70] opt.tau = [40, 70]
opt.k_vals = [1,4,16,32] opt.k_vals = [1, 4, 16, 32]
if opt.dataset=='semi_fungi': if opt.dataset == "semi_fungi":
opt.tau = [40,70] opt.tau = [40, 70]
opt.k_vals = [1,4,16,32] opt.k_vals = [1, 4, 16, 32]
if opt.finetune == 'true': if opt.finetune == "true":
opt.finetune = True opt.finetune = True
elif opt.finetune == 'false': elif opt.finetune == "false":
opt.finetune = False opt.finetune = False
"""===========================================================================""" """==========================================================================="""
################### TensorBoard Settings ################## ################### TensorBoard Settings ##################
timestamp = datetime.datetime.now().strftime(r"%Y-%m-%d_%H-%M-%S") timestamp = datetime.datetime.now().strftime(r"%Y-%m-%d_%H-%M-%S")
exp_name = aux.args2exp_name(opt) exp_name = aux.args2exp_name(opt)
opt.save_name = f"weights_{exp_name}" +'/'+ timestamp opt.save_name = f"weights_{exp_name}" + "/" + timestamp
random.seed(opt.seed) random.seed(opt.seed)
np.random.seed(opt.seed) np.random.seed(opt.seed)
torch.manual_seed(opt.seed) torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed); torch.cuda.manual_seed_all(opt.seed) torch.cuda.manual_seed(opt.seed)
torch.cuda.manual_seed_all(opt.seed)
tensorboard_path = Path(f"logs/logs_{exp_name}") / timestamp tensorboard_path = Path(f"logs/logs_{exp_name}") / timestamp
tensorboard_path.parent.mkdir(exist_ok=True, parents=True) tensorboard_path.parent.mkdir(exist_ok=True, parents=True)
global writer; global writer
writer = SummaryWriter(tensorboard_path) writer = SummaryWriter(tensorboard_path)
"""============================================================================""" """============================================================================"""
################### GPU SETTINGS ########################### ################### GPU SETTINGS ###########################
os.environ["CUDA_DEVICE_ORDER"] ="PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]= str(opt.gpu) # os.environ["CUDA_VISIBLE_DEVICES"]= str(opt.gpu)
print('using #GPUs:',torch.cuda.device_count()) print("using #GPUs:", torch.cuda.device_count())
"""============================================================================""" """============================================================================"""
#################### DATALOADER SETUPS ################## #################### DATALOADER SETUPS ##################
#Returns a dictionary containing 'training', 'testing', and 'evaluation' dataloaders. # Returns a dictionary containing 'training', 'testing', and 'evaluation' dataloaders.
#The 'testing'-dataloader corresponds to the validation set, and the 'evaluation'-dataloader # The 'testing'-dataloader corresponds to the validation set, and the 'evaluation'-dataloader
#Is simply using the training set, however running under the same rules as 'testing' dataloader, # Is simply using the training set, however running under the same rules as 'testing' dataloader,
#i.e. no shuffling and no random cropping. # i.e. no shuffling and no random cropping.
dataloaders = data.give_dataloaders(opt.dataset, opt.trainset, opt.testset, opt, cluster_path=opt.cluster_path) dataloaders = data.give_dataloaders(
#Because the number of supervised classes is dataset dependent, we store them after opt.dataset, opt.trainset, opt.testset, opt, cluster_path=opt.cluster_path
#initializing the dataloader )
opt.num_classes = len(dataloaders['training'].dataset.avail_classes) # Because the number of supervised classes is dataset dependent, we store them after
# initializing the dataloader
opt.num_classes = len(dataloaders["training"].dataset.avail_classes)
print("num_classes:", opt.num_classes) print("num_classes:", opt.num_classes)
print("train dataset size:", len(dataloaders['training'])) print("train dataset size:", len(dataloaders["training"]))
"""============================================================================""" """============================================================================"""
##################### NETWORK SETUP ################## ##################### NETWORK SETUP ##################
opt.device = torch.device('cuda') opt.device = torch.device("cuda")
model = netlib.networkselect(opt) model = netlib.networkselect(opt)
#Push to Device # Push to Device
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
_ = model.to(opt.device) _ = model.to(opt.device)
#Place trainable parameter in list of parameters to train: # Place trainable parameter in list of parameters to train:
if 'fc_lr_mul' in vars(opt).keys() and opt.fc_lr_mul!=0: if "fc_lr_mul" in vars(opt).keys() and opt.fc_lr_mul != 0:
all_but_fc_params = list(filter(lambda x: 'last_linear' not in x[0],model.named_parameters())) all_but_fc_params = list(
filter(lambda x: "last_linear" not in x[0], model.named_parameters())
)
for ind, param in enumerate(all_but_fc_params): for ind, param in enumerate(all_but_fc_params):
all_but_fc_params[ind] = param[1] all_but_fc_params[ind] = param[1]
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
fc_params = model.module.model.last_linear.parameters() fc_params = model.module.model.last_linear.parameters()
else: else:
fc_params = model.model.last_linear.parameters() fc_params = model.model.last_linear.parameters()
to_optim = [{'params':all_but_fc_params,'lr':opt.lr,'weight_decay':opt.decay}, to_optim = [
{'params':fc_params,'lr':opt.lr*opt.fc_lr_mul,'weight_decay':opt.decay}] {"params": all_but_fc_params, "lr": opt.lr, "weight_decay": opt.decay},
{
"params": fc_params,
"lr": opt.lr * opt.fc_lr_mul,
"weight_decay": opt.decay,
},
]
else: else:
to_optim = [{'params':model.parameters(),'lr':opt.lr,'weight_decay':opt.decay}] to_optim = [
{"params": model.parameters(), "lr": opt.lr, "weight_decay": opt.decay}
]
"""============================================================================""" """============================================================================"""
#################### CREATE LOGGING FILES ############### #################### CREATE LOGGING FILES ###############
#Each dataset usually has a set of standard metrics to log. aux.metrics_to_examine() # Each dataset usually has a set of standard metrics to log. aux.metrics_to_examine()
#returns a dict which lists metrics to log for training ('train') and validation/testing ('val') # returns a dict which lists metrics to log for training ('train') and validation/testing ('val')
metrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals) metrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals)
# example output: {'train': ['Epochs', 'Time', 'Train Loss', 'Time'], # example output: {'train': ['Epochs', 'Time', 'Train Loss', 'Time'],
# 'val': ['Epochs','Time','NMI','F1', 'Recall @ 1','Recall @ 2','Recall @ 4','Recall @ 8']} # 'val': ['Epochs','Time','NMI','F1', 'Recall @ 1','Recall @ 2','Recall @ 4','Recall @ 8']}
#Using the provided metrics of interest, we generate a LOGGER instance. # Using the provided metrics of interest, we generate a LOGGER instance.
#Note that 'start_new' denotes that a new folder should be made in which everything will be stored. # Note that 'start_new' denotes that a new folder should be made in which everything will be stored.
#This includes network weights as well. # This includes network weights as well.
LOG = aux.LOGGER(opt, metrics_to_log, name='Base', start_new=True) LOG = aux.LOGGER(opt, metrics_to_log, name="Base", start_new=True)
#If graphviz is installed on the system, a computational graph of the underlying # If graphviz is installed on the system, a computational graph of the underlying
#network will be made as well. # network will be made as well.
"""============================================================================""" """============================================================================"""
#################### LOSS SETUP #################### #################### LOSS SETUP ####################
#Depending on opt.loss and opt.sampling, the respective criterion is returned, # Depending on opt.loss and opt.sampling, the respective criterion is returned,
#and if the loss has trainable parameters, to_optim is appended. # and if the loss has trainable parameters, to_optim is appended.
criterion, to_optim = losses.loss_select(opt.loss, opt, to_optim) criterion, to_optim = losses.loss_select(opt.loss, opt, to_optim)
_ = criterion.to(opt.device) _ = criterion.to(opt.device)
"""============================================================================""" """============================================================================"""
##################### OPTIONAL EVALUATIONS ##################### ##################### OPTIONAL EVALUATIONS #####################
#Store the averaged gradients returned from the embedding to the last conv. layer. # Store the averaged gradients returned from the embedding to the last conv. layer.
if opt.grad_measure: if opt.grad_measure:
grad_measure = eval.GradientMeasure(opt, name='baseline') grad_measure = eval.GradientMeasure(opt, name="baseline")
#Store the relative distances between average intra- and inter-class distance. # Store the relative distances between average intra- and inter-class distance.
if opt.dist_measure: if opt.dist_measure:
#Add a distance measure for training distance ratios # Add a distance measure for training distance ratios
distance_measure = eval.DistanceMeasure(dataloaders['evaluation'], opt, name='Train', update_epochs=1) distance_measure = eval.DistanceMeasure(
dataloaders["evaluation"], opt, name="Train", update_epochs=1
)
# #If uncommented: Do the same for the test set # #If uncommented: Do the same for the test set
# distance_measure_test = eval.DistanceMeasure(dataloaders['testing'], opt, name='Train', update_epochs=1) # distance_measure_test = eval.DistanceMeasure(dataloaders['testing'], opt, name='Train', update_epochs=1)
"""============================================================================""" """============================================================================"""
#################### OPTIM SETUP #################### #################### OPTIM SETUP ####################
#As optimizer, Adam with standard parameters is used. # As optimizer, Adam with standard parameters is used.
if opt.opt == 'adam': if opt.opt == "adam":
optimizer = torch.optim.Adam(to_optim) optimizer = torch.optim.Adam(to_optim)
elif opt.opt == 'sgd': elif opt.opt == "sgd":
optimizer = torch.optim.SGD(to_optim) optimizer = torch.optim.SGD(to_optim)
else: else:
raise Exception('unknown optimiser') raise Exception("unknown optimiser")
# for the SOA measures in the paper - need to use SGD and 0.05 learning rate # for the SOA measures in the paper - need to use SGD and 0.05 learning rate
#optimizer = torch.optim.Adam(to_optim) # optimizer = torch.optim.Adam(to_optim)
#optimizer = torch.optim.SGD(to_optim) # optimizer = torch.optim.SGD(to_optim)
if opt.scheduler=='exp': if opt.scheduler == "exp":
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=opt.gamma) scheduler = torch.optim.lr_scheduler.ExponentialLR(
elif opt.scheduler=='step': optimizer, gamma=opt.gamma
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.tau, gamma=opt.gamma) )
elif opt.scheduler=='none': elif opt.scheduler == "step":
print('No scheduling used!') scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=opt.tau, gamma=opt.gamma
)
elif opt.scheduler == "none":
print("No scheduling used!")
else: else:
raise Exception('No scheduling option for input: {}'.format(opt.scheduler)) raise Exception("No scheduling option for input: {}".format(opt.scheduler))
def same_model(model1,model2): def same_model(model1, model2):
for p1, p2 in zip(model1.parameters(), model2.parameters()): for p1, p2 in zip(model1.parameters(), model2.parameters()):
if p1.data.ne(p2.data).sum() > 0: if p1.data.ne(p2.data).sum() > 0:
return False return False
...@@ -227,7 +377,9 @@ def same_model(model1,model2): ...@@ -227,7 +377,9 @@ def same_model(model1,model2):
"""============================================================================""" """============================================================================"""
#################### TRAINER FUNCTION ############################ #################### TRAINER FUNCTION ############################
def train_one_epoch_finetune(train_dataloader, model, optimizer, criterion, opt, epoch): def train_one_epoch_finetune(
train_dataloader, model, optimizer, criterion, opt, epoch
):
""" """
This function is called every epoch to perform training of the network over one full This function is called every epoch to perform training of the network over one full
(randomized) iteration of the dataset. (randomized) iteration of the dataset.
...@@ -244,106 +396,138 @@ def train_one_epoch_finetune(train_dataloader, model, optimizer, criterion, opt, ...@@ -244,106 +396,138 @@ def train_one_epoch_finetune(train_dataloader, model, optimizer, criterion, opt,
Nothing! Nothing!
""" """
loss_collect = [] loss_collect = []
start = time.time() start = time.time()
data_iterator = tqdm(train_dataloader, desc='Epoch {} Training gt labels...'.format(epoch)) data_iterator = tqdm(
for i,(class_labels, input) in enumerate(data_iterator): train_dataloader, desc="Epoch {} Training gt labels...".format(epoch)
)
for i, (class_labels, input) in enumerate(data_iterator):
#Compute embeddings for input batch # Compute embeddings for input batch
features = model(input.to(opt.device)) features = model(input.to(opt.device))
#Compute loss. # Compute loss.
if opt.loss != 'smoothap': if opt.loss != "smoothap":
loss = criterion(features, class_labels) loss = criterion(features, class_labels)
else: else:
loss = criterion(features) loss = criterion(features)
#Ensure gradients are set to zero at beginning # Ensure gradients are set to zero at beginning
optimizer.zero_grad() optimizer.zero_grad()
#Compute gradient # Compute gradient
loss.backward() loss.backward()
train_dataloader.dataset.classes_visited = [] train_dataloader.dataset.classes_visited = []
if opt.grad_measure: if opt.grad_measure:
#If desired, save computed gradients. # If desired, save computed gradients.
grad_measure.include(model.model.last_linear) grad_measure.include(model.model.last_linear)
#Update weights using comp. gradients. # Update weights using comp. gradients.
optimizer.step() optimizer.step()
#Store loss per iteration. # Store loss per iteration.
loss_collect.append(loss.item()) loss_collect.append(loss.item())
if i==len(train_dataloader)-1: if i == len(train_dataloader) - 1:
data_iterator.set_description('Epoch (Train) {0}: Mean Loss [{1:.4f}]'.format(epoch, np.mean(loss_collect))) data_iterator.set_description(
"Epoch (Train) {0}: Mean Loss [{1:.4f}]".format(
#Save metrics epoch, np.mean(loss_collect)
LOG.log('train', LOG.metrics_to_log['train'], [epoch, np.round(time.time()-start,4), np.mean(loss_collect)]) )
writer.add_scalar('global/training_loss',np.mean(loss_collect),epoch) )
# Save metrics
LOG.log(
"train",
LOG.metrics_to_log["train"],
[epoch, np.round(time.time() - start, 4), np.mean(loss_collect)],
)
writer.add_scalar("global/training_loss", np.mean(loss_collect), epoch)
if opt.grad_measure: if opt.grad_measure:
#Dump stored gradients to Pickle-File. # Dump stored gradients to Pickle-File.
grad_measure.dump(epoch) grad_measure.dump(epoch)
"""============================================================================""" """============================================================================"""
"""========================== MAIN TRAINING PART ==============================""" """========================== MAIN TRAINING PART =============================="""
"""============================================================================""" """============================================================================"""
################### SCRIPT MAIN ########################## ################### SCRIPT MAIN ##########################
print('\n-----\n') print("\n-----\n")
# Each dataset requires slightly different dataloaders. # Each dataset requires slightly different dataloaders.
if opt.dataset == 'Inaturalist' or 'semi_fungi': if opt.dataset == "Inaturalist" or "semi_fungi":
eval_params = {'dataloader': dataloaders['testing'], 'model': model, 'opt': opt, 'epoch': 0} eval_params = {
"dataloader": dataloaders["testing"],
"model": model,
"opt": opt,
"epoch": 0,
}
# Compute Evaluation metrics, print them and store in LOG. # Compute Evaluation metrics, print them and store in LOG.
print('epochs -> '+str(opt.n_epochs)) print("epochs -> " + str(opt.n_epochs))
import time import time
for epoch in range(opt.n_epochs): for epoch in range(opt.n_epochs):
### Print current learning rates for all parameters ### Print current learning rates for all parameters
if opt.scheduler!='none': print('Running with learning rates {}...'.format(' | '.join('{}'.format(x) for x in scheduler.get_lr()))) if opt.scheduler != "none":
print(
"Running with learning rates {}...".format(
" | ".join("{}".format(x) for x in scheduler.get_lr())
)
)
### Train one epoch ### Train one epoch
_ = model.train() _ = model.train()
train_one_epoch_finetune(dataloaders['training'], model, optimizer, criterion, opt, epoch)
dataloaders['training'].dataset.reshuffle() train_one_epoch_finetune(
dataloaders["training"], model, optimizer, criterion, opt, epoch
)
dataloaders["training"].dataset.reshuffle()
### Evaluate ### Evaluate
_ = model.eval() _ = model.eval()
#Each dataset requires slightly different dataloaders. # Each dataset requires slightly different dataloaders.
if opt.dataset == 'Inaturalist': if opt.dataset == "Inaturalist":
eval_params = {'dataloader':dataloaders['testing'], 'model':model, 'opt':opt, 'epoch':epoch} eval_params = {
elif opt.dataset=='semi_fungi': "dataloader": dataloaders["testing"],
eval_params = {'dataloader':dataloaders['testing'], 'model':model, 'opt':opt, 'epoch':epoch} "model": model,
"opt": opt,
#Compute Evaluation metrics, print them and store in LOG. "epoch": epoch,
}
elif opt.dataset == "semi_fungi":
eval_params = {
"dataloader": dataloaders["testing"],
"model": model,
"opt": opt,
"epoch": epoch,
}
# Compute Evaluation metrics, print them and store in LOG.
if opt.infrequent_eval == 1: if opt.infrequent_eval == 1:
epoch_freq = 10 epoch_freq = 10
else: else:
epoch_freq = 1 epoch_freq = 1
if epoch%epoch_freq == 0: if epoch % epoch_freq == 0:
results = eval.evaluate(opt.dataset, LOG, save=True, **eval_params) results = eval.evaluate(opt.dataset, LOG, save=True, **eval_params)
writer.add_scalar('global/recall1',results[0][0],epoch+1) writer.add_scalar("global/recall1", results[0][0], epoch + 1)
writer.add_scalar('global/recall2',results[0][1],epoch+1) writer.add_scalar("global/recall2", results[0][1], epoch + 1)
writer.add_scalar('global/recall3',results[0][2],epoch+1) writer.add_scalar("global/recall3", results[0][2], epoch + 1)
writer.add_scalar('global/recall4',results[0][3],epoch+1) writer.add_scalar("global/recall4", results[0][3], epoch + 1)
writer.add_scalar('global/NMI',results[1],epoch+1) writer.add_scalar("global/NMI", results[1], epoch + 1)
writer.add_scalar('global/F1',results[2],epoch+1) writer.add_scalar("global/F1", results[2], epoch + 1)
#Update the Metric Plot and save it. # Update the Metric Plot and save it.
#LOG.update_info_plot() # LOG.update_info_plot()
#(optional) compute ratio of intra- to interdistances. # (optional) compute ratio of intra- to interdistances.
if opt.dist_measure: if opt.dist_measure:
distance_measure.measure(model, epoch) distance_measure.measure(model, epoch)
# distance_measure_test.measure(model, epoch) # distance_measure_test.measure(model, epoch)
### Learning Rate Scheduling Step ### Learning Rate Scheduling Step
if opt.scheduler != 'none': if opt.scheduler != "none":
scheduler.step() scheduler.step()
print('\n-----\n') print("\n-----\n")
print("Time:" ,time.time() - start) print("Time:", time.time() - start)
...@@ -12,227 +12,391 @@ need to change all of the copyrights at the top of all of the files ...@@ -12,227 +12,391 @@ need to change all of the copyrights at the top of all of the files
#################### LIBRARIES ######################## #################### LIBRARIES ########################
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import os, numpy as np, argparse, random, matplotlib, datetime import argparse
import datetime
import os
import random
import matplotlib
import numpy as np
os.chdir(os.path.dirname(os.path.realpath(__file__))) os.chdir(os.path.dirname(os.path.realpath(__file__)))
matplotlib.use('agg') matplotlib.use("agg")
import auxiliaries as aux import auxiliaries as aux
import datasets as data import datasets as data
import netlib as netlib
import losses as losses
import evaluate as eval import evaluate as eval
import losses as losses
import netlib as netlib
import torch.multiprocessing import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
torch.multiprocessing.set_sharing_strategy("file_system")
################### INPUT ARGUMENTS ################### ################### INPUT ARGUMENTS ###################
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
####### Main Parameter: Dataset to use for Training ####### Main Parameter: Dataset to use for Training
parser.add_argument('--dataset', default='Inaturalist', type=str, help='Dataset to use.', choices=['Inaturalist', 'semi_fungi']) parser.add_argument(
"--dataset",
default="Inaturalist",
type=str,
help="Dataset to use.",
choices=["Inaturalist", "semi_fungi"],
)
### General Training Parameters ### General Training Parameters
parser.add_argument('--lr', default=0.00001, type=float, help='Learning Rate for network parameters.') parser.add_argument(
parser.add_argument('--fc_lr_mul', default=5, type=float, help='OPTIONAL: Multiply the embedding layer learning rate by this value. If set to 0, the embedding layer shares the same learning rate.') "--lr",
parser.add_argument('--n_epochs', default=400, type=int, help='Number of training epochs.') default=0.00001,
parser.add_argument('--kernels', default=8, type=int, help='Number of workers for pytorch dataloader.') type=float,
parser.add_argument('--bs', default=112 , type=int, help='Mini-Batchsize to use.') help="Learning Rate for network parameters.",
parser.add_argument('--samples_per_class', default=4, type=int, help='Number of samples in one class drawn before choosing the next class') )
parser.add_argument('--seed', default=1, type=int, help='Random seed for reproducibility.') parser.add_argument(
parser.add_argument('--scheduler', default='step', type=str, help='Type of learning rate scheduling. Currently: step & exp.') "--fc_lr_mul",
parser.add_argument('--gamma', default=0.3, type=float, help='Learning rate reduction after tau epochs.') default=5,
parser.add_argument('--decay', default=0.0004, type=float, help='Weight decay for optimizer.') type=float,
parser.add_argument('--tau', default= [200,300],nargs='+',type=int,help='Stepsize(s) before reducing learning rate.') help="OPTIONAL: Multiply the embedding layer learning rate by this value. If set to 0, the embedding layer shares the same learning rate.",
parser.add_argument('--infrequent_eval', default=0,type=int, help='only compute evaluation metrics every 10 epochs') )
parser.add_argument('--opt', default = 'adam',help='adam or sgd') parser.add_argument(
"--n_epochs", default=400, type=int, help="Number of training epochs."
)
parser.add_argument(
"--kernels",
default=8,
type=int,
help="Number of workers for pytorch dataloader.",
)
parser.add_argument(
"--bs", default=112, type=int, help="Mini-Batchsize to use."
)
parser.add_argument(
"--samples_per_class",
default=4,
type=int,
help="Number of samples in one class drawn before choosing the next class",
)
parser.add_argument(
"--seed", default=1, type=int, help="Random seed for reproducibility."
)
parser.add_argument(
"--scheduler",
default="step",
type=str,
help="Type of learning rate scheduling. Currently: step & exp.",
)
parser.add_argument(
"--gamma",
default=0.3,
type=float,
help="Learning rate reduction after tau epochs.",
)
parser.add_argument(
"--decay", default=0.0004, type=float, help="Weight decay for optimizer."
)
parser.add_argument(
"--tau",
default=[200, 300],
nargs="+",
type=int,
help="Stepsize(s) before reducing learning rate.",
)
parser.add_argument(
"--infrequent_eval",
default=0,
type=int,
help="only compute evaluation metrics every 10 epochs",
)
parser.add_argument("--opt", default="adam", help="adam or sgd")
##### Loss-specific Settings ##### Loss-specific Settings
parser.add_argument('--loss', default='smoothap', type=str) parser.add_argument("--loss", default="smoothap", type=str)
parser.add_argument('--sigmoid_temperature', default=0.01, type=float, help='SmoothAP: the temperature of the sigmoid used in SmoothAP loss') parser.add_argument(
"--sigmoid_temperature",
default=0.01,
type=float,
help="SmoothAP: the temperature of the sigmoid used in SmoothAP loss",
)
##### Evaluation Settings ##### Evaluation Settings
parser.add_argument('--k_vals', nargs='+', default=[1,2,4,8], type=int, help='Recall @ Values.') parser.add_argument(
parser.add_argument('--resume', default='', type=str, help='path to checkpoint to load weights from (if empty then ImageNet pre-trained weights are loaded') "--k_vals",
nargs="+",
default=[1, 2, 4, 8],
type=int,
help="Recall @ Values.",
)
parser.add_argument(
"--resume",
default="",
type=str,
help="path to checkpoint to load weights from (if empty then ImageNet pre-trained weights are loaded",
)
##### Network parameters ##### Network parameters
parser.add_argument('--embed_dim', default=512, type=int, help='Embedding dimensionality of the network') parser.add_argument(
parser.add_argument('--arch', default='resnet50', type=str, help='Network backend choice: resnet50, googlenet, BNinception') "--embed_dim",
parser.add_argument('--grad_measure', action='store_true', help='If added, gradients passed from embedding layer to the last conv-layer are stored in each iteration.') default=512,
parser.add_argument('--dist_measure', action='store_true', help='If added, the ratio between intra- and interclass distances is stored after each epoch.') type=int,
parser.add_argument('--not_pretrained', action='store_true', help='If added, the network will be trained WITHOUT ImageNet-pretrained weights.') help="Embedding dimensionality of the network",
)
parser.add_argument(
"--arch",
default="resnet50",
type=str,
help="Network backend choice: resnet50, googlenet, BNinception",
)
parser.add_argument(
"--grad_measure",
action="store_true",
help="If added, gradients passed from embedding layer to the last conv-layer are stored in each iteration.",
)
parser.add_argument(
"--dist_measure",
action="store_true",
help="If added, the ratio between intra- and interclass distances is stored after each epoch.",
)
parser.add_argument(
"--not_pretrained",
action="store_true",
help="If added, the network will be trained WITHOUT ImageNet-pretrained weights.",
)
##### Setup Parameters ##### Setup Parameters
parser.add_argument('--gpu', default=0, type=int, help='GPU-id for GPU to use.') parser.add_argument("--gpu", default=0, type=int, help="GPU-id for GPU to use.")
parser.add_argument('--savename', default='', type=str, help='Save folder name if any special information is to be included.') parser.add_argument(
"--savename",
default="",
type=str,
help="Save folder name if any special information is to be included.",
)
### Paths to datasets and storage folder ### Paths to datasets and storage folder
parser.add_argument('--source_path', default='/scratch/shared/beegfs/abrown/datasets', type=str, help='Path to data') parser.add_argument(
parser.add_argument('--save_path', default=os.getcwd()+'/Training_Results', type=str, help='Where to save the checkpoints') "--source_path",
default="/scratch/shared/beegfs/abrown/datasets",
type=str,
help="Path to data",
)
parser.add_argument(
"--save_path",
default=os.getcwd() + "/Training_Results",
type=str,
help="Where to save the checkpoints",
)
### adational ### adational
parser.add_argument('--trainset', default="lin_train_set1.txt", type=str) parser.add_argument("--trainset", default="lin_train_set1.txt", type=str)
parser.add_argument('--all_trainset', default="train_set1.txt", type=str) parser.add_argument("--all_trainset", default="train_set1.txt", type=str)
parser.add_argument('--testset', default="test_set1.txt", type=str) parser.add_argument("--testset", default="test_set1.txt", type=str)
parser.add_argument('--finetune', default='true', type=str) parser.add_argument("--finetune", default="true", type=str)
parser.add_argument('--cluster_path', default="", type=str) parser.add_argument("--cluster_path", default="", type=str)
parser.add_argument('--get_features', default="false", type=str) parser.add_argument("--get_features", default="false", type=str)
parser.add_argument('--class_num', default=948, type=int) parser.add_argument("--class_num", default=948, type=int)
parser.add_argument('--iter', default=0, type=int) parser.add_argument("--iter", default=0, type=int)
parser.add_argument('--pretrained_weights', default="", type=str, help='pretrained weight path') parser.add_argument(
parser.add_argument('--use_bn_in_head', default=False, type=aux.bool_flag, "--pretrained_weights", default="", type=str, help="pretrained weight path"
help="Whether to use batch normalizations in projection head (Default: False)") )
parser.add_argument("--checkpoint_key", default="teacher", type=str, parser.add_argument(
help='Key to use in the checkpoint (example: "teacher")') "--use_bn_in_head",
parser.add_argument('--drop_path_rate', default=0.1, type=float, help="stochastic depth rate") default=False,
parser.add_argument('--linsize', default=29011, type=int, help="Lin data size.") type=aux.bool_flag,
parser.add_argument('--uinsize', default=18403, type=int, help="Uin data size.") help="Whether to use batch normalizations in projection head (Default: False)",
)
parser.add_argument(
"--checkpoint_key",
default="teacher",
type=str,
help='Key to use in the checkpoint (example: "teacher")',
)
parser.add_argument(
"--drop_path_rate", default=0.1, type=float, help="stochastic depth rate"
)
parser.add_argument("--linsize", default=29011, type=int, help="Lin data size.")
parser.add_argument("--uinsize", default=18403, type=int, help="Uin data size.")
opt = parser.parse_args() opt = parser.parse_args()
"""============================================================================""" """============================================================================"""
opt.source_path += '/' + opt.dataset opt.source_path += "/" + opt.dataset
opt.save_path += '/' + opt.dataset + "_" + str(opt.embed_dim) opt.save_path += "/" + opt.dataset + "_" + str(opt.embed_dim)
if opt.dataset== 'Inaturalist': if opt.dataset == "Inaturalist":
opt.n_epochs = 90 opt.n_epochs = 90
opt.tau = [40,70] opt.tau = [40, 70]
opt.k_vals = [1,4,16,32] opt.k_vals = [1, 4, 16, 32]
if opt.dataset=='semi_fungi': if opt.dataset == "semi_fungi":
opt.tau = [40,70] opt.tau = [40, 70]
opt.k_vals = [1,4,16,32] opt.k_vals = [1, 4, 16, 32]
if opt.get_features == "true": if opt.get_features == "true":
opt.get_features = True opt.get_features = True
if opt.get_features == "false": if opt.get_features == "false":
opt.get_features = False opt.get_features = False
if opt.finetune == 'true': if opt.finetune == "true":
opt.finetune = True opt.finetune = True
elif opt.finetune == 'false': elif opt.finetune == "false":
opt.finetune = False opt.finetune = False
"""===========================================================================""" """==========================================================================="""
################### TensorBoard Settings ################## ################### TensorBoard Settings ##################
timestamp = datetime.datetime.now().strftime(r"%Y-%m-%d_%H-%M-%S") timestamp = datetime.datetime.now().strftime(r"%Y-%m-%d_%H-%M-%S")
exp_name = aux.args2exp_name(opt) exp_name = aux.args2exp_name(opt)
opt.save_name = f"weights_{exp_name}" +'/'+ timestamp opt.save_name = f"weights_{exp_name}" + "/" + timestamp
random.seed(opt.seed) random.seed(opt.seed)
np.random.seed(opt.seed) np.random.seed(opt.seed)
torch.manual_seed(opt.seed) torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed); torch.cuda.manual_seed_all(opt.seed) torch.cuda.manual_seed(opt.seed)
torch.cuda.manual_seed_all(opt.seed)
"""============================================================================""" """============================================================================"""
################### GPU SETTINGS ########################### ################### GPU SETTINGS ###########################
os.environ["CUDA_DEVICE_ORDER"] ="PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]= str(opt.gpu) # os.environ["CUDA_VISIBLE_DEVICES"]= str(opt.gpu)
print('using #GPUs:',torch.cuda.device_count()) print("using #GPUs:", torch.cuda.device_count())
"""============================================================================""" """============================================================================"""
##################### NETWORK SETUP ################## ##################### NETWORK SETUP ##################
opt.device = torch.device('cuda') opt.device = torch.device("cuda")
model = netlib.networkselect(opt) model = netlib.networkselect(opt)
#Push to Device # Push to Device
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
_ = model.to(opt.device) _ = model.to(opt.device)
#Place trainable parameter in list of parameters to train: # Place trainable parameter in list of parameters to train:
if 'fc_lr_mul' in vars(opt).keys() and opt.fc_lr_mul!=0: if "fc_lr_mul" in vars(opt).keys() and opt.fc_lr_mul != 0:
all_but_fc_params = list(filter(lambda x: 'last_linear' not in x[0],model.named_parameters())) all_but_fc_params = list(
filter(lambda x: "last_linear" not in x[0], model.named_parameters())
)
for ind, param in enumerate(all_but_fc_params): for ind, param in enumerate(all_but_fc_params):
all_but_fc_params[ind] = param[1] all_but_fc_params[ind] = param[1]
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
fc_params = model.module.model.last_linear.parameters() fc_params = model.module.model.last_linear.parameters()
else: else:
fc_params = model.model.last_linear.parameters() fc_params = model.model.last_linear.parameters()
to_optim = [{'params':all_but_fc_params,'lr':opt.lr,'weight_decay':opt.decay}, to_optim = [
{'params':fc_params,'lr':opt.lr*opt.fc_lr_mul,'weight_decay':opt.decay}] {"params": all_but_fc_params, "lr": opt.lr, "weight_decay": opt.decay},
{
"params": fc_params,
"lr": opt.lr * opt.fc_lr_mul,
"weight_decay": opt.decay,
},
]
else: else:
to_optim = [{'params':model.parameters(),'lr':opt.lr,'weight_decay':opt.decay}] to_optim = [
{"params": model.parameters(), "lr": opt.lr, "weight_decay": opt.decay}
]
"""============================================================================""" """============================================================================"""
#################### DATALOADER SETUPS ################## #################### DATALOADER SETUPS ##################
#Returns a dictionary containing 'training', 'testing', and 'evaluation' dataloaders. # Returns a dictionary containing 'training', 'testing', and 'evaluation' dataloaders.
#The 'testing'-dataloader corresponds to the validation set, and the 'evaluation'-dataloader # The 'testing'-dataloader corresponds to the validation set, and the 'evaluation'-dataloader
#Is simply using the training set, however running under the same rules as 'testing' dataloader, # Is simply using the training set, however running under the same rules as 'testing' dataloader,
#i.e. no shuffling and no random cropping. # i.e. no shuffling and no random cropping.
dataloaders = data.give_dataloaders(opt.dataset, opt.trainset, opt.testset, opt) dataloaders = data.give_dataloaders(opt.dataset, opt.trainset, opt.testset, opt)
#Because the number of supervised classes is dataset dependent, we store them after # Because the number of supervised classes is dataset dependent, we store them after
#initializing the dataloader # initializing the dataloader
opt.num_classes = len(dataloaders['training'].dataset.avail_classes) opt.num_classes = len(dataloaders["training"].dataset.avail_classes)
"""============================================================================""" """============================================================================"""
#################### CREATE LOGGING FILES ############### #################### CREATE LOGGING FILES ###############
#Each dataset usually has a set of standard metrics to log. aux.metrics_to_examine() # Each dataset usually has a set of standard metrics to log. aux.metrics_to_examine()
#returns a dict which lists metrics to log for training ('train') and validation/testing ('val') # returns a dict which lists metrics to log for training ('train') and validation/testing ('val')
metrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals) metrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals)
# example output: {'train': ['Epochs', 'Time', 'Train Loss', 'Time'], # example output: {'train': ['Epochs', 'Time', 'Train Loss', 'Time'],
# 'val': ['Epochs','Time','NMI','F1', 'Recall @ 1','Recall @ 2','Recall @ 4','Recall @ 8']} # 'val': ['Epochs','Time','NMI','F1', 'Recall @ 1','Recall @ 2','Recall @ 4','Recall @ 8']}
#Using the provided metrics of interest, we generate a LOGGER instance. # Using the provided metrics of interest, we generate a LOGGER instance.
#Note that 'start_new' denotes that a new folder should be made in which everything will be stored. # Note that 'start_new' denotes that a new folder should be made in which everything will be stored.
#This includes network weights as well. # This includes network weights as well.
#If graphviz is installed on the system, a computational graph of the underlying # If graphviz is installed on the system, a computational graph of the underlying
#network will be made as well. # network will be made as well.
"""============================================================================""" """============================================================================"""
#################### LOSS SETUP #################### #################### LOSS SETUP ####################
#Depending on opt.loss and opt.sampling, the respective criterion is returned, # Depending on opt.loss and opt.sampling, the respective criterion is returned,
#and if the loss has trainable parameters, to_optim is appended. # and if the loss has trainable parameters, to_optim is appended.
LOG = aux.LOGGER(opt, metrics_to_log, name='Base', start_new=True) LOG = aux.LOGGER(opt, metrics_to_log, name="Base", start_new=True)
criterion, to_optim = losses.loss_select(opt.loss, opt, to_optim) criterion, to_optim = losses.loss_select(opt.loss, opt, to_optim)
_ = criterion.to(opt.device) _ = criterion.to(opt.device)
"""============================================================================""" """============================================================================"""
##################### OPTIONAL EVALUATIONS ##################### ##################### OPTIONAL EVALUATIONS #####################
#Store the averaged gradients returned from the embedding to the last conv. layer. # Store the averaged gradients returned from the embedding to the last conv. layer.
if opt.grad_measure: if opt.grad_measure:
grad_measure = eval.GradientMeasure(opt, name='baseline') grad_measure = eval.GradientMeasure(opt, name="baseline")
#Store the relative distances between average intra- and inter-class distance. # Store the relative distances between average intra- and inter-class distance.
if opt.dist_measure: if opt.dist_measure:
#Add a distance measure for training distance ratios # Add a distance measure for training distance ratios
distance_measure = eval.DistanceMeasure(dataloaders['evaluation'], opt, name='Train', update_epochs=1) distance_measure = eval.DistanceMeasure(
dataloaders["evaluation"], opt, name="Train", update_epochs=1
)
# #If uncommented: Do the same for the test set # #If uncommented: Do the same for the test set
# distance_measure_test = eval.DistanceMeasure(dataloaders['testing'], opt, name='Train', update_epochs=1) # distance_measure_test = eval.DistanceMeasure(dataloaders['testing'], opt, name='Train', update_epochs=1)
"""============================================================================""" """============================================================================"""
#################### OPTIM SETUP #################### #################### OPTIM SETUP ####################
#As optimizer, Adam with standard parameters is used. # As optimizer, Adam with standard parameters is used.
if opt.opt == 'adam': if opt.opt == "adam":
optimizer = torch.optim.Adam(to_optim) optimizer = torch.optim.Adam(to_optim)
elif opt.opt == 'sgd': elif opt.opt == "sgd":
optimizer = torch.optim.SGD(to_optim) optimizer = torch.optim.SGD(to_optim)
else: else:
raise Exception('unknown optimiser') raise Exception("unknown optimiser")
# for the SOA measures in the paper - need to use SGD and 0.05 learning rate # for the SOA measures in the paper - need to use SGD and 0.05 learning rate
#optimizer = torch.optim.Adam(to_optim) # optimizer = torch.optim.Adam(to_optim)
#optimizer = torch.optim.SGD(to_optim) # optimizer = torch.optim.SGD(to_optim)
if opt.scheduler=='exp': if opt.scheduler == "exp":
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=opt.gamma) scheduler = torch.optim.lr_scheduler.ExponentialLR(
elif opt.scheduler=='step': optimizer, gamma=opt.gamma
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.tau, gamma=opt.gamma) )
elif opt.scheduler=='none': elif opt.scheduler == "step":
print('No scheduling used!') scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=opt.tau, gamma=opt.gamma
)
elif opt.scheduler == "none":
print("No scheduling used!")
else: else:
raise Exception('No scheduling option for input: {}'.format(opt.scheduler)) raise Exception("No scheduling option for input: {}".format(opt.scheduler))
def same_model(model1,model2):
def same_model(model1, model2):
for p1, p2 in zip(model1.parameters(), model2.parameters()): for p1, p2 in zip(model1.parameters(), model2.parameters()):
if p1.data.ne(p2.data).sum() > 0: if p1.data.ne(p2.data).sum() > 0:
return False return False
return True return True
"""============================================================================""" """============================================================================"""
"""================================ TESTING ===================================""" """================================ TESTING ==================================="""
"""============================================================================""" """============================================================================"""
################### SCRIPT MAIN ########################## ################### SCRIPT MAIN ##########################
print('\n-----\n') print("\n-----\n")
# Compute Evaluation metrics, print them and store in LOG. # Compute Evaluation metrics, print them and store in LOG.
_ = model.eval() _ = model.eval()
aux.vis(model, dataloaders['training'], opt.device, split="T_train_iter"+str(opt.iter)+"_"+str(opt.loss), opt=opt) aux.vis(
aux.vis(model, dataloaders['testing'], opt.device, split="all_train_iter"+str(opt.iter)+"_"+str(opt.loss), opt=opt) model,
aux.vis(model, dataloaders['eval'], opt.device, split="test_iter"+str(opt.iter)+"_"+str(opt.loss), opt=opt) dataloaders["training"],
#Update the Metric Plot and save it. opt.device,
print('\n-----\n') split="T_train_iter" + str(opt.iter) + "_" + str(opt.loss),
opt=opt,
)
aux.vis(
model,
dataloaders["testing"],
opt.device,
split="all_train_iter" + str(opt.iter) + "_" + str(opt.loss),
opt=opt,
)
aux.vis(
model,
dataloaders["eval"],
opt.device,
split="test_iter" + str(opt.iter) + "_" + str(opt.loss),
opt=opt,
)
# Update the Metric Plot and save it.
print("\n-----\n")
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
###################### LIBRARIES ################################################# ###################### LIBRARIES #################################################
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import torch, faiss import faiss
import numpy as np import numpy as np
import torch
from scipy import sparse from scipy import sparse
"""=================================================================================================""" """================================================================================================="""
############ LOSS SELECTION FUNCTION ##################### ############ LOSS SELECTION FUNCTION #####################
def loss_select(loss, opt, to_optim): def loss_select(loss, opt, to_optim):
...@@ -22,19 +23,25 @@ def loss_select(loss, opt, to_optim): ...@@ -22,19 +23,25 @@ def loss_select(loss, opt, to_optim):
Returns: Returns:
criterion (torch.nn.Module inherited), to_optim (optionally appended) criterion (torch.nn.Module inherited), to_optim (optionally appended)
""" """
if loss == 'smoothap': if loss == "smoothap":
loss_params = {'anneal':opt.sigmoid_temperature, 'batch_size':opt.bs, "num_id":int(opt.bs / opt.samples_per_class), 'feat_dims':opt.embed_dim} loss_params = {
criterion = SmoothAP(**loss_params) "anneal": opt.sigmoid_temperature,
"batch_size": opt.bs,
"num_id": int(opt.bs / opt.samples_per_class),
"feat_dims": opt.embed_dim,
}
criterion = SmoothAP(**loss_params)
else: else:
raise Exception('Loss {} not available!'.format(loss)) raise Exception("Loss {} not available!".format(loss))
return criterion, to_optim return criterion, to_optim
"""==============================================Smooth-AP========================================""" """==============================================Smooth-AP========================================"""
def sigmoid(tensor, temp=1.0): def sigmoid(tensor, temp=1.0):
""" temperature controlled sigmoid """temperature controlled sigmoid
takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp
""" """
exponent = -tensor / temp exponent = -tensor / temp
...@@ -58,7 +65,7 @@ class BinarizedF(torch.autograd.Function): ...@@ -58,7 +65,7 @@ class BinarizedF(torch.autograd.Function):
return output return output
def backward(self, output_grad): def backward(self, output_grad):
inp, = self.saved_tensors (inp,) = self.saved_tensors
input_abs = torch.abs(inp) input_abs = torch.abs(inp)
ones = torch.ones_like(inp) ones = torch.ones_like(inp)
zeros = torch.zeros_like(inp) zeros = torch.zeros_like(inp)
...@@ -122,7 +129,7 @@ class SmoothAP(torch.nn.Module): ...@@ -122,7 +129,7 @@ class SmoothAP(torch.nn.Module):
""" """
super(SmoothAP, self).__init__() super(SmoothAP, self).__init__()
assert(batch_size%num_id==0) assert batch_size % num_id == 0
self.anneal = anneal self.anneal = anneal
self.batch_size = batch_size self.batch_size = batch_size
...@@ -130,8 +137,7 @@ class SmoothAP(torch.nn.Module): ...@@ -130,8 +137,7 @@ class SmoothAP(torch.nn.Module):
self.feat_dims = feat_dims self.feat_dims = feat_dims
def forward(self, preds): def forward(self, preds):
"""Forward pass for all input predictions: preds - (batch_size x feat_dims) """ """Forward pass for all input predictions: preds - (batch_size x feat_dims)"""
# ------ differentiable ranking of all retrieval set ------ # ------ differentiable ranking of all retrieval set ------
# compute the mask which ignores the relevance score of the query to itself # compute the mask which ignores the relevance score of the query to itself
...@@ -149,12 +155,20 @@ class SmoothAP(torch.nn.Module): ...@@ -149,12 +155,20 @@ class SmoothAP(torch.nn.Module):
# ------ differentiable ranking of only positive set in retrieval set ------ # ------ differentiable ranking of only positive set in retrieval set ------
# compute the mask which only gives non-zero weights to the positive set # compute the mask which only gives non-zero weights to the positive set
xs = preds.view(self.num_id, int(self.batch_size / self.num_id), self.feat_dims) xs = preds.view(
self.num_id, int(self.batch_size / self.num_id), self.feat_dims
)
pos_mask = 1.0 - torch.eye(int(self.batch_size / self.num_id)) pos_mask = 1.0 - torch.eye(int(self.batch_size / self.num_id))
pos_mask = pos_mask.unsqueeze(dim=0).unsqueeze(dim=0).repeat(self.num_id, int(self.batch_size / self.num_id), 1, 1) pos_mask = (
pos_mask.unsqueeze(dim=0)
.unsqueeze(dim=0)
.repeat(self.num_id, int(self.batch_size / self.num_id), 1, 1)
)
# compute the relevance scores # compute the relevance scores
sim_pos = torch.bmm(xs, xs.permute(0, 2, 1)) sim_pos = torch.bmm(xs, xs.permute(0, 2, 1))
sim_pos_repeat = sim_pos.unsqueeze(dim=2).repeat(1, 1, int(self.batch_size / self.num_id), 1) sim_pos_repeat = sim_pos.unsqueeze(dim=2).repeat(
1, 1, int(self.batch_size / self.num_id), 1
)
# compute the difference matrix # compute the difference matrix
sim_pos_diff = sim_pos_repeat - sim_pos_repeat.permute(0, 1, 3, 2) sim_pos_diff = sim_pos_repeat - sim_pos_repeat.permute(0, 1, 3, 2)
# pass through the sigmoid # pass through the sigmoid
...@@ -166,6 +180,14 @@ class SmoothAP(torch.nn.Module): ...@@ -166,6 +180,14 @@ class SmoothAP(torch.nn.Module):
ap = torch.zeros(1).cuda() ap = torch.zeros(1).cuda()
group = int(self.batch_size / self.num_id) group = int(self.batch_size / self.num_id)
for ind in range(self.num_id): for ind in range(self.num_id):
pos_divide = torch.sum(sim_pos_rk[ind] / (sim_all_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)])) pos_divide = torch.sum(
sim_pos_rk[ind]
/ (
sim_all_rk[
(ind * group) : ((ind + 1) * group),
(ind * group) : ((ind + 1) * group),
]
)
)
ap = ap + ((pos_divide / group) / self.batch_size) ap = ap + ((pos_divide / group) / self.batch_size)
return (1 - ap) return 1 - ap
...@@ -15,150 +15,263 @@ import warnings ...@@ -15,150 +15,263 @@ import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import os, numpy as np, argparse, random, matplotlib, datetime import argparse
import datetime
import os
import random
import matplotlib
import numpy as np
os.chdir(os.path.dirname(os.path.realpath(__file__))) os.chdir(os.path.dirname(os.path.realpath(__file__)))
from pathlib import Path from pathlib import Path
matplotlib.use('agg') matplotlib.use("agg")
from tqdm import tqdm
import auxiliaries as aux import auxiliaries as aux
import datasets as data import datasets as data
import netlib as netlib
import losses as losses
import evaluate as eval import evaluate as eval
from tensorboardX import SummaryWriter import losses as losses
import netlib as netlib
import torch.multiprocessing import torch.multiprocessing
from tensorboardX import SummaryWriter
from tqdm import tqdm
torch.multiprocessing.set_sharing_strategy('file_system') torch.multiprocessing.set_sharing_strategy("file_system")
################### INPUT ARGUMENTS ################### ################### INPUT ARGUMENTS ###################
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
####### Main Parameter: Dataset to use for Training ####### Main Parameter: Dataset to use for Training
parser.add_argument('--dataset', default='vehicle_id', type=str, help='Dataset to use.', parser.add_argument(
choices=['SoftInaturalist', 'Inaturalist', 'vehicle_id', 'semi_fungi']) "--dataset",
default="vehicle_id",
type=str,
help="Dataset to use.",
choices=["SoftInaturalist", "Inaturalist", "vehicle_id", "semi_fungi"],
)
### General Training Parameters ### General Training Parameters
parser.add_argument('--lr', default=0.00001, type=float, help='Learning Rate for network parameters.') parser.add_argument(
parser.add_argument('--fc_lr_mul', default=5, type=float, "--lr",
help='OPTIONAL: Multiply the embedding layer learning rate by this value. If set to 0, the embedding layer shares the same learning rate.') default=0.00001,
parser.add_argument('--n_epochs', default=400, type=int, help='Number of training epochs.') type=float,
parser.add_argument('--kernels', default=8, type=int, help='Number of workers for pytorch dataloader.') help="Learning Rate for network parameters.",
parser.add_argument('--bs', default=112, type=int, help='Mini-Batchsize to use.') )
parser.add_argument('--samples_per_class', default=4, type=int, parser.add_argument(
help='Number of samples in one class drawn before choosing the next class') "--fc_lr_mul",
parser.add_argument('--seed', default=1, type=int, help='Random seed for reproducibility.') default=5,
parser.add_argument('--scheduler', default='step', type=str, type=float,
help='Type of learning rate scheduling. Currently: step & exp.') help="OPTIONAL: Multiply the embedding layer learning rate by this value. If set to 0, the embedding layer shares the same learning rate.",
parser.add_argument('--gamma', default=0.3, type=float, help='Learning rate reduction after tau epochs.') )
parser.add_argument('--decay', default=0.0004, type=float, help='Weight decay for optimizer.') parser.add_argument(
parser.add_argument('--tau', default=[200, 300], nargs='+', type=int, help='Stepsize(s) before reducing learning rate.') "--n_epochs", default=400, type=int, help="Number of training epochs."
parser.add_argument('--infrequent_eval', default=0, type=int, help='only compute evaluation metrics every 10 epochs') )
parser.add_argument('--opt', default='adam', help='adam or sgd') parser.add_argument(
"--kernels",
default=8,
type=int,
help="Number of workers for pytorch dataloader.",
)
parser.add_argument(
"--bs", default=112, type=int, help="Mini-Batchsize to use."
)
parser.add_argument(
"--samples_per_class",
default=4,
type=int,
help="Number of samples in one class drawn before choosing the next class",
)
parser.add_argument(
"--seed", default=1, type=int, help="Random seed for reproducibility."
)
parser.add_argument(
"--scheduler",
default="step",
type=str,
help="Type of learning rate scheduling. Currently: step & exp.",
)
parser.add_argument(
"--gamma",
default=0.3,
type=float,
help="Learning rate reduction after tau epochs.",
)
parser.add_argument(
"--decay", default=0.0004, type=float, help="Weight decay for optimizer."
)
parser.add_argument(
"--tau",
default=[200, 300],
nargs="+",
type=int,
help="Stepsize(s) before reducing learning rate.",
)
parser.add_argument(
"--infrequent_eval",
default=0,
type=int,
help="only compute evaluation metrics every 10 epochs",
)
parser.add_argument("--opt", default="adam", help="adam or sgd")
##### Loss-specific Settings ##### Loss-specific Settings
parser.add_argument('--loss', default='smoothap', type=str) parser.add_argument("--loss", default="smoothap", type=str)
parser.add_argument('--sigmoid_temperature', default=0.01, type=float, parser.add_argument(
help='SmoothAP: the temperature of the sigmoid used in SmoothAP loss') "--sigmoid_temperature",
default=0.01,
type=float,
help="SmoothAP: the temperature of the sigmoid used in SmoothAP loss",
)
##### Evaluation Settings ##### Evaluation Settings
parser.add_argument('--k_vals', nargs='+', default=[1, 2, 4, 8], type=int, help='Recall @ Values.') parser.add_argument(
parser.add_argument('--resume', default='', type=str, "--k_vals",
help='path to checkpoint to load weights from (if empty then ImageNet pre-trained weights are loaded') nargs="+",
default=[1, 2, 4, 8],
type=int,
help="Recall @ Values.",
)
parser.add_argument(
"--resume",
default="",
type=str,
help="path to checkpoint to load weights from (if empty then ImageNet pre-trained weights are loaded",
)
##### Network parameters ##### Network parameters
parser.add_argument('--embed_dim', default=512, type=int, help='Embedding dimensionality of the network') parser.add_argument(
parser.add_argument('--arch', default='resnet50', type=str, "--embed_dim",
help='Network backend choice: resnet50') default=512,
parser.add_argument('--pretrained_weights', default="", type=str, help='pretrained weight path') type=int,
parser.add_argument('--use_bn_in_head', default=False, type=aux.bool_flag, help="Embedding dimensionality of the network",
help="Whether to use batch normalizations in projection head (Default: False)") )
parser.add_argument("--checkpoint_key", default="teacher", type=str, parser.add_argument(
help='Key to use in the checkpoint (example: "teacher")') "--arch",
parser.add_argument('--drop_path_rate', default=0.1, type=float, help="stochastic depth rate") default="resnet50",
parser.add_argument('--grad_measure', action='store_true', type=str,
help='If added, gradients passed from embedding layer to the last conv-layer are stored in each iteration.') help="Network backend choice: resnet50",
parser.add_argument('--dist_measure', action='store_true', )
help='If added, the ratio between intra- and interclass distances is stored after each epoch.') parser.add_argument(
parser.add_argument('--not_pretrained', action='store_true', "--pretrained_weights", default="", type=str, help="pretrained weight path"
help='If added, the network will be trained WITHOUT ImageNet-pretrained weights.') )
parser.add_argument(
"--use_bn_in_head",
default=False,
type=aux.bool_flag,
help="Whether to use batch normalizations in projection head (Default: False)",
)
parser.add_argument(
"--checkpoint_key",
default="teacher",
type=str,
help='Key to use in the checkpoint (example: "teacher")',
)
parser.add_argument(
"--drop_path_rate", default=0.1, type=float, help="stochastic depth rate"
)
parser.add_argument(
"--grad_measure",
action="store_true",
help="If added, gradients passed from embedding layer to the last conv-layer are stored in each iteration.",
)
parser.add_argument(
"--dist_measure",
action="store_true",
help="If added, the ratio between intra- and interclass distances is stored after each epoch.",
)
parser.add_argument(
"--not_pretrained",
action="store_true",
help="If added, the network will be trained WITHOUT ImageNet-pretrained weights.",
)
##### Setup Parameters ##### Setup Parameters
parser.add_argument('--gpu', default=0, type=int, help='GPU-id for GPU to use.') parser.add_argument("--gpu", default=0, type=int, help="GPU-id for GPU to use.")
parser.add_argument('--savename', default='', type=str, parser.add_argument(
help='Save folder name if any special information is to be included.') "--savename",
default="",
type=str,
help="Save folder name if any special information is to be included.",
)
### Paths to datasets and storage folder ### Paths to datasets and storage folder
parser.add_argument('--source_path', default='/scratch/shared/beegfs/abrown/datasets', type=str, help='Path to data') parser.add_argument(
parser.add_argument('--save_path', default=os.getcwd() + '/Training_Results', type=str, "--source_path",
help='Where to save the checkpoints') default="/scratch/shared/beegfs/abrown/datasets",
type=str,
help="Path to data",
)
parser.add_argument(
"--save_path",
default=os.getcwd() + "/Training_Results",
type=str,
help="Where to save the checkpoints",
)
### additional parameters ### additional parameters
parser.add_argument('--trainset', default="lin_train_set1.txt", type=str) parser.add_argument("--trainset", default="lin_train_set1.txt", type=str)
parser.add_argument('--testset', default="Inaturalist_test_set1.txt", type=str) parser.add_argument("--testset", default="Inaturalist_test_set1.txt", type=str)
parser.add_argument('--cluster_path', default="", type=str) parser.add_argument("--cluster_path", default="", type=str)
parser.add_argument('--finetune', default="false", type=str) parser.add_argument("--finetune", default="false", type=str)
parser.add_argument('--class_num', default=948, type=int) parser.add_argument("--class_num", default=948, type=int)
parser.add_argument('--get_features', default="false", type=str) parser.add_argument("--get_features", default="false", type=str)
parser.add_argument('--linsize', default=29011, type=int, help="Lin data size.") parser.add_argument("--linsize", default=29011, type=int, help="Lin data size.")
parser.add_argument('--uinsize', default=18403, type=int, help="Uin data size.") parser.add_argument("--uinsize", default=18403, type=int, help="Uin data size.")
parser.add_argument('--iter', default=0, type=int) parser.add_argument("--iter", default=0, type=int)
opt = parser.parse_args() opt = parser.parse_args()
"""============================================================================""" """============================================================================"""
if opt.dataset == "SoftInaturalist": if opt.dataset == "SoftInaturalist":
opt.source_path += '/Inaturalist' opt.source_path += "/Inaturalist"
opt.save_path += '/Inaturalist' + "_" + str(opt.embed_dim) opt.save_path += "/Inaturalist" + "_" + str(opt.embed_dim)
else: else:
opt.source_path += '/' + opt.dataset opt.source_path += "/" + opt.dataset
opt.save_path += '/' + opt.dataset + "_" + str(opt.embed_dim) opt.save_path += "/" + opt.dataset + "_" + str(opt.embed_dim)
if opt.dataset == 'Inaturalist': if opt.dataset == "Inaturalist":
# opt.n_epochs = 90 # opt.n_epochs = 90
opt.tau = [40, 70] opt.tau = [40, 70]
opt.k_vals = [1, 4, 16, 32] opt.k_vals = [1, 4, 16, 32]
if opt.dataset == 'SoftInaturalist': if opt.dataset == "SoftInaturalist":
# opt.n_epochs = 90 # opt.n_epochs = 90
opt.tau = [40, 70] opt.tau = [40, 70]
opt.k_vals = [1, 4, 16, 32] opt.k_vals = [1, 4, 16, 32]
if opt.dataset == 'vehicle_id': if opt.dataset == "vehicle_id":
opt.k_vals = [1, 5] opt.k_vals = [1, 5]
if opt.dataset == 'semi_fungi': if opt.dataset == "semi_fungi":
opt.tau = [40, 70] opt.tau = [40, 70]
opt.k_vals = [1, 4, 16, 32] opt.k_vals = [1, 4, 16, 32]
if opt.finetune == 'true': if opt.finetune == "true":
opt.finetune = True opt.finetune = True
elif opt.finetune == 'false': elif opt.finetune == "false":
opt.finetune = False opt.finetune = False
if opt.get_features == 'true': if opt.get_features == "true":
opt.get_features = True opt.get_features = True
elif opt.get_features == 'false': elif opt.get_features == "false":
opt.get_features = False opt.get_features = False
"""===========================================================================""" """==========================================================================="""
################### TensorBoard Settings ################## ################### TensorBoard Settings ##################
timestamp = datetime.datetime.now().strftime(r"%Y-%m-%d_%H-%M-%S") timestamp = datetime.datetime.now().strftime(r"%Y-%m-%d_%H-%M-%S")
exp_name = aux.args2exp_name(opt) exp_name = aux.args2exp_name(opt)
opt.save_name = f"weights_{exp_name}" + '/' + timestamp opt.save_name = f"weights_{exp_name}" + "/" + timestamp
random.seed(opt.seed) random.seed(opt.seed)
np.random.seed(opt.seed) np.random.seed(opt.seed)
torch.manual_seed(opt.seed) torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed); torch.cuda.manual_seed(opt.seed)
torch.cuda.manual_seed_all(opt.seed) torch.cuda.manual_seed_all(opt.seed)
tensorboard_path = Path(f"logs/logs_{exp_name}") / timestamp tensorboard_path = Path(f"logs/logs_{exp_name}") / timestamp
tensorboard_path.parent.mkdir(exist_ok=True, parents=True) tensorboard_path.parent.mkdir(exist_ok=True, parents=True)
global writer; global writer
writer = SummaryWriter(tensorboard_path) writer = SummaryWriter(tensorboard_path)
"""============================================================================""" """============================================================================"""
################### GPU SETTINGS ########################### ################### GPU SETTINGS ###########################
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]= str(opt.gpu) # os.environ["CUDA_VISIBLE_DEVICES"]= str(opt.gpu)
print('using #GPUs:', torch.cuda.device_count()) print("using #GPUs:", torch.cuda.device_count())
"""============================================================================""" """============================================================================"""
##################### NETWORK SETUP ################## ##################### NETWORK SETUP ##################
opt.device = torch.device('cuda') opt.device = torch.device("cuda")
model = netlib.networkselect(opt) model = netlib.networkselect(opt)
# Push to Device # Push to Device
...@@ -167,9 +280,11 @@ if torch.cuda.device_count() > 1: ...@@ -167,9 +280,11 @@ if torch.cuda.device_count() > 1:
_ = model.to(opt.device) _ = model.to(opt.device)
# Place trainable parameter in list of parameters to train: # Place trainable parameter in list of parameters to train:
if 'fc_lr_mul' in vars(opt).keys() and opt.fc_lr_mul != 0: if "fc_lr_mul" in vars(opt).keys() and opt.fc_lr_mul != 0:
all_but_fc_params = list(filter(lambda x: 'last_linear' not in x[0], model.named_parameters())) all_but_fc_params = list(
filter(lambda x: "last_linear" not in x[0], model.named_parameters())
)
for ind, param in enumerate(all_but_fc_params): for ind, param in enumerate(all_but_fc_params):
all_but_fc_params[ind] = param[1] all_but_fc_params[ind] = param[1]
...@@ -179,10 +294,18 @@ if 'fc_lr_mul' in vars(opt).keys() and opt.fc_lr_mul != 0: ...@@ -179,10 +294,18 @@ if 'fc_lr_mul' in vars(opt).keys() and opt.fc_lr_mul != 0:
else: else:
fc_params = model.model.last_linear.parameters() fc_params = model.model.last_linear.parameters()
to_optim = [{'params': all_but_fc_params, 'lr': opt.lr, 'weight_decay': opt.decay}, to_optim = [
{'params': fc_params, 'lr': opt.lr * opt.fc_lr_mul, 'weight_decay': opt.decay}] {"params": all_but_fc_params, "lr": opt.lr, "weight_decay": opt.decay},
{
"params": fc_params,
"lr": opt.lr * opt.fc_lr_mul,
"weight_decay": opt.decay,
},
]
else: else:
to_optim = [{'params': model.parameters(), 'lr': opt.lr, 'weight_decay': opt.decay}] to_optim = [
{"params": model.parameters(), "lr": opt.lr, "weight_decay": opt.decay}
]
"""============================================================================""" """============================================================================"""
#################### DATALOADER SETUPS ################## #################### DATALOADER SETUPS ##################
# Returns a dictionary containing 'training', 'testing', and 'evaluation' dataloaders. # Returns a dictionary containing 'training', 'testing', and 'evaluation' dataloaders.
...@@ -192,7 +315,7 @@ else: ...@@ -192,7 +315,7 @@ else:
dataloaders = data.give_dataloaders(opt.dataset, opt.trainset, opt.testset, opt) dataloaders = data.give_dataloaders(opt.dataset, opt.trainset, opt.testset, opt)
# Because the number of supervised classes is dataset dependent, we store them after # Because the number of supervised classes is dataset dependent, we store them after
# initializing the dataloader # initializing the dataloader
opt.num_classes = len(dataloaders['training'].dataset.avail_classes) opt.num_classes = len(dataloaders["training"].dataset.avail_classes)
"""============================================================================""" """============================================================================"""
#################### CREATE LOGGING FILES ############### #################### CREATE LOGGING FILES ###############
...@@ -206,7 +329,7 @@ metrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals) ...@@ -206,7 +329,7 @@ metrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals)
# Using the provided metrics of interest, we generate a LOGGER instance. # Using the provided metrics of interest, we generate a LOGGER instance.
# Note that 'start_new' denotes that a new folder should be made in which everything will be stored. # Note that 'start_new' denotes that a new folder should be made in which everything will be stored.
# This includes network weights as well. # This includes network weights as well.
LOG = aux.LOGGER(opt, metrics_to_log, name='Base', start_new=True) LOG = aux.LOGGER(opt, metrics_to_log, name="Base", start_new=True)
# If graphviz is installed on the system, a computational graph of the underlying # If graphviz is installed on the system, a computational graph of the underlying
# network will be made as well. # network will be made as well.
...@@ -221,34 +344,40 @@ _ = criterion.to(opt.device) ...@@ -221,34 +344,40 @@ _ = criterion.to(opt.device)
##################### OPTIONAL EVALUATIONS ##################### ##################### OPTIONAL EVALUATIONS #####################
# Store the averaged gradients returned from the embedding to the last conv. layer. # Store the averaged gradients returned from the embedding to the last conv. layer.
if opt.grad_measure: if opt.grad_measure:
grad_measure = eval.GradientMeasure(opt, name='baseline') grad_measure = eval.GradientMeasure(opt, name="baseline")
# Store the relative distances between average intra- and inter-class distance. # Store the relative distances between average intra- and inter-class distance.
if opt.dist_measure: if opt.dist_measure:
# Add a distance measure for training distance ratios # Add a distance measure for training distance ratios
distance_measure = eval.DistanceMeasure(dataloaders['evaluation'], opt, name='Train', update_epochs=1) distance_measure = eval.DistanceMeasure(
dataloaders["evaluation"], opt, name="Train", update_epochs=1
)
# #If uncommented: Do the same for the test set # #If uncommented: Do the same for the test set
# distance_measure_test = eval.DistanceMeasure(dataloaders['testing'], opt, name='Train', update_epochs=1) # distance_measure_test = eval.DistanceMeasure(dataloaders['testing'], opt, name='Train', update_epochs=1)
"""============================================================================""" """============================================================================"""
#################### OPTIM SETUP #################### #################### OPTIM SETUP ####################
# As optimizer, Adam with standard parameters is used. # As optimizer, Adam with standard parameters is used.
if opt.opt == 'adam': if opt.opt == "adam":
optimizer = torch.optim.Adam(to_optim) optimizer = torch.optim.Adam(to_optim)
elif opt.opt == 'sgd': elif opt.opt == "sgd":
optimizer = torch.optim.SGD(to_optim) optimizer = torch.optim.SGD(to_optim)
else: else:
raise Exception('unknown optimiser') raise Exception("unknown optimiser")
# for the SOA measures in the paper - need to use SGD and 0.05 learning rate # for the SOA measures in the paper - need to use SGD and 0.05 learning rate
# optimizer = torch.optim.Adam(to_optim) # optimizer = torch.optim.Adam(to_optim)
# optimizer = torch.optim.SGD(to_optim) # optimizer = torch.optim.SGD(to_optim)
if opt.scheduler == 'exp': if opt.scheduler == "exp":
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=opt.gamma) scheduler = torch.optim.lr_scheduler.ExponentialLR(
elif opt.scheduler == 'step': optimizer, gamma=opt.gamma
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.tau, gamma=opt.gamma) )
elif opt.scheduler == 'none': elif opt.scheduler == "step":
print('No scheduling used!') scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=opt.tau, gamma=opt.gamma
)
elif opt.scheduler == "none":
print("No scheduling used!")
else: else:
raise Exception('No scheduling option for input: {}'.format(opt.scheduler)) raise Exception("No scheduling option for input: {}".format(opt.scheduler))
def same_model(model1, model2): def same_model(model1, model2):
...@@ -282,14 +411,16 @@ def train_one_epoch(train_dataloader, model, optimizer, criterion, opt, epoch): ...@@ -282,14 +411,16 @@ def train_one_epoch(train_dataloader, model, optimizer, criterion, opt, epoch):
loss_collect = [] loss_collect = []
start = time.time() start = time.time()
data_iterator = tqdm(train_dataloader, desc='Epoch {} Training...'.format(epoch)) data_iterator = tqdm(
train_dataloader, desc="Epoch {} Training...".format(epoch)
)
for i, (class_labels, input) in enumerate(data_iterator): for i, (class_labels, input) in enumerate(data_iterator):
# Compute embeddings for input batch # Compute embeddings for input batch
features = model(input.to(opt.device)) features = model(input.to(opt.device))
# Compute loss. # Compute loss.
if opt.loss != 'smoothap': if opt.loss != "smoothap":
loss = criterion(features, class_labels) loss = criterion(features, class_labels)
else: else:
loss = criterion(features) loss = criterion(features)
...@@ -311,11 +442,19 @@ def train_one_epoch(train_dataloader, model, optimizer, criterion, opt, epoch): ...@@ -311,11 +442,19 @@ def train_one_epoch(train_dataloader, model, optimizer, criterion, opt, epoch):
# Store loss per iteration. # Store loss per iteration.
loss_collect.append(loss.item()) loss_collect.append(loss.item())
if i == len(train_dataloader) - 1: if i == len(train_dataloader) - 1:
data_iterator.set_description('Epoch (Train) {0}: Mean Loss [{1:.4f}]'.format(epoch, np.mean(loss_collect))) data_iterator.set_description(
"Epoch (Train) {0}: Mean Loss [{1:.4f}]".format(
epoch, np.mean(loss_collect)
)
)
# Save metrics # Save metrics
LOG.log('train', LOG.metrics_to_log['train'], [epoch, np.round(time.time() - start, 4), np.mean(loss_collect)]) LOG.log(
writer.add_scalar('global/training_loss', np.mean(loss_collect), epoch) "train",
LOG.metrics_to_log["train"],
[epoch, np.round(time.time() - start, 4), np.mean(loss_collect)],
)
writer.add_scalar("global/training_loss", np.mean(loss_collect), epoch)
if opt.grad_measure: if opt.grad_measure:
# Dump stored gradients to Pickle-File. # Dump stored gradients to Pickle-File.
grad_measure.dump(epoch) grad_measure.dump(epoch)
...@@ -325,42 +464,77 @@ def train_one_epoch(train_dataloader, model, optimizer, criterion, opt, epoch): ...@@ -325,42 +464,77 @@ def train_one_epoch(train_dataloader, model, optimizer, criterion, opt, epoch):
"""========================== MAIN TRAINING PART ==============================""" """========================== MAIN TRAINING PART =============================="""
"""============================================================================""" """============================================================================"""
################### SCRIPT MAIN ########################## ################### SCRIPT MAIN ##########################
print('\n-----\n') print("\n-----\n")
# Each dataset requires slightly different dataloaders. # Each dataset requires slightly different dataloaders.
if opt.dataset == 'SoftInaturalist' or 'Inaturalist' or 'semi_fungi': if opt.dataset == "SoftInaturalist" or "Inaturalist" or "semi_fungi":
eval_params = {'dataloader': dataloaders['testing'], 'model': model, 'opt': opt, 'epoch': 0} eval_params = {
"dataloader": dataloaders["testing"],
"model": model,
"opt": opt,
"epoch": 0,
}
elif opt.dataset == 'vehicle_id': elif opt.dataset == "vehicle_id":
eval_params = { eval_params = {
'dataloaders': [dataloaders['testing_set1'], dataloaders['testing_set2'], dataloaders['testing_set3']], "dataloaders": [
'model': model, 'opt': opt, 'epoch': 0} dataloaders["testing_set1"],
dataloaders["testing_set2"],
dataloaders["testing_set3"],
],
"model": model,
"opt": opt,
"epoch": 0,
}
# Compute Evaluation metrics, print them and store in LOG. # Compute Evaluation metrics, print them and store in LOG.
print('epochs -> ' + str(opt.n_epochs)) print("epochs -> " + str(opt.n_epochs))
import time import time
for epoch in range(opt.n_epochs): for epoch in range(opt.n_epochs):
### Print current learning rates for all parameters ### Print current learning rates for all parameters
if opt.scheduler != 'none': print( if opt.scheduler != "none":
'Running with learning rates {}...'.format(' | '.join('{}'.format(x) for x in scheduler.get_lr()))) print(
"Running with learning rates {}...".format(
" | ".join("{}".format(x) for x in scheduler.get_lr())
)
)
### Train one epoch ### Train one epoch
_ = model.train() _ = model.train()
train_one_epoch(dataloaders['training'], model, optimizer, criterion, opt, epoch) train_one_epoch(
dataloaders["training"], model, optimizer, criterion, opt, epoch
)
dataloaders['training'].dataset.reshuffle() dataloaders["training"].dataset.reshuffle()
### Evaluate ### Evaluate
_ = model.eval() _ = model.eval()
# Each dataset requires slightly different dataloaders. # Each dataset requires slightly different dataloaders.
if opt.dataset == 'Inaturalist': if opt.dataset == "Inaturalist":
eval_params = {'dataloader': dataloaders['evaluation'], 'model': model, 'opt': opt, 'epoch': epoch} eval_params = {
elif opt.dataset == 'vehicle_id': "dataloader": dataloaders["evaluation"],
"model": model,
"opt": opt,
"epoch": epoch,
}
elif opt.dataset == "vehicle_id":
eval_params = {
"dataloaders": [
dataloaders["testing_set1"],
dataloaders["testing_set2"],
dataloaders["testing_set3"],
],
"model": model,
"opt": opt,
"epoch": epoch,
}
elif opt.dataset == "semi_fungi":
eval_params = { eval_params = {
'dataloaders': [dataloaders['testing_set1'], dataloaders['testing_set2'], dataloaders['testing_set3']], "dataloader": dataloaders["testing"],
'model': model, 'opt': opt, 'epoch': epoch} "model": model,
elif opt.dataset == 'semi_fungi': "opt": opt,
eval_params = {'dataloader': dataloaders['testing'], 'model': model, 'opt': opt, 'epoch': epoch} "epoch": epoch,
}
# Compute Evaluation metrics, print them and store in LOG. # Compute Evaluation metrics, print them and store in LOG.
if opt.infrequent_eval == 1: if opt.infrequent_eval == 1:
...@@ -368,25 +542,26 @@ for epoch in range(opt.n_epochs): ...@@ -368,25 +542,26 @@ for epoch in range(opt.n_epochs):
else: else:
epoch_freq = 1 epoch_freq = 1
if not opt.dataset == 'vehicle_id': if not opt.dataset == "vehicle_id":
if epoch % epoch_freq == 0: if epoch % epoch_freq == 0:
results = eval.evaluate(opt.dataset, LOG, save=True, **eval_params) results = eval.evaluate(opt.dataset, LOG, save=True, **eval_params)
writer.add_scalar('global/recall1', results[0][0], epoch + 1) writer.add_scalar("global/recall1", results[0][0], epoch + 1)
writer.add_scalar('global/recall2', results[0][1], epoch + 1) writer.add_scalar("global/recall2", results[0][1], epoch + 1)
writer.add_scalar('global/recall3', results[0][2], epoch + 1) writer.add_scalar("global/recall3", results[0][2], epoch + 1)
writer.add_scalar('global/recall4', results[0][3], epoch + 1) writer.add_scalar("global/recall4", results[0][3], epoch + 1)
writer.add_scalar('global/NMI', results[1], epoch + 1) writer.add_scalar("global/NMI", results[1], epoch + 1)
writer.add_scalar('global/F1', results[2], epoch + 1) writer.add_scalar("global/F1", results[2], epoch + 1)
else: else:
results = eval.evaluate(opt.dataset, LOG, save=True, **eval_params) results = eval.evaluate(opt.dataset, LOG, save=True, **eval_params)
writer.add_scalar('global/recall1', results[2], epoch + 1) writer.add_scalar("global/recall1", results[2], epoch + 1)
writer.add_scalar('global/recall2', results[3], writer.add_scalar(
epoch + 1) # writer.add_scalar('global/recall3',results[0][2],0) "global/recall2", results[3], epoch + 1
writer.add_scalar('global/recall3', results[6], epoch + 1) ) # writer.add_scalar('global/recall3',results[0][2],0)
writer.add_scalar('global/recall4', results[7], epoch + 1) writer.add_scalar("global/recall3", results[6], epoch + 1)
writer.add_scalar('global/recall5', results[10], epoch + 1) writer.add_scalar("global/recall4", results[7], epoch + 1)
writer.add_scalar('global/recall6', results[11], epoch + 1) writer.add_scalar("global/recall5", results[10], epoch + 1)
writer.add_scalar("global/recall6", results[11], epoch + 1)
# Update the Metric Plot and save it. # Update the Metric Plot and save it.
# LOG.update_info_plot() # LOG.update_info_plot()
# (optional) compute ratio of intra- to interdistances. # (optional) compute ratio of intra- to interdistances.
...@@ -395,7 +570,7 @@ for epoch in range(opt.n_epochs): ...@@ -395,7 +570,7 @@ for epoch in range(opt.n_epochs):
# distance_measure_test.measure(model, epoch) # distance_measure_test.measure(model, epoch)
### Learning Rate Scheduling Step ### Learning Rate Scheduling Step
if opt.scheduler != 'none': if opt.scheduler != "none":
scheduler.step() scheduler.step()
print('\n-----\n') print("\n-----\n")
# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines # repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines
############################ LIBRARIES ###################################### ############################ LIBRARIES ######################################
from collections import OrderedDict
import os import os
from collections import OrderedDict
import auxiliaries as aux
import pretrainedmodels as ptm
import torch import torch
import torch.nn as nn import torch.nn as nn
import pretrainedmodels as ptm
import auxiliaries as aux
"""=============================================================""" """============================================================="""
...@@ -23,7 +24,9 @@ def initialize_weights(model): ...@@ -23,7 +24,9 @@ def initialize_weights(model):
""" """
for idx, module in enumerate(model.modules()): for idx, module in enumerate(model.modules()):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(
module.weight, mode="fan_out", nonlinearity="relu"
)
elif isinstance(module, nn.BatchNorm2d): elif isinstance(module, nn.BatchNorm2d):
nn.init.constant_(module.weight, 1) nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0) nn.init.constant_(module.bias, 0)
...@@ -62,19 +65,19 @@ def networkselect(opt): ...@@ -62,19 +65,19 @@ def networkselect(opt):
Returns: Returns:
Network of choice Network of choice
""" """
if opt.arch == 'resnet50': if opt.arch == "resnet50":
network = ResNet50(opt) network = ResNet50(opt)
else: else:
raise Exception('Network {} not available!'.format(opt.arch)) raise Exception("Network {} not available!".format(opt.arch))
if opt.resume: if opt.resume:
weights = torch.load(os.path.join(opt.save_path, opt.resume)) weights = torch.load(os.path.join(opt.save_path, opt.resume))
weights_state_dict = weights['state_dict'] weights_state_dict = weights["state_dict"]
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
encoder_state_dict = OrderedDict() encoder_state_dict = OrderedDict()
for k, v in weights_state_dict.items(): for k, v in weights_state_dict.items():
k = k.replace('module.', '') k = k.replace("module.", "")
encoder_state_dict[k] = v encoder_state_dict[k] = v
network.load_state_dict(encoder_state_dict) network.load_state_dict(encoder_state_dict)
...@@ -106,25 +109,42 @@ class ResNet50(nn.Module): ...@@ -106,25 +109,42 @@ class ResNet50(nn.Module):
self.pars = opt self.pars = opt
if not opt.not_pretrained: if not opt.not_pretrained:
print('Getting pretrained weights...') print("Getting pretrained weights...")
self.model = ptm.__dict__['resnet50'](num_classes=1000, pretrained='imagenet') self.model = ptm.__dict__["resnet50"](
print('Done.') num_classes=1000, pretrained="imagenet"
)
print("Done.")
else: else:
print('Not utilizing pretrained weights!') print("Not utilizing pretrained weights!")
self.model = ptm.__dict__['resnet50'](num_classes=1000, pretrained=None) self.model = ptm.__dict__["resnet50"](
for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): num_classes=1000, pretrained=None
)
for module in filter(
lambda m: type(m) == nn.BatchNorm2d, self.model.modules()
):
module.eval() module.eval()
module.train = lambda _: None module.train = lambda _: None
if opt.embed_dim != 2048: if opt.embed_dim != 2048:
self.model.last_linear = torch.nn.Linear(self.model.last_linear.in_features, opt.embed_dim) self.model.last_linear = torch.nn.Linear(
self.model.last_linear.in_features, opt.embed_dim
self.layer_blocks = nn.ModuleList([self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4]) )
self.layer_blocks = nn.ModuleList(
[
self.model.layer1,
self.model.layer2,
self.model.layer3,
self.model.layer4,
]
)
self.loss = opt.loss self.loss = opt.loss
self.feature = True self.feature = True
def forward(self, x, feature=False, is_init_cluster_generation=False): def forward(self, x, feature=False, is_init_cluster_generation=False):
x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x)))) x = self.model.maxpool(
self.model.relu(self.model.bn1(self.model.conv1(x)))
)
for layerblock in self.layer_blocks: for layerblock in self.layer_blocks:
x = layerblock(x) x = layerblock(x)
...@@ -139,7 +159,7 @@ class ResNet50(nn.Module): ...@@ -139,7 +159,7 @@ class ResNet50(nn.Module):
feat = torch.nn.functional.normalize(mod_x, dim=-1) feat = torch.nn.functional.normalize(mod_x, dim=-1)
if feature or self.loss == 'smoothap': if feature or self.loss == "smoothap":
return feat return feat
else: else:
pred = self.linear(feat) pred = self.linear(feat)
......
import numpy as np
import pickle import pickle
import dgl import numpy as np
import torch import torch
from utils import (
build_knns,
build_next_level,
decode,
density_estimation,
fast_knns2spmat,
knns2ordered_nbrs,
l2norm,
row_normalize,
sparse_mx_to_indices_values,
)
import dgl
from utils import (build_knns, fast_knns2spmat, row_normalize, knns2ordered_nbrs,
density_estimation, sparse_mx_to_indices_values, l2norm,
decode, build_next_level)
class LanderDataset(object): class LanderDataset(object):
def __init__(self, features, labels, cluster_features=None, k=10, levels=1, faiss_gpu=False): def __init__(
self,
features,
labels,
cluster_features=None,
k=10,
levels=1,
faiss_gpu=False,
):
self.k = k self.k = k
self.gs = [] self.gs = []
self.nbrs = [] self.nbrs = []
...@@ -17,7 +34,7 @@ class LanderDataset(object): ...@@ -17,7 +34,7 @@ class LanderDataset(object):
self.levels = levels self.levels = levels
# Initialize features and labels # Initialize features and labels
features = l2norm(features.astype('float32')) features = l2norm(features.astype("float32"))
global_features = features.copy() global_features = features.copy()
if cluster_features is None: if cluster_features is None:
cluster_features = features cluster_features = features
...@@ -32,28 +49,48 @@ class LanderDataset(object): ...@@ -32,28 +49,48 @@ class LanderDataset(object):
self.levels = lvl self.levels = lvl
break break
if faiss_gpu: if faiss_gpu:
knns = build_knns(features, self.k, 'faiss_gpu') knns = build_knns(features, self.k, "faiss_gpu")
else: else:
knns = build_knns(features, self.k, 'faiss') knns = build_knns(features, self.k, "faiss")
dists, nbrs = knns2ordered_nbrs(knns) dists, nbrs = knns2ordered_nbrs(knns)
self.nbrs.append(nbrs) self.nbrs.append(nbrs)
self.dists.append(dists) self.dists.append(dists)
density = density_estimation(dists, nbrs, labels) density = density_estimation(dists, nbrs, labels)
g = self._build_graph(features, cluster_features, labels, density, knns) g = self._build_graph(
features, cluster_features, labels, density, knns
)
self.gs.append(g) self.gs.append(g)
if lvl >= self.levels - 1: if lvl >= self.levels - 1:
break break
# Decode peak nodes # Decode peak nodes
new_pred_labels, peaks,\ (
global_edges, global_pred_labels, global_peaks = decode(g, 0, 'sim', True, new_pred_labels,
ids, global_edges, global_num_nodes, peaks,
global_peaks) global_edges,
global_pred_labels,
global_peaks,
) = decode(
g,
0,
"sim",
True,
ids,
global_edges,
global_num_nodes,
global_peaks,
)
ids = ids[peaks] ids = ids[peaks]
features, labels, cluster_features = build_next_level(features, labels, peaks, features, labels, cluster_features = build_next_level(
global_features, global_pred_labels, global_peaks) features,
labels,
peaks,
global_features,
global_pred_labels,
global_peaks,
)
def _build_graph(self, features, cluster_features, labels, density, knns): def _build_graph(self, features, cluster_features, labels, density, knns):
adj = fast_knns2spmat(knns, self.k) adj = fast_knns2spmat(knns, self.k)
...@@ -61,17 +98,33 @@ class LanderDataset(object): ...@@ -61,17 +98,33 @@ class LanderDataset(object):
indices, values, shape = sparse_mx_to_indices_values(adj) indices, values, shape = sparse_mx_to_indices_values(adj)
g = dgl.graph((indices[1], indices[0])) g = dgl.graph((indices[1], indices[0]))
g.ndata['features'] = torch.FloatTensor(features) g.ndata["features"] = torch.FloatTensor(features)
g.ndata['cluster_features'] = torch.FloatTensor(cluster_features) g.ndata["cluster_features"] = torch.FloatTensor(cluster_features)
g.ndata['labels'] = torch.LongTensor(labels) g.ndata["labels"] = torch.LongTensor(labels)
g.ndata['density'] = torch.FloatTensor(density) g.ndata["density"] = torch.FloatTensor(density)
g.edata['affine'] = torch.FloatTensor(values) g.edata["affine"] = torch.FloatTensor(values)
# A Bipartite from DGL sampler will not store global eid, so we explicitly save it here # A Bipartite from DGL sampler will not store global eid, so we explicitly save it here
g.edata['global_eid'] = g.edges(form='eid') g.edata["global_eid"] = g.edges(form="eid")
g.ndata['norm'] = torch.FloatTensor(adj_row_sum) g.ndata["norm"] = torch.FloatTensor(adj_row_sum)
g.apply_edges(lambda edges: {'raw_affine': edges.data['affine'] / edges.dst['norm']}) g.apply_edges(
g.apply_edges(lambda edges: {'labels_conn': (edges.src['labels'] == edges.dst['labels']).long()}) lambda edges: {
g.apply_edges(lambda edges: {'mask_conn': (edges.src['density'] > edges.dst['density']).bool()}) "raw_affine": edges.data["affine"] / edges.dst["norm"]
}
)
g.apply_edges(
lambda edges: {
"labels_conn": (
edges.src["labels"] == edges.dst["labels"]
).long()
}
)
g.apply_edges(
lambda edges: {
"mask_conn": (
edges.src["density"] > edges.dst["density"]
).bool()
}
)
return g return g
def __getitem__(self, index): def __getitem__(self, index):
......
from .lander import LANDER
from .graphconv import GraphConv from .graphconv import GraphConv
from .lander import LANDER
...@@ -6,25 +6,27 @@ from torch.autograd import Variable ...@@ -6,25 +6,27 @@ from torch.autograd import Variable
# Below code are based on # Below code are based on
# https://zhuanlan.zhihu.com/p/28527749 # https://zhuanlan.zhihu.com/p/28527749
class FocalLoss(nn.Module): class FocalLoss(nn.Module):
r""" r"""
This criterion is a implemenation of Focal Loss, which is proposed in This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection. Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch. The losses are averaged across observations for each minibatch.
Args: Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples putting more focus on hard, misclassified examples
size_average(bool): By default, the losses are averaged over observations for each minibatch. size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are However, if the field size_average is set to False, the losses are
instead summed for each minibatch. instead summed for each minibatch.
""" """
def __init__(self, class_num, alpha=None, gamma=2, size_average=True): def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
super(FocalLoss, self).__init__() super(FocalLoss, self).__init__()
if alpha is None: if alpha is None:
...@@ -46,18 +48,17 @@ class FocalLoss(nn.Module): ...@@ -46,18 +48,17 @@ class FocalLoss(nn.Module):
class_mask = inputs.data.new(N, C).fill_(0) class_mask = inputs.data.new(N, C).fill_(0)
class_mask = Variable(class_mask) class_mask = Variable(class_mask)
ids = targets.view(-1, 1) ids = targets.view(-1, 1)
class_mask.scatter_(1, ids.data, 1.) class_mask.scatter_(1, ids.data, 1.0)
if inputs.is_cuda and not self.alpha.is_cuda: if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.cuda() self.alpha = self.alpha.cuda()
alpha = self.alpha[ids.data.view(-1)] alpha = self.alpha[ids.data.view(-1)]
probs = (P*class_mask).sum(1).view(-1,1) probs = (P * class_mask).sum(1).view(-1, 1)
log_p = probs.log() log_p = probs.log()
batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
if self.size_average: if self.size_average:
loss = batch_loss.mean() loss = batch_loss.mean()
......
...@@ -9,6 +9,7 @@ from torch.nn import init ...@@ -9,6 +9,7 @@ from torch.nn import init
import dgl.function as fn import dgl.function as fn
from dgl.nn.pytorch import GATConv from dgl.nn.pytorch import GATConv
class GraphConvLayer(nn.Module): class GraphConvLayer(nn.Module):
def __init__(self, in_feats, out_feats, bias=True): def __init__(self, in_feats, out_feats, bias=True):
super(GraphConvLayer, self).__init__() super(GraphConvLayer, self).__init__()
...@@ -19,25 +20,29 @@ class GraphConvLayer(nn.Module): ...@@ -19,25 +20,29 @@ class GraphConvLayer(nn.Module):
srcfeat, dstfeat = feat srcfeat, dstfeat = feat
else: else:
srcfeat = feat srcfeat = feat
dstfeat = feat[:bipartite.num_dst_nodes()] dstfeat = feat[: bipartite.num_dst_nodes()]
graph = bipartite.local_var() graph = bipartite.local_var()
graph.srcdata['h'] = srcfeat graph.srcdata["h"] = srcfeat
graph.update_all(fn.u_mul_e('h', 'affine', 'm'), graph.update_all(
fn.sum(msg='m', out='h')) fn.u_mul_e("h", "affine", "m"), fn.sum(msg="m", out="h")
)
gcn_feat = torch.cat([dstfeat, graph.dstdata['h']], dim=-1) gcn_feat = torch.cat([dstfeat, graph.dstdata["h"]], dim=-1)
out = self.mlp(gcn_feat) out = self.mlp(gcn_feat)
return out return out
class GraphConv(nn.Module): class GraphConv(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0, use_GAT = False, K = 1): def __init__(self, in_dim, out_dim, dropout=0, use_GAT=False, K=1):
super(GraphConv, self).__init__() super(GraphConv, self).__init__()
self.in_dim = in_dim self.in_dim = in_dim
self.out_dim = out_dim self.out_dim = out_dim
if use_GAT: if use_GAT:
self.gcn_layer = GATConv(in_dim, out_dim, K, allow_zero_in_degree = True) self.gcn_layer = GATConv(
in_dim, out_dim, K, allow_zero_in_degree=True
)
self.bias = nn.Parameter(torch.Tensor(K, out_dim)) self.bias = nn.Parameter(torch.Tensor(K, out_dim))
init.constant_(self.bias, 0) init.constant_(self.bias, 0)
else: else:
...@@ -50,7 +55,7 @@ class GraphConv(nn.Module): ...@@ -50,7 +55,7 @@ class GraphConv(nn.Module):
out = self.gcn_layer(bipartite, features) out = self.gcn_layer(bipartite, features)
if self.use_GAT: if self.use_GAT:
out = torch.mean(out + self.bias, dim = 1) out = torch.mean(out + self.bias, dim=1)
out = out.reshape(out.shape[0], -1) out = out.reshape(out.shape[0], -1)
out = F.relu(out) out = F.relu(out)
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -9,13 +8,24 @@ import torch.nn.functional as F ...@@ -9,13 +8,24 @@ import torch.nn.functional as F
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from .graphconv import GraphConv
from .focal_loss import FocalLoss from .focal_loss import FocalLoss
from .graphconv import GraphConv
class LANDER(nn.Module): class LANDER(nn.Module):
def __init__(self, feature_dim, nhid, num_conv=4, dropout=0, def __init__(
use_GAT=True, K=1, balance=False, self,
use_cluster_feat = True, use_focal_loss = True, **kwargs): feature_dim,
nhid,
num_conv=4,
dropout=0,
use_GAT=True,
K=1,
balance=False,
use_cluster_feat=True,
use_focal_loss=True,
**kwargs
):
super(LANDER, self).__init__() super(LANDER, self).__init__()
nhid_half = int(nhid / 2) nhid_half = int(nhid / 2)
self.use_cluster_feat = use_cluster_feat self.use_cluster_feat = use_cluster_feat
...@@ -31,15 +41,19 @@ class LANDER(nn.Module): ...@@ -31,15 +41,19 @@ class LANDER(nn.Module):
self.conv = nn.ModuleList() self.conv = nn.ModuleList()
self.conv.append(GraphConv(self.feature_dim, nhid, dropout, use_GAT, K)) self.conv.append(GraphConv(self.feature_dim, nhid, dropout, use_GAT, K))
for i in range(1, num_conv): for i in range(1, num_conv):
self.conv.append(GraphConv(input_dim[i], output_dim[i], dropout, use_GAT, K)) self.conv.append(
GraphConv(input_dim[i], output_dim[i], dropout, use_GAT, K)
)
self.src_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half) self.src_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half)
self.dst_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half) self.dst_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half)
self.classifier_conn = nn.Sequential(nn.PReLU(nhid_half), self.classifier_conn = nn.Sequential(
nn.Linear(nhid_half, nhid_half), nn.PReLU(nhid_half),
nn.PReLU(nhid_half), nn.Linear(nhid_half, nhid_half),
nn.Linear(nhid_half, 2)) nn.PReLU(nhid_half),
nn.Linear(nhid_half, 2),
)
if self.use_focal_loss: if self.use_focal_loss:
self.loss_conn = FocalLoss(2) self.loss_conn = FocalLoss(2)
...@@ -50,75 +64,119 @@ class LANDER(nn.Module): ...@@ -50,75 +64,119 @@ class LANDER(nn.Module):
self.balance = balance self.balance = balance
def pred_conn(self, edges): def pred_conn(self, edges):
src_feat = self.src_mlp(edges.src['conv_features']) src_feat = self.src_mlp(edges.src["conv_features"])
dst_feat = self.dst_mlp(edges.dst['conv_features']) dst_feat = self.dst_mlp(edges.dst["conv_features"])
pred_conn = self.classifier_conn(src_feat + dst_feat) pred_conn = self.classifier_conn(src_feat + dst_feat)
return {'pred_conn': pred_conn} return {"pred_conn": pred_conn}
def pred_den_msg(self, edges): def pred_den_msg(self, edges):
prob = edges.data['prob_conn'] prob = edges.data["prob_conn"]
res = edges.data['raw_affine'] * (prob[:, 1] - prob[:, 0]) res = edges.data["raw_affine"] * (prob[:, 1] - prob[:, 0])
return {'pred_den_msg': res} return {"pred_den_msg": res}
def forward(self, bipartites): def forward(self, bipartites):
if isinstance(bipartites, dgl.DGLGraph): if isinstance(bipartites, dgl.DGLGraph):
bipartites = [bipartites] * len(self.conv) bipartites = [bipartites] * len(self.conv)
if self.use_cluster_feat: if self.use_cluster_feat:
neighbor_x = torch.cat([bipartites[0].ndata['features'], bipartites[0].ndata['cluster_features']], axis=1) neighbor_x = torch.cat(
[
bipartites[0].ndata["features"],
bipartites[0].ndata["cluster_features"],
],
axis=1,
)
else: else:
neighbor_x = bipartites[0].ndata['features'] neighbor_x = bipartites[0].ndata["features"]
for i in range(len(self.conv)): for i in range(len(self.conv)):
neighbor_x = self.conv[i](bipartites[i], neighbor_x) neighbor_x = self.conv[i](bipartites[i], neighbor_x)
output_bipartite = bipartites[-1] output_bipartite = bipartites[-1]
output_bipartite.ndata['conv_features'] = neighbor_x output_bipartite.ndata["conv_features"] = neighbor_x
else: else:
if self.use_cluster_feat: if self.use_cluster_feat:
neighbor_x_src = torch.cat([bipartites[0].srcdata['features'], bipartites[0].srcdata['cluster_features']], axis=1) neighbor_x_src = torch.cat(
center_x_src = torch.cat([bipartites[1].srcdata['features'], bipartites[1].srcdata['cluster_features']], axis=1) [
bipartites[0].srcdata["features"],
bipartites[0].srcdata["cluster_features"],
],
axis=1,
)
center_x_src = torch.cat(
[
bipartites[1].srcdata["features"],
bipartites[1].srcdata["cluster_features"],
],
axis=1,
)
else: else:
neighbor_x_src = bipartites[0].srcdata['features'] neighbor_x_src = bipartites[0].srcdata["features"]
center_x_src = bipartites[1].srcdata['features'] center_x_src = bipartites[1].srcdata["features"]
for i in range(len(self.conv)): for i in range(len(self.conv)):
neighbor_x_dst = neighbor_x_src[:bipartites[i].num_dst_nodes()] neighbor_x_dst = neighbor_x_src[: bipartites[i].num_dst_nodes()]
neighbor_x_src = self.conv[i](bipartites[i], (neighbor_x_src, neighbor_x_dst)) neighbor_x_src = self.conv[i](
center_x_dst = center_x_src[:bipartites[i+1].num_dst_nodes()] bipartites[i], (neighbor_x_src, neighbor_x_dst)
center_x_src = self.conv[i](bipartites[i+1], (center_x_src, center_x_dst)) )
center_x_dst = center_x_src[: bipartites[i + 1].num_dst_nodes()]
center_x_src = self.conv[i](
bipartites[i + 1], (center_x_src, center_x_dst)
)
output_bipartite = bipartites[-1] output_bipartite = bipartites[-1]
output_bipartite.srcdata['conv_features'] = neighbor_x_src output_bipartite.srcdata["conv_features"] = neighbor_x_src
output_bipartite.dstdata['conv_features'] = center_x_src output_bipartite.dstdata["conv_features"] = center_x_src
output_bipartite.apply_edges(self.pred_conn) output_bipartite.apply_edges(self.pred_conn)
output_bipartite.edata['prob_conn'] = F.softmax(output_bipartite.edata['pred_conn'], dim=1) output_bipartite.edata["prob_conn"] = F.softmax(
output_bipartite.update_all(self.pred_den_msg, fn.mean('pred_den_msg', 'pred_den')) output_bipartite.edata["pred_conn"], dim=1
)
output_bipartite.update_all(
self.pred_den_msg, fn.mean("pred_den_msg", "pred_den")
)
return output_bipartite return output_bipartite
def compute_loss(self, bipartite): def compute_loss(self, bipartite):
pred_den = bipartite.dstdata['pred_den'] pred_den = bipartite.dstdata["pred_den"]
loss_den = self.loss_den(pred_den, bipartite.dstdata['density']) loss_den = self.loss_den(pred_den, bipartite.dstdata["density"])
labels_conn = bipartite.edata['labels_conn'] labels_conn = bipartite.edata["labels_conn"]
mask_conn = bipartite.edata['mask_conn'] mask_conn = bipartite.edata["mask_conn"]
if self.balance: if self.balance:
labels_conn = bipartite.edata['labels_conn'] labels_conn = bipartite.edata["labels_conn"]
neg_check = torch.logical_and(bipartite.edata['labels_conn'] == 0, mask_conn) neg_check = torch.logical_and(
bipartite.edata["labels_conn"] == 0, mask_conn
)
num_neg = torch.sum(neg_check).item() num_neg = torch.sum(neg_check).item()
neg_indices = torch.where(neg_check)[0] neg_indices = torch.where(neg_check)[0]
pos_check = torch.logical_and(bipartite.edata['labels_conn'] == 1, mask_conn) pos_check = torch.logical_and(
bipartite.edata["labels_conn"] == 1, mask_conn
)
num_pos = torch.sum(pos_check).item() num_pos = torch.sum(pos_check).item()
pos_indices = torch.where(pos_check)[0] pos_indices = torch.where(pos_check)[0]
if num_pos > num_neg: if num_pos > num_neg:
mask_conn[pos_indices[np.random.choice(num_pos, num_pos - num_neg, replace = False)]] = 0 mask_conn[
pos_indices[
np.random.choice(
num_pos, num_pos - num_neg, replace=False
)
]
] = 0
elif num_pos < num_neg: elif num_pos < num_neg:
mask_conn[neg_indices[np.random.choice(num_neg, num_neg - num_pos, replace = False)]] = 0 mask_conn[
neg_indices[
np.random.choice(
num_neg, num_neg - num_pos, replace=False
)
]
] = 0
# In subgraph training, it may happen that all edges are masked in a batch # In subgraph training, it may happen that all edges are masked in a batch
if mask_conn.sum() > 0: if mask_conn.sum() > 0:
loss_conn = self.loss_conn(bipartite.edata['pred_conn'][mask_conn], labels_conn[mask_conn]) loss_conn = self.loss_conn(
bipartite.edata["pred_conn"][mask_conn], labels_conn[mask_conn]
)
loss = loss_den + loss_conn loss = loss_den + loss_conn
loss_den_val = loss_den.item() loss_den_val = loss_den.item()
loss_conn_val = loss_conn.item() loss_conn_val = loss_conn.item()
......
import argparse, time, os, pickle import argparse
import numpy as np import os
import pickle
import time
import dgl import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
from models import LANDER
from dataset import LanderDataset from dataset import LanderDataset
from utils import evaluation, decode, build_next_level, stop_iterating from models import LANDER
from utils import build_next_level, decode, evaluation, stop_iterating
import dgl
########### ###########
# ArgParser # ArgParser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Dataset # Dataset
parser.add_argument('--data_path', type=str, required=True) parser.add_argument("--data_path", type=str, required=True)
parser.add_argument('--model_filename', type=str, default='lander.pth') parser.add_argument("--model_filename", type=str, default="lander.pth")
parser.add_argument('--faiss_gpu', action='store_true') parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument('--early_stop', action='store_true') parser.add_argument("--early_stop", action="store_true")
# HyperParam # HyperParam
parser.add_argument('--knn_k', type=int, default=10) parser.add_argument("--knn_k", type=int, default=10)
parser.add_argument('--levels', type=int, default=1) parser.add_argument("--levels", type=int, default=1)
parser.add_argument('--tau', type=float, default=0.5) parser.add_argument("--tau", type=float, default=0.5)
parser.add_argument('--threshold', type=str, default='prob') parser.add_argument("--threshold", type=str, default="prob")
parser.add_argument('--metrics', type=str, default='pairwise,bcubed,nmi') parser.add_argument("--metrics", type=str, default="pairwise,bcubed,nmi")
# Model # Model
parser.add_argument('--hidden', type=int, default=512) parser.add_argument("--hidden", type=int, default=512)
parser.add_argument('--num_conv', type=int, default=4) parser.add_argument("--num_conv", type=int, default=4)
parser.add_argument('--dropout', type=float, default=0.) parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument('--gat', action='store_true') parser.add_argument("--gat", action="store_true")
parser.add_argument('--gat_k', type=int, default=1) parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument('--balance', action='store_true') parser.add_argument("--balance", action="store_true")
parser.add_argument('--use_cluster_feat', action='store_true') parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument('--use_focal_loss', action='store_true') parser.add_argument("--use_focal_loss", action="store_true")
parser.add_argument('--use_gt', action='store_true') parser.add_argument("--use_gt", action="store_true")
args = parser.parse_args() args = parser.parse_args()
########################### ###########################
# Environment Configuration # Environment Configuration
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device('cuda') device = torch.device("cuda")
else: else:
device = torch.device('cpu') device = torch.device("cpu")
################## ##################
# Data Preparation # Data Preparation
with open(args.data_path, 'rb') as f: with open(args.data_path, "rb") as f:
features, labels = pickle.load(f) features, labels = pickle.load(f)
global_features = features.copy() global_features = features.copy()
dataset = LanderDataset(features=features, labels=labels, k=args.knn_k, dataset = LanderDataset(
levels=1, faiss_gpu=args.faiss_gpu) features=features,
labels=labels,
k=args.knn_k,
levels=1,
faiss_gpu=args.faiss_gpu,
)
g = dataset.gs[0].to(device) g = dataset.gs[0].to(device)
global_labels = labels.copy() global_labels = labels.copy()
ids = np.arange(g.number_of_nodes()) ids = np.arange(g.number_of_nodes())
...@@ -63,13 +71,18 @@ global_num_nodes = g.number_of_nodes() ...@@ -63,13 +71,18 @@ global_num_nodes = g.number_of_nodes()
################## ##################
# Model Definition # Model Definition
if not args.use_gt: if not args.use_gt:
feature_dim = g.ndata['features'].shape[1] feature_dim = g.ndata["features"].shape[1]
model = LANDER(feature_dim=feature_dim, nhid=args.hidden, model = LANDER(
num_conv=args.num_conv, dropout=args.dropout, feature_dim=feature_dim,
use_GAT=args.gat, K=args.gat_k, nhid=args.hidden,
balance=args.balance, num_conv=args.num_conv,
use_cluster_feat=args.use_cluster_feat, dropout=args.dropout,
use_focal_loss=args.use_focal_loss) use_GAT=args.gat,
K=args.gat_k,
balance=args.balance,
use_cluster_feat=args.use_cluster_feat,
use_focal_loss=args.use_focal_loss,
)
model.load_state_dict(torch.load(args.model_filename)) model.load_state_dict(torch.load(args.model_filename))
model = model.to(device) model = model.to(device)
model.eval() model.eval()
...@@ -82,23 +95,54 @@ for level in range(args.levels): ...@@ -82,23 +95,54 @@ for level in range(args.levels):
if not args.use_gt: if not args.use_gt:
with torch.no_grad(): with torch.no_grad():
g = model(g) g = model(g)
new_pred_labels, peaks,\ (
global_edges, global_pred_labels, global_peaks = decode(g, args.tau, args.threshold, args.use_gt, new_pred_labels,
ids, global_edges, global_num_nodes) peaks,
global_edges,
global_pred_labels,
global_peaks,
) = decode(
g,
args.tau,
args.threshold,
args.use_gt,
ids,
global_edges,
global_num_nodes,
)
ids = ids[peaks] ids = ids[peaks]
new_global_edges_len = len(global_edges[0]) new_global_edges_len = len(global_edges[0])
num_edges_add_this_level = new_global_edges_len - global_edges_len num_edges_add_this_level = new_global_edges_len - global_edges_len
if stop_iterating(level, args.levels, args.early_stop, num_edges_add_this_level, num_edges_add_last_level, args.knn_k): if stop_iterating(
level,
args.levels,
args.early_stop,
num_edges_add_this_level,
num_edges_add_last_level,
args.knn_k,
):
break break
global_edges_len = new_global_edges_len global_edges_len = new_global_edges_len
num_edges_add_last_level = num_edges_add_this_level num_edges_add_last_level = num_edges_add_this_level
# build new dataset # build new dataset
features, labels, cluster_features = build_next_level(features, labels, peaks, features, labels, cluster_features = build_next_level(
global_features, global_pred_labels, global_peaks) features,
labels,
peaks,
global_features,
global_pred_labels,
global_peaks,
)
# After the first level, the number of nodes reduce a lot. Using cpu faiss is faster. # After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.
dataset = LanderDataset(features=features, labels=labels, k=args.knn_k, dataset = LanderDataset(
levels=1, faiss_gpu=False, cluster_features = cluster_features) features=features,
labels=labels,
k=args.knn_k,
levels=1,
faiss_gpu=False,
cluster_features=cluster_features,
)
if len(dataset.gs) == 0: if len(dataset.gs) == 0:
break break
g = dataset.gs[0].to(device) g = dataset.gs[0].to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment