Unverified Commit ba824141 authored by Álvaro Somoza's avatar Álvaro Somoza Committed by GitHub
Browse files

[docs] Add controlnet example to marigold (#8289)



* initial doc

* fix wrong LCM sentence

* implement binary colormap without requiring matplotlib
update section about Marigold for ControlNet
update formatting of marigold_usage.md

* fix indentation

---------
Co-authored-by: default avataranton <anton.obukhov@gmail.com>
parent fe5f035f
...@@ -373,7 +373,7 @@ with imageio.get_reader(path_in) as reader: ...@@ -373,7 +373,7 @@ with imageio.get_reader(path_in) as reader:
latents = 0.9 * latents + 0.1 * last_frame_latent latents = 0.9 * latents + 0.1 * last_frame_latent
depth = pipe( depth = pipe(
frame, match_input_resolution=False, latents=latents, output_latent=True, frame, match_input_resolution=False, latents=latents, output_latent=True
) )
last_frame_latent = depth.latent last_frame_latent = depth.latent
out.append(pipe.image_processor.visualize_depth(depth.prediction)[0]) out.append(pipe.image_processor.visualize_depth(depth.prediction)[0])
...@@ -396,4 +396,71 @@ The result is much more stable now: ...@@ -396,4 +396,71 @@ The result is much more stable now:
</div> </div>
</div> </div>
Hopefully, you will find Marigold useful for solving your downstream tasks, be it a part of a more broad generative workflow, or a broader perception task, such as 3D reconstruction. ## Marigold for ControlNet
\ No newline at end of file
A very common application for depth prediction with diffusion models comes in conjunction with ControlNet.
Depth crispness plays a crucial role in obtaining high-quality results from ControlNet.
As seen in comparisons with other methods above, Marigold excels at that task.
The snippet below demonstrates how to load an image, compute depth, and pass it into ControlNet in a compatible format:
```python
import torch
import diffusers
device = "cuda"
generator = torch.Generator(device=device).manual_seed(2024)
image = diffusers.utils.load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_depth_source.png"
)
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
"prs-eth/marigold-lcm-v1-0", torch_dtype=torch.float16, variant="fp16"
).to("cuda")
depth_image = pipe(image, generator=generator).prediction
depth_image = pipe.image_processor.visualize_depth(depth_image, color_map="binary")
depth_image[0].save("motorcycle_controlnet_depth.png")
controlnet = diffusers.ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
).to("cuda")
pipe = diffusers.StableDiffusionXLControlNetPipeline.from_pretrained(
"SG161222/RealVisXL_V4.0", torch_dtype=torch.float16, variant="fp16", controlnet=controlnet
).to("cuda")
pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
controlnet_out = pipe(
prompt="high quality photo of a sports bike, city",
negative_prompt="",
guidance_scale=6.5,
num_inference_steps=25,
image=depth_image,
controlnet_conditioning_scale=0.7,
control_guidance_end=0.7,
generator=generator,
).images
controlnet_out[0].save("motorcycle_controlnet_out.png")
```
<div class="flex gap-4">
<div style="flex: 1 1 33%; max-width: 33%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_depth_source.png"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">
Input image
</figcaption>
</div>
<div style="flex: 1 1 33%; max-width: 33%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/8e61e31f9feb7756c0404ceff26f3f0e5d3fe610/marigold/motorcycle_controlnet_depth.png"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">
Depth in the format compatible with ControlNet
</figcaption>
</div>
<div style="flex: 1 1 33%; max-width: 33%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/8e61e31f9feb7756c0404ceff26f3f0e5d3fe610/marigold/motorcycle_controlnet_out.png"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">
ControlNet generation, conditioned on depth and prompt: "high quality photo of a sports bike, city"
</figcaption>
</div>
</div>
Hopefully, you will find Marigold useful for solving your downstream tasks, be it a part of a more broad generative workflow, or a perception task, such as 3D reconstruction.
...@@ -245,9 +245,9 @@ class MarigoldImageProcessor(ConfigMixin): ...@@ -245,9 +245,9 @@ class MarigoldImageProcessor(ConfigMixin):
) -> Union[np.ndarray, torch.Tensor]: ) -> Union[np.ndarray, torch.Tensor]:
""" """
Converts a monochrome image into an RGB image by applying the specified colormap. This function mimics the Converts a monochrome image into an RGB image by applying the specified colormap. This function mimics the
behavior of matplotlib.colormaps, but allows the user to use the most discriminative color map "Spectral" behavior of matplotlib.colormaps, but allows the user to use the most discriminative color maps ("Spectral",
without having to install or import matplotlib. For all other cases, the function will attempt to use the "binary") without having to install or import matplotlib. For all other cases, the function will attempt to use
native implementation. the native implementation.
Args: Args:
image: 2D tensor of values between 0 and 1, either as np.ndarray or torch.Tensor. image: 2D tensor of values between 0 and 1, either as np.ndarray or torch.Tensor.
...@@ -255,7 +255,7 @@ class MarigoldImageProcessor(ConfigMixin): ...@@ -255,7 +255,7 @@ class MarigoldImageProcessor(ConfigMixin):
bytes: Whether to return the output as uint8 or floating point image. bytes: Whether to return the output as uint8 or floating point image.
_force_method: _force_method:
Can be used to specify whether to use the native implementation (`"matplotlib"`), the efficient custom Can be used to specify whether to use the native implementation (`"matplotlib"`), the efficient custom
implementation of the "Spectral" color map (`"custom"`), or rely on autodetection (`None`, default). implementation of the select color maps (`"custom"`), or rely on autodetection (`None`, default).
Returns: Returns:
An RGB-colorized tensor corresponding to the input image. An RGB-colorized tensor corresponding to the input image.
...@@ -265,6 +265,26 @@ class MarigoldImageProcessor(ConfigMixin): ...@@ -265,6 +265,26 @@ class MarigoldImageProcessor(ConfigMixin):
if _force_method not in (None, "matplotlib", "custom"): if _force_method not in (None, "matplotlib", "custom"):
raise ValueError("_force_method must be either `None`, `'matplotlib'` or `'custom'`.") raise ValueError("_force_method must be either `None`, `'matplotlib'` or `'custom'`.")
supported_cmaps = {
"binary": [
(1.0, 1.0, 1.0),
(0.0, 0.0, 0.0),
],
"Spectral": [ # Taken from matplotlib/_cm.py
(0.61960784313725492, 0.003921568627450980, 0.25882352941176473), # 0.0 -> [0]
(0.83529411764705885, 0.24313725490196078, 0.30980392156862746),
(0.95686274509803926, 0.42745098039215684, 0.2627450980392157),
(0.99215686274509807, 0.68235294117647061, 0.38039215686274508),
(0.99607843137254903, 0.8784313725490196, 0.54509803921568623),
(1.0, 1.0, 0.74901960784313726),
(0.90196078431372551, 0.96078431372549022, 0.59607843137254901),
(0.6705882352941176, 0.8666666666666667, 0.64313725490196083),
(0.4, 0.76078431372549016, 0.6470588235294118),
(0.19607843137254902, 0.53333333333333333, 0.74117647058823533),
(0.36862745098039218, 0.30980392156862746, 0.63529411764705879), # 1.0 -> [K-1]
],
}
def method_matplotlib(image, cmap, bytes=False): def method_matplotlib(image, cmap, bytes=False):
if is_matplotlib_available(): if is_matplotlib_available():
import matplotlib import matplotlib
...@@ -298,24 +318,19 @@ class MarigoldImageProcessor(ConfigMixin): ...@@ -298,24 +318,19 @@ class MarigoldImageProcessor(ConfigMixin):
else: else:
image = image.float() image = image.float()
if cmap != "Spectral": is_cmap_reversed = cmap.endswith("_r")
raise ValueError("Only 'Spectral' color map is available without installing matplotlib.") if is_cmap_reversed:
cmap = cmap[:-2]
_Spectral_data = ( # Taken from matplotlib/_cm.py if cmap not in supported_cmaps:
(0.61960784313725492, 0.003921568627450980, 0.25882352941176473), # 0.0 -> [0] raise ValueError(
(0.83529411764705885, 0.24313725490196078, 0.30980392156862746), f"Only {list(supported_cmaps.keys())} color maps are available without installing matplotlib."
(0.95686274509803926, 0.42745098039215684, 0.2627450980392157),
(0.99215686274509807, 0.68235294117647061, 0.38039215686274508),
(0.99607843137254903, 0.8784313725490196, 0.54509803921568623),
(1.0, 1.0, 0.74901960784313726),
(0.90196078431372551, 0.96078431372549022, 0.59607843137254901),
(0.6705882352941176, 0.8666666666666667, 0.64313725490196083),
(0.4, 0.76078431372549016, 0.6470588235294118),
(0.19607843137254902, 0.53333333333333333, 0.74117647058823533),
(0.36862745098039218, 0.30980392156862746, 0.63529411764705879), # 1.0 -> [K-1]
) )
cmap = torch.tensor(_Spectral_data, dtype=torch.float, device=image.device) # [K,3] cmap = supported_cmaps[cmap]
if is_cmap_reversed:
cmap = cmap[::-1]
cmap = torch.tensor(cmap, dtype=torch.float, device=image.device) # [K,3]
K = cmap.shape[0] K = cmap.shape[0]
pos = image.clamp(min=0, max=1) * (K - 1) pos = image.clamp(min=0, max=1) * (K - 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment