Unverified Commit ba824141 authored by Álvaro Somoza's avatar Álvaro Somoza Committed by GitHub
Browse files

[docs] Add controlnet example to marigold (#8289)



* initial doc

* fix wrong LCM sentence

* implement binary colormap without requiring matplotlib
update section about Marigold for ControlNet
update formatting of marigold_usage.md

* fix indentation

---------
Co-authored-by: default avataranton <anton.obukhov@gmail.com>
parent fe5f035f
......@@ -245,9 +245,9 @@ class MarigoldImageProcessor(ConfigMixin):
) -> Union[np.ndarray, torch.Tensor]:
"""
Converts a monochrome image into an RGB image by applying the specified colormap. This function mimics the
behavior of matplotlib.colormaps, but allows the user to use the most discriminative color map "Spectral"
without having to install or import matplotlib. For all other cases, the function will attempt to use the
native implementation.
behavior of matplotlib.colormaps, but allows the user to use the most discriminative color maps ("Spectral",
"binary") without having to install or import matplotlib. For all other cases, the function will attempt to use
the native implementation.
Args:
image: 2D tensor of values between 0 and 1, either as np.ndarray or torch.Tensor.
......@@ -255,7 +255,7 @@ class MarigoldImageProcessor(ConfigMixin):
bytes: Whether to return the output as uint8 or floating point image.
_force_method:
Can be used to specify whether to use the native implementation (`"matplotlib"`), the efficient custom
implementation of the "Spectral" color map (`"custom"`), or rely on autodetection (`None`, default).
implementation of the select color maps (`"custom"`), or rely on autodetection (`None`, default).
Returns:
An RGB-colorized tensor corresponding to the input image.
......@@ -265,6 +265,26 @@ class MarigoldImageProcessor(ConfigMixin):
if _force_method not in (None, "matplotlib", "custom"):
raise ValueError("_force_method must be either `None`, `'matplotlib'` or `'custom'`.")
supported_cmaps = {
"binary": [
(1.0, 1.0, 1.0),
(0.0, 0.0, 0.0),
],
"Spectral": [ # Taken from matplotlib/_cm.py
(0.61960784313725492, 0.003921568627450980, 0.25882352941176473), # 0.0 -> [0]
(0.83529411764705885, 0.24313725490196078, 0.30980392156862746),
(0.95686274509803926, 0.42745098039215684, 0.2627450980392157),
(0.99215686274509807, 0.68235294117647061, 0.38039215686274508),
(0.99607843137254903, 0.8784313725490196, 0.54509803921568623),
(1.0, 1.0, 0.74901960784313726),
(0.90196078431372551, 0.96078431372549022, 0.59607843137254901),
(0.6705882352941176, 0.8666666666666667, 0.64313725490196083),
(0.4, 0.76078431372549016, 0.6470588235294118),
(0.19607843137254902, 0.53333333333333333, 0.74117647058823533),
(0.36862745098039218, 0.30980392156862746, 0.63529411764705879), # 1.0 -> [K-1]
],
}
def method_matplotlib(image, cmap, bytes=False):
if is_matplotlib_available():
import matplotlib
......@@ -298,24 +318,19 @@ class MarigoldImageProcessor(ConfigMixin):
else:
image = image.float()
if cmap != "Spectral":
raise ValueError("Only 'Spectral' color map is available without installing matplotlib.")
is_cmap_reversed = cmap.endswith("_r")
if is_cmap_reversed:
cmap = cmap[:-2]
_Spectral_data = ( # Taken from matplotlib/_cm.py
(0.61960784313725492, 0.003921568627450980, 0.25882352941176473), # 0.0 -> [0]
(0.83529411764705885, 0.24313725490196078, 0.30980392156862746),
(0.95686274509803926, 0.42745098039215684, 0.2627450980392157),
(0.99215686274509807, 0.68235294117647061, 0.38039215686274508),
(0.99607843137254903, 0.8784313725490196, 0.54509803921568623),
(1.0, 1.0, 0.74901960784313726),
(0.90196078431372551, 0.96078431372549022, 0.59607843137254901),
(0.6705882352941176, 0.8666666666666667, 0.64313725490196083),
(0.4, 0.76078431372549016, 0.6470588235294118),
(0.19607843137254902, 0.53333333333333333, 0.74117647058823533),
(0.36862745098039218, 0.30980392156862746, 0.63529411764705879), # 1.0 -> [K-1]
)
if cmap not in supported_cmaps:
raise ValueError(
f"Only {list(supported_cmaps.keys())} color maps are available without installing matplotlib."
)
cmap = torch.tensor(_Spectral_data, dtype=torch.float, device=image.device) # [K,3]
cmap = supported_cmaps[cmap]
if is_cmap_reversed:
cmap = cmap[::-1]
cmap = torch.tensor(cmap, dtype=torch.float, device=image.device) # [K,3]
K = cmap.shape[0]
pos = image.clamp(min=0, max=1) * (K - 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment