evaluation_am2k.py 2.67 KB
Newer Older
bailuo's avatar
init  
bailuo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import cv2
import sys
import numpy as np
sys.path.insert(0, './utils')
from evaluate import compute_sad_loss, compute_mse_loss, compute_mad_loss, compute_gradient_loss, compute_connectivity_error
import argparse
from tqdm import tqdm

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--pred-dir', type=str, default='path/to/outputs/am2k', help="pred alpha dir")
    parser.add_argument('--label-dir', type=str, default='path/to/AM2k/validation/mask/', help="GT alpha dir")
    parser.add_argument('--detailmap-dir', type=str, default='path/to/AM2k/validation/trimap/', help="trimap dir")

    args = parser.parse_args()

    mse_loss = []
    sad_loss = []
    mad_loss = []
    grad_loss = []
    conn_loss = []
    ### loss_unknown only consider the unknown regions, i.e. trimap==128, as trimap-based methods do
    mse_loss_unknown = []
    sad_loss_unknown = []
    
    for img in tqdm(os.listdir(args.label_dir)):
        print(img)
        #pred = cv2.imread(os.path.join(args.pred_dir, img.replace('.png', '.jpg')), 0).astype(np.float32)
        pred = cv2.imread(os.path.join(args.pred_dir, img), 0).astype(np.float32)
        label = cv2.imread(os.path.join(args.label_dir, img), 0).astype(np.float32)
        detailmap = cv2.imread(os.path.join(args.detailmap_dir, img), 0).astype(np.float32)

        detailmap[detailmap > 0] = 128

        mse_loss_unknown_ = compute_mse_loss(pred, label, detailmap)
        sad_loss_unknown_ = compute_sad_loss(pred, label, detailmap)[0]

        detailmap[...] = 128

        mse_loss_ = compute_mse_loss(pred, label, detailmap)
        sad_loss_ = compute_sad_loss(pred, label, detailmap)[0]
        mad_loss_ = compute_mad_loss(pred, label, detailmap)
        grad_loss_ = compute_gradient_loss(pred, label, detailmap)
        conn_loss_ = compute_connectivity_error(pred, label, detailmap)

        print('Whole Image: MSE:', mse_loss_, ' SAD:', sad_loss_, ' MAD:', mad_loss_, 'Grad:', grad_loss_, ' Conn:', conn_loss_)
        print('Detail Region: MSE:', mse_loss_unknown_, ' SAD:', sad_loss_unknown_)

        mse_loss_unknown.append(mse_loss_unknown_)
        sad_loss_unknown.append(sad_loss_unknown_)

        mse_loss.append(mse_loss_)
        sad_loss.append(sad_loss_)
        mad_loss.append(mad_loss_)
        grad_loss.append(grad_loss_)
        conn_loss.append(conn_loss_)

    print('Average:')
    print('Whole Image: MSE:', np.array(mse_loss).mean(), ' SAD:', np.array(sad_loss).mean(), ' MAD:', np.array(mad_loss).mean(), ' Grad:', np.array(grad_loss).mean(), ' Conn:', np.array(conn_loss).mean())
    print('Detail Region: MSE:', np.array(mse_loss_unknown).mean(), ' SAD:', np.array(sad_loss_unknown).mean())