Commit ab754c8c authored by Raul Puri's avatar Raul Puri
Browse files

functionalized code

parent 5dd2a9ad
import torch
import os import os
import numpy as np
import argparse import argparse
import collections import collections
parser = argparse.ArgumentParser() import numpy as np
parser.add_argument('--paths', required=True, nargs='+') import torch
parser.add_argument('--eval', action='store_true')
parser.add_argument('--outdir') def process_files(args):
parser.add_argument('--prediction-name', default='test_predictions.pt') all_predictions = collections.OrderedDict()
parser.add_argument('--calc-threshold', action='store_true') all_labels = collections.OrderedDict()
parser.add_argument('--one-threshold', action='store_true') all_uid = collections.OrderedDict()
parser.add_argument('--threshold', nargs='+', default=None, type=float) for path in args.paths:
parser.add_argument('--labels',nargs='+', default=None)
args = parser.parse_args()
all_predictions = collections.OrderedDict()
all_labels = collections.OrderedDict()
all_uid = collections.OrderedDict()
for path in args.paths:
path = os.path.join(path, args.prediction_name) path = os.path.join(path, args.prediction_name)
try: try:
data = torch.load(path) data = torch.load(path)
...@@ -38,10 +29,11 @@ for path in args.paths: ...@@ -38,10 +29,11 @@ for path in args.paths:
except Exception as e: except Exception as e:
print(e) print(e)
continue continue
all_correct = 0 return all_predictions, all_labels, all_uid
count = 0
def get_threshold(all_predictions, all_labels):
if args.one_threshold: def get_threshold(all_predictions, all_labels, one_threshold=False):
if one_threshold:
all_predictons = {'combined': np.concatenate(list(all_predictions.values()))} all_predictons = {'combined': np.concatenate(list(all_predictions.values()))}
all_labels = {'combined': np.concatenate(list(all_predictions.labels()))} all_labels = {'combined': np.concatenate(list(all_predictions.labels()))}
out_thresh = [] out_thresh = []
...@@ -50,6 +42,8 @@ def get_threshold(all_predictions, all_labels): ...@@ -50,6 +42,8 @@ def get_threshold(all_predictions, all_labels):
labels = all_labels[dataset] labels = all_labels[dataset]
out_thresh.append(calc_threshold(preds,labels)) out_thresh.append(calc_threshold(preds,labels))
return out_thresh return out_thresh
def calc_threshold(p, l): def calc_threshold(p, l):
trials = [(i)*(1./100.) for i in range(100)] trials = [(i)*(1./100.) for i in range(100)]
best_acc = float('-inf') best_acc = float('-inf')
...@@ -61,6 +55,7 @@ def calc_threshold(p, l): ...@@ -61,6 +55,7 @@ def calc_threshold(p, l):
best_thresh = t best_thresh = t
return best_thresh return best_thresh
def apply_threshold(preds, t): def apply_threshold(preds, t):
assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0]))) assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0])))
prob = preds[:,-1] prob = preds[:,-1]
...@@ -69,6 +64,7 @@ def apply_threshold(preds, t): ...@@ -69,6 +64,7 @@ def apply_threshold(preds, t):
preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1 preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1
return preds return preds
def threshold_predictions(all_predictions, threshold): def threshold_predictions(all_predictions, threshold):
if len(threshold)!=len(all_predictions): if len(threshold)!=len(all_predictions):
threshold = [threshold[-1]]*(len(all_predictions)-len(threshold)) threshold = [threshold[-1]]*(len(all_predictions)-len(threshold))
...@@ -78,17 +74,25 @@ def threshold_predictions(all_predictions, threshold): ...@@ -78,17 +74,25 @@ def threshold_predictions(all_predictions, threshold):
all_predictions[dataset] = apply_threshold(preds, thresh) all_predictions[dataset] = apply_threshold(preds, thresh)
return all_predictions return all_predictions
for d in all_predictions:
def postprocess_predictions(all_predictions, all_labels, args):
for d in all_predictions:
all_predictions[d] = all_predictions[d]/len(args.paths) all_predictions[d] = all_predictions[d]/len(args.paths)
if args.calc_threshold: if args.calc_threshold:
args.threshold = get_threshold(all_predictions, all_labels) args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold)
print('threshold', args.threshold) print('threshold', args.threshold)
if args.threshold is not None: if args.threshold is not None:
all_predictions = threshold_predictions(all_predictions, args.threshold) all_predictions = threshold_predictions(all_predictions, args.threshold)
for dataset in all_predictions: return all_predictions, all_labels
def write_predictions(all_predictions, all_labels, all_uid, args):
all_correct = 0
count = 0
for dataset in all_predictions:
preds = all_predictions[dataset] preds = all_predictions[dataset]
preds = np.argmax(preds, -1) preds = np.argmax(preds, -1)
if args.eval: if args.eval:
...@@ -105,5 +109,37 @@ for dataset in all_predictions: ...@@ -105,5 +109,37 @@ for dataset in all_predictions:
with open(outpath, 'w') as f: with open(outpath, 'w') as f:
f.write('id\tlabel\n') f.write('id\tlabel\n')
f.write('\n'.join(str(uid)+'\t'+str(args.labels[p]) for uid, p in zip(all_uid[dataset], preds.tolist()))) f.write('\n'.join(str(uid)+'\t'+str(args.labels[p]) for uid, p in zip(all_uid[dataset], preds.tolist())))
if args.eval: if args.eval:
print(all_correct/count) print(all_correct/count)
def ensemble_predictions(args):
all_predictions, all_labels, all_uid = process_files(args)
all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args)
write_predictions(all_predictions, all_labels, all_uid, args)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--paths', required=True, nargs='+',
help='paths to checkpoint directories used in ensemble')
parser.add_argument('--eval', action='store_true',
help='compute accuracy metrics against labels (dev set)')
parser.add_argument('--outdir',
help='directory to place ensembled predictions in')
parser.add_argument('--prediction-name', default='test_predictions.pt',
help='name of predictions in checkpoint directories')
parser.add_argument('--calc-threshold', action='store_true',
help='calculate threshold classification')
parser.add_argument('--one-threshold', action='store_true',
help='use on threshold for all subdatasets')
parser.add_argument('--threshold', nargs='+', default=None, type=float,
help='user supplied threshold for classification')
parser.add_argument('--labels',nargs='+', default=None,
help='whitespace separated list of label names')
args = parser.parse_args()
ensemble_predictions(args)
if __name__ == '__main__':
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment