Unverified Commit 05aca98d authored by Tianyue Cao's avatar Tianyue Cao Committed by GitHub
Browse files

[Example] Progressive Sample Selection (PSS). (#4263)



* upload PSS

* upload PSS

* upload PSS

* pss code reformat

* fix bug

* update README

* update train bash

* remove vit

* update README

* delete InfoPlotter

* delete Smooth_AP_loss.py

* update README

* update README
Co-authored-by: default avatarTianjun Xiao <xiaotj1990327@gmail.com>
parent 5fc1d0c8
# PSS
Code for the ECCV '22 submission "PSS: Progressive Sample Selection for Open-World Visual Representation Learning".
## Dependencies
We use python 3.7. The CUDA version needs to be 10.2. Besides DGL==0.6.1, we depend on several packages. To install dependencies using conda:
```commandline
conda create -n pss python=3.7 # create env
conda activate pss # activate env
conda install pytorch==1.7.0 torchvision==0.8.0 cudatoolkit=10.2 -c pytorch # install pytorch 1.7 version
conda install -y cudatoolkit=10.2 faiss-gpu=1.6.5 -c pytorch # install faiss gpu version matching cuda 10.2
pip install dgl-cu102 # install dgl for cuda 10.2
pip install tqdm # install tqdm
pip install matplotlib # install matplotlib
pip install pandas # install pandas
pip install pretrainedmodels # install pretrainedmodels
pip install tensorboardX # install tensorboardX
pip install seaborn # install seaborn
pip install scikit-learn
cd ..
git clone https://github.com/yjxiong/clustering-benchmark.git # install clustering-benchmark for evaluation
cd clustering-benchmark
python setup.py install
cd ../PSS
```
## Data
We use the iNaturalist 2018 dataset.
- download link: https://www.kaggle.com/c/inaturalist-2018/data;
- annotations are in `Smooth_AP/data/Inaturalist`;
- annotation txt files for different data splits are in [S3 link]|[[Google Drive](https://drive.google.com/drive/folders/1xrWogJGef4Ex5OGjiImgA06bAnk2MDrK?usp=sharing)]|[[Baidu Netdisk](https://pan.baidu.com/s/14S0Fns29a4o7kFDlNyyPjA?pwd=uwsg)] (password:uwsg).
Download `train_val2018.tar.gz` and the data split txt files to `data/Inaturalist/` folder. Extract the `tar.gz` files.
The data folder has the following structure:
```bash
PSS
|- data
| |- Inaturalist
| |- train2018.json.tar.gz
| |- train_val2018.tar.gz
| |- val2018.json.tar.gz
| |- train_val2018
| | |- Actinopterygii
| | |- ...
| |- lin_train_set1.txt
| |- train_set1.txt
| |- uin_train_set1.txt
| |- uout_train_set1.txt
| |- in_train_set1.txt
| |- Inaturalist_test_set1.txt
|-...
```
## Training
Run `bash train.sh` to train the model.
## Test
Run `bash test.sh` to evaluate on the test set.
\ No newline at end of file
# Smooth_AP
Referenced from the ECCV '20 paper ["Smooth-AP: Smoothing the Path Towards Large-Scale Image Retrieval"](https://www.robots.ox.ac.uk/~vgg/research/smooth-ap/), reference code is from https://github.com/Andrew-Brown1/Smooth_AP.
![teaser](https://github.com/Andrew-Brown1/Smooth_AP/blob/master/ims/teaser.png)
# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines
################## LIBRARIES ##############################
import warnings
warnings.filterwarnings("ignore")
import numpy as np, os, csv, datetime, torch, faiss
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn import metrics
import pickle as pkl
from torch import nn
"""============================================================================================================="""
################### TensorBoard Settings ###################
def args2exp_name(args):
exp_name = f"{args.dataset}_{args.loss}_{args.lr}_bs{args.bs}_spc{args.samples_per_class}_embed{args.embed_dim}_arch{args.arch}_decay{args.decay}_fclr{args.fc_lr_mul}_anneal{args.sigmoid_temperature}"
return exp_name
################# ACQUIRE NUMBER OF WEIGHTS #################
def gimme_params(model):
"""
Provide number of trainable parameters (i.e. those requiring gradient computation) for input network.
Args:
model: PyTorch Network
Returns:
int, number of parameters.
"""
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return params
################# SAVE TRAINING PARAMETERS IN NICE STRING #################
def gimme_save_string(opt):
"""
Taking the set of parameters and convert it to easy-to-read string, which can be stored later.
Args:
opt: argparse.Namespace, contains all training-specific parameters.
Returns:
string, returns string summary of parameters.
"""
varx = vars(opt)
base_str = ''
for key in varx:
base_str += str(key)
if isinstance(varx[key],dict):
for sub_key, sub_item in varx[key].items():
base_str += '\n\t'+str(sub_key)+': '+str(sub_item)
else:
base_str += '\n\t'+str(varx[key])
base_str+='\n\n'
return base_str
def f1_score(model_generated_cluster_labels, target_labels, feature_coll, computed_centroids):
"""
NOTE: MOSTLY ADAPTED FROM https://github.com/wzzheng/HDML on Hardness-Aware Deep Metric Learning.
Args:
model_generated_cluster_labels: np.ndarray [n_samples x 1], Cluster labels computed on top of data embeddings.
target_labels: np.ndarray [n_samples x 1], ground truth labels for each data sample.
feature_coll: np.ndarray [n_samples x embed_dim], total data embedding made by network.
computed_centroids: np.ndarray [num_cluster=num_classes x embed_dim], cluster coordinates
Returns:
float, F1-score
"""
from scipy.special import comb
d = np.zeros(len(feature_coll))
for i in range(len(feature_coll)):
d[i] = np.linalg.norm(feature_coll[i,:] - computed_centroids[model_generated_cluster_labels[i],:])
labels_pred = np.zeros(len(feature_coll))
for i in np.unique(model_generated_cluster_labels):
index = np.where(model_generated_cluster_labels == i)[0]
ind = np.argmin(d[index])
cid = index[ind]
labels_pred[index] = cid
N = len(target_labels)
#Cluster n_labels
avail_labels = np.unique(target_labels)
n_labels = len(avail_labels)
#Count the number of objects in each cluster
count_cluster = np.zeros(n_labels)
for i in range(n_labels):
count_cluster[i] = len(np.where(target_labels == avail_labels[i])[0])
#Build a mapping from item_id to item index
keys = np.unique(labels_pred)
num_item = len(keys)
values = range(num_item)
item_map = dict()
for i in range(len(keys)):
item_map.update([(keys[i], values[i])])
#Count the number of objects of each item
count_item = np.zeros(num_item)
for i in range(N):
index = item_map[labels_pred[i]]
count_item[index] = count_item[index] + 1
#Compute True Positive (TP) plus False Positive (FP) count
tp_fp = 0
for k in range(n_labels):
if count_cluster[k] > 1:
tp_fp = tp_fp + comb(count_cluster[k], 2)
#Compute True Positive (TP) count
tp = 0
for k in range(n_labels):
member = np.where(target_labels == avail_labels[k])[0]
member_ids = labels_pred[member]
count = np.zeros(num_item)
for j in range(len(member)):
index = item_map[member_ids[j]]
count[index] = count[index] + 1
for i in range(num_item):
if count[i] > 1:
tp = tp + comb(count[i], 2)
#Compute False Positive (FP) count
fp = tp_fp - tp
#Compute False Negative (FN) count
count = 0
for j in range(num_item):
if count_item[j] > 1:
count = count + comb(count_item[j], 2)
fn = count - tp
# compute F measure
beta = 1
P = tp / (tp + fp)
R = tp / (tp + fn)
F1 = (beta*beta + 1) * P * R / (beta*beta * P + R)
return F1
"""============================================================================================================="""
def eval_metrics_one_dataset(model, test_dataloader, device, k_vals, opt):
"""
Compute evaluation metrics on test-dataset, e.g. NMI, F1 and Recall @ k.
Args:
model: PyTorch network, network to compute evaluation metrics for.
test_dataloader: PyTorch Dataloader, dataloader for test dataset, should have no shuffling and correct processing.
device: torch.device, Device to run inference on.
k_vals: list of int, Recall values to compute
opt: argparse.Namespace, contains all training-specific parameters.
Returns:
F1 score (float), NMI score (float), recall_at_k (list of float), data embedding (np.ndarray)
"""
torch.cuda.empty_cache()
_ = model.eval()
n_classes = len(test_dataloader.dataset.avail_classes)
with torch.no_grad():
### For all test images, extract features
target_labels, feature_coll = [],[]
final_iter = tqdm(test_dataloader, desc='Computing Evaluation Metrics...')
image_paths= [x[0] for x in test_dataloader.dataset.image_list]
for idx, inp in enumerate(final_iter):
input_img, target = inp[-1], inp[0]
target_labels.extend(target.numpy().tolist())
out = model(input_img.to(device), feature=True)
feature_coll.extend(out.cpu().detach().numpy().tolist())
#pdb.set_trace()
target_labels = np.hstack(target_labels).reshape(-1,1)
feature_coll = np.vstack(feature_coll).astype('float32')
torch.cuda.empty_cache()
### Set Faiss CPU Cluster index
cpu_cluster_index = faiss.IndexFlatL2(feature_coll.shape[-1])
kmeans = faiss.Clustering(feature_coll.shape[-1], n_classes)
kmeans.niter = 20
kmeans.min_points_per_centroid = 1
kmeans.max_points_per_centroid = 1000000000
### Train Kmeans
kmeans.train(feature_coll, cpu_cluster_index)
computed_centroids = faiss.vector_float_to_array(kmeans.centroids).reshape(n_classes, feature_coll.shape[-1])
### Assign feature points to clusters
faiss_search_index = faiss.IndexFlatL2(computed_centroids.shape[-1])
faiss_search_index.add(computed_centroids)
_, model_generated_cluster_labels = faiss_search_index.search(feature_coll, 1)
### Compute NMI
NMI = metrics.cluster.normalized_mutual_info_score(model_generated_cluster_labels.reshape(-1), target_labels.reshape(-1))
### Recover max(k_vals) nehbours to use for recall computation
faiss_search_index = faiss.IndexFlatL2(feature_coll.shape[-1])
faiss_search_index.add(feature_coll)
_, k_closest_points = faiss_search_index.search(feature_coll, int(np.max(k_vals)+1))
k_closest_classes = target_labels.reshape(-1)[k_closest_points[:,1:]]
print('computing recalls')
### Compute Recall
recall_all_k = []
for k in k_vals:
recall_at_k = np.sum([1 for target, recalled_predictions in zip(target_labels, k_closest_classes) if target in recalled_predictions[:k]])/len(target_labels)
recall_all_k.append(recall_at_k)
print('finished recalls')
print('computing F1')
### Compute F1 Score
F1 = 0
# F1 = f1_score(model_generated_cluster_labels, target_labels, feature_coll, computed_centroids)
print('finished computing f1')
return F1, NMI, recall_all_k, feature_coll
def eval_metrics_query_and_gallery_dataset(model, query_dataloader, gallery_dataloader, device, k_vals, opt):
"""
Compute evaluation metrics on test-dataset, e.g. NMI, F1 and Recall @ k.
Args:
model: PyTorch network, network to compute evaluation metrics for.
query_dataloader: PyTorch Dataloader, dataloader for query dataset, for which nearest neighbours in the gallery dataset are retrieved.
gallery_dataloader: PyTorch Dataloader, dataloader for gallery dataset, provides target samples which are to be retrieved in correspondance to the query dataset.
device: torch.device, Device to run inference on.
k_vals: list of int, Recall values to compute
opt: argparse.Namespace, contains all training-specific parameters.
Returns:
F1 score (float), NMI score (float), recall_at_ks (list of float), query data embedding (np.ndarray), gallery data embedding (np.ndarray)
"""
torch.cuda.empty_cache()
_ = model.eval()
n_classes = len(query_dataloader.dataset.avail_classes)
with torch.no_grad():
### For all query test images, extract features
query_target_labels, query_feature_coll = [],[]
query_image_paths = [x[0] for x in query_dataloader.dataset.image_list]
query_iter = tqdm(query_dataloader, desc='Extraction Query Features')
for idx,inp in enumerate(query_iter):
input_img,target = inp[-1], inp[0]
query_target_labels.extend(target.numpy().tolist())
out = model(input_img.to(device), feature=True)
query_feature_coll.extend(out.cpu().detach().numpy().tolist())
### For all gallery test images, extract features
gallery_target_labels, gallery_feature_coll = [],[]
gallery_image_paths = [x[0] for x in gallery_dataloader.dataset.image_list]
gallery_iter = tqdm(gallery_dataloader, desc='Extraction Gallery Features')
for idx,inp in enumerate(gallery_iter):
input_img,target = inp[-1], inp[0]
gallery_target_labels.extend(target.numpy().tolist())
out = model(input_img.to(device), feature=True)
gallery_feature_coll.extend(out.cpu().detach().numpy().tolist())
query_target_labels, query_feature_coll = np.hstack(query_target_labels).reshape(-1,1), np.vstack(query_feature_coll).astype('float32')
gallery_target_labels, gallery_feature_coll = np.hstack(gallery_target_labels).reshape(-1,1), np.vstack(gallery_feature_coll).astype('float32')
torch.cuda.empty_cache()
### Set CPU Cluster index
stackset = np.concatenate([query_feature_coll, gallery_feature_coll],axis=0)
stacklabels = np.concatenate([query_target_labels, gallery_target_labels],axis=0)
cpu_cluster_index = faiss.IndexFlatL2(stackset.shape[-1])
kmeans = faiss.Clustering(stackset.shape[-1], n_classes)
kmeans.niter = 20
kmeans.min_points_per_centroid = 1
kmeans.max_points_per_centroid = 1000000000
### Train Kmeans
kmeans.train(stackset, cpu_cluster_index)
computed_centroids = faiss.vector_float_to_array(kmeans.centroids).reshape(n_classes, stackset.shape[-1])
### Assign feature points to clusters
faiss_search_index = faiss.IndexFlatL2(computed_centroids.shape[-1])
faiss_search_index.add(computed_centroids)
_, model_generated_cluster_labels = faiss_search_index.search(stackset, 1)
### Compute NMI
NMI = metrics.cluster.normalized_mutual_info_score(model_generated_cluster_labels.reshape(-1), stacklabels.reshape(-1))
### Recover max(k_vals) nearest neighbours to use for recall computation
faiss_search_index = faiss.IndexFlatL2(gallery_feature_coll.shape[-1])
faiss_search_index.add(gallery_feature_coll)
_, k_closest_points = faiss_search_index.search(query_feature_coll, int(np.max(k_vals)))
k_closest_classes = gallery_target_labels.reshape(-1)[k_closest_points]
### Compute Recall
recall_all_k = []
for k in k_vals:
recall_at_k = np.sum([1 for target, recalled_predictions in zip(query_target_labels, k_closest_classes) if target in recalled_predictions[:k]])/len(query_target_labels)
recall_all_k.append(recall_at_k)
recall_str = ', '.join('@{0}: {1:.4f}'.format(k,rec) for k,rec in zip(k_vals, recall_all_k))
### Compute F1 score
F1 = f1_score(model_generated_cluster_labels, stacklabels, stackset, computed_centroids)
return F1, NMI, recall_all_k, query_feature_coll, gallery_feature_coll
"""============================================================================================================="""
####### RECOVER CLOSEST EXAMPLE IMAGES #######
def recover_closest_one_dataset(feature_matrix_all, image_paths, save_path, n_image_samples=10, n_closest=3):
"""
Provide sample recoveries.
Args:
feature_matrix_all: np.ndarray [n_samples x embed_dim], full data embedding of test samples.
image_paths: list [n_samples], list of datapaths corresponding to <feature_matrix_all>
save_path: str, where to store sample image.
n_image_samples: Number of sample recoveries.
n_closest: Number of closest recoveries to show.
Returns:
Nothing!
"""
image_paths = np.array([x[0] for x in image_paths])
sample_idxs = np.random.choice(np.arange(len(feature_matrix_all)), n_image_samples)
faiss_search_index = faiss.IndexFlatL2(feature_matrix_all.shape[-1])
faiss_search_index.add(feature_matrix_all)
_, closest_feature_idxs = faiss_search_index.search(feature_matrix_all, n_closest+1)
sample_paths = image_paths[closest_feature_idxs][sample_idxs]
f,axes = plt.subplots(n_image_samples, n_closest+1)
for i,(ax,plot_path) in enumerate(zip(axes.reshape(-1), sample_paths.reshape(-1))):
ax.imshow(np.array(Image.open(plot_path)))
ax.set_xticks([])
ax.set_yticks([])
if i%(n_closest+1):
ax.axvline(x=0, color='g', linewidth=13)
else:
ax.axvline(x=0, color='r', linewidth=13)
f.set_size_inches(10,20)
f.tight_layout()
f.savefig(save_path)
plt.close()
####### RECOVER CLOSEST EXAMPLE IMAGES #######
def recover_closest_inshop(query_feature_matrix_all, gallery_feature_matrix_all, query_image_paths, gallery_image_paths, save_path, n_image_samples=10, n_closest=3):
"""
Provide sample recoveries.
Args:
query_feature_matrix_all: np.ndarray [n_query_samples x embed_dim], full data embedding of query samples.
gallery_feature_matrix_all: np.ndarray [n_gallery_samples x embed_dim], full data embedding of gallery samples.
query_image_paths: list [n_samples], list of datapaths corresponding to <query_feature_matrix_all>
gallery_image_paths: list [n_samples], list of datapaths corresponding to <gallery_feature_matrix_all>
save_path: str, where to store sample image.
n_image_samples: Number of sample recoveries.
n_closest: Number of closest recoveries to show.
Returns:
Nothing!
"""
query_image_paths, gallery_image_paths = np.array(query_image_paths), np.array(gallery_image_paths)
sample_idxs = np.random.choice(np.arange(len(query_feature_matrix_all)), n_image_samples)
faiss_search_index = faiss.IndexFlatL2(gallery_feature_matrix_all.shape[-1])
faiss_search_index.add(gallery_feature_matrix_all)
_, closest_feature_idxs = faiss_search_index.search(query_feature_matrix_all, n_closest)
image_paths = gallery_image_paths[closest_feature_idxs]
image_paths = np.concatenate([query_image_paths.reshape(-1,1), image_paths],axis=-1)
sample_paths = image_paths[closest_feature_idxs][sample_idxs]
f,axes = plt.subplots(n_image_samples, n_closest+1)
for i,(ax,plot_path) in enumerate(zip(axes.reshape(-1), sample_paths.reshape(-1))):
ax.imshow(np.array(Image.open(plot_path)))
ax.set_xticks([])
ax.set_yticks([])
if i%(n_closest+1):
ax.axvline(x=0, color='g', linewidth=13)
else:
ax.axvline(x=0, color='r', linewidth=13)
f.set_size_inches(10,20)
f.tight_layout()
f.savefig(save_path)
plt.close()
"""============================================================================================================="""
################## SET NETWORK TRAINING CHECKPOINT #####################
def set_checkpoint(model, opt, progress_saver, savepath):
"""
Store relevant parameters (model and progress saver, as well as parameter-namespace).
Can be easily extend for other stuff.
Args:
model: PyTorch network, network whose parameters are to be saved.
opt: argparse.Namespace, includes all training-specific parameters
progress_saver: subclass of LOGGER-class, contains a running memory of all training metrics.
savepath: str, where to save checkpoint.
Returns:
Nothing!
"""
torch.save({'state_dict':model.state_dict(), 'opt':opt,
'progress':progress_saver}, savepath)
"""============================================================================================================="""
################## WRITE TO CSV FILE #####################
class CSV_Writer():
"""
Class to append newly compute training metrics to a csv file
for data logging.
Is used together with the LOGGER class.
"""
def __init__(self, save_path, columns):
"""
Args:
save_path: str, where to store the csv file
columns: list of str, name of csv columns under which the resp. metrics are stored.
Returns:
Nothing!
"""
self.save_path = save_path
self.columns = columns
with open(self.save_path, "a") as csv_file:
writer = csv.writer(csv_file, delimiter=",")
writer.writerow(self.columns)
def log(self, inputs):
"""
log one set of entries to the csv.
Args:
inputs: [list of int/str/float], values to append to the csv. Has to be of the same length as self.columns.
Returns:
Nothing!
"""
with open(self.save_path, "a") as csv_file:
writer = csv.writer(csv_file, delimiter=',')
writer.writerow(inputs)
################## GENERATE LOGGING FOLDER/FILES #######################
def set_logging(opt):
"""
Generate the folder in which everything is saved.
If opt.savename is given, folder will take on said name.
If not, a name based on the start time is provided.
If the folder already exists, it will by iterated until it can be created without
deleting existing data.
The current opt.save_path will be extended to account for the new save_folder name.
Args:
opt: argparse.Namespace, contains all training-specific parameters.
Returns:
Nothing!
"""
checkfolder = opt.save_path+'/'+str(opt.iter)
#Create start-time-based name if opt.savename is not give.
if opt.savename == '':
date = datetime.datetime.now()
checkfolder = opt.save_path+'/'+str(opt.iter)
#If folder already exists, iterate over it until is doesn't.
# counter = 1
# while os.path.exists(checkfolder):
# checkfolder = opt.save_path+'/'+opt.savename+'_'+str(counter)
# counter += 1
#Create Folder
if not os.path.exists(checkfolder):
os.makedirs(checkfolder)
opt.save_path = checkfolder
#Store training parameters as text and pickle in said folder.
with open(opt.save_path+'/Parameter_Info.txt','w') as f:
f.write(gimme_save_string(opt))
pkl.dump(opt,open(opt.save_path+"/hypa.pkl","wb"))
import pdb
class LOGGER():
"""
This class provides a collection of logging properties that are useful for training.
These include setting the save folder, in which progression of training/testing metrics is visualized,
csv log-files are stored, sample recoveries are plotted and an internal data saver.
"""
def __init__(self, opt, metrics_to_log, name='Basic', start_new=True):
"""
Args:
opt: argparse.Namespace, contains all training-specific parameters.
metrics_to_log: dict, dictionary which shows in what structure the data should be saved.
is given as the output of aux.metrics_to_examine. Example:
{'train': ['Epochs', 'Time', 'Train Loss', 'Time'],
'val': ['Epochs','Time','NMI','F1', 'Recall @ 1','Recall @ 2','Recall @ 4','Recall @ 8']}
name: Name of this logger. Will be used to distinguish logged files from other LOGGER instances.
start_new: If set to true, a new save folder will be created initially.
Returns:
Nothing!
"""
self.prop = opt
self.metrics_to_log = metrics_to_log
### Make Logging Directories
if start_new: set_logging(opt)
### Set Progress Saver Dict
self.progress_saver = self.provide_progress_saver(metrics_to_log)
### Set CSV Writters
self.csv_loggers= {mode:CSV_Writer(opt.save_path+'/log_'+mode+'_'+name+'.csv', lognames) for mode, lognames in metrics_to_log.items()}
def provide_progress_saver(self, metrics_to_log):
"""
Provide Progress Saver dictionary.
Args:
metrics_to_log: see __init__(). Describes the structure of Progress_Saver.
"""
Progress_Saver = {key:{sub_key:[] for sub_key in metrics_to_log[key]} for key in metrics_to_log.keys()}
return Progress_Saver
def log(self, main_keys, metric_keys, values):
"""
Actually log new values in csv and Progress Saver dict internally.
Args:
main_keys: Main key in which data will be stored. Normally is either 'train' for training metrics or 'val' for validation metrics.
metric_keys: Needs to follow the list length of self.progress_saver[main_key(s)]. List of metric keys that are extended with new values.
values: Needs to be a list of the same structure as metric_keys. Actual values that are appended.
"""
if not isinstance(main_keys, list): main_keys = [main_keys]
if not isinstance(metric_keys, list): metric_keys = [metric_keys]
if not isinstance(values, list): values = [values]
#Log data to progress saver dict.
for main_key in main_keys:
for value, metric_key in zip(values, metric_keys):
self.progress_saver[main_key][metric_key].append(value)
#Append data to csv.
self.csv_loggers[main_key].log(values)
def update_info_plot(self):
"""
Create a new updated version of training/metric progression plot.
Args:
None
Returns:
Nothing!
"""
t_epochs = self.progress_saver['val']['Epochs']
t_loss_list = [self.progress_saver['train']['Train Loss']]
t_legend_handles = ['Train Loss']
v_epochs = self.progress_saver['val']['Epochs']
#Because Vehicle-ID normally uses three different test sets, a distinction has to be made.
if self.prop.dataset != 'vehicle_id':
title = ' | '.join(key+': {0:3.3f}'.format(np.max(item)) for key,item in self.progress_saver['val'].items() if key not in ['Time', 'Epochs'])
self.info_plot.title = title
v_metric_list = [self.progress_saver['val'][key] for key in self.progress_saver['val'].keys() if key not in ['Time', 'Epochs']]
v_legend_handles = [key for key in self.progress_saver['val'].keys() if key not in ['Time', 'Epochs']]
self.info_plot.make_plot(t_epochs, v_epochs, t_loss_list, v_metric_list, t_legend_handles, v_legend_handles)
else:
#Iterate over all test sets.
for i in range(3):
title = ' | '.join(key+': {0:3.3f}'.format(np.max(item)) for key,item in self.progress_saver['val'].items() if key not in ['Time', 'Epochs'] and 'Set {}'.format(i) in key)
self.info_plot['Set {}'.format(i)].title = title
v_metric_list = [self.progress_saver['val'][key] for key in self.progress_saver['val'].keys() if key not in ['Time', 'Epochs'] and 'Set {}'.format(i) in key]
v_legend_handles = [key for key in self.progress_saver['val'].keys() if key not in ['Time', 'Epochs'] and 'Set {}'.format(i) in key]
self.info_plot['Set {}'.format(i)].make_plot(t_epochs, v_epochs, t_loss_list, v_metric_list, t_legend_handles, v_legend_handles, appendix='set_{}'.format(i))
def metrics_to_examine(dataset, k_vals):
"""
Please only use either of the following keys:
-> Epochs, Time, Train Loss for training
-> Epochs, Time, NMI, F1 & Recall @ k for validation
Args:
dataset: str, dataset for which a storing structure for LOGGER.progress_saver is to be made.
k_vals: list of int, Recall @ k - values.
Returns:
metric_dict: Dictionary representing the storing structure for LOGGER.progress_saver. See LOGGER.__init__() for an example.
"""
metric_dict = {'train':['Epochs','Time','Train Loss']}
if dataset=='vehicle_id':
metric_dict['val'] = ['Epochs','Time']
#Vehicle_ID uses three test sets
for i in range(3):
metric_dict['val'] += ['Set {} NMI'.format(i), 'Set {} F1'.format(i)]
for k in k_vals:
metric_dict['val'] += ['Set {} Recall @ {}'.format(i,k)]
else:
metric_dict['val'] = ['Epochs','Time','NMI', 'F1']
metric_dict['val'] += ['Recall @ {}'.format(k) for k in k_vals]
return metric_dict
def bool_flag(s):
"""
Parse boolean arguments from the command line.
"""
FALSY_STRINGS = {"off", "false", "0"}
TRUTHY_STRINGS = {"on", "true", "1"}
if s.lower() in FALSY_STRINGS:
return False
elif s.lower() in TRUTHY_STRINGS:
return True
else:
raise argparse.ArgumentTypeError("invalid value for a boolean flag")
def vis(model, test_dataloader, device, split, opt):
linsize = opt.linsize
torch.cuda.empty_cache()
if opt.dataset == "Inaturalist":
if opt.iter > 0:
with open(opt.cluster_path, 'rb') as clusterf:
path2idx, global_features, global_pred_labels, gt_labels, masks = pkl.load(clusterf)
gt_labels = gt_labels + len(np.unique(global_pred_labels))
idx2path = {v: k for k, v in path2idx.items()}
else:
with open(os.path.join(opt.source_path, "train_set1.txt")) as f:
filelines = f.readlines()
paths = [x.strip() for x in filelines]
Lin_paths = paths[:linsize]
masks = np.zeros(len(paths))
masks[:len(Lin_paths)] = 0
masks[len(Lin_paths):] = 2
_ = model.eval()
path2ids = {}
with torch.no_grad():
### For all test images, extract features
target_labels, feature_coll = [],[]
final_iter = tqdm(test_dataloader, desc='Computing Evaluation Metrics...')
image_paths = [x[0] for x in test_dataloader.dataset.image_list]
for i in range(len(image_paths)):
path2ids[image_paths[i]] = i
for idx, inp in enumerate(final_iter):
input_img, target = inp[-1], inp[0]
target_labels.extend(target.numpy().tolist())
out = model(input_img.to(device), feature=True)
feature_coll.extend(out.cpu().detach().numpy().tolist())
#pdb.set_trace()
target_labels = np.hstack(target_labels).reshape(-1)
feature_coll = np.vstack(feature_coll).astype('float32')
if (opt.dataset == "Inaturalist") and "all_train" in split:
if opt.iter > 0:
predicted_features = np.zeros_like(feature_coll)
path2ids_new = {}
target_labels_new = np.zeros_like(target_labels)
for i in range(len(idx2path.keys())):
path = idx2path[i]
idxx = path2ids[path]
path2ids_new[path] = i
predicted_features[i] = feature_coll[idxx]
target_labels_new[i] = target_labels[idxx]
path2ids = path2ids_new
feature_coll = predicted_features
target_labels = target_labels_new
gtlabels = target_labels
lastuselected = np.where(masks == 1)
masks[lastuselected] = 0
print(len(np.where(masks == 0)[0]))
else:
predicted_features = np.zeros_like(feature_coll)
path2ids_new = {}
target_labels_new = np.zeros_like(target_labels)
for i in range(len(paths)):
path = paths[i]
idxx = path2ids[opt.source_path+'/'+path]
path2ids_new[opt.source_path+'/'+path] = i
predicted_features[i] = feature_coll[idxx]
target_labels_new[i] = target_labels[idxx]
path2ids = path2ids_new
feature_coll = predicted_features
target_labels = target_labels_new
gtlabels = target_labels
if "all_train" not in split:
print("all_train not in split.")
gtlabels = target_labels
output_feature_path = os.path.join(opt.source_path,split+"_inat_features.pkl")
print("Dump features into {}.".format(output_feature_path))
with open(output_feature_path, "wb") as f:
pkl.dump([path2ids, feature_coll, target_labels, gtlabels, masks], f)
print(target_labels.max())
print("target_labels:", target_labels.shape)
print("feature_coll:", feature_coll.shape)
\ No newline at end of file
# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines
################# LIBRARIES ###############################
import pickle
import warnings
from numpy.core.arrayprint import IntegerFormat
warnings.filterwarnings("ignore")
import numpy as np, pandas as pd, copy, torch, random, os
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
"""============================================================================"""
################ FUNCTION TO RETURN ALL DATALOADERS NECESSARY ####################
def give_dataloaders(dataset, trainset, testset, opt, cluster_path=""):
"""
Args:
dataset: string, name of dataset for which the dataloaders should be returned.
opt: argparse.Namespace, contains all training-specific parameters.
Returns:
dataloaders: dict of dataloaders for training, testing and evaluation on training.
"""
#Dataset selection
if opt.dataset=='Inaturalist':
if opt.finetune:
datasets = give_inat_datasets_finetune_1head(testset, cluster_path, opt)
else:
if opt.get_features:
datasets = give_inaturalist_datasets_for_features(opt)
else:
datasets = give_inaturalist_datasets(opt)
else:
raise Exception('No Dataset >{}< available!'.format(dataset))
#Move datasets to dataloaders.
dataloaders = {}
for key, dataset in datasets.items():
if (isinstance(dataset, TrainDatasetsmoothap) or isinstance(dataset, TrainDatasetsmoothap1Head))\
and key in ['training', 'clustering']:
dataloaders[key] = torch.utils.data.DataLoader(dataset, batch_size=opt.bs,
num_workers=opt.kernels, sampler=torch.utils.data.SequentialSampler(dataset),
pin_memory=True, drop_last=True)
else:
is_val = dataset.is_validation
if key == 'training' or key == 'clustering':
dataloaders[key] = torch.utils.data.DataLoader(dataset, batch_size=opt.bs,
num_workers=opt.kernels, shuffle=not is_val, pin_memory=True, drop_last=not is_val)
else:
dataloaders[key] = torch.utils.data.DataLoader(dataset, batch_size=opt.bs,
num_workers=6, shuffle=not is_val, pin_memory=True, drop_last=not is_val)
return dataloaders
def give_inaturalist_datasets(opt):
"""
This function generates a training, testing and evaluation dataloader for Metric Learning on the Inaturalist 2018 dataset.
For Metric Learning, training and test sets are provided by given json files. Will define a train and test split
So no random shuffling of classes.
Args:
opt: argparse.Namespace, contains all traininig-specific parameters.
Returns:
dict of PyTorch datasets for training, testing and evaluation.
"""
#Load text-files containing classes and imagepaths.
#Generate image_dicts of shape {class_idx:[list of paths to images belong to this class] ...}
train_image_dict, val_image_dict, test_image_dict = {},{},{}
with open(os.path.join(opt.source_path, opt.trainset)) as f:
FileLines = f.readlines()
FileLines = [x.strip() for x in FileLines]
for entry in FileLines:
info = entry.split('/')
if '/'.join([info[-3],info[-2]]) not in train_image_dict:
train_image_dict['/'.join([info[-3],info[-2]])] = []
train_image_dict['/'.join([info[-3],info[-2]])].append(os.path.join(opt.source_path,entry))
with open(os.path.join(opt.source_path, opt.testset)) as f:
FileLines = f.readlines()
FileLines = [x.strip() for x in FileLines]
for entry in FileLines:
info = entry.split('/')
if '/'.join([info[-3],info[-2]]) not in val_image_dict:
val_image_dict['/'.join([info[-3], info[-2]])] = []
val_image_dict['/'.join([info[-3], info[-2]])].append(os.path.join(opt.source_path,entry))
with open(os.path.join(opt.source_path, opt.testset)) as f:
FileLines = f.readlines()
FileLines = [x.strip() for x in FileLines]
for entry in FileLines:
info = entry.split('/')
if '/'.join([info[-3],info[-2]]) not in test_image_dict:
test_image_dict['/'.join([info[-3],info[-2]])] = []
test_image_dict['/'.join([info[-3],info[-2]])].append(os.path.join(opt.source_path,entry))
new_train_dict = {}
class_ind_ind = 0
for cate in train_image_dict:
new_train_dict["te/%d"%class_ind_ind] = train_image_dict[cate]
class_ind_ind += 1
train_image_dict = new_train_dict
train_dataset = TrainDatasetsmoothap(train_image_dict, opt)
val_dataset = BaseTripletDataset(val_image_dict, opt, is_validation=True)
eval_dataset = BaseTripletDataset(test_image_dict, opt, is_validation=True)
#train_dataset.conversion = conversion
#val_dataset.conversion = conversion
#eval_dataset.conversion = conversion
return {'training':train_dataset, 'testing':val_dataset, 'evaluation':eval_dataset}
def give_inaturalist_datasets_for_features(opt):
"""
This function generates a training, testing and evaluation dataloader for Metric Learning on the Inaturalist 2018 dataset.
For Metric Learning, training and test sets are provided by given json files. Will define a train and test split
So no random shuffling of classes.
Args:
opt: argparse.Namespace, contains all traininig-specific parameters.
Returns:
dict of PyTorch datasets for training, testing and evaluation.
"""
# Load text-files containing classes and imagepaths.
# Generate image_dicts of shape {class_idx:[list of paths to images belong to this class] ...}
train_image_dict, test_image_dict, eval_image_dict = {}, {}, {}
if opt.iter > 0:
with open(os.path.join(opt.cluster_path), 'rb') as clusterf:
path2idx, global_features, global_pred_labels, gt_labels, masks = pickle.load(clusterf)
gt_labels = gt_labels + len(np.unique(global_pred_labels))
for path, idx in path2idx.items():
if global_pred_labels[idx] == -1:
if "te/%d" % gt_labels[idx] not in test_image_dict:
test_image_dict["te/%d" % gt_labels[idx]] = []
test_image_dict["te/%d" % gt_labels[idx]].append(path)
else:
if "te/%d" % global_pred_labels[idx] not in train_image_dict:
train_image_dict["te/%d" % global_pred_labels[idx]] = []
train_image_dict["te/%d" % global_pred_labels[idx]].append(path)
if "te/%d" % global_pred_labels[idx] not in test_image_dict:
test_image_dict["te/%d" % global_pred_labels[idx]] = []
test_image_dict["te/%d" % global_pred_labels[idx]].append(path)
else:
with open(os.path.join(opt.source_path, opt.trainset)) as f:
FileLines = f.readlines()
FileLines = [x.strip() for x in FileLines]
for entry in FileLines:
info = entry.split('/')
if '/'.join([info[-3], info[-2]]) not in train_image_dict:
train_image_dict['/'.join([info[-3], info[-2]])] = []
train_image_dict['/'.join([info[-3], info[-2]])].append(os.path.join(opt.source_path, entry))
with open(os.path.join(opt.source_path, opt.all_trainset)) as f:
FileLines = f.readlines()
FileLines = [x.strip() for x in FileLines]
for entry in FileLines:
info = entry.split('/')
if '/'.join([info[-3], info[-2]]) not in test_image_dict:
test_image_dict['/'.join([info[-3], info[-2]])] = []
test_image_dict['/'.join([info[-3], info[-2]])].append(os.path.join(opt.source_path, entry))
with open(os.path.join(opt.source_path, opt.testset)) as f:
FileLines = f.readlines()
FileLines = [x.strip() for x in FileLines]
for entry in FileLines:
info = entry.split('/')
if '/'.join([info[-3], info[-2]]) not in eval_image_dict:
eval_image_dict['/'.join([info[-3], info[-2]])] = []
eval_image_dict['/'.join([info[-3], info[-2]])].append(os.path.join(opt.source_path, entry))
new_train_dict = {}
class_ind_ind = 0
for cate in train_image_dict:
new_train_dict["te/%d" % class_ind_ind] = train_image_dict[cate]
class_ind_ind += 1
train_image_dict = new_train_dict
new_test_dict = {}
class_ind_ind = 0
for cate in test_image_dict:
new_test_dict["te/%d" % class_ind_ind] = test_image_dict[cate]
class_ind_ind += 1
test_image_dict = new_test_dict
new_eval_dict = {}
class_ind_ind = 0
for cate in eval_image_dict:
new_eval_dict["te/%d" % class_ind_ind] = eval_image_dict[cate]
class_ind_ind += 1
eval_image_dict = new_eval_dict
train_dataset = BaseTripletDataset(train_image_dict, opt, is_validation=True)
test_dataset = BaseTripletDataset(test_image_dict, opt, is_validation=True)
eval_dataset = BaseTripletDataset(eval_image_dict, opt, is_validation=True)
# train_dataset.conversion = conversion
# val_dataset.conversion = conversion
# eval_dataset.conversion = conversion
return {'training': train_dataset, 'testing': test_dataset, 'eval': eval_dataset}
def give_inat_datasets_finetune_1head(testset, cluster_label_path, opt):
"""
This function generates a training, testing and evaluation dataloader for Metric Learning on the Inaturalist 2018 dataset.
For Metric Learning, training and test sets are provided by given json files. Will define a train and test split
So no random shuffling of classes.
Args:
opt: argparse.Namespace, contains all traininig-specific parameters.
Returns:
dict of PyTorch datasets for training, testing and evaluation.
"""
# Load cluster labels from hilander results.
import pickle
train_image_dict, val_image_dict, cluster_image_dict = {}, {}, {}
with open(cluster_label_path, 'rb') as clusterf:
path2idx, global_features, global_pred_labels, gt_labels, masks = pickle.load(clusterf)
for path, idx in path2idx.items():
if global_pred_labels[idx] == -1:
continue
else:
if "te/%d" % global_pred_labels[idx] not in train_image_dict:
train_image_dict["te/%d" % global_pred_labels[idx]] = []
train_image_dict["te/%d" % global_pred_labels[idx]].append(path)
with open(os.path.join(opt.source_path, testset)) as f:
FileLines = f.readlines()
FileLines = [x.strip() for x in FileLines]
for entry in FileLines:
info = entry.split('/')
if '/'.join([info[-3], info[-2]]) not in val_image_dict:
val_image_dict['/'.join([info[-3], info[-2]])] = []
val_image_dict['/'.join([info[-3], info[-2]])].append(os.path.join(opt.source_path, entry))
train_dataset = TrainDatasetsmoothap(train_image_dict, opt)
val_dataset = BaseTripletDataset(val_image_dict, opt, is_validation=True)
# train_dataset.conversion = conversion
# val_dataset.conversion = conversion
# eval_dataset.conversion = conversion
return {'training': train_dataset, 'testing': val_dataset, 'evaluation': val_dataset}
################## BASIC PYTORCH DATASET USED FOR ALL DATASETS ##################################
class BaseTripletDataset(Dataset):
"""
Dataset class to provide (augmented) correctly prepared training samples corresponding to standard DML literature.
This includes normalizing to ImageNet-standards, and Random & Resized cropping of shapes 224 for ResNet50 and 227 for
GoogLeNet during Training. During validation, only resizing to 256 or center cropping to 224/227 is performed.
"""
def __init__(self, image_dict, opt, samples_per_class=8, is_validation=False):
"""
Dataset Init-Function.
Args:
image_dict: dict, Dictionary of shape {class_idx:[list of paths to images belong to this class] ...} providing all the training paths and classes.
opt: argparse.Namespace, contains all training-specific parameters.
samples_per_class: Number of samples to draw from one class before moving to the next when filling the batch.
is_validation: If is true, dataset properties for validation/testing are used instead of ones for training.
Returns:
Nothing!
"""
#Define length of dataset
self.n_files = np.sum([len(image_dict[key]) for key in image_dict.keys()])
self.is_validation = is_validation
self.pars = opt
self.image_dict = image_dict
self.avail_classes = sorted(list(self.image_dict.keys()))
#Convert image dictionary from classname:content to class_idx:content, because the initial indices are not necessarily from 0 - <n_classes>.
self.image_dict = {i:self.image_dict[key] for i,key in enumerate(self.avail_classes)}
self.avail_classes = sorted(list(self.image_dict.keys()))
#Init. properties that are used when filling up batches.
if not self.is_validation:
self.samples_per_class = samples_per_class
#Select current class to sample images from up to <samples_per_class>
self.current_class = np.random.randint(len(self.avail_classes))
self.classes_visited = [self.current_class, self.current_class]
self.n_samples_drawn = 0
#Data augmentation/processing methods.
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
transf_list = []
if not self.is_validation:
transf_list.extend([transforms.RandomResizedCrop(size=224) if opt.arch=='resnet50' else transforms.RandomResizedCrop(size=227),
transforms.RandomHorizontalFlip(0.5)])
else:
transf_list.extend([transforms.Resize(256),
transforms.CenterCrop(224) if opt.arch=='resnet50' else transforms.CenterCrop(227)])
transf_list.extend([transforms.ToTensor(), normalize])
self.transform = transforms.Compose(transf_list)
#Convert Image-Dict to list of (image_path, image_class). Allows for easier direct sampling.
self.image_list = [[(x,key) for x in self.image_dict[key]] for key in self.image_dict.keys()]
self.image_list = [x for y in self.image_list for x in y]
#Flag that denotes if dataset is called for the first time.
self.is_init = True
def ensure_3dim(self, img):
"""
Function that ensures that the input img is three-dimensional.
Args:
img: PIL.Image, image which is to be checked for three-dimensionality (i.e. if some images are black-and-white in an otherwise coloured dataset).
Returns:
Checked PIL.Image img.
"""
if len(img.size)==2:
img = img.convert('RGB')
return img
def __getitem__(self, idx):
"""
Args:
idx: Sample idx for training sample
Returns:
tuple of form (sample_class, torch.Tensor() of input image)
"""
if self.pars.loss == 'smoothap' or self.pars.loss == 'smoothap_element':
if self.is_init:
#self.current_class = self.avail_classes[idx%len(self.avail_classes)]
self.is_init = False
if not self.is_validation:
if self.samples_per_class==1:
return self.image_list[idx][-1], self.transform(self.ensure_3dim(Image.open(self.image_list[idx][0])))
if self.n_samples_drawn==self.samples_per_class:
#Once enough samples per class have been drawn, we choose another class to draw samples from.
#Note that we ensure with self.classes_visited that no class is chosen if it had been chosen
#previously or one before that.
counter = copy.deepcopy(self.avail_classes)
for prev_class in self.classes_visited:
if prev_class in counter: counter.remove(prev_class)
self.current_class = counter[idx%len(counter)]
#self.classes_visited = self.classes_visited[1:]+[self.current_class]
# EDIT -> there can be no class repeats
self.classes_visited = self.classes_visited+[self.current_class]
self.n_samples_drawn = 0
class_sample_idx = idx%len(self.image_dict[self.current_class])
self.n_samples_drawn += 1
out_img = self.transform(self.ensure_3dim(Image.open(self.image_dict[self.current_class][class_sample_idx])))
return self.current_class,out_img
else:
return self.image_list[idx][-1], self.transform(self.ensure_3dim(Image.open(self.image_list[idx][0])))
else:
if self.is_init:
self.current_class = self.avail_classes[idx%len(self.avail_classes)]
self.is_init = False
if not self.is_validation:
if self.samples_per_class==1:
return self.image_list[idx][-1], self.transform(self.ensure_3dim(Image.open(self.image_list[idx][0])))
if self.n_samples_drawn==self.samples_per_class:
#Once enough samples per class have been drawn, we choose another class to draw samples from.
#Note that we ensure with self.classes_visited that no class is chosen if it had been chosen
#previously or one before that.
counter = copy.deepcopy(self.avail_classes)
for prev_class in self.classes_visited:
if prev_class in counter: counter.remove(prev_class)
self.current_class = counter[idx%len(counter)]
self.classes_visited = self.classes_visited[1:]+[self.current_class]
self.n_samples_drawn = 0
class_sample_idx = idx%len(self.image_dict[self.current_class])
self.n_samples_drawn += 1
out_img = self.transform(self.ensure_3dim(Image.open(self.image_dict[self.current_class][class_sample_idx])))
return self.current_class, out_img
else:
return self.image_list[idx][-1], self.transform(self.ensure_3dim(Image.open(self.image_list[idx][0])))
def __len__(self):
return self.n_files
flatten = lambda l: [item for sublist in l for item in sublist]
######################## dataset for SmoothAP regular training ##################################
class TrainDatasetsmoothap(Dataset):
"""
This dataset class allows mini-batch formation pre-epoch, for greater speed
"""
def __init__(self, image_dict, opt):
"""
Args:
image_dict: two-level dict, `super_dict[super_class_id][class_id]` gives the list of
image paths having the same super-label and class label
"""
self.image_dict = image_dict
self.dataset_name = opt.dataset
self.batch_size = opt.bs
self.samples_per_class = opt.samples_per_class
for sub in self.image_dict:
newsub = []
for instance in self.image_dict[sub]:
newsub.append((sub, instance))
self.image_dict[sub] = newsub
# checks
# provide avail_classes
self.avail_classes = [*self.image_dict]
# Data augmentation/processing methods.
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
transf_list = []
transf_list.extend([
transforms.RandomResizedCrop(size=224) if opt.arch in ['resnet50', 'resnet50_mcn'] else transforms.RandomResizedCrop(size=227),
transforms.RandomHorizontalFlip(0.5)])
transf_list.extend([transforms.ToTensor(), normalize])
self.transform = transforms.Compose(transf_list)
self.reshuffle()
def ensure_3dim(self, img):
if len(img.size) == 2:
img = img.convert('RGB')
return img
def reshuffle(self):
image_dict = copy.deepcopy(self.image_dict)
print('shuffling data')
for sub in image_dict:
random.shuffle(image_dict[sub])
classes = [*image_dict]
random.shuffle(classes)
total_batches = []
batch = []
finished = 0
while finished == 0:
for sub_class in classes:
if (len(image_dict[sub_class]) >=self.samples_per_class) and (len(batch) < self.batch_size/self.samples_per_class) :
batch.append(image_dict[sub_class][:self.samples_per_class])
image_dict[sub_class] = image_dict[sub_class][self.samples_per_class:]
if len(batch) == self.batch_size/self.samples_per_class:
total_batches.append(batch)
batch = []
else:
finished = 1
random.shuffle(total_batches)
self.dataset = flatten(flatten(total_batches))
def __getitem__(self, idx):
# we use SequentialSampler together with SuperLabelTrainDataset,
# so idx==0 indicates the start of a new epoch
batch_item = self.dataset[idx]
if self.dataset_name == 'Inaturalist':
cls = int(batch_item[0].split('/')[1])
else:
cls = batch_item[0]
img = Image.open(batch_item[1])
return cls, self.transform(self.ensure_3dim(img))
def __len__(self):
return len(self.dataset)
class TrainDatasetsmoothap1Head(Dataset):
"""
This dataset class allows mini-batch formation pre-epoch, for greater speed
"""
def __init__(self, image_dict_L, image_dict_U, opt):
"""
Args:
image_dict: two-level dict, `super_dict[super_class_id][class_id]` gives the list of
image paths having the same super-label and class label
"""
self.image_dict_L = image_dict_L
self.image_dict_U = image_dict_U
self.dataset_name = opt.dataset
self.batch_size = opt.bs
self.samples_per_class = opt.samples_per_class
for sub_L in self.image_dict_L:
newsub_L = []
for instance in self.image_dict_L[sub_L]:
newsub_L.append((sub_L, instance))
self.image_dict_L[sub_L] = newsub_L
for sub_U in self.image_dict_U:
newsub_U = []
for instance in self.image_dict_U[sub_U]:
newsub_U.append((sub_U, instance))
self.image_dict_U[sub_U] = newsub_U
# checks
# provide avail_classes
self.avail_classes = [*self.image_dict_L] + [*self.image_dict_U]
# Data augmentation/processing methods.
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
transf_list = []
transf_list.extend([
transforms.RandomResizedCrop(size=224) if opt.arch in ['resnet50', 'resnet50_mcn'] else transforms.RandomResizedCrop(size=227),
transforms.RandomHorizontalFlip(0.5)])
transf_list.extend([transforms.ToTensor(), normalize])
self.transform = transforms.Compose(transf_list)
self.reshuffle()
def sample_same_size(self):
image_dict = copy.deepcopy(self.image_dict_L)
L_size = 0
for sub_L in self.image_dict_L:
L_size += len(self.image_dict_L[sub_L])
U_size = 0
classes_U = [*self.image_dict_U]
# while U_size < len(list(self.image_dict_U)) and U_size < L_size:
while len(classes_U) != 0:
sub_U = random.choice(classes_U)
classes_U.remove(sub_U)
sub_U_size = len(self.image_dict_U[sub_U])
if sub_U in [*image_dict]:
image_dict[sub_U].extend(self.image_dict_U[sub_U])
else:
image_dict[sub_U] = self.image_dict_U[sub_U]
U_size += sub_U_size
return image_dict
def ensure_3dim(self, img):
if len(img.size) == 2:
img = img.convert('RGB')
return img
def reshuffle(self):
image_dict = self.sample_same_size()
print('shuffling data')
for sub in image_dict:
random.shuffle(image_dict[sub])
classes = [*image_dict]
random.shuffle(classes)
total_batches = []
batch = []
finished = 0
while finished == 0:
for sub_class in classes:
if (len(image_dict[sub_class]) >=self.samples_per_class) and (len(batch) < self.batch_size/self.samples_per_class) :
batch.append(image_dict[sub_class][:self.samples_per_class])
image_dict[sub_class] = image_dict[sub_class][self.samples_per_class:]
if len(batch) == self.batch_size/self.samples_per_class:
total_batches.append(batch)
batch = []
else:
finished = 1
random.shuffle(total_batches)
self.dataset = flatten(flatten(total_batches))
def __getitem__(self, idx):
# we use SequentialSampler together with SuperLabelTrainDataset,
# so idx==0 indicates the start of a new epoch
batch_item = self.dataset[idx]
if self.dataset_name == 'Inaturalist':
cls = int(batch_item[0].split('/')[1])
else:
cls = batch_item[0]
img = Image.open(str(batch_item[1]))
return cls, self.transform(self.ensure_3dim(img))
def __len__(self):
return len(self.dataset)
# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines
##################################### LIBRARIES ###########################################
import warnings
warnings.filterwarnings("ignore")
import numpy as np, time, pickle as pkl, csv
import matplotlib.pyplot as plt
from scipy.spatial import distance
from sklearn.preprocessing import normalize
from tqdm import tqdm
import torch, torch.nn as nn
import auxiliaries as aux
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
"""=================================================================================================================="""
"""=================================================================================================================="""
"""========================================================="""
def evaluate(dataset, LOG, **kwargs):
"""
Given a dataset name, applies the correct evaluation function.
Args:
dataset: str, name of dataset.
LOG: aux.LOGGER instance, main logging class.
**kwargs: Input Argument Dict, depends on dataset.
Returns:
(optional) Computed metrics. Are normally written directly to LOG and printed.
"""
if dataset in ['Inaturalist', 'semi_fungi']:
ret = evaluate_one_dataset(LOG, **kwargs)
elif dataset in ['vehicle_id']:
ret = evaluate_multiple_datasets(LOG, **kwargs)
else:
raise Exception('No implementation for dataset {} available!')
return ret
"""========================================================="""
class DistanceMeasure():
"""
Container class to run and log the change of distance ratios
between intra-class distances and inter-class distances.
"""
def __init__(self, checkdata, opt, name='Train', update_epochs=1):
"""
Args:
checkdata: PyTorch DataLoader, data to check distance progression.
opt: argparse.Namespace, contains all training-specific parameters.
name: str, Name of instance. Important for savenames.
update_epochs: int, Only compute distance ratios every said epoch.
Returns:
Nothing!
"""
self.update_epochs = update_epochs
self.pars = opt
self.save_path = opt.save_path
self.name = name
self.csv_file = opt.save_path+'/distance_measures_{}.csv'.format(self.name)
with open(self.csv_file,'a') as csv_file:
writer = csv.writer(csv_file, delimiter=',')
writer.writerow(['Rel. Intra/Inter Distance'])
self.checkdata = checkdata
self.mean_class_dists = []
self.epochs = []
def measure(self, model, epoch):
"""
Compute distance ratios of intra- and interclass distance.
Args:
model: PyTorch Network, network that produces the resp. embeddings.
epoch: Current epoch.
Returns:
Nothing!
"""
if epoch%self.update_epochs: return
self.epochs.append(epoch)
torch.cuda.empty_cache()
_ = model.eval()
#Compute Embeddings
with torch.no_grad():
feature_coll, target_coll = [],[]
data_iter = tqdm(self.checkdata, desc='Estimating Data Distances...')
for idx, data in enumerate(data_iter):
input_img, target = data[1], data[0]
features = model(input_img.to(self.pars.device))
feature_coll.extend(features.cpu().detach().numpy().tolist())
target_coll.extend(target.numpy().tolist())
feature_coll = np.vstack(feature_coll).astype('float32')
target_coll = np.hstack(target_coll).reshape(-1)
avail_labels = np.unique(target_coll)
#Compute indixes of embeddings for each class.
class_positions = []
for lab in avail_labels:
class_positions.append(np.where(target_coll==lab)[0])
#Compute average intra-class distance and center of mass.
com_class, dists_class = [],[]
for class_pos in class_positions:
dists = distance.cdist(feature_coll[class_pos],feature_coll[class_pos],'cosine')
dists = np.sum(dists)/(len(dists)**2-len(dists))
# dists = np.linalg.norm(np.std(feature_coll_aux[class_pos],axis=0).reshape(1,-1)).reshape(-1)
com = normalize(np.mean(feature_coll[class_pos],axis=0).reshape(1,-1)).reshape(-1)
dists_class.append(dists)
com_class.append(com)
#Compute mean inter-class distances by the class-coms.
mean_inter_dist = distance.cdist(np.array(com_class), np.array(com_class), 'cosine')
mean_inter_dist = np.sum(mean_inter_dist)/(len(mean_inter_dist)**2-len(mean_inter_dist))
#Compute distance ratio
mean_class_dist = np.mean(np.array(dists_class)/mean_inter_dist)
self.mean_class_dists.append(mean_class_dist)
self.update(mean_class_dist)
def update(self, mean_class_dist):
"""
Update Loggers.
Args:
mean_class_dist: float, Distance Ratio
Returns:
Nothing!
"""
self.update_csv(mean_class_dist)
self.update_plot()
def update_csv(self, mean_class_dist):
"""
Update CSV.
Args:
mean_class_dist: float, Distance Ratio
Returns:
Nothing!
"""
with open(self.csv_file, 'a') as csv_file:
writer = csv.writer(csv_file, delimiter=',')
writer.writerow([mean_class_dist])
def update_plot(self):
"""
Update progression plot.
Args:
None.
Returns:
Nothing!
"""
plt.style.use('ggplot')
f,ax = plt.subplots(1)
ax.set_title('Mean Intra- over Interclassdistances')
ax.plot(self.epochs, self.mean_class_dists, label='Class')
f.legend()
f.set_size_inches(15,8)
f.savefig(self.save_path+'/distance_measures_{}.svg'.format(self.name))
class GradientMeasure():
"""
Container for gradient measure functionalities.
Measure the gradients coming from the embedding layer to the final conv. layer
to examine learning signal.
"""
def __init__(self, opt, name='class-it'):
"""
Args:
opt: argparse.Namespace, contains all training-specific parameters.
name: Name of class instance. Important for the savename.
Returns:
Nothing!
"""
self.pars = opt
self.name = name
self.saver = {'grad_normal_mean':[], 'grad_normal_std':[], 'grad_abs_mean':[], 'grad_abs_std':[]}
def include(self, params):
"""
Include the gradients for a set of parameters, normally the final embedding layer.
Args:
params: PyTorch Network layer after .backward() was called.
Returns:
Nothing!
"""
gradients = [params.weight.grad.detach().cpu().numpy()]
for grad in gradients:
### Shape: 128 x 2048
self.saver['grad_normal_mean'].append(np.mean(grad,axis=0))
self.saver['grad_normal_std'].append(np.std(grad,axis=0))
self.saver['grad_abs_mean'].append(np.mean(np.abs(grad),axis=0))
self.saver['grad_abs_std'].append(np.std(np.abs(grad),axis=0))
def dump(self, epoch):
"""
Append all gradients to a pickle file.
Args:
epoch: Current epoch
Returns:
Nothing!
"""
with open(self.pars.save_path+'/grad_dict_{}.pkl'.format(self.name),'ab') as f:
pkl.dump([self.saver], f)
self.saver = {'grad_normal_mean':[], 'grad_normal_std':[], 'grad_abs_mean':[], 'grad_abs_std':[]}
"""========================================================="""
def evaluate_one_dataset(LOG, dataloader, model, opt, save=True, give_return=True, epoch=0):
"""
Compute evaluation metrics, update LOGGER and print results.
Args:
LOG: aux.LOGGER-instance. Main Logging Functionality.
dataloader: PyTorch Dataloader, Testdata to be evaluated.
model: PyTorch Network, Network to evaluate.
opt: argparse.Namespace, contains all training-specific parameters.
save: bool, if True, Checkpoints are saved when testing metrics (specifically Recall @ 1) improve.
give_return: bool, if True, return computed metrics.
epoch: int, current epoch, required for logger.
Returns:
(optional) Computed metrics. Are normally written directly to LOG and printed.
"""
start = time.time()
image_paths = np.array(dataloader.dataset.image_list)
with torch.no_grad():
#Compute Metrics
F1, NMI, recall_at_ks, feature_matrix_all = aux.eval_metrics_one_dataset(model, dataloader, device=opt.device, k_vals=opt.k_vals, opt=opt)
#Make printable summary string.
result_str = ', '.join('@{0}: {1:.4f}'.format(k,rec) for k,rec in zip(opt.k_vals, recall_at_ks))
result_str = 'Epoch (Test) {0}: NMI [{1:.4f}] | F1 [{2:.4f}] | Recall [{3}]'.format(epoch, NMI, F1, result_str)
if LOG is not None:
if save:
if not len(LOG.progress_saver['val']['Recall @ 1']) or recall_at_ks[0]>np.max(LOG.progress_saver['val']['Recall @ 1']):
#Save Checkpoint
print("Set checkpoint at {}.".format(LOG.prop.save_path+'/checkpoint_{}.pth.tar'.format(opt.iter)))
aux.set_checkpoint(model, opt, LOG.progress_saver, LOG.prop.save_path+'/checkpoint_{}.pth.tar'.format(opt.iter))
# aux.recover_closest_one_dataset(feature_matrix_all, image_paths, LOG.prop.save_path+'/sample_recoveries.png')
#Update logs.
LOG.log('val', LOG.metrics_to_log['val'], [epoch, np.round(time.time()-start), NMI, F1]+recall_at_ks)
print(result_str)
if give_return:
return recall_at_ks, NMI, F1
else:
None
"""========================================================="""
def evaluate_query_and_gallery_dataset(LOG, query_dataloader, gallery_dataloader, model, opt, save=True, give_return=True, epoch=0):
"""
Compute evaluation metrics, update LOGGER and print results, specifically for In-Shop Clothes.
Args:
LOG: aux.LOGGER-instance. Main Logging Functionality.
query_dataloader: PyTorch Dataloader, Query-testdata to be evaluated.
gallery_dataloader: PyTorch Dataloader, Gallery-testdata to be evaluated.
model: PyTorch Network, Network to evaluate.
opt: argparse.Namespace, contains all training-specific parameters.
save: bool, if True, Checkpoints are saved when testing metrics (specifically Recall @ 1) improve.
give_return: bool, if True, return computed metrics.
epoch: int, current epoch, required for logger.
Returns:
(optional) Computed metrics. Are normally written directly to LOG and printed.
"""
start = time.time()
query_image_paths = np.array([x[0] for x in query_dataloader.dataset.image_list])
gallery_image_paths = np.array([x[0] for x in gallery_dataloader.dataset.image_list])
with torch.no_grad():
#Compute Metri cs.
F1, NMI, recall_at_ks, query_feature_matrix_all, gallery_feature_matrix_all = aux.eval_metrics_query_and_gallery_dataset(model, query_dataloader, gallery_dataloader, device=opt.device, k_vals = opt.k_vals, opt=opt)
#Generate printable summary string.
result_str = ', '.join('@{0}: {1:.4f}'.format(k,rec) for k,rec in zip(opt.k_vals, recall_at_ks))
result_str = 'Epoch (Test) {0}: NMI [{1:.4f}] | F1 [{2:.4f}] | Recall [{3}]'.format(epoch, NMI, F1, result_str)
if LOG is not None:
if save:
if not len(LOG.progress_saver['val']['Recall @ 1']) or recall_at_ks[0]>np.max(LOG.progress_saver['val']['Recall @ 1']):
#Save Checkpoint
aux.set_checkpoint(model, opt, LOG.progress_saver, LOG.prop.save_path+'/checkpoint.pth.tar')
aux.recover_closest_inshop(query_feature_matrix_all, gallery_feature_matrix_all, query_image_paths, gallery_image_paths, LOG.prop.save_path+'/sample_recoveries.png')
#Update logs.
LOG.log('val', LOG.metrics_to_log['val'], [epoch, np.round(time.time()-start), NMI, F1]+recall_at_ks)
print(result_str)
if give_return:
return recall_at_ks, NMI, F1
else:
None
"""========================================================="""
def evaluate_multiple_datasets(LOG, dataloaders, model, opt, save=True, give_return=True, epoch=0):
"""
Compute evaluation metrics, update LOGGER and print results, specifically for Multi-test datasets s.a. PKU Vehicle ID.
Args:
LOG: aux.LOGGER-instance. Main Logging Functionality.
dataloaders: List of PyTorch Dataloaders, test-dataloaders to evaluate.
model: PyTorch Network, Network to evaluate.
opt: argparse.Namespace, contains all training-specific parameters.
sa ve: bool, if True, Checkpoints are saved when testing metrics (specifically Recall @ 1) improve.
give_return: bool, i f True, return computed metrics.
epoch: int, current epoch, required for logger.
Returns :
(optional) Computed metrics. Are normally written directly to LOG and printed.
"""
start = time.time()
csv_data = [epoch]
with torch.no_grad():
for i,dataloader in enumerate(dataloaders):
print('Working on Set {}/{}'.format(i+1, len(dataloaders)))
image_paths = np.array(dataloader.dataset.image_list)
#Compute Metrics for specific testset.
F1, NMI, recall_at_ks, feature_matrix_all = aux.eval_metrics_one_dataset(model, dataloader, device=opt.device, k_vals=opt.k_vals, opt=opt)
#Generate printable summary string.
result_str = ', '.join('@{0}: {1:.4f}'.format(k,rec) for k,rec in zip(opt.k_vals, recall_at_ks))
result_str = 'SET {0}: Epoch (Test) {1}: NMI [{2:.4f}] | F1 {3:.4f}| Recall [{4}]'.format(i+1, epoch, NMI, F1, result_str)
if LOG is not None:
if save:
if not len(LOG.progress_saver['val']['Set {} Recall @ 1'.format(i)]) or recall_at_ks[0]>np.max(LOG.progress_saver['val']['Set {} Recall @ 1'.format(i)]):
#Save Checkpoint for specific test set.
aux.set_checkpoint(model, opt, LOG.progress_saver, LOG.prop.save_path+'/checkpoint_set{}.pth.tar'.format(i+1))
aux.recover_closest_one_dataset(feature_matrix_all, image_paths, LOG.prop.save_path+'/sample_recoveries_set{}.png'.format(i+1))
csv_data += [NMI, F1]+recall_at_ks
print(result_str)
csv_data.insert(0, np.round(time.time()-start))
#Update logs.
LOG.log('val', LOG.metrics_to_log['val'], csv_data)
#if give_return:
return csv_data[2:]
#else:
# None
\ No newline at end of file
import os, torch, argparse
import netlib as netlib
import auxiliaries as aux
import datasets as data
import evaluate as eval
if __name__ == '__main__':
################## INPUT ARGUMENTS ###################
parser = argparse.ArgumentParser()
####### Main Parameter: Dataset to use for Training
parser.add_argument('--dataset', default='vehicle_id', type=str, help='Dataset to use.',
choices=['Inaturalist', 'vehicle_id'])
parser.add_argument('--source_path', default='/scratch/shared/beegfs/abrown/datasets', type=str,
help='Path to training data.')
parser.add_argument('--save_path', default=os.getcwd() + '/Training_Results', type=str,
help='Where to save everything.')
parser.add_argument('--savename', default='', type=str,
help='Save folder name if any special information is to be included.')
### General Training Parameters
parser.add_argument('--kernels', default=8, type=int, help='Number of workers for pytorch dataloader.')
parser.add_argument('--bs', default=112, type=int, help='Mini-Batchsize to use.')
parser.add_argument('--samples_per_class', default=4, type=int,help='Number of samples in one class drawn before choosing the next class. Set to >1 for losses other than ProxyNCA.')
parser.add_argument('--loss', default='smoothap', type=str)
##### Evaluation Settings
parser.add_argument('--k_vals', nargs='+', default=[1, 2, 4, 8], type=int, help='Recall @ Values.')
##### Network parameters
parser.add_argument('--embed_dim', default=512, type=int,
help='Embedding dimensionality of the network. Note: in literature, dim=128 is used for ResNet50 and dim=512 for GoogLeNet.')
parser.add_argument('--arch', default='resnet50', type=str,
help='Network backend choice: resnet50, googlenet, BNinception')
parser.add_argument('--gpu', default=0, type=int, help='GPU-id for GPU to use.')
parser.add_argument('--resume', default='', type=str, help='path to where weights to be evaluated are saved.')
parser.add_argument('--not_pretrained', action='store_true',
help='If added, the network will be trained WITHOUT ImageNet-pretrained weights.')
parser.add_argument('--trainset', default="lin_train_set1.txt", type=str)
parser.add_argument('--testset', default="Inaturalist_test_set1.txt", type=str)
parser.add_argument('--cluster_path', default="", type=str)
parser.add_argument('--finetune', default="false", type=str)
parser.add_argument('--class_num', default=948, type=int)
parser.add_argument('--get_features', default="false", type=str)
parser.add_argument('--patch_size', default=16, type=int, help='vit patch size')
parser.add_argument('--pretrained_weights', default="", type=str, help='pretrained weight path')
parser.add_argument('--use_bn_in_head', default=False, type=aux.bool_flag,
help="Whether to use batch normalizations in projection head (Default: False)")
parser.add_argument("--checkpoint_key", default="teacher", type=str,
help='Key to use in the checkpoint (example: "teacher")')
parser.add_argument('--drop_path_rate', default=0.1, type=float, help="stochastic depth rate")
parser.add_argument('--norm_last_layer', default=True, type=aux.bool_flag,
help="""Whether or not to weight normalize the last layer of the DINO head.
Not normalizing leads to better performance but can make the training unstable.
In our experiments, we typically set this paramater to False with vit_small and True with vit_base.""")
parser.add_argument('--linsize', default=29011, type=int, help="Lin data size.")
parser.add_argument('--uinsize', default=18403, type=int, help="Uin data size.")
opt = parser.parse_args()
"""============================================================================"""
opt.source_path += '/' + opt.dataset
if opt.dataset == 'Inaturalist':
opt.n_epochs = 90
opt.tau = [40, 70]
opt.k_vals = [1, 4, 16, 32]
if opt.dataset == 'vehicle_id':
opt.k_vals = [1, 5]
if opt.finetune == 'true':
opt.finetune = True
elif opt.finetune == 'false':
opt.finetune = False
if opt.get_features == 'true':
opt.get_features = True
elif opt.get_features == 'false':
opt.get_features = False
metrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals)
LOG = aux.LOGGER(opt, metrics_to_log, name='Base', start_new=True)
"""============================================================================"""
##################### NETWORK SETUP ##################
opt.device = torch.device('cuda')
model = netlib.networkselect(opt)
# Push to Device
_ = model.to(opt.device)
"""============================================================================"""
#################### DATALOADER SETUPS ##################
# Returns a dictionary containing 'training', 'testing', and 'evaluation' dataloaders.
# The 'testing'-dataloader corresponds to the validation set, and the 'evaluation'-dataloader
# Is simply using the training set, however running under the same rules as 'testing' dataloader,
# i.e. no shuffling and no random cropping.
dataloaders = data.give_dataloaders(opt.dataset, opt.trainset, opt.testset, opt)
# Because the number of supervised classes is dataset dependent, we store them after
# initializing the dataloader
opt.num_classes = len(dataloaders['training'].dataset.avail_classes)
if opt.dataset == 'Inaturalist':
eval_params = {'dataloader': dataloaders['testing'], 'model': model, 'opt': opt, 'epoch': 0}
elif opt.dataset == 'vehicle_id':
eval_params = {
'dataloaders': [dataloaders['testing_set1'], dataloaders['testing_set2'], dataloaders['testing_set3']],
'model': model, 'opt': opt, 'epoch': 0}
"""============================================================================"""
####################evaluation ##################
results = eval.evaluate(opt.dataset, LOG, save=True, **eval_params)
# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines
"""to do:
clean all of the files - particularly the main.py and also the losses and dataset files and the file for doing the dataloading
-- fast loading etc
need to change all of the copyrights at the top of all of the files
"""
#################### LIBRARIES ########################
import warnings
warnings.filterwarnings("ignore")
import os, numpy as np, argparse, random, matplotlib, datetime
os.chdir(os.path.dirname(os.path.realpath(__file__)))
from pathlib import Path
matplotlib.use('agg')
from tqdm import tqdm
import auxiliaries as aux
import datasets as data
import netlib as netlib
import losses as losses
import evaluate as eval
from tensorboardX import SummaryWriter
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
import time
start = time.time()
################### INPUT ARGUMENTS ###################
parser = argparse.ArgumentParser()
####### Main Parameter: Dataset to use for Training
parser.add_argument('--dataset', default='Inaturalist', type=str, help='Dataset to use.', choices=['Inaturalist','semi_fungi'])
### General Training Parameters
parser.add_argument('--lr', default=0.00001, type=float, help='Learning Rate for network parameters.')
parser.add_argument('--fc_lr_mul', default=5, type=float, help='OPTIONAL: Multiply the embedding layer learning rate by this value. If set to 0, the embedding layer shares the same learning rate.')
parser.add_argument('--n_epochs', default=400, type=int, help='Number of training epochs.')
parser.add_argument('--kernels', default=8, type=int, help='Number of workers for pytorch dataloader.')
parser.add_argument('--bs', default=112 , type=int, help='Mini-Batchsize to use.')
parser.add_argument('--samples_per_class', default=4, type=int, help='Number of samples in one class drawn before choosing the next class')
parser.add_argument('--seed', default=1, type=int, help='Random seed for reproducibility.')
parser.add_argument('--scheduler', default='step', type=str, help='Type of learning rate scheduling. Currently: step & exp.')
parser.add_argument('--gamma', default=0.3, type=float, help='Learning rate reduction after tau epochs.')
parser.add_argument('--decay', default=0.001, type=float, help='Weight decay for optimizer.')
parser.add_argument('--tau', default= [200,300],nargs='+',type=int,help='Stepsize(s) before reducing learning rate.')
parser.add_argument('--infrequent_eval', default=0,type=int, help='only compute evaluation metrics every 10 epochs')
parser.add_argument('--opt', default = 'adam',help='adam or sgd')
##### Loss-specific Settings
parser.add_argument('--loss', default='smoothap', type=str)
parser.add_argument('--sigmoid_temperature', default=0.01, type=float, help='SmoothAP: the temperature of the sigmoid used in SmoothAP loss')
##### Evaluation Settings
parser.add_argument('--k_vals', nargs='+', default=[1,2,4,8], type=int, help='Recall @ Values.')
parser.add_argument('--resume', default='', type=str, help='path to checkpoint to load weights from (if empty then ImageNet pre-trained weights are loaded')
##### Network parameters
parser.add_argument('--embed_dim', default=512, type=int, help='Embedding dimensionality of the network')
parser.add_argument('--arch', default='resnet50', type=str, help='Network backend choice: resnet50, googlenet, BNinception')
parser.add_argument('--grad_measure', action='store_true', help='If added, gradients passed from embedding layer to the last conv-layer are stored in each iteration.')
parser.add_argument('--dist_measure', action='store_true', help='If added, the ratio between intra- and interclass distances is stored after each epoch.')
parser.add_argument('--not_pretrained', action='store_true', help='If added, the network will be trained WITHOUT ImageNet-pretrained weights.')
##### Setup Parameters
parser.add_argument('--gpu', default=0, type=int, help='GPU-id for GPU to use.')
parser.add_argument('--savename', default='', type=str, help='Save folder name if any special information is to be included.')
### Paths to datasets and storage folder
parser.add_argument('--source_path', default='/scratch/shared/beegfs/abrown/datasets', type=str, help='Path to data')
parser.add_argument('--save_path', default=os.getcwd()+'/Training_Results', type=str, help='Where to save the checkpoints')
### additional parameters
parser.add_argument('--trainset', default="lin_train_set1.txt", type=str)
parser.add_argument('--testset', default="Inaturalist_test_set1.txt", type=str)
parser.add_argument('--cluster_path', default="", type=str)
parser.add_argument('--finetune', default='true', type=str)
parser.add_argument('--class_num', default=948, type=int)
parser.add_argument('--pretrained_weights', default="", type=str, help='pretrained weight path')
parser.add_argument('--use_bn_in_head', default=False, type=aux.bool_flag,
help="Whether to use batch normalizations in projection head (Default: False)")
parser.add_argument("--checkpoint_key", default="teacher", type=str,
help='Key to use in the checkpoint (example: "teacher")')
parser.add_argument('--drop_path_rate', default=0.1, type=float, help="stochastic depth rate")
parser.add_argument('--iter', default=1, type=int)
opt = parser.parse_args()
"""============================================================================"""
opt.source_path += '/' + opt.dataset
opt.save_path += '/' + opt.dataset + "_" + str(opt.embed_dim)
if opt.dataset== 'Inaturalist':
# opt.n_epochs = 90
opt.tau = [40, 70]
opt.k_vals = [1,4,16,32]
if opt.dataset=='semi_fungi':
opt.tau = [40,70]
opt.k_vals = [1,4,16,32]
if opt.finetune == 'true':
opt.finetune = True
elif opt.finetune == 'false':
opt.finetune = False
"""==========================================================================="""
################### TensorBoard Settings ##################
timestamp = datetime.datetime.now().strftime(r"%Y-%m-%d_%H-%M-%S")
exp_name = aux.args2exp_name(opt)
opt.save_name = f"weights_{exp_name}" +'/'+ timestamp
random.seed(opt.seed)
np.random.seed(opt.seed)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed); torch.cuda.manual_seed_all(opt.seed)
tensorboard_path = Path(f"logs/logs_{exp_name}") / timestamp
tensorboard_path.parent.mkdir(exist_ok=True, parents=True)
global writer;
writer = SummaryWriter(tensorboard_path)
"""============================================================================"""
################### GPU SETTINGS ###########################
os.environ["CUDA_DEVICE_ORDER"] ="PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]= str(opt.gpu)
print('using #GPUs:',torch.cuda.device_count())
"""============================================================================"""
#################### DATALOADER SETUPS ##################
#Returns a dictionary containing 'training', 'testing', and 'evaluation' dataloaders.
#The 'testing'-dataloader corresponds to the validation set, and the 'evaluation'-dataloader
#Is simply using the training set, however running under the same rules as 'testing' dataloader,
#i.e. no shuffling and no random cropping.
dataloaders = data.give_dataloaders(opt.dataset, opt.trainset, opt.testset, opt, cluster_path=opt.cluster_path)
#Because the number of supervised classes is dataset dependent, we store them after
#initializing the dataloader
opt.num_classes = len(dataloaders['training'].dataset.avail_classes)
print("num_classes:", opt.num_classes)
print("train dataset size:", len(dataloaders['training']))
"""============================================================================"""
##################### NETWORK SETUP ##################
opt.device = torch.device('cuda')
model = netlib.networkselect(opt)
#Push to Device
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
_ = model.to(opt.device)
#Place trainable parameter in list of parameters to train:
if 'fc_lr_mul' in vars(opt).keys() and opt.fc_lr_mul!=0:
all_but_fc_params = list(filter(lambda x: 'last_linear' not in x[0],model.named_parameters()))
for ind, param in enumerate(all_but_fc_params):
all_but_fc_params[ind] = param[1]
if torch.cuda.device_count() > 1:
fc_params = model.module.model.last_linear.parameters()
else:
fc_params = model.model.last_linear.parameters()
to_optim = [{'params':all_but_fc_params,'lr':opt.lr,'weight_decay':opt.decay},
{'params':fc_params,'lr':opt.lr*opt.fc_lr_mul,'weight_decay':opt.decay}]
else:
to_optim = [{'params':model.parameters(),'lr':opt.lr,'weight_decay':opt.decay}]
"""============================================================================"""
#################### CREATE LOGGING FILES ###############
#Each dataset usually has a set of standard metrics to log. aux.metrics_to_examine()
#returns a dict which lists metrics to log for training ('train') and validation/testing ('val')
metrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals)
# example output: {'train': ['Epochs', 'Time', 'Train Loss', 'Time'],
# 'val': ['Epochs','Time','NMI','F1', 'Recall @ 1','Recall @ 2','Recall @ 4','Recall @ 8']}
#Using the provided metrics of interest, we generate a LOGGER instance.
#Note that 'start_new' denotes that a new folder should be made in which everything will be stored.
#This includes network weights as well.
LOG = aux.LOGGER(opt, metrics_to_log, name='Base', start_new=True)
#If graphviz is installed on the system, a computational graph of the underlying
#network will be made as well.
"""============================================================================"""
#################### LOSS SETUP ####################
#Depending on opt.loss and opt.sampling, the respective criterion is returned,
#and if the loss has trainable parameters, to_optim is appended.
criterion, to_optim = losses.loss_select(opt.loss, opt, to_optim)
_ = criterion.to(opt.device)
"""============================================================================"""
##################### OPTIONAL EVALUATIONS #####################
#Store the averaged gradients returned from the embedding to the last conv. layer.
if opt.grad_measure:
grad_measure = eval.GradientMeasure(opt, name='baseline')
#Store the relative distances between average intra- and inter-class distance.
if opt.dist_measure:
#Add a distance measure for training distance ratios
distance_measure = eval.DistanceMeasure(dataloaders['evaluation'], opt, name='Train', update_epochs=1)
# #If uncommented: Do the same for the test set
# distance_measure_test = eval.DistanceMeasure(dataloaders['testing'], opt, name='Train', update_epochs=1)
"""============================================================================"""
#################### OPTIM SETUP ####################
#As optimizer, Adam with standard parameters is used.
if opt.opt == 'adam':
optimizer = torch.optim.Adam(to_optim)
elif opt.opt == 'sgd':
optimizer = torch.optim.SGD(to_optim)
else:
raise Exception('unknown optimiser')
# for the SOA measures in the paper - need to use SGD and 0.05 learning rate
#optimizer = torch.optim.Adam(to_optim)
#optimizer = torch.optim.SGD(to_optim)
if opt.scheduler=='exp':
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=opt.gamma)
elif opt.scheduler=='step':
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.tau, gamma=opt.gamma)
elif opt.scheduler=='none':
print('No scheduling used!')
else:
raise Exception('No scheduling option for input: {}'.format(opt.scheduler))
def same_model(model1,model2):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
if p1.data.ne(p2.data).sum() > 0:
return False
return True
"""============================================================================"""
#################### TRAINER FUNCTION ############################
def train_one_epoch_finetune(train_dataloader, model, optimizer, criterion, opt, epoch):
"""
This function is called every epoch to perform training of the network over one full
(randomized) iteration of the dataset.
Args:
train_dataloader: torch.utils.data.DataLoader, returns (augmented) training data.
model: Network to train.
optimizer: Optimizer to use for training.
criterion: criterion to use during training.
opt: argparse.Namespace, Contains all relevant parameters.
epoch: int, Current epoch.
Returns:
Nothing!
"""
loss_collect = []
start = time.time()
data_iterator = tqdm(train_dataloader, desc='Epoch {} Training gt labels...'.format(epoch))
for i,(class_labels, input) in enumerate(data_iterator):
#Compute embeddings for input batch
features = model(input.to(opt.device))
#Compute loss.
if opt.loss != 'smoothap':
loss = criterion(features, class_labels)
else:
loss = criterion(features)
#Ensure gradients are set to zero at beginning
optimizer.zero_grad()
#Compute gradient
loss.backward()
train_dataloader.dataset.classes_visited = []
if opt.grad_measure:
#If desired, save computed gradients.
grad_measure.include(model.model.last_linear)
#Update weights using comp. gradients.
optimizer.step()
#Store loss per iteration.
loss_collect.append(loss.item())
if i==len(train_dataloader)-1:
data_iterator.set_description('Epoch (Train) {0}: Mean Loss [{1:.4f}]'.format(epoch, np.mean(loss_collect)))
#Save metrics
LOG.log('train', LOG.metrics_to_log['train'], [epoch, np.round(time.time()-start,4), np.mean(loss_collect)])
writer.add_scalar('global/training_loss',np.mean(loss_collect),epoch)
if opt.grad_measure:
#Dump stored gradients to Pickle-File.
grad_measure.dump(epoch)
"""============================================================================"""
"""========================== MAIN TRAINING PART =============================="""
"""============================================================================"""
################### SCRIPT MAIN ##########################
print('\n-----\n')
# Each dataset requires slightly different dataloaders.
if opt.dataset == 'Inaturalist' or 'semi_fungi':
eval_params = {'dataloader': dataloaders['testing'], 'model': model, 'opt': opt, 'epoch': 0}
# Compute Evaluation metrics, print them and store in LOG.
print('epochs -> '+str(opt.n_epochs))
import time
for epoch in range(opt.n_epochs):
### Print current learning rates for all parameters
if opt.scheduler!='none': print('Running with learning rates {}...'.format(' | '.join('{}'.format(x) for x in scheduler.get_lr())))
### Train one epoch
_ = model.train()
train_one_epoch_finetune(dataloaders['training'], model, optimizer, criterion, opt, epoch)
dataloaders['training'].dataset.reshuffle()
### Evaluate
_ = model.eval()
#Each dataset requires slightly different dataloaders.
if opt.dataset == 'Inaturalist':
eval_params = {'dataloader':dataloaders['testing'], 'model':model, 'opt':opt, 'epoch':epoch}
elif opt.dataset=='semi_fungi':
eval_params = {'dataloader':dataloaders['testing'], 'model':model, 'opt':opt, 'epoch':epoch}
#Compute Evaluation metrics, print them and store in LOG.
if opt.infrequent_eval == 1:
epoch_freq = 10
else:
epoch_freq = 1
if epoch%epoch_freq == 0:
results = eval.evaluate(opt.dataset, LOG, save=True, **eval_params)
writer.add_scalar('global/recall1',results[0][0],epoch+1)
writer.add_scalar('global/recall2',results[0][1],epoch+1)
writer.add_scalar('global/recall3',results[0][2],epoch+1)
writer.add_scalar('global/recall4',results[0][3],epoch+1)
writer.add_scalar('global/NMI',results[1],epoch+1)
writer.add_scalar('global/F1',results[2],epoch+1)
#Update the Metric Plot and save it.
#LOG.update_info_plot()
#(optional) compute ratio of intra- to interdistances.
if opt.dist_measure:
distance_measure.measure(model, epoch)
# distance_measure_test.measure(model, epoch)
### Learning Rate Scheduling Step
if opt.scheduler != 'none':
scheduler.step()
print('\n-----\n')
print("Time:" ,time.time() - start)
# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines
"""to do:
clean all of the files - particularly the main.py and also the losses and dataset files and the file for doing the dataloading
-- fast loading etc
need to change all of the copyrights at the top of all of the files
"""
#################### LIBRARIES ########################
import warnings
warnings.filterwarnings("ignore")
import os, numpy as np, argparse, random, matplotlib, datetime
os.chdir(os.path.dirname(os.path.realpath(__file__)))
matplotlib.use('agg')
import auxiliaries as aux
import datasets as data
import netlib as netlib
import losses as losses
import evaluate as eval
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
################### INPUT ARGUMENTS ###################
parser = argparse.ArgumentParser()
####### Main Parameter: Dataset to use for Training
parser.add_argument('--dataset', default='Inaturalist', type=str, help='Dataset to use.', choices=['Inaturalist', 'semi_fungi'])
### General Training Parameters
parser.add_argument('--lr', default=0.00001, type=float, help='Learning Rate for network parameters.')
parser.add_argument('--fc_lr_mul', default=5, type=float, help='OPTIONAL: Multiply the embedding layer learning rate by this value. If set to 0, the embedding layer shares the same learning rate.')
parser.add_argument('--n_epochs', default=400, type=int, help='Number of training epochs.')
parser.add_argument('--kernels', default=8, type=int, help='Number of workers for pytorch dataloader.')
parser.add_argument('--bs', default=112 , type=int, help='Mini-Batchsize to use.')
parser.add_argument('--samples_per_class', default=4, type=int, help='Number of samples in one class drawn before choosing the next class')
parser.add_argument('--seed', default=1, type=int, help='Random seed for reproducibility.')
parser.add_argument('--scheduler', default='step', type=str, help='Type of learning rate scheduling. Currently: step & exp.')
parser.add_argument('--gamma', default=0.3, type=float, help='Learning rate reduction after tau epochs.')
parser.add_argument('--decay', default=0.0004, type=float, help='Weight decay for optimizer.')
parser.add_argument('--tau', default= [200,300],nargs='+',type=int,help='Stepsize(s) before reducing learning rate.')
parser.add_argument('--infrequent_eval', default=0,type=int, help='only compute evaluation metrics every 10 epochs')
parser.add_argument('--opt', default = 'adam',help='adam or sgd')
##### Loss-specific Settings
parser.add_argument('--loss', default='smoothap', type=str)
parser.add_argument('--sigmoid_temperature', default=0.01, type=float, help='SmoothAP: the temperature of the sigmoid used in SmoothAP loss')
##### Evaluation Settings
parser.add_argument('--k_vals', nargs='+', default=[1,2,4,8], type=int, help='Recall @ Values.')
parser.add_argument('--resume', default='', type=str, help='path to checkpoint to load weights from (if empty then ImageNet pre-trained weights are loaded')
##### Network parameters
parser.add_argument('--embed_dim', default=512, type=int, help='Embedding dimensionality of the network')
parser.add_argument('--arch', default='resnet50', type=str, help='Network backend choice: resnet50, googlenet, BNinception')
parser.add_argument('--grad_measure', action='store_true', help='If added, gradients passed from embedding layer to the last conv-layer are stored in each iteration.')
parser.add_argument('--dist_measure', action='store_true', help='If added, the ratio between intra- and interclass distances is stored after each epoch.')
parser.add_argument('--not_pretrained', action='store_true', help='If added, the network will be trained WITHOUT ImageNet-pretrained weights.')
##### Setup Parameters
parser.add_argument('--gpu', default=0, type=int, help='GPU-id for GPU to use.')
parser.add_argument('--savename', default='', type=str, help='Save folder name if any special information is to be included.')
### Paths to datasets and storage folder
parser.add_argument('--source_path', default='/scratch/shared/beegfs/abrown/datasets', type=str, help='Path to data')
parser.add_argument('--save_path', default=os.getcwd()+'/Training_Results', type=str, help='Where to save the checkpoints')
### adational
parser.add_argument('--trainset', default="lin_train_set1.txt", type=str)
parser.add_argument('--all_trainset', default="train_set1.txt", type=str)
parser.add_argument('--testset', default="test_set1.txt", type=str)
parser.add_argument('--finetune', default='true', type=str)
parser.add_argument('--cluster_path', default="", type=str)
parser.add_argument('--get_features', default="false", type=str)
parser.add_argument('--class_num', default=948, type=int)
parser.add_argument('--iter', default=0, type=int)
parser.add_argument('--pretrained_weights', default="", type=str, help='pretrained weight path')
parser.add_argument('--use_bn_in_head', default=False, type=aux.bool_flag,
help="Whether to use batch normalizations in projection head (Default: False)")
parser.add_argument("--checkpoint_key", default="teacher", type=str,
help='Key to use in the checkpoint (example: "teacher")')
parser.add_argument('--drop_path_rate', default=0.1, type=float, help="stochastic depth rate")
parser.add_argument('--linsize', default=29011, type=int, help="Lin data size.")
parser.add_argument('--uinsize', default=18403, type=int, help="Uin data size.")
opt = parser.parse_args()
"""============================================================================"""
opt.source_path += '/' + opt.dataset
opt.save_path += '/' + opt.dataset + "_" + str(opt.embed_dim)
if opt.dataset== 'Inaturalist':
opt.n_epochs = 90
opt.tau = [40,70]
opt.k_vals = [1,4,16,32]
if opt.dataset=='semi_fungi':
opt.tau = [40,70]
opt.k_vals = [1,4,16,32]
if opt.get_features == "true":
opt.get_features = True
if opt.get_features == "false":
opt.get_features = False
if opt.finetune == 'true':
opt.finetune = True
elif opt.finetune == 'false':
opt.finetune = False
"""==========================================================================="""
################### TensorBoard Settings ##################
timestamp = datetime.datetime.now().strftime(r"%Y-%m-%d_%H-%M-%S")
exp_name = aux.args2exp_name(opt)
opt.save_name = f"weights_{exp_name}" +'/'+ timestamp
random.seed(opt.seed)
np.random.seed(opt.seed)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed); torch.cuda.manual_seed_all(opt.seed)
"""============================================================================"""
################### GPU SETTINGS ###########################
os.environ["CUDA_DEVICE_ORDER"] ="PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]= str(opt.gpu)
print('using #GPUs:',torch.cuda.device_count())
"""============================================================================"""
##################### NETWORK SETUP ##################
opt.device = torch.device('cuda')
model = netlib.networkselect(opt)
#Push to Device
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
_ = model.to(opt.device)
#Place trainable parameter in list of parameters to train:
if 'fc_lr_mul' in vars(opt).keys() and opt.fc_lr_mul!=0:
all_but_fc_params = list(filter(lambda x: 'last_linear' not in x[0],model.named_parameters()))
for ind, param in enumerate(all_but_fc_params):
all_but_fc_params[ind] = param[1]
if torch.cuda.device_count() > 1:
fc_params = model.module.model.last_linear.parameters()
else:
fc_params = model.model.last_linear.parameters()
to_optim = [{'params':all_but_fc_params,'lr':opt.lr,'weight_decay':opt.decay},
{'params':fc_params,'lr':opt.lr*opt.fc_lr_mul,'weight_decay':opt.decay}]
else:
to_optim = [{'params':model.parameters(),'lr':opt.lr,'weight_decay':opt.decay}]
"""============================================================================"""
#################### DATALOADER SETUPS ##################
#Returns a dictionary containing 'training', 'testing', and 'evaluation' dataloaders.
#The 'testing'-dataloader corresponds to the validation set, and the 'evaluation'-dataloader
#Is simply using the training set, however running under the same rules as 'testing' dataloader,
#i.e. no shuffling and no random cropping.
dataloaders = data.give_dataloaders(opt.dataset, opt.trainset, opt.testset, opt)
#Because the number of supervised classes is dataset dependent, we store them after
#initializing the dataloader
opt.num_classes = len(dataloaders['training'].dataset.avail_classes)
"""============================================================================"""
#################### CREATE LOGGING FILES ###############
#Each dataset usually has a set of standard metrics to log. aux.metrics_to_examine()
#returns a dict which lists metrics to log for training ('train') and validation/testing ('val')
metrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals)
# example output: {'train': ['Epochs', 'Time', 'Train Loss', 'Time'],
# 'val': ['Epochs','Time','NMI','F1', 'Recall @ 1','Recall @ 2','Recall @ 4','Recall @ 8']}
#Using the provided metrics of interest, we generate a LOGGER instance.
#Note that 'start_new' denotes that a new folder should be made in which everything will be stored.
#This includes network weights as well.
#If graphviz is installed on the system, a computational graph of the underlying
#network will be made as well.
"""============================================================================"""
#################### LOSS SETUP ####################
#Depending on opt.loss and opt.sampling, the respective criterion is returned,
#and if the loss has trainable parameters, to_optim is appended.
LOG = aux.LOGGER(opt, metrics_to_log, name='Base', start_new=True)
criterion, to_optim = losses.loss_select(opt.loss, opt, to_optim)
_ = criterion.to(opt.device)
"""============================================================================"""
##################### OPTIONAL EVALUATIONS #####################
#Store the averaged gradients returned from the embedding to the last conv. layer.
if opt.grad_measure:
grad_measure = eval.GradientMeasure(opt, name='baseline')
#Store the relative distances between average intra- and inter-class distance.
if opt.dist_measure:
#Add a distance measure for training distance ratios
distance_measure = eval.DistanceMeasure(dataloaders['evaluation'], opt, name='Train', update_epochs=1)
# #If uncommented: Do the same for the test set
# distance_measure_test = eval.DistanceMeasure(dataloaders['testing'], opt, name='Train', update_epochs=1)
"""============================================================================"""
#################### OPTIM SETUP ####################
#As optimizer, Adam with standard parameters is used.
if opt.opt == 'adam':
optimizer = torch.optim.Adam(to_optim)
elif opt.opt == 'sgd':
optimizer = torch.optim.SGD(to_optim)
else:
raise Exception('unknown optimiser')
# for the SOA measures in the paper - need to use SGD and 0.05 learning rate
#optimizer = torch.optim.Adam(to_optim)
#optimizer = torch.optim.SGD(to_optim)
if opt.scheduler=='exp':
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=opt.gamma)
elif opt.scheduler=='step':
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.tau, gamma=opt.gamma)
elif opt.scheduler=='none':
print('No scheduling used!')
else:
raise Exception('No scheduling option for input: {}'.format(opt.scheduler))
def same_model(model1,model2):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
if p1.data.ne(p2.data).sum() > 0:
return False
return True
"""============================================================================"""
"""================================ TESTING ==================================="""
"""============================================================================"""
################### SCRIPT MAIN ##########################
print('\n-----\n')
# Compute Evaluation metrics, print them and store in LOG.
_ = model.eval()
aux.vis(model, dataloaders['training'], opt.device, split="T_train_iter"+str(opt.iter)+"_"+str(opt.loss), opt=opt)
aux.vis(model, dataloaders['testing'], opt.device, split="all_train_iter"+str(opt.iter)+"_"+str(opt.loss), opt=opt)
aux.vis(model, dataloaders['eval'], opt.device, split="test_iter"+str(opt.iter)+"_"+str(opt.loss), opt=opt)
#Update the Metric Plot and save it.
print('\n-----\n')
# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines
###################### LIBRARIES #################################################
import warnings
warnings.filterwarnings("ignore")
import torch, faiss
import numpy as np
from scipy import sparse
"""================================================================================================="""
############ LOSS SELECTION FUNCTION #####################
def loss_select(loss, opt, to_optim):
"""
Selection function which returns the respective criterion while appending to list of trainable parameters if required.
Args:
loss: str, name of loss function to return.
opt: argparse.Namespace, contains all training-specific parameters.
to_optim: list of trainable parameters. Is extend if loss function contains those as well.
Returns:
criterion (torch.nn.Module inherited), to_optim (optionally appended)
"""
if loss == 'smoothap':
loss_params = {'anneal':opt.sigmoid_temperature, 'batch_size':opt.bs, "num_id":int(opt.bs / opt.samples_per_class), 'feat_dims':opt.embed_dim}
criterion = SmoothAP(**loss_params)
else:
raise Exception('Loss {} not available!'.format(loss))
return criterion, to_optim
"""==============================================Smooth-AP========================================"""
def sigmoid(tensor, temp=1.0):
""" temperature controlled sigmoid
takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp
"""
exponent = -tensor / temp
# clamp the input tensor for stability
exponent = torch.clamp(exponent, min=-50, max=50)
y = 1.0 / (1.0 + torch.exp(exponent))
return y
def compute_aff(x):
"""computes the affinity matrix between an input vector and itself"""
return torch.mm(x, x.t())
class BinarizedF(torch.autograd.Function):
def forward(self, inp):
self.save_for_backward(inp)
a = torch.ones_like(inp)
b = torch.zeros_like(inp)
output = torch.where(inp > 0, a, b)
return output
def backward(self, output_grad):
inp, = self.saved_tensors
input_abs = torch.abs(inp)
ones = torch.ones_like(inp)
zeros = torch.zeros_like(inp)
input_grad = torch.where(input_abs > 0, ones, zeros)
return input_grad
class BinarizedModule(torch.nn.Module):
def __init__(self):
super(BinarizedModule, self).__init__()
self.BF = BinarizedF()
def forward(self, inp):
output = self.BF(inp)
return output
class SmoothAP(torch.nn.Module):
"""PyTorch implementation of the Smooth-AP loss.
implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns
the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must
have the same number of instances represented in the mini-batch and must be ordered sequentially by class.
e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like:
labels = ( A, A, A, B, B, B, C, C, C)
(the order of the classes however does not matter)
For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the
mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the
same class. The loss returns the average Smooth-AP across all instances in the mini-batch.
Args:
anneal : float
the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature
results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function.
batch_size : int
the batch size being used during training.
num_id : int
the number of different classes that are represented in the batch.
feat_dims : int
the dimension of the input feature embeddings
Shape:
- Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor)
- Output: scalar
Examples::
>>> loss = SmoothAP(0.01, 60, 6, 256)
>>> input = torch.randn(60, 256, requires_grad=True).cuda()
>>> output = loss(input)
>>> output.backward()
"""
def __init__(self, anneal, batch_size, num_id, feat_dims):
"""
Parameters
----------
anneal : float
the temperature of the sigmoid that is used to smooth the ranking function
batch_size : int
the batch size being used
num_id : int
the number of different classes that are represented in the batch
feat_dims : int
the dimension of the input feature embeddings
"""
super(SmoothAP, self).__init__()
assert(batch_size%num_id==0)
self.anneal = anneal
self.batch_size = batch_size
self.num_id = num_id
self.feat_dims = feat_dims
def forward(self, preds):
"""Forward pass for all input predictions: preds - (batch_size x feat_dims) """
# ------ differentiable ranking of all retrieval set ------
# compute the mask which ignores the relevance score of the query to itself
mask = 1.0 - torch.eye(self.batch_size)
mask = mask.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
# compute the relevance scores via cosine similarity of the CNN-produced embedding vectors
sim_all = compute_aff(preds)
sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, self.batch_size, 1)
# compute the difference matrix
sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1)
# pass through the sigmoid
sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask.cuda()
# compute the rankings
sim_all_rk = torch.sum(sim_sg, dim=-1) + 1
# ------ differentiable ranking of only positive set in retrieval set ------
# compute the mask which only gives non-zero weights to the positive set
xs = preds.view(self.num_id, int(self.batch_size / self.num_id), self.feat_dims)
pos_mask = 1.0 - torch.eye(int(self.batch_size / self.num_id))
pos_mask = pos_mask.unsqueeze(dim=0).unsqueeze(dim=0).repeat(self.num_id, int(self.batch_size / self.num_id), 1, 1)
# compute the relevance scores
sim_pos = torch.bmm(xs, xs.permute(0, 2, 1))
sim_pos_repeat = sim_pos.unsqueeze(dim=2).repeat(1, 1, int(self.batch_size / self.num_id), 1)
# compute the difference matrix
sim_pos_diff = sim_pos_repeat - sim_pos_repeat.permute(0, 1, 3, 2)
# pass through the sigmoid
sim_pos_sg = sigmoid(sim_pos_diff, temp=self.anneal) * pos_mask.cuda()
# compute the rankings of the positive set
sim_pos_rk = torch.sum(sim_pos_sg, dim=-1) + 1
# sum the values of the Smooth-AP for all instances in the mini-batch
ap = torch.zeros(1).cuda()
group = int(self.batch_size / self.num_id)
for ind in range(self.num_id):
pos_divide = torch.sum(sim_pos_rk[ind] / (sim_all_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)]))
ap = ap + ((pos_divide / group) / self.batch_size)
return (1 - ap)
# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines
"""to do:
clean all of the files - particularly the main.py and also the losses and dataset files and the file for doing the dataloading
-- fast loading etc
need to change all of the copyrights at the top of all of the files
"""
#################### LIBRARIES ########################
import warnings
warnings.filterwarnings("ignore")
import os, numpy as np, argparse, random, matplotlib, datetime
os.chdir(os.path.dirname(os.path.realpath(__file__)))
from pathlib import Path
matplotlib.use('agg')
from tqdm import tqdm
import auxiliaries as aux
import datasets as data
import netlib as netlib
import losses as losses
import evaluate as eval
from tensorboardX import SummaryWriter
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
################### INPUT ARGUMENTS ###################
parser = argparse.ArgumentParser()
####### Main Parameter: Dataset to use for Training
parser.add_argument('--dataset', default='vehicle_id', type=str, help='Dataset to use.',
choices=['SoftInaturalist', 'Inaturalist', 'vehicle_id', 'semi_fungi'])
### General Training Parameters
parser.add_argument('--lr', default=0.00001, type=float, help='Learning Rate for network parameters.')
parser.add_argument('--fc_lr_mul', default=5, type=float,
help='OPTIONAL: Multiply the embedding layer learning rate by this value. If set to 0, the embedding layer shares the same learning rate.')
parser.add_argument('--n_epochs', default=400, type=int, help='Number of training epochs.')
parser.add_argument('--kernels', default=8, type=int, help='Number of workers for pytorch dataloader.')
parser.add_argument('--bs', default=112, type=int, help='Mini-Batchsize to use.')
parser.add_argument('--samples_per_class', default=4, type=int,
help='Number of samples in one class drawn before choosing the next class')
parser.add_argument('--seed', default=1, type=int, help='Random seed for reproducibility.')
parser.add_argument('--scheduler', default='step', type=str,
help='Type of learning rate scheduling. Currently: step & exp.')
parser.add_argument('--gamma', default=0.3, type=float, help='Learning rate reduction after tau epochs.')
parser.add_argument('--decay', default=0.0004, type=float, help='Weight decay for optimizer.')
parser.add_argument('--tau', default=[200, 300], nargs='+', type=int, help='Stepsize(s) before reducing learning rate.')
parser.add_argument('--infrequent_eval', default=0, type=int, help='only compute evaluation metrics every 10 epochs')
parser.add_argument('--opt', default='adam', help='adam or sgd')
##### Loss-specific Settings
parser.add_argument('--loss', default='smoothap', type=str)
parser.add_argument('--sigmoid_temperature', default=0.01, type=float,
help='SmoothAP: the temperature of the sigmoid used in SmoothAP loss')
##### Evaluation Settings
parser.add_argument('--k_vals', nargs='+', default=[1, 2, 4, 8], type=int, help='Recall @ Values.')
parser.add_argument('--resume', default='', type=str,
help='path to checkpoint to load weights from (if empty then ImageNet pre-trained weights are loaded')
##### Network parameters
parser.add_argument('--embed_dim', default=512, type=int, help='Embedding dimensionality of the network')
parser.add_argument('--arch', default='resnet50', type=str,
help='Network backend choice: resnet50')
parser.add_argument('--pretrained_weights', default="", type=str, help='pretrained weight path')
parser.add_argument('--use_bn_in_head', default=False, type=aux.bool_flag,
help="Whether to use batch normalizations in projection head (Default: False)")
parser.add_argument("--checkpoint_key", default="teacher", type=str,
help='Key to use in the checkpoint (example: "teacher")')
parser.add_argument('--drop_path_rate', default=0.1, type=float, help="stochastic depth rate")
parser.add_argument('--grad_measure', action='store_true',
help='If added, gradients passed from embedding layer to the last conv-layer are stored in each iteration.')
parser.add_argument('--dist_measure', action='store_true',
help='If added, the ratio between intra- and interclass distances is stored after each epoch.')
parser.add_argument('--not_pretrained', action='store_true',
help='If added, the network will be trained WITHOUT ImageNet-pretrained weights.')
##### Setup Parameters
parser.add_argument('--gpu', default=0, type=int, help='GPU-id for GPU to use.')
parser.add_argument('--savename', default='', type=str,
help='Save folder name if any special information is to be included.')
### Paths to datasets and storage folder
parser.add_argument('--source_path', default='/scratch/shared/beegfs/abrown/datasets', type=str, help='Path to data')
parser.add_argument('--save_path', default=os.getcwd() + '/Training_Results', type=str,
help='Where to save the checkpoints')
### additional parameters
parser.add_argument('--trainset', default="lin_train_set1.txt", type=str)
parser.add_argument('--testset', default="Inaturalist_test_set1.txt", type=str)
parser.add_argument('--cluster_path', default="", type=str)
parser.add_argument('--finetune', default="false", type=str)
parser.add_argument('--class_num', default=948, type=int)
parser.add_argument('--get_features', default="false", type=str)
parser.add_argument('--linsize', default=29011, type=int, help="Lin data size.")
parser.add_argument('--uinsize', default=18403, type=int, help="Uin data size.")
parser.add_argument('--iter', default=0, type=int)
opt = parser.parse_args()
"""============================================================================"""
if opt.dataset == "SoftInaturalist":
opt.source_path += '/Inaturalist'
opt.save_path += '/Inaturalist' + "_" + str(opt.embed_dim)
else:
opt.source_path += '/' + opt.dataset
opt.save_path += '/' + opt.dataset + "_" + str(opt.embed_dim)
if opt.dataset == 'Inaturalist':
# opt.n_epochs = 90
opt.tau = [40, 70]
opt.k_vals = [1, 4, 16, 32]
if opt.dataset == 'SoftInaturalist':
# opt.n_epochs = 90
opt.tau = [40, 70]
opt.k_vals = [1, 4, 16, 32]
if opt.dataset == 'vehicle_id':
opt.k_vals = [1, 5]
if opt.dataset == 'semi_fungi':
opt.tau = [40, 70]
opt.k_vals = [1, 4, 16, 32]
if opt.finetune == 'true':
opt.finetune = True
elif opt.finetune == 'false':
opt.finetune = False
if opt.get_features == 'true':
opt.get_features = True
elif opt.get_features == 'false':
opt.get_features = False
"""==========================================================================="""
################### TensorBoard Settings ##################
timestamp = datetime.datetime.now().strftime(r"%Y-%m-%d_%H-%M-%S")
exp_name = aux.args2exp_name(opt)
opt.save_name = f"weights_{exp_name}" + '/' + timestamp
random.seed(opt.seed)
np.random.seed(opt.seed)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed);
torch.cuda.manual_seed_all(opt.seed)
tensorboard_path = Path(f"logs/logs_{exp_name}") / timestamp
tensorboard_path.parent.mkdir(exist_ok=True, parents=True)
global writer;
writer = SummaryWriter(tensorboard_path)
"""============================================================================"""
################### GPU SETTINGS ###########################
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]= str(opt.gpu)
print('using #GPUs:', torch.cuda.device_count())
"""============================================================================"""
##################### NETWORK SETUP ##################
opt.device = torch.device('cuda')
model = netlib.networkselect(opt)
# Push to Device
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
_ = model.to(opt.device)
# Place trainable parameter in list of parameters to train:
if 'fc_lr_mul' in vars(opt).keys() and opt.fc_lr_mul != 0:
all_but_fc_params = list(filter(lambda x: 'last_linear' not in x[0], model.named_parameters()))
for ind, param in enumerate(all_but_fc_params):
all_but_fc_params[ind] = param[1]
if torch.cuda.device_count() > 1:
fc_params = model.module.model.last_linear.parameters()
else:
fc_params = model.model.last_linear.parameters()
to_optim = [{'params': all_but_fc_params, 'lr': opt.lr, 'weight_decay': opt.decay},
{'params': fc_params, 'lr': opt.lr * opt.fc_lr_mul, 'weight_decay': opt.decay}]
else:
to_optim = [{'params': model.parameters(), 'lr': opt.lr, 'weight_decay': opt.decay}]
"""============================================================================"""
#################### DATALOADER SETUPS ##################
# Returns a dictionary containing 'training', 'testing', and 'evaluation' dataloaders.
# The 'testing'-dataloader corresponds to the validation set, and the 'evaluation'-dataloader
# Is simply using the training set, however running under the same rules as 'testing' dataloader,
# i.e. no shuffling and no random cropping.
dataloaders = data.give_dataloaders(opt.dataset, opt.trainset, opt.testset, opt)
# Because the number of supervised classes is dataset dependent, we store them after
# initializing the dataloader
opt.num_classes = len(dataloaders['training'].dataset.avail_classes)
"""============================================================================"""
#################### CREATE LOGGING FILES ###############
# Each dataset usually has a set of standard metrics to log. aux.metrics_to_examine()
# returns a dict which lists metrics to log for training ('train') and validation/testing ('val')
metrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals)
# example output: {'train': ['Epochs', 'Time', 'Train Loss', 'Time'],
# 'val': ['Epochs','Time','NMI','F1', 'Recall @ 1','Recall @ 2','Recall @ 4','Recall @ 8']}
# Using the provided metrics of interest, we generate a LOGGER instance.
# Note that 'start_new' denotes that a new folder should be made in which everything will be stored.
# This includes network weights as well.
LOG = aux.LOGGER(opt, metrics_to_log, name='Base', start_new=True)
# If graphviz is installed on the system, a computational graph of the underlying
# network will be made as well.
"""============================================================================"""
#################### LOSS SETUP ####################
# Depending on opt.loss and opt.sampling, the respective criterion is returned,
# and if the loss has trainable parameters, to_optim is appended.
criterion, to_optim = losses.loss_select(opt.loss, opt, to_optim)
_ = criterion.to(opt.device)
"""============================================================================"""
##################### OPTIONAL EVALUATIONS #####################
# Store the averaged gradients returned from the embedding to the last conv. layer.
if opt.grad_measure:
grad_measure = eval.GradientMeasure(opt, name='baseline')
# Store the relative distances between average intra- and inter-class distance.
if opt.dist_measure:
# Add a distance measure for training distance ratios
distance_measure = eval.DistanceMeasure(dataloaders['evaluation'], opt, name='Train', update_epochs=1)
# #If uncommented: Do the same for the test set
# distance_measure_test = eval.DistanceMeasure(dataloaders['testing'], opt, name='Train', update_epochs=1)
"""============================================================================"""
#################### OPTIM SETUP ####################
# As optimizer, Adam with standard parameters is used.
if opt.opt == 'adam':
optimizer = torch.optim.Adam(to_optim)
elif opt.opt == 'sgd':
optimizer = torch.optim.SGD(to_optim)
else:
raise Exception('unknown optimiser')
# for the SOA measures in the paper - need to use SGD and 0.05 learning rate
# optimizer = torch.optim.Adam(to_optim)
# optimizer = torch.optim.SGD(to_optim)
if opt.scheduler == 'exp':
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=opt.gamma)
elif opt.scheduler == 'step':
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.tau, gamma=opt.gamma)
elif opt.scheduler == 'none':
print('No scheduling used!')
else:
raise Exception('No scheduling option for input: {}'.format(opt.scheduler))
def same_model(model1, model2):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
if p1.data.ne(p2.data).sum() > 0:
return False
return True
"""============================================================================"""
#################### TRAINER FUNCTION ############################
def train_one_epoch(train_dataloader, model, optimizer, criterion, opt, epoch):
"""
This function is called every epoch to perform training of the network over one full
(randomized) iteration of the dataset.
Args:
train_dataloader: torch.utils.data.DataLoader, returns (augmented) training data.
model: Network to train.
optimizer: Optimizer to use for training.
criterion: criterion to use during training.
opt: argparse.Namespace, Contains all relevant parameters.
epoch: int, Current epoch.
Returns:
Nothing!
"""
loss_collect = []
start = time.time()
data_iterator = tqdm(train_dataloader, desc='Epoch {} Training...'.format(epoch))
for i, (class_labels, input) in enumerate(data_iterator):
# Compute embeddings for input batch
features = model(input.to(opt.device))
# Compute loss.
if opt.loss != 'smoothap':
loss = criterion(features, class_labels)
else:
loss = criterion(features)
# Ensure gradients are set to zero at beginning
optimizer.zero_grad()
# Compute gradient
loss.backward()
train_dataloader.dataset.classes_visited = []
if opt.grad_measure:
# If desired, save computed gradients.
grad_measure.include(model.model.last_linear)
# Update weights using comp. gradients.
optimizer.step()
# Store loss per iteration.
loss_collect.append(loss.item())
if i == len(train_dataloader) - 1:
data_iterator.set_description('Epoch (Train) {0}: Mean Loss [{1:.4f}]'.format(epoch, np.mean(loss_collect)))
# Save metrics
LOG.log('train', LOG.metrics_to_log['train'], [epoch, np.round(time.time() - start, 4), np.mean(loss_collect)])
writer.add_scalar('global/training_loss', np.mean(loss_collect), epoch)
if opt.grad_measure:
# Dump stored gradients to Pickle-File.
grad_measure.dump(epoch)
"""============================================================================"""
"""========================== MAIN TRAINING PART =============================="""
"""============================================================================"""
################### SCRIPT MAIN ##########################
print('\n-----\n')
# Each dataset requires slightly different dataloaders.
if opt.dataset == 'SoftInaturalist' or 'Inaturalist' or 'semi_fungi':
eval_params = {'dataloader': dataloaders['testing'], 'model': model, 'opt': opt, 'epoch': 0}
elif opt.dataset == 'vehicle_id':
eval_params = {
'dataloaders': [dataloaders['testing_set1'], dataloaders['testing_set2'], dataloaders['testing_set3']],
'model': model, 'opt': opt, 'epoch': 0}
# Compute Evaluation metrics, print them and store in LOG.
print('epochs -> ' + str(opt.n_epochs))
import time
for epoch in range(opt.n_epochs):
### Print current learning rates for all parameters
if opt.scheduler != 'none': print(
'Running with learning rates {}...'.format(' | '.join('{}'.format(x) for x in scheduler.get_lr())))
### Train one epoch
_ = model.train()
train_one_epoch(dataloaders['training'], model, optimizer, criterion, opt, epoch)
dataloaders['training'].dataset.reshuffle()
### Evaluate
_ = model.eval()
# Each dataset requires slightly different dataloaders.
if opt.dataset == 'Inaturalist':
eval_params = {'dataloader': dataloaders['evaluation'], 'model': model, 'opt': opt, 'epoch': epoch}
elif opt.dataset == 'vehicle_id':
eval_params = {
'dataloaders': [dataloaders['testing_set1'], dataloaders['testing_set2'], dataloaders['testing_set3']],
'model': model, 'opt': opt, 'epoch': epoch}
elif opt.dataset == 'semi_fungi':
eval_params = {'dataloader': dataloaders['testing'], 'model': model, 'opt': opt, 'epoch': epoch}
# Compute Evaluation metrics, print them and store in LOG.
if opt.infrequent_eval == 1:
epoch_freq = 5
else:
epoch_freq = 1
if not opt.dataset == 'vehicle_id':
if epoch % epoch_freq == 0:
results = eval.evaluate(opt.dataset, LOG, save=True, **eval_params)
writer.add_scalar('global/recall1', results[0][0], epoch + 1)
writer.add_scalar('global/recall2', results[0][1], epoch + 1)
writer.add_scalar('global/recall3', results[0][2], epoch + 1)
writer.add_scalar('global/recall4', results[0][3], epoch + 1)
writer.add_scalar('global/NMI', results[1], epoch + 1)
writer.add_scalar('global/F1', results[2], epoch + 1)
else:
results = eval.evaluate(opt.dataset, LOG, save=True, **eval_params)
writer.add_scalar('global/recall1', results[2], epoch + 1)
writer.add_scalar('global/recall2', results[3],
epoch + 1) # writer.add_scalar('global/recall3',results[0][2],0)
writer.add_scalar('global/recall3', results[6], epoch + 1)
writer.add_scalar('global/recall4', results[7], epoch + 1)
writer.add_scalar('global/recall5', results[10], epoch + 1)
writer.add_scalar('global/recall6', results[11], epoch + 1)
# Update the Metric Plot and save it.
# LOG.update_info_plot()
# (optional) compute ratio of intra- to interdistances.
if opt.dist_measure:
distance_measure.measure(model, epoch)
# distance_measure_test.measure(model, epoch)
### Learning Rate Scheduling Step
if opt.scheduler != 'none':
scheduler.step()
print('\n-----\n')
# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines
############################ LIBRARIES ######################################
from collections import OrderedDict
import os
import torch
import torch.nn as nn
import pretrainedmodels as ptm
import auxiliaries as aux
"""============================================================="""
def initialize_weights(model):
"""
Function to initialize network weights.
NOTE: NOT USED IN MAIN SCRIPT.
Args:
model: PyTorch Network
Returns:
Nothing!
"""
for idx, module in enumerate(model.modules()):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(module, nn.BatchNorm2d):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(0, 0.01)
module.bias.data.zero_()
"""=================================================================================================================================="""
### ATTRIBUTE CHANGE HELPER
def rename_attr(model, attr, name):
"""
Rename attribute in a class. Simply helper function.
Args:
model: General Class for which attributes should be renamed.
attr: str, Name of target attribute.
name: str, New attribute name.
"""
setattr(model, name, getattr(model, attr))
delattr(model, attr)
"""=================================================================================================================================="""
### NETWORK SELECTION FUNCTION
def networkselect(opt):
"""
Selection function for available networks.
Args:
opt: argparse.Namespace, contains all training-specific training parameters.
Returns:
Network of choice
"""
if opt.arch == 'resnet50':
network = ResNet50(opt)
else:
raise Exception('Network {} not available!'.format(opt.arch))
if opt.resume:
weights = torch.load(os.path.join(opt.save_path, opt.resume))
weights_state_dict = weights['state_dict']
if torch.cuda.device_count() > 1:
encoder_state_dict = OrderedDict()
for k, v in weights_state_dict.items():
k = k.replace('module.', '')
encoder_state_dict[k] = v
network.load_state_dict(encoder_state_dict)
else:
network.load_state_dict(weights_state_dict)
# print("=================== network =======================")
# for parameter in network.parameters():
# parameter.requires_grad = False
# for parameter in network.layer_blocks[-1].parameters():
# parameter.requires_grad = True
return network
"""============================================================="""
class ResNet50(nn.Module):
"""
Container for ResNet50 s.t. it can be used for metric learning.
The Network has been broken down to allow for higher modularity, if one wishes
to target specific layers/blocks directly.
"""
def __init__(self, opt, list_style=False, no_norm=False):
super(ResNet50, self).__init__()
self.pars = opt
if not opt.not_pretrained:
print('Getting pretrained weights...')
self.model = ptm.__dict__['resnet50'](num_classes=1000, pretrained='imagenet')
print('Done.')
else:
print('Not utilizing pretrained weights!')
self.model = ptm.__dict__['resnet50'](num_classes=1000, pretrained=None)
for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()):
module.eval()
module.train = lambda _: None
if opt.embed_dim != 2048:
self.model.last_linear = torch.nn.Linear(self.model.last_linear.in_features, opt.embed_dim)
self.layer_blocks = nn.ModuleList([self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4])
self.loss = opt.loss
self.feature = True
def forward(self, x, feature=False, is_init_cluster_generation=False):
x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x))))
for layerblock in self.layer_blocks:
x = layerblock(x)
x = self.model.avgpool(x)
x = x.view(x.size(0), -1)
if self.pars.embed_dim != 2048:
mod_x = self.model.last_linear(x)
else:
mod_x = x
feat = torch.nn.functional.normalize(mod_x, dim=-1)
if feature or self.loss == 'smoothap':
return feat
else:
pred = self.linear(feat)
return pred
python Smooth_AP/src/evaluate_model.py \
--dataset Inaturalist \
--bs 384 \
--source_path ~/code/Smooth_AP/data/ --embed_dim 128 \
--resume $CHECKPOINT_PATH \
--class_num 948 --loss smoothap \
--trainset lin_train_set1.txt \
--testset Inaturalist_test_set1.txt \
--linsize 29011 --uinsize 18403
\ No newline at end of file
import argparse, time, os, pickle
import random
import sys
sys.path.append("..")
from utils.deduce import get_edge_dist
import numpy as np
import shutil
import dgl
import torch
import torch.optim as optim
from models import LANDER
from dataset import LanderDataset
from utils import evaluation, decode, build_next_level, stop_iterating
from matplotlib import pyplot as plt
import seaborn
STATISTIC = False
###########
# ArgParser
parser = argparse.ArgumentParser()
# Dataset
parser.add_argument('--data_path', type=str, required=True)
parser.add_argument('--model_filename', type=str, default='lander.pth')
parser.add_argument('--faiss_gpu', action='store_true')
parser.add_argument('--num_workers', type=int, default=0)
parser.add_argument('--output_filename', type=str, default='data/features.pkl')
# HyperParam
parser.add_argument('--knn_k', type=int, default=10)
parser.add_argument('--levels', type=int, default=1)
parser.add_argument('--tau', type=float, default=0.5)
parser.add_argument('--threshold', type=str, default='prob')
parser.add_argument('--metrics', type=str, default='pairwise,bcubed,nmi')
parser.add_argument('--early_stop', action='store_true')
# Model
parser.add_argument('--hidden', type=int, default=512)
parser.add_argument('--num_conv', type=int, default=4)
parser.add_argument('--dropout', type=float, default=0.)
parser.add_argument('--gat', action='store_true')
parser.add_argument('--gat_k', type=int, default=1)
parser.add_argument('--balance', action='store_true')
parser.add_argument('--use_cluster_feat', action='store_true')
parser.add_argument('--use_focal_loss', action='store_true')
parser.add_argument('--use_gt', action='store_true')
# Subgraph
parser.add_argument('--batch_size', type=int, default=4096)
parser.add_argument('--mode', type=str, default="1head")
parser.add_argument('--midpoint', type=str, default="false")
parser.add_argument('--linsize', type=int, default=29011)
parser.add_argument('--uinsize', type=int, default=18403)
parser.add_argument('--inclasses', type=int, default=948)
parser.add_argument('--thresh', type=float, default=1.0)
parser.add_argument('--draw', type=str, default='false')
parser.add_argument('--density_distance_pkl', type=str, default="density_distance.pkl")
parser.add_argument('--density_lindistance_jpg', type=str, default="density_lindistance.jpg")
args = parser.parse_args()
print(args)
MODE = args.mode
linsize = args.linsize
uinsize = args.uinsize
inclasses = args.inclasses
if args.draw == 'false':
args.draw = False
elif args.draw == 'true':
args.draw = True
###########################
# Environment Configuration
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
##################
# Data Preparation
with open(args.data_path, 'rb') as f:
loaded_data = pickle.load(f)
path2idx, features, pred_labels, labels, masks = loaded_data
idx2path = {v: k for k, v in path2idx.items()}
gtlabels = labels
orifeatures = features
orilabels = gtlabels
if MODE == "selectbydensity":
lastusim = np.where(masks == 1)
masks[lastusim] = 2
selectedidx = np.where(masks != 0)
features = features[selectedidx]
labels = gtlabels[selectedidx]
selectmasks = masks[selectedidx]
print("filtered features:", len(features))
print("mask0:", len(np.where(masks == 0)[0]))
print("mask1:", len(np.where(masks == 1)[0]))
print("mask2:", len(np.where(masks == 2)[0]))
elif MODE == "recluster":
selectedidx = np.where(masks == 1)
features = features[selectedidx]
labels = gtlabels[selectedidx]
labelspred = pred_labels[selectedidx]
selectmasks = masks[selectedidx]
gtlabels = gtlabels[selectedidx]
print("filtered features:", len(features))
else:
selectedidx = np.where(masks != 0)
features = features[selectedidx]
labels = gtlabels[selectedidx]
labelspred = pred_labels[selectedidx]
selectmasks = masks[selectedidx]
gtlabels = gtlabels[selectedidx]
print("filtered features:", len(features))
global_features = features.copy() # global features
dataset = LanderDataset(features=features, labels=labels, k=args.knn_k,
levels=1, faiss_gpu=False)
g = dataset.gs[0]
g.ndata['pred_den'] = torch.zeros((g.number_of_nodes()))
g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2))
global_labels = labels.copy()
ids = np.arange(g.number_of_nodes())
global_edges = ([], [])
global_peaks = np.array([], dtype=np.long)
global_edges_len = len(global_edges[0])
global_num_nodes = g.number_of_nodes()
global_densities = g.ndata['density'][:linsize]
global_densities = np.sort(global_densities)
xs = np.arange(len(global_densities))
fanouts = [args.knn_k - 1 for i in range(args.num_conv + 1)]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
# fix the number of edges
test_loader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.number_of_nodes()), sampler,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
num_workers=args.num_workers
)
##################
# Model Definition
if not args.use_gt:
feature_dim = g.ndata['features'].shape[1]
model = LANDER(feature_dim=feature_dim, nhid=args.hidden,
num_conv=args.num_conv, dropout=args.dropout,
use_GAT=args.gat, K=args.gat_k,
balance=args.balance,
use_cluster_feat=args.use_cluster_feat,
use_focal_loss=args.use_focal_loss)
model.load_state_dict(torch.load(args.model_filename))
model = model.to(device)
model.eval()
# number of edges added is the indicator for early stopping
num_edges_add_last_level = np.Inf
##################################
# Predict connectivity and density
for level in range(args.levels):
print("level:", level)
if not args.use_gt:
total_batches = len(test_loader)
for batch, minibatch in enumerate(test_loader):
input_nodes, sub_g, bipartites = minibatch
sub_g = sub_g.to(device)
bipartites = [b.to(device) for b in bipartites]
with torch.no_grad():
output_bipartite = model(bipartites)
global_nid = output_bipartite.dstdata[dgl.NID]
global_eid = output_bipartite.edata['global_eid']
g.ndata['pred_den'][global_nid] = output_bipartite.dstdata['pred_den'].to('cpu')
g.edata['prob_conn'][global_eid] = output_bipartite.edata['prob_conn'].to('cpu')
torch.cuda.empty_cache()
if (batch + 1) % 10 == 0:
print('Batch %d / %d for inference' % (batch, total_batches))
new_pred_labels, peaks, \
global_edges, global_pred_labels, global_peaks = decode(g, args.tau, args.threshold, args.use_gt,
ids, global_edges, global_num_nodes,
global_peaks)
if level == 0:
global_pred_densities = g.ndata['pred_den']
global_densities = g.ndata['density']
g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2))
ids = ids[peaks]
new_global_edges_len = len(global_edges[0])
num_edges_add_this_level = new_global_edges_len - global_edges_len
if stop_iterating(level, args.levels, args.early_stop, num_edges_add_this_level, num_edges_add_last_level,
args.knn_k):
break
global_edges_len = new_global_edges_len
num_edges_add_last_level = num_edges_add_this_level
# build new dataset
features, labels, cluster_features = build_next_level(features, labels, peaks,
global_features, global_pred_labels, global_peaks)
# After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.
dataset = LanderDataset(features=features, labels=labels, k=args.knn_k,
levels=1, faiss_gpu=False, cluster_features=cluster_features)
g = dataset.gs[0]
g.ndata['pred_den'] = torch.zeros((g.number_of_nodes()))
g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2))
test_loader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.number_of_nodes()), sampler,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
num_workers=args.num_workers
)
if MODE == "selectbydensity":
thresh = args.thresh
global_pred_densities = np.array(global_pred_densities).astype(float)
global_densities = np.array(global_densities).astype(float)
distance = np.abs(global_pred_densities - global_densities)
print("densities shape", global_pred_densities.shape)
print(global_pred_densities.max(), global_pred_densities.min())
selectidx = np.where(global_pred_densities > thresh)[0]
selected_pred_densities = global_pred_densities[selectidx]
selected_densities = global_densities[selectidx]
selected_distance = np.abs(selected_pred_densities - selected_densities)
print(np.mean(selected_distance))
print("number of selected samples:", len(selectidx))
notselectidx = np.where(global_pred_densities <= thresh)
print("not selected:", len(notselectidx[0]))
global_pred_labels[notselectidx] = -1
global_pred_labels_new = np.zeros_like(orilabels)
global_pred_labels_new[:] = -1
Tidx = np.where(masks != 2)
print("T:", len(Tidx[0]))
l_in_gt = orilabels[Tidx]
l_in_features = orifeatures[Tidx]
l_in_gt_new = np.zeros_like(l_in_gt)
l_in_unique = np.unique(l_in_gt)
for i in range(len(l_in_unique)):
l_in = l_in_unique[i]
l_in_idx = np.where(l_in_gt == l_in)
l_in_gt_new[l_in_idx] = i
print("len(l_in_unique)", len(l_in_unique))
if args.draw:
prototypes = np.zeros((len(l_in_unique), features.shape[1]))
for i in range(len(l_in_unique)):
idx = np.where(l_in_gt_new == i)
prototypes[i] = np.mean(l_in_features[idx], axis=0)
similarity_matrix = torch.mm(torch.from_numpy(global_features.astype(np.float32)),
torch.from_numpy(prototypes.astype(np.float32)).t())
similarity_matrix = (1 - similarity_matrix) / 2
minvalues, selected_pred_labels = torch.min(similarity_matrix, 1)
# far-close ratio
closeidx = np.where(minvalues < 0.15)
faridx = np.where(minvalues >= 0.15)
print("far:", len(faridx[0]))
print("close:", len(closeidx[0]))
cutidx = np.where(global_pred_densities >= 0.5)
draw_minvalues = minvalues[cutidx]
draw_densities = global_pred_densities[cutidx]
with open(args.density_distance_pkl, 'wb') as f:
pickle.dump((global_pred_densities, minvalues), f)
print("dumped.")
plt.clf()
fig, ax = plt.subplots()
import random
if len(draw_densities) > 10000:
samples_idx = random.sample(range(len(draw_minvalues)), 10000)
ax.plot(draw_densities[random], draw_minvalues[random], color='tab:blue', marker='*', linestyle="None",
markersize=1)
else:
ax.plot(draw_densities[random], draw_minvalues[random], color='tab:blue', marker='*', linestyle="None",
markersize=1)
plt.savefig(args.density_lindistance_jpg)
global_pred_labels_new[Tidx] = l_in_gt_new
global_pred_labels[selectidx] = global_pred_labels[selectidx] + len(l_in_unique)
global_pred_labels_new[selectedidx] = global_pred_labels
global_pred_labels = global_pred_labels_new
linunique = np.unique(global_pred_labels[Tidx])
uunique = np.unique(global_pred_labels[selectedidx])
allnique = np.unique(global_pred_labels)
print("labels")
print(len(linunique), len(uunique), len(allnique))
global_masks = np.zeros_like(masks)
global_masks[:] = 1
global_masks[np.array(selectedidx[0])[notselectidx]] = 2
Tidx = np.where(masks != 2)
global_masks[Tidx] = 0
print("mask0", len(np.where(global_masks == 0)[0]))
print("mask1", len(np.where(global_masks == 1)[0]))
print("mask2", len(np.where(global_masks == 2)[0]))
print("all", len(masks), len(orilabels), len(orifeatures))
global_gt_labels = orilabels
if MODE == "recluster":
global_pred_labels_new = np.zeros_like(orilabels)
global_pred_labels_new[:] = -1
Tidx = np.where(masks == 0)
print("T:", len(Tidx[0]))
l_in_gt = orilabels[Tidx]
l_in_features = orifeatures[Tidx]
l_in_gt_new = np.zeros_like(l_in_gt)
l_in_unique = np.unique(l_in_gt)
for i in range(len(l_in_unique)):
l_in = l_in_unique[i]
l_in_idx = np.where(l_in_gt == l_in)
l_in_gt_new[l_in_idx] = i
print("len(l_in_unique)", len(l_in_unique))
global_pred_labels_new[Tidx] = l_in_gt_new
print(len(global_pred_labels))
print(len(selectedidx[0]))
global_pred_labels_new[selectedidx[0]] = global_pred_labels + len(l_in_unique)
global_pred_labels = global_pred_labels_new
global_masks = masks
print("mask0", len(np.where(global_masks == 0)[0]))
print("mask1", len(np.where(global_masks == 1)[0]))
print("mask2", len(np.where(global_masks == 2)[0]))
print("all", len(masks), len(orilabels), len(orifeatures))
global_gt_labels = orilabels
if MODE == "donothing":
global_masks = masks
pass
print("##################### L_in ########################")
print(linsize)
if len(global_pred_labels) >= linsize:
evaluation(global_pred_labels[:linsize], global_gt_labels[:linsize], args.metrics)
else:
print("No samples in L_in!")
print("##################### U_in ########################")
uinidx = np.where(global_pred_labels[linsize:linsize + uinsize] != -1)[0]
uinidx = uinidx + linsize
print(len(uinidx))
if len(uinidx):
evaluation(global_pred_labels[uinidx], global_gt_labels[uinidx], args.metrics)
else:
print("No samples in U_in!")
print("##################### U_out ########################")
uoutidx = np.where(global_pred_labels[linsize + uinsize:] != -1)[0]
uoutidx = uoutidx + linsize + uinsize
print(len(uoutidx))
if len(uoutidx):
evaluation(global_pred_labels[uoutidx], global_gt_labels[uoutidx], args.metrics)
else:
print("No samples in U_out!")
print("##################### U ########################")
uidx = np.where(global_pred_labels[linsize:] != -1)[0]
uidx = uidx + linsize
print(len(uidx))
if len(uidx):
evaluation(global_pred_labels[uidx], global_gt_labels[uidx], args.metrics)
else:
print("No samples in U!")
print("##################### L+U ########################")
luidx = np.where(global_pred_labels != -1)[0]
print(len(luidx))
evaluation(global_pred_labels[luidx], global_gt_labels[luidx], args.metrics)
print("##################### new selected samples ########################")
sidx = np.where(global_masks == 1)[0]
print(len(sidx))
if len(sidx) != 0:
evaluation(global_pred_labels[sidx], global_gt_labels[sidx], args.metrics)
print("##################### not selected samples ########################")
nsidx = np.where(global_masks == 2)[0]
print(len(nsidx))
if len(nsidx) != 0:
evaluation(global_pred_labels[nsidx], global_gt_labels[nsidx], args.metrics)
with open(args.output_filename, 'wb') as f:
print(orifeatures.shape)
print(global_pred_labels.shape)
print(global_gt_labels.shape)
print(global_masks.shape)
pickle.dump([path2idx, orifeatures, global_pred_labels, global_gt_labels, global_masks], f)
#!/bin/bash
mkdir hilander_checkpoint
####################### ITER 0 #######################
# iter 0 (supervised baseline) - train Smooth-AP
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python Smooth_AP/src/main.py \
--dataset Inaturalist --lr 1e-5 --fc_lr_mul 1 \
--n_epochs 400 --bs 384 \
--source_path "../../data/" --embed_dim 128 \
--class_num 948 --loss smoothap --infrequent_eval 1 \
--trainset lin_train_set1.txt --testset Inaturalist_test_set1.txt
# iter 0 (supervised baseline) - get feature
python Smooth_AP/src/get_features.py \
--dataset Inaturalist --lr 1e-5 --fc_lr_mul 1 \
--n_epochs 400 --bs 384 \
--source_path "../../data/" --embed_dim 128 \
--resume "0/checkpoint_0.pth.tar" \
--finetune false --get_features true --iter 0 \
--class_num 948 --loss smoothap \
--trainset lin_train_set1.txt \
--all_trainset train_set1.txt \
--testset Inaturalist_test_set1.txt \
--linsize 29011
# iter 0 (supervised baseline) - train hi-lander
python train_subg_inat.py \
--data_path "/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/data/Inaturalist/T_train_iter0_smoothap_inat_features.pkl" \
--model_filename '/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/hilander_checkpoint/inat_l_smoothap_iter0.pth' \
--knn_k 10,5,3 --levels 2,3,4 \
--hidden 512 --epochs 1000 --lr 0.01 \
--batch_size 4096 --num_conv 1 --gat --balance
# iter 0 (supervised baseline) - get pseudo labels
python test_subg_inat.py \
--data_path '/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/data/Inaturalist/all_train_iter0_smoothap_inat_features.pkl' \
--model_filename '/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/hilander_checkpoint/inat_l_smoothap_iter0.pth' --knn_k 10 \
--tau 0.9 --level 10 --threshold prob \
--hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop \
--mode selectbydensity --thresh 0.8 \
--linsize 29011 --uinsize 18403 --inclasses 948 \
--output_filename 'data/inat_hilander_l_smoothap_train_selectbydensity_iter0.pkl'
for i in {1..4} ; do
last_iter=`expr $i - 1`
echo ${last_iter}
# iter i - train Smooth-AP
python Smooth_AP/src/finetune_1head.py \
--dataset Inaturalist --lr 1e-5 --fc_lr_mul 1 \
--n_epochs 400 --bs 384 --class_num 1024 \
--source_path "../../data/" --embed_dim 128 \
--trainset lin_train_set1.txt --testset Inaturalist_test_set1.txt \
--cluster_path "../../data/inat_hilander_l_smoothap_train_selectbydensity_iter${last_iter}.pkl" \
--finetune true --loss smoothap --infrequent_eval 1 --iter ${i}
# iter i - get feature
python Smooth_AP/src/get_features.py \
--dataset Inaturalist --lr 1e-5 --fc_lr_mul 1 \
--n_epochs 400 --bs 384 \
--source_path "../../data/" --embed_dim 128 \
--resume "${i}/checkpoint_${i}.pth.tar" \
--finetune false --get_features true --iter ${i} \
--class_num 948 --loss smoothap \
--trainset lin_train_set1.txt \
--all_trainset train_set1.txt \
--testset Inaturalist_test_set1.txt \
--linsize 29011 --uinsize 18403 \
--cluster_path "../../data/inat_hilander_l_smoothap_train_selectbydensity_iter${last_iter}.pkl"
# iter i - train hi-lander
python train_subg_inat.py \
--data_path "/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/data/Inaturalist/T_train_iter${i}_smoothap_inat_features.pkl" \
--model_filename "/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/hilander_checkpoint/inat_l_smoothap_iter${i}.pth" \
--knn_k 10,5,3 --levels 2,3,4 \
--hidden 512 --epochs 1000 --lr 0.01 \
--batch_size 4096 --num_conv 1 --gat --balance
# iter i - get pseudo labels
python test_subg_inat.py \
--data_path "/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/data/Inaturalist/all_train_iter${i}_smoothap_inat_features.pkl" \
--model_filename "/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/hilander_checkpoint/inat_l_smoothap_iter${i}.pth" --knn_k 10 \
--tau 0.9 --level 10 --threshold prob \
--hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop \
--mode selectbydensity --thresh 0.8 \
--linsize 29011 --uinsize 18403 --inclasses 948 \
--output_filename "data/inat_hilander_l_smoothap_train_selectbydensity_iter${i}.pkl"
done
import argparse, time, os, pickle
import random
import numpy as np
import dgl
import torch
import torch.optim as optim
import sys
sys.path.append("..")
from models import LANDER
from dataset import LanderDataset
###########
# ArgParser
parser = argparse.ArgumentParser()
# Dataset
parser.add_argument('--data_path', type=str, required=True)
parser.add_argument('--levels', type=str, default='1')
parser.add_argument('--faiss_gpu', action='store_true')
parser.add_argument('--model_filename', type=str, default='lander.pth')
# KNN
parser.add_argument('--knn_k', type=str, default='10')
parser.add_argument('--num_workers', type=int, default=0)
# Model
parser.add_argument('--hidden', type=int, default=512)
parser.add_argument('--num_conv', type=int, default=1)
parser.add_argument('--dropout', type=float, default=0.)
parser.add_argument('--gat', action='store_true')
parser.add_argument('--gat_k', type=int, default=1)
parser.add_argument('--balance', action='store_true')
parser.add_argument('--use_cluster_feat', action='store_true')
parser.add_argument('--use_focal_loss', action='store_true')
# Training
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight_decay', type=float, default=1e-5)
args = parser.parse_args()
print(args)
###########################
# Environment Configuration
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
# setup_seed(20)
##################
# Data Preparation
with open(args.data_path, 'rb') as f:
path2idx, features, labels, _, masks = pickle.load(f)
# lidx = np.where(masks==0)
# features = features[lidx]
# labels = labels[lidx]
print("features.shape:", features.shape)
print("labels.shape:", labels.shape)
k_list = [int(k) for k in args.knn_k.split(',')]
lvl_list = [int(l) for l in args.levels.split(',')]
gs = []
nbrs = []
ks = []
datasets = []
for k, l in zip(k_list, lvl_list):
print("k:", k)
print("levels:", l)
dataset = LanderDataset(features=features, labels=labels, k=k,
levels=l, faiss_gpu=args.faiss_gpu)
gs += [g for g in dataset.gs]
ks += [k for g in dataset.gs]
nbrs += [nbr for nbr in dataset.nbrs]
datasets.append(dataset)
# with open("./dataset.pkl", 'rb') as f:
# datasets = pickle.load(f)
# for i in range(len(datasets)):
# dataset = datasets[i]
# k = k_list[i]
# gs += [g for g in dataset.gs]
# ks += [k for g in dataset.gs]
# nbrs += [nbr for nbr in dataset.nbrs]
with open("./dataset.pkl", 'wb') as f:
pickle.dump(datasets, f)
print('Dataset Prepared.')
def set_train_sampler_loader(g, k):
fanouts = [k-1 for i in range(args.num_conv + 1)]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
# fix the number of edges
train_dataloader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.number_of_nodes()), sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers
)
return train_dataloader
train_loaders = []
for gidx, g in enumerate(gs):
train_dataloader = set_train_sampler_loader(gs[gidx], ks[gidx])
train_loaders.append(train_dataloader)
##################
# Model Definition
feature_dim = gs[0].ndata['features'].shape[1]
print("feature dimension:", feature_dim)
model = LANDER(feature_dim=feature_dim, nhid=args.hidden,
num_conv=args.num_conv, dropout=args.dropout,
use_GAT=args.gat, K=args.gat_k,
balance=args.balance,
use_cluster_feat=args.use_cluster_feat,
use_focal_loss=args.use_focal_loss)
model = model.to(device)
model.train()
#################
# Hyperparameters
opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay)
# keep num_batch_per_loader the same for every sub_dataloader
num_batch_per_loader = len(train_loaders[0])
train_loaders = [iter(train_loader) for train_loader in train_loaders]
num_loaders = len(train_loaders)
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt,
T_max=args.epochs * num_batch_per_loader * num_loaders,
eta_min=1e-5)
print('Start Training.')
###############
# Training Loop
for epoch in range(args.epochs):
loss_den_val_total = []
loss_conn_val_total = []
loss_val_total = []
for batch in range(num_batch_per_loader):
for loader_id in range(num_loaders):
try:
minibatch = next(train_loaders[loader_id])
except:
train_loaders[loader_id] = iter(set_train_sampler_loader(gs[loader_id], ks[loader_id]))
minibatch = next(train_loaders[loader_id])
input_nodes, sub_g, bipartites = minibatch
sub_g = sub_g.to(device)
bipartites = [b.to(device) for b in bipartites]
# get the feature for the input_nodes
opt.zero_grad()
output_bipartite = model(bipartites)
loss, loss_den_val, loss_conn_val = model.compute_loss(output_bipartite)
loss_den_val_total.append(loss_den_val)
loss_conn_val_total.append(loss_conn_val)
loss_val_total.append(loss.item())
loss.backward()
opt.step()
if (batch + 1) % 10 == 0:
print('epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f'%
(epoch, batch, num_batch_per_loader, loader_id, num_loaders,
loss.item(), loss_den_val, loss_conn_val))
scheduler.step()
print('epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f'%
(epoch, np.array(loss_val_total).mean(),
np.array(loss_den_val_total).mean(), np.array(loss_conn_val_total).mean()))
torch.save(model.state_dict(), args.model_filename)
torch.save(model.state_dict(), args.model_filename)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment