AP.py 4.47 KB
Newer Older
lishj6's avatar
lishj6 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import numpy as np
from .distance import chamfer_distance, frechet_distance, chamfer_distance_batch
from typing import List, Tuple, Union
from numpy.typing import NDArray

def average_precision(recalls, precisions, mode='area'):
    """Calculate average precision. 

    Args:
        recalls (ndarray): shape (num_dets, )
        precisions (ndarray): shape (num_dets, )
        mode (str): 'area' or '11points', 'area' means calculating the area
            under precision-recall curve, '11points' means calculating
            the average precision of recalls at [0, 0.1, ..., 1]

    Returns:
        float: calculated average precision
    """

    recalls = recalls[np.newaxis, :]
    precisions = precisions[np.newaxis, :]

    assert recalls.shape == precisions.shape and recalls.ndim == 2
    num_scales = recalls.shape[0]
    ap = 0.

    if mode == 'area':
        zeros = np.zeros((num_scales, 1), dtype=recalls.dtype)
        ones = np.ones((num_scales, 1), dtype=recalls.dtype)
        mrec = np.hstack((zeros, recalls, ones))
        mpre = np.hstack((zeros, precisions, zeros))
        for i in range(mpre.shape[1] - 1, 0, -1):
            mpre[:, i - 1] = np.maximum(mpre[:, i - 1], mpre[:, i])
        
        ind = np.where(mrec[0, 1:] != mrec[0, :-1])[0]
        ap = np.sum(
            (mrec[0, ind + 1] - mrec[0, ind]) * mpre[0, ind + 1])
    
    elif mode == '11points':
        for thr in np.arange(0, 1 + 1e-3, 0.1):
            precs = precisions[0, recalls[i, :] >= thr]
            prec = precs.max() if precs.size > 0 else 0
            ap += prec
        ap /= 11
    else:
        raise ValueError(
            'Unrecognized mode, only "area" and "11points" are supported')
    
    return ap

def instance_match(pred_lines: NDArray, 
                   scores: NDArray, 
                   gt_lines: NDArray, 
                   thresholds: Union[Tuple, List], 
                   metric: str='chamfer') -> List:
    """Compute whether detected lines are true positive or false positive.

    Args:
        pred_lines (array): Detected lines of a sample, of shape (M, INTERP_NUM, 2 or 3).
        scores (array): Confidence score of each line, of shape (M, ).
        gt_lines (array): GT lines of a sample, of shape (N, INTERP_NUM, 2 or 3).
        thresholds (list of tuple): List of thresholds.
        metric (str): Distance function for lines matching. Default: 'chamfer'.

    Returns:
        list_of_tp_fp (list): tp-fp matching result at all thresholds
    """

    if metric == 'chamfer':
        distance_fn = chamfer_distance

    elif metric == 'frechet':
        distance_fn = frechet_distance
    
    else:
        raise ValueError(f'unknown distance function {metric}')

    num_preds = pred_lines.shape[0]
    num_gts = gt_lines.shape[0]

    # tp and fp
    tp_fp_list = []
    tp = np.zeros((num_preds), dtype=np.float32)
    fp = np.zeros((num_preds), dtype=np.float32)

    # if there is no gt lines in this sample, then all pred lines are false positives
    if num_gts == 0:
        fp[...] = 1
        for thr in thresholds:
            tp_fp_list.append((tp.copy(), fp.copy()))
        return tp_fp_list
    
    if num_preds == 0:
        for thr in thresholds:
            tp_fp_list.append((tp.copy(), fp.copy()))
        return tp_fp_list

    assert pred_lines.shape[1] == gt_lines.shape[1], \
        "sample points num should be the same"

    # distance matrix: M x N
    matrix = np.zeros((num_preds, num_gts))

    # for i in range(num_preds):
    #     for j in range(num_gts):
    #         matrix[i, j] = distance_fn(pred_lines[i], gt_lines[j])
    
    matrix = chamfer_distance_batch(pred_lines, gt_lines)
    # for each det, the min distance with all gts
    matrix_min = matrix.min(axis=1)

    # for each det, which gt is the closest to it
    matrix_argmin = matrix.argmin(axis=1)
    # sort all dets in descending order by scores
    sort_inds = np.argsort(-scores)

    # match under different thresholds
    for thr in thresholds:
        tp = np.zeros((num_preds), dtype=np.float32)
        fp = np.zeros((num_preds), dtype=np.float32)

        gt_covered = np.zeros(num_gts, dtype=bool)
        for i in sort_inds:
            if matrix_min[i] <= thr:
                matched_gt = matrix_argmin[i]
                if not gt_covered[matched_gt]:
                    gt_covered[matched_gt] = True
                    tp[i] = 1
                else:
                    fp[i] = 1
            else:
                fp[i] = 1
        
        tp_fp_list.append((tp, fp))

    return tp_fp_list