ensemble_classifier.py 4.03 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import os
import numpy as np
import argparse
import collections

parser = argparse.ArgumentParser()
parser.add_argument('--paths', required=True, nargs='+')
parser.add_argument('--eval', action='store_true')
parser.add_argument('--outdir')
parser.add_argument('--prediction-name', default='test_predictions.pt')
parser.add_argument('--calc-threshold', action='store_true')
parser.add_argument('--one-threshold', action='store_true')
parser.add_argument('--threshold', nargs='+', default=None, type=float)
parser.add_argument('--labels',nargs='+', default=None)
args = parser.parse_args()

all_predictions = collections.OrderedDict()
all_labels = collections.OrderedDict()
all_uid = collections.OrderedDict()
for path in args.paths:
    path = os.path.join(path, args.prediction_name)
    try:
        data = torch.load(path)
        for dataset in data:
            name, d = dataset
            predictions, labels, uid = d
            if name not in all_predictions:
                all_predictions[name] = np.array(predictions)
                if args.labels is None:
                    args.labels = [i for i in range(all_predictions[name].shape[1])]
                if args.eval:
                    all_labels[name] = np.array(labels)
                all_uid[name] = np.array(uid)
            else:
                all_predictions[name] += np.array(predictions)
                assert np.allclose(all_uid[name], np.array(uid))
    except Exception as e:
        print(e)
        continue
all_correct = 0
count = 0
def get_threshold(all_predictions, all_labels):
    if args.one_threshold:
        all_predictons = {'combined': np.concatenate(list(all_predictions.values()))}
        all_labels = {'combined': np.concatenate(list(all_predictions.labels()))}
    out_thresh = []
    for dataset in all_predictions:
        preds = all_predictions[dataset]
        labels = all_labels[dataset]
        out_thresh.append(calc_threshold(preds,labels))
    return out_thresh
def calc_threshold(p, l):
    trials = [(i)*(1./100.) for i in range(100)]
    best_acc = float('-inf')
    best_thresh = 0
    for t in trials:
        acc = ((apply_threshold(p, t).argmax(-1) == l).astype(float)).mean()
        if acc > best_acc:
            best_acc = acc
            best_thresh = t
    return best_thresh

def apply_threshold(preds, t):
    assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0])))
    prob = preds[:,-1]
    thresholded = (prob >= t).astype(int)
    preds = np.zeros_like(preds)
    preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1
    return preds

def threshold_predictions(all_predictions, threshold):
    if len(threshold)!=len(all_predictions):
        threshold = [threshold[-1]]*(len(all_predictions)-len(threshold))
    for i, dataset in enumerate(all_predictions):
        thresh = threshold[i]
        preds = all_predictions[dataset]
        all_predictions[dataset] = apply_threshold(preds, thresh)
    return all_predictions

for d in all_predictions:
    all_predictions[d] = all_predictions[d]/len(args.paths)

if args.calc_threshold:
    args.threshold = get_threshold(all_predictions, all_labels)
    print('threshold', args.threshold)

if args.threshold is not None:
    all_predictions = threshold_predictions(all_predictions, args.threshold)

for dataset in all_predictions:
    preds = all_predictions[dataset]
    preds = np.argmax(preds, -1)
    if args.eval:
        correct = (preds == all_labels[dataset]).sum()
        num = len(all_labels[dataset])
        accuracy = correct/num
        count += num
        all_correct += correct
        accuracy = (preds == all_labels[dataset]).mean()
        print(accuracy)
    if not os.path.exists(os.path.join(args.outdir, dataset)):
        os.makedirs(os.path.join(args.outdir, dataset))
    outpath = os.path.join(args.outdir, dataset, os.path.splitext(args.prediction_name)[0]+'.tsv')
    with open(outpath, 'w') as f:
        f.write('id\tlabel\n')
        f.write('\n'.join(str(uid)+'\t'+str(args.labels[p]) for uid, p in zip(all_uid[dataset], preds.tolist())))
if args.eval:
    print(all_correct/count)