"...text-generation-inference.git" did not exist on "dbdc587dddf0f16b7c05b28fb632acf9f65f185f"
Commit ab754c8c authored by Raul Puri's avatar Raul Puri
Browse files

functionalized code

parent 5dd2a9ad
import torch
import os import os
import numpy as np
import argparse import argparse
import collections import collections
parser = argparse.ArgumentParser() import numpy as np
parser.add_argument('--paths', required=True, nargs='+') import torch
parser.add_argument('--eval', action='store_true')
parser.add_argument('--outdir') def process_files(args):
parser.add_argument('--prediction-name', default='test_predictions.pt') all_predictions = collections.OrderedDict()
parser.add_argument('--calc-threshold', action='store_true') all_labels = collections.OrderedDict()
parser.add_argument('--one-threshold', action='store_true') all_uid = collections.OrderedDict()
parser.add_argument('--threshold', nargs='+', default=None, type=float) for path in args.paths:
parser.add_argument('--labels',nargs='+', default=None) path = os.path.join(path, args.prediction_name)
args = parser.parse_args() try:
data = torch.load(path)
all_predictions = collections.OrderedDict() for dataset in data:
all_labels = collections.OrderedDict() name, d = dataset
all_uid = collections.OrderedDict() predictions, labels, uid = d
for path in args.paths: if name not in all_predictions:
path = os.path.join(path, args.prediction_name) all_predictions[name] = np.array(predictions)
try: if args.labels is None:
data = torch.load(path) args.labels = [i for i in range(all_predictions[name].shape[1])]
for dataset in data: if args.eval:
name, d = dataset all_labels[name] = np.array(labels)
predictions, labels, uid = d all_uid[name] = np.array(uid)
if name not in all_predictions: else:
all_predictions[name] = np.array(predictions) all_predictions[name] += np.array(predictions)
if args.labels is None: assert np.allclose(all_uid[name], np.array(uid))
args.labels = [i for i in range(all_predictions[name].shape[1])] except Exception as e:
if args.eval: print(e)
all_labels[name] = np.array(labels) continue
all_uid[name] = np.array(uid) return all_predictions, all_labels, all_uid
else:
all_predictions[name] += np.array(predictions)
assert np.allclose(all_uid[name], np.array(uid)) def get_threshold(all_predictions, all_labels, one_threshold=False):
except Exception as e: if one_threshold:
print(e)
continue
all_correct = 0
count = 0
def get_threshold(all_predictions, all_labels):
if args.one_threshold:
all_predictons = {'combined': np.concatenate(list(all_predictions.values()))} all_predictons = {'combined': np.concatenate(list(all_predictions.values()))}
all_labels = {'combined': np.concatenate(list(all_predictions.labels()))} all_labels = {'combined': np.concatenate(list(all_predictions.labels()))}
out_thresh = [] out_thresh = []
...@@ -50,6 +42,8 @@ def get_threshold(all_predictions, all_labels): ...@@ -50,6 +42,8 @@ def get_threshold(all_predictions, all_labels):
labels = all_labels[dataset] labels = all_labels[dataset]
out_thresh.append(calc_threshold(preds,labels)) out_thresh.append(calc_threshold(preds,labels))
return out_thresh return out_thresh
def calc_threshold(p, l): def calc_threshold(p, l):
trials = [(i)*(1./100.) for i in range(100)] trials = [(i)*(1./100.) for i in range(100)]
best_acc = float('-inf') best_acc = float('-inf')
...@@ -61,6 +55,7 @@ def calc_threshold(p, l): ...@@ -61,6 +55,7 @@ def calc_threshold(p, l):
best_thresh = t best_thresh = t
return best_thresh return best_thresh
def apply_threshold(preds, t): def apply_threshold(preds, t):
assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0]))) assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0])))
prob = preds[:,-1] prob = preds[:,-1]
...@@ -69,6 +64,7 @@ def apply_threshold(preds, t): ...@@ -69,6 +64,7 @@ def apply_threshold(preds, t):
preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1 preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1
return preds return preds
def threshold_predictions(all_predictions, threshold): def threshold_predictions(all_predictions, threshold):
if len(threshold)!=len(all_predictions): if len(threshold)!=len(all_predictions):
threshold = [threshold[-1]]*(len(all_predictions)-len(threshold)) threshold = [threshold[-1]]*(len(all_predictions)-len(threshold))
...@@ -78,32 +74,72 @@ def threshold_predictions(all_predictions, threshold): ...@@ -78,32 +74,72 @@ def threshold_predictions(all_predictions, threshold):
all_predictions[dataset] = apply_threshold(preds, thresh) all_predictions[dataset] = apply_threshold(preds, thresh)
return all_predictions return all_predictions
for d in all_predictions:
all_predictions[d] = all_predictions[d]/len(args.paths)
if args.calc_threshold: def postprocess_predictions(all_predictions, all_labels, args):
args.threshold = get_threshold(all_predictions, all_labels) for d in all_predictions:
print('threshold', args.threshold) all_predictions[d] = all_predictions[d]/len(args.paths)
if args.threshold is not None: if args.calc_threshold:
all_predictions = threshold_predictions(all_predictions, args.threshold) args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold)
print('threshold', args.threshold)
for dataset in all_predictions: if args.threshold is not None:
preds = all_predictions[dataset] all_predictions = threshold_predictions(all_predictions, args.threshold)
preds = np.argmax(preds, -1)
return all_predictions, all_labels
def write_predictions(all_predictions, all_labels, all_uid, args):
all_correct = 0
count = 0
for dataset in all_predictions:
preds = all_predictions[dataset]
preds = np.argmax(preds, -1)
if args.eval:
correct = (preds == all_labels[dataset]).sum()
num = len(all_labels[dataset])
accuracy = correct/num
count += num
all_correct += correct
accuracy = (preds == all_labels[dataset]).mean()
print(accuracy)
if not os.path.exists(os.path.join(args.outdir, dataset)):
os.makedirs(os.path.join(args.outdir, dataset))
outpath = os.path.join(args.outdir, dataset, os.path.splitext(args.prediction_name)[0]+'.tsv')
with open(outpath, 'w') as f:
f.write('id\tlabel\n')
f.write('\n'.join(str(uid)+'\t'+str(args.labels[p]) for uid, p in zip(all_uid[dataset], preds.tolist())))
if args.eval: if args.eval:
correct = (preds == all_labels[dataset]).sum() print(all_correct/count)
num = len(all_labels[dataset])
accuracy = correct/num
count += num def ensemble_predictions(args):
all_correct += correct all_predictions, all_labels, all_uid = process_files(args)
accuracy = (preds == all_labels[dataset]).mean() all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args)
print(accuracy) write_predictions(all_predictions, all_labels, all_uid, args)
if not os.path.exists(os.path.join(args.outdir, dataset)):
os.makedirs(os.path.join(args.outdir, dataset))
outpath = os.path.join(args.outdir, dataset, os.path.splitext(args.prediction_name)[0]+'.tsv') def main():
with open(outpath, 'w') as f: parser = argparse.ArgumentParser()
f.write('id\tlabel\n') parser.add_argument('--paths', required=True, nargs='+',
f.write('\n'.join(str(uid)+'\t'+str(args.labels[p]) for uid, p in zip(all_uid[dataset], preds.tolist()))) help='paths to checkpoint directories used in ensemble')
if args.eval: parser.add_argument('--eval', action='store_true',
print(all_correct/count) help='compute accuracy metrics against labels (dev set)')
parser.add_argument('--outdir',
help='directory to place ensembled predictions in')
parser.add_argument('--prediction-name', default='test_predictions.pt',
help='name of predictions in checkpoint directories')
parser.add_argument('--calc-threshold', action='store_true',
help='calculate threshold classification')
parser.add_argument('--one-threshold', action='store_true',
help='use on threshold for all subdatasets')
parser.add_argument('--threshold', nargs='+', default=None, type=float,
help='user supplied threshold for classification')
parser.add_argument('--labels',nargs='+', default=None,
help='whitespace separated list of label names')
args = parser.parse_args()
ensemble_predictions(args)
if __name__ == '__main__':
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment