"official/projects/movinet/configs/movinet_test.py" did not exist on "16e2edce7f165a65847f3ad6df1913fc84976f8c"
palette.py 4.88 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
r"""Modified from ``https://github.com/sergeyk/rayleigh''.
"""
import os
import os.path as osp
import numpy as np
from skimage.color import hsv2rgb, rgb2lab, lab2rgb
from skimage.io import imsave
from sklearn.metrics import euclidean_distances

__all__ = ['Palette']

def rgb2hex(rgb):
    return '#%02x%02x%02x' % tuple([int(round(255.0 * u)) for u in rgb])

def hex2rgb(hex):
    rgb = hex.strip('#')
    fn = lambda u: round(int(u, 16) / 255.0, 5)
    return fn(rgb[:2]), fn(rgb[2:4]), fn(rgb[4:6])

class Palette(object):
    r"""Create a color palette (codebook) in the form of a 2D grid of colors.
        Further, the rightmost column has num_hues gradations from black to white.

    Parameters:
        num_hues: number of colors with full lightness and saturation, in the middle.
        num_sat: number of rows above middle row that show the same hues with decreasing saturation.
    """
    def __init__(self, num_hues=11, num_sat=5, num_light=4):
        n = num_sat + 2 * num_light

        # hues
        if num_hues == 8:
            hues = np.tile(np.array([0.,  0.10,  0.15,  0.28, 0.51, 0.58, 0.77,  0.85]), (n, 1))
        elif num_hues == 9:
            hues = np.tile(np.array([0.,  0.10,  0.15,  0.28, 0.49, 0.54, 0.60, 0.7, 0.87]), (n, 1))
        elif num_hues == 10:
            hues = np.tile(np.array([0.,  0.10,  0.15,  0.28, 0.49, 0.54, 0.60, 0.66, 0.76, 0.87]), (n, 1))
        elif num_hues == 11:
            hues = np.tile(np.array([0.0, 0.0833, 0.166, 0.25, 0.333, 0.5, 0.56333, 0.666, 0.73, 0.803, 0.916]), (n, 1))
        else:
            hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (n, 1))
        
        # saturations
        sats = np.hstack((
            np.linspace(0, 1, num_sat + 2)[1:-1],
            1,
            [1] * num_light,
            [0.4] * (num_light - 1)))
        sats = np.tile(np.atleast_2d(sats).T, (1, num_hues))

        # lights
        lights = np.hstack((
            [1] * num_sat,
            1,
            np.linspace(1, 0.2, num_light + 2)[1:-1],
            np.linspace(1, 0.2, num_light + 2)[1:-2]))
        lights = np.tile(np.atleast_2d(lights).T, (1, num_hues))

        # colors
        rgb = hsv2rgb(np.dstack([hues, sats, lights]))
        gray = np.tile(np.linspace(1, 0, n)[:, np.newaxis, np.newaxis], (1, 1, 3))
        self.thumbnail = np.hstack([rgb, gray])

        # flatten
        rgb = rgb.T.reshape(3, -1).T
        gray = gray.T.reshape(3, -1).T
        self.rgb = np.vstack((rgb, gray))
        self.lab = rgb2lab(self.rgb[np.newaxis, :, :]).squeeze()
        self.hex = [rgb2hex(u) for u in self.rgb]
        self.lab_dists = euclidean_distances(self.lab, squared=True)
    
    def histogram(self, rgb_img, sigma=20):
        # compute histogram
        lab = rgb2lab(rgb_img).reshape((-1, 3))
        min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1)
        hist = 1.0 * np.bincount(min_ind, minlength=self.lab.shape[0]) / lab.shape[0]

        # smooth histogram
        if sigma > 0:
            weight = np.exp(-self.lab_dists / (2.0 * sigma ** 2))
            weight = weight / weight.sum(1)[:, np.newaxis]
            hist = (weight * hist).sum(1)
            hist[hist < 1e-5] = 0
        return hist
    
    def get_palette_image(self, hist, percentile=90, width=200, height=50):
        # curate histogram
        ind = np.argsort(-hist)
        ind = ind[hist[ind] > np.percentile(hist, percentile)]
        hist = hist[ind] / hist[ind].sum()

        # draw palette
        nums = np.array(hist * width, dtype=int)
        array = np.vstack([np.tile(np.array(u), (v, 1)) for u, v in zip(self.rgb[ind], nums)])
        array = np.tile(array[np.newaxis, :, :], (height, 1, 1))
        if array.shape[1] < width:
            array = np.concatenate([array, np.zeros((height, width - array.shape[1], 3))], axis=1)
        return array

    def quantize_image(self, rgb_img):
        lab = rgb2lab(rgb_img).reshape((-1, 3))
        min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1)
        quantized_lab = self.lab[min_ind]
        img = lab2rgb(quantized_lab.reshape(rgb_img.shape))
        return img
    
    def export(self, dirname):
        if not osp.exists(dirname):
            os.makedirs(dirname)
        
        # save thumbnail
        imsave(osp.join(dirname, 'palette.png'), self.thumbnail)

        # save html
        with open(osp.join(dirname, 'palette.html'), 'w') as f:
            html = '''
            <style>
                span {
                    width: 20px;
                    height: 20px;
                    margin: 2px;
                    padding: 0px;
                    display: inline-block;
                }
            </style>
            '''
            for row in self.thumbnail:
                for col in row:
                    html += '<a id="{0}"><span style="background-color: {0}" /></a>\n'.format(rgb2hex(col))
                html += '<br />\n'
            f.write(html)