inference.py 4.39 KB
Newer Older
1
"""
2
Adapted from the inference.py to demonstate the usage of the util functions.
3
4
"""

5
import sys
6
import numpy as np
7
import pydensecrf.densecrf as dcrf
8
9
10
11
12
13
14
15
16
17
18

# Get im{read,write} from somewhere.
try:
    from cv2 import imread, imwrite
except ImportError:
    # Note that, sadly, skimage unconditionally import scipy and matplotlib,
    # so you'll need them if you don't have OpenCV. But you probably have them.
    from skimage.io import imread, imsave
    imwrite = imsave
    # TODO: Use scipy instead.

19
from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral, create_pairwise_gaussian
20

21
22
23
24
25
26
if len(sys.argv) != 4:
    print("Usage: python {} IMAGE ANNO OUTPUT".format(sys.argv[0]))
    print("")
    print("IMAGE and ANNO are inputs and OUTPUT is where the result should be written.")
    sys.exit(1)

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
fn_im = sys.argv[1]
fn_anno = sys.argv[2]
fn_output = sys.argv[3]

##################################
### Read images and annotation ###
##################################
img = imread(fn_im)

# Convert the annotation's RGB color to a single 32-bit integer color 0xBBGGRR
anno_rgb = imread(fn_anno).astype(np.uint32)
anno_lbl = anno_rgb[:,:,0] + (anno_rgb[:,:,1] << 8) + (anno_rgb[:,:,2] << 16)

# Convert the 32bit integer color to 1, 2, ... labels.
# Note that all-black, i.e. the value 0 for background will stay 0.
colors, labels = np.unique(anno_lbl, return_inverse=True)

# And create a mapping back from the labels to 32bit integer colors.
# But remove the all-0 black, that won't exist in the MAP!
colors = colors[1:]
colorize = np.empty((len(colors), 3), np.uint8)
colorize[:,0] = (colors & 0x0000FF)
colorize[:,1] = (colors & 0x00FF00) >> 8
colorize[:,2] = (colors & 0xFF0000) >> 16

# Compute the number of classes in the label image.
# We subtract one because the number shouldn't include the value 0 which stands
# for "unknown" or "unsure".
55
56
n_labels = len(set(labels.flat)) - 1
print(n_labels, " labels and \"unknown\" 0: ", set(labels.flat))
57
58
59
60
61
62
63
64
65
66

###########################
### Setup the CRF model ###
###########################
use_2d = False
# use_2d = True
if use_2d:
    print("Using 2D specialized functions")

    # Example using the DenseCRF2D code
67
    d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], n_labels)
68
69

    # get unary potentials (neg log probability)
70
    U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=True)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    d.setUnaryEnergy(U)

    # This adds the color-independent term, features are the locations only.
    d.addPairwiseGaussian(sxy=(3, 3), compat=3, kernel=dcrf.DIAG_KERNEL,
                          normalization=dcrf.NORMALIZE_SYMMETRIC)

    # This adds the color-dependent term, i.e. features are (x,y,r,g,b).
    d.addPairwiseBilateral(sxy=(80, 80), srgb=(13, 13, 13), rgbim=img,
                           compat=10,
                           kernel=dcrf.DIAG_KERNEL,
                           normalization=dcrf.NORMALIZE_SYMMETRIC)
else:
    print("Using generic 2D functions")

    # Example using the DenseCRF class and the util functions
86
    d = dcrf.DenseCRF(img.shape[1] * img.shape[0], n_labels)
87
88

    # get unary potentials (neg log probability)
89
    U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=True)
90
    d.setUnaryEnergy(U)
91

92
93
94
95
96
    # This creates the color-independent features and then add them to the CRF
    feats = create_pairwise_gaussian(sdims=(3, 3), shape=img.shape[:2])
    d.addPairwiseEnergy(feats, compat=3,
                        kernel=dcrf.DIAG_KERNEL,
                        normalization=dcrf.NORMALIZE_SYMMETRIC)
97

98
99
100
101
102
103
    # This creates the color-dependent features and then add them to the CRF
    feats = create_pairwise_bilateral(sdims=(80, 80), schan=(13, 13, 13),
                                      img=img, chdim=2)
    d.addPairwiseEnergy(feats, compat=10,
                        kernel=dcrf.DIAG_KERNEL,
                        normalization=dcrf.NORMALIZE_SYMMETRIC)
104
105


106
107
108
####################################
### Do inference and compute MAP ###
####################################
109

110
111
# Run five inference steps.
Q = d.inference(5)
112

113
114
# Find out the most probable class for each pixel.
MAP = np.argmax(Q, axis=0)
115

116
117
118
# Convert the MAP (labels) back to the corresponding colors and save the image.
MAP = colorize[MAP,:]
imsave(fn_output, MAP.reshape(img.shape))
119

120
121
122
123
124
# Just randomly manually run inference iterations
Q, tmp1, tmp2 = d.startInference()
for i in range(5):
    print("KL-divergence at {}: {}".format(i, d.klDivergence(Q)))
    d.stepInference(Q, tmp1, tmp2)