r"""Modified from ``https://github.com/sergeyk/rayleigh''.
"""
import os
import os.path as osp
import numpy as np
from skimage.color import hsv2rgb, rgb2lab, lab2rgb
from skimage.io import imsave
from sklearn.metrics import euclidean_distances
__all__ = ['Palette']
def rgb2hex(rgb):
return '#%02x%02x%02x' % tuple([int(round(255.0 * u)) for u in rgb])
def hex2rgb(hex):
rgb = hex.strip('#')
fn = lambda u: round(int(u, 16) / 255.0, 5)
return fn(rgb[:2]), fn(rgb[2:4]), fn(rgb[4:6])
class Palette(object):
r"""Create a color palette (codebook) in the form of a 2D grid of colors.
Further, the rightmost column has num_hues gradations from black to white.
Parameters:
num_hues: number of colors with full lightness and saturation, in the middle.
num_sat: number of rows above middle row that show the same hues with decreasing saturation.
"""
def __init__(self, num_hues=11, num_sat=5, num_light=4):
n = num_sat + 2 * num_light
# hues
if num_hues == 8:
hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]), (n, 1))
elif num_hues == 9:
hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]), (n, 1))
elif num_hues == 10:
hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76, 0.87]), (n, 1))
elif num_hues == 11:
hues = np.tile(np.array([0.0, 0.0833, 0.166, 0.25, 0.333, 0.5, 0.56333, 0.666, 0.73, 0.803, 0.916]), (n, 1))
else:
hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (n, 1))
# saturations
sats = np.hstack((
np.linspace(0, 1, num_sat + 2)[1:-1],
1,
[1] * num_light,
[0.4] * (num_light - 1)))
sats = np.tile(np.atleast_2d(sats).T, (1, num_hues))
# lights
lights = np.hstack((
[1] * num_sat,
1,
np.linspace(1, 0.2, num_light + 2)[1:-1],
np.linspace(1, 0.2, num_light + 2)[1:-2]))
lights = np.tile(np.atleast_2d(lights).T, (1, num_hues))
# colors
rgb = hsv2rgb(np.dstack([hues, sats, lights]))
gray = np.tile(np.linspace(1, 0, n)[:, np.newaxis, np.newaxis], (1, 1, 3))
self.thumbnail = np.hstack([rgb, gray])
# flatten
rgb = rgb.T.reshape(3, -1).T
gray = gray.T.reshape(3, -1).T
self.rgb = np.vstack((rgb, gray))
self.lab = rgb2lab(self.rgb[np.newaxis, :, :]).squeeze()
self.hex = [rgb2hex(u) for u in self.rgb]
self.lab_dists = euclidean_distances(self.lab, squared=True)
def histogram(self, rgb_img, sigma=20):
# compute histogram
lab = rgb2lab(rgb_img).reshape((-1, 3))
min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1)
hist = 1.0 * np.bincount(min_ind, minlength=self.lab.shape[0]) / lab.shape[0]
# smooth histogram
if sigma > 0:
weight = np.exp(-self.lab_dists / (2.0 * sigma ** 2))
weight = weight / weight.sum(1)[:, np.newaxis]
hist = (weight * hist).sum(1)
hist[hist < 1e-5] = 0
return hist
def get_palette_image(self, hist, percentile=90, width=200, height=50):
# curate histogram
ind = np.argsort(-hist)
ind = ind[hist[ind] > np.percentile(hist, percentile)]
hist = hist[ind] / hist[ind].sum()
# draw palette
nums = np.array(hist * width, dtype=int)
array = np.vstack([np.tile(np.array(u), (v, 1)) for u, v in zip(self.rgb[ind], nums)])
array = np.tile(array[np.newaxis, :, :], (height, 1, 1))
if array.shape[1] < width:
array = np.concatenate([array, np.zeros((height, width - array.shape[1], 3))], axis=1)
return array
def quantize_image(self, rgb_img):
lab = rgb2lab(rgb_img).reshape((-1, 3))
min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1)
quantized_lab = self.lab[min_ind]
img = lab2rgb(quantized_lab.reshape(rgb_img.shape))
return img
def export(self, dirname):
if not osp.exists(dirname):
os.makedirs(dirname)
# save thumbnail
imsave(osp.join(dirname, 'palette.png'), self.thumbnail)
# save html
with open(osp.join(dirname, 'palette.html'), 'w') as f:
html = '''
'''
for row in self.thumbnail:
for col in row:
html += '\n'.format(rgb2hex(col))
html += '
\n'
f.write(html)