Unverified Commit c2e87869 authored by Bingxin Ke's avatar Bingxin Ke Committed by GitHub
Browse files

[Community pipeline] Marigold depth estimation update -- align with marigold v0.1.5 (#7524)

* add resample option; check denoise_step; update ckpt path

* Add seeding in pipeline to increase reproducibility

* fix typo

* fix typo
parent ca61287d
...@@ -85,14 +85,25 @@ This depth estimation pipeline processes a single input image through multiple d ...@@ -85,14 +85,25 @@ This depth estimation pipeline processes a single input image through multiple d
```python ```python
import numpy as np import numpy as np
import torch
from PIL import Image from PIL import Image
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
# Original DDIM version (higher quality)
pipe = DiffusionPipeline.from_pretrained(
"prs-eth/marigold-v1-0",
custom_pipeline="marigold_depth_estimation"
# torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float).
# variant="fp16", # (optional) Use with `torch_dtype=torch.float16`, to directly load fp16 checkpoint
)
# (New) LCM version (faster speed)
pipe = DiffusionPipeline.from_pretrained( pipe = DiffusionPipeline.from_pretrained(
"Bingxin/Marigold", "prs-eth/marigold-lcm-v1-0",
custom_pipeline="marigold_depth_estimation" custom_pipeline="marigold_depth_estimation"
# torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float). # torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float).
# variant="fp16", # (optional) Use with `torch_dtype=torch.float16`, to directly load fp16 checkpoint
) )
pipe.to("cuda") pipe.to("cuda")
...@@ -101,12 +112,21 @@ img_path_or_url = "https://share.phys.ethz.ch/~pf/bingkedata/marigold/pipeline_e ...@@ -101,12 +112,21 @@ img_path_or_url = "https://share.phys.ethz.ch/~pf/bingkedata/marigold/pipeline_e
image: Image.Image = load_image(img_path_or_url) image: Image.Image = load_image(img_path_or_url)
pipeline_output = pipe( pipeline_output = pipe(
image, # Input image. image, # Input image.
# ----- recommended setting for DDIM version -----
# denoising_steps=10, # (optional) Number of denoising steps of each inference pass. Default: 10. # denoising_steps=10, # (optional) Number of denoising steps of each inference pass. Default: 10.
# ensemble_size=10, # (optional) Number of inference passes in the ensemble. Default: 10. # ensemble_size=10, # (optional) Number of inference passes in the ensemble. Default: 10.
# ------------------------------------------------
# ----- recommended setting for LCM version ------
# denoising_steps=4,
# ensemble_size=5,
# -------------------------------------------------
# processing_res=768, # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768. # processing_res=768, # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768.
# match_input_res=True, # (optional) Resize depth prediction to match input resolution. # match_input_res=True, # (optional) Resize depth prediction to match input resolution.
# batch_size=0, # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0. # batch_size=0, # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0.
# seed=2024, # (optional) Random seed can be set to ensure additional reproducibility. Default: None (unseeded). Note: forcing --batch_size 1 helps to increase reproducibility. To ensure full reproducibility, deterministic mode needs to be used.
# color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral". Set to `None` to skip colormap generation. # color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral". Set to `None` to skip colormap generation.
# show_progress_bar=True, # (optional) If true, will show progress bars of the inference progress. # show_progress_bar=True, # (optional) If true, will show progress bars of the inference progress.
) )
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
# -------------------------------------------------------------------------- # --------------------------------------------------------------------------
import logging
import math import math
from typing import Dict, Union from typing import Dict, Union
...@@ -25,6 +26,7 @@ import matplotlib ...@@ -25,6 +26,7 @@ import matplotlib
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from PIL.Image import Resampling
from scipy.optimize import minimize from scipy.optimize import minimize
from torch.utils.data import DataLoader, TensorDataset from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -34,13 +36,14 @@ from diffusers import ( ...@@ -34,13 +36,14 @@ from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMScheduler, DDIMScheduler,
DiffusionPipeline, DiffusionPipeline,
LCMScheduler,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils import BaseOutput, check_min_version from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.28.0.dev0") check_min_version("0.25.0")
class MarigoldDepthOutput(BaseOutput): class MarigoldDepthOutput(BaseOutput):
...@@ -61,6 +64,19 @@ class MarigoldDepthOutput(BaseOutput): ...@@ -61,6 +64,19 @@ class MarigoldDepthOutput(BaseOutput):
uncertainty: Union[None, np.ndarray] uncertainty: Union[None, np.ndarray]
def get_pil_resample_method(method_str: str) -> Resampling:
resample_method_dic = {
"bilinear": Resampling.BILINEAR,
"bicubic": Resampling.BICUBIC,
"nearest": Resampling.NEAREST,
}
resample_method = resample_method_dic.get(method_str, None)
if resample_method is None:
raise ValueError(f"Unknown resampling method: {resample_method}")
else:
return resample_method
class MarigoldPipeline(DiffusionPipeline): class MarigoldPipeline(DiffusionPipeline):
""" """
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io. Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
...@@ -113,7 +129,9 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -113,7 +129,9 @@ class MarigoldPipeline(DiffusionPipeline):
ensemble_size: int = 10, ensemble_size: int = 10,
processing_res: int = 768, processing_res: int = 768,
match_input_res: bool = True, match_input_res: bool = True,
resample_method: str = "bilinear",
batch_size: int = 0, batch_size: int = 0,
seed: Union[int, None] = None,
color_map: str = "Spectral", color_map: str = "Spectral",
show_progress_bar: bool = True, show_progress_bar: bool = True,
ensemble_kwargs: Dict = None, ensemble_kwargs: Dict = None,
...@@ -129,7 +147,9 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -129,7 +147,9 @@ class MarigoldPipeline(DiffusionPipeline):
If set to 0: will not resize at all. If set to 0: will not resize at all.
match_input_res (`bool`, *optional*, defaults to `True`): match_input_res (`bool`, *optional*, defaults to `True`):
Resize depth prediction to match input resolution. Resize depth prediction to match input resolution.
Only valid if `limit_input_res` is not None. Only valid if `processing_res` > 0.
resample_method: (`str`, *optional*, defaults to `bilinear`):
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
denoising_steps (`int`, *optional*, defaults to `10`): denoising_steps (`int`, *optional*, defaults to `10`):
Number of diffusion denoising steps (DDIM) during inference. Number of diffusion denoising steps (DDIM) during inference.
ensemble_size (`int`, *optional*, defaults to `10`): ensemble_size (`int`, *optional*, defaults to `10`):
...@@ -137,6 +157,8 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -137,6 +157,8 @@ class MarigoldPipeline(DiffusionPipeline):
batch_size (`int`, *optional*, defaults to `0`): batch_size (`int`, *optional*, defaults to `0`):
Inference batch size, no bigger than `num_ensemble`. Inference batch size, no bigger than `num_ensemble`.
If set to 0, the script will automatically decide the proper batch size. If set to 0, the script will automatically decide the proper batch size.
seed (`int`, *optional*, defaults to `None`)
Reproducibility seed.
show_progress_bar (`bool`, *optional*, defaults to `True`): show_progress_bar (`bool`, *optional*, defaults to `True`):
Display a progress bar of diffusion denoising. Display a progress bar of diffusion denoising.
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation): color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
...@@ -146,8 +168,7 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -146,8 +168,7 @@ class MarigoldPipeline(DiffusionPipeline):
Returns: Returns:
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1] - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
- **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
values in [0, 1]. None if `color_map` is `None`
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
coming from ensembling. None if `ensemble_size = 1` coming from ensembling. None if `ensemble_size = 1`
""" """
...@@ -158,13 +179,21 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -158,13 +179,21 @@ class MarigoldPipeline(DiffusionPipeline):
if not match_input_res: if not match_input_res:
assert processing_res is not None, "Value error: `resize_output_back` is only valid with " assert processing_res is not None, "Value error: `resize_output_back` is only valid with "
assert processing_res >= 0 assert processing_res >= 0
assert denoising_steps >= 1
assert ensemble_size >= 1 assert ensemble_size >= 1
# Check if denoising step is reasonable
self._check_inference_step(denoising_steps)
resample_method: Resampling = get_pil_resample_method(resample_method)
# ----------------- Image Preprocess ----------------- # ----------------- Image Preprocess -----------------
# Resize image # Resize image
if processing_res > 0: if processing_res > 0:
input_image = self.resize_max_res(input_image, max_edge_resolution=processing_res) input_image = self.resize_max_res(
input_image,
max_edge_resolution=processing_res,
resample_method=resample_method,
)
# Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel # Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
input_image = input_image.convert("RGB") input_image = input_image.convert("RGB")
image = np.asarray(input_image) image = np.asarray(input_image)
...@@ -203,9 +232,10 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -203,9 +232,10 @@ class MarigoldPipeline(DiffusionPipeline):
rgb_in=batched_img, rgb_in=batched_img,
num_inference_steps=denoising_steps, num_inference_steps=denoising_steps,
show_pbar=show_progress_bar, show_pbar=show_progress_bar,
seed=seed,
) )
depth_pred_ls.append(depth_pred_raw.detach().clone()) depth_pred_ls.append(depth_pred_raw.detach())
depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze() depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze()
torch.cuda.empty_cache() # clear vram cache for ensembling torch.cuda.empty_cache() # clear vram cache for ensembling
# ----------------- Test-time ensembling ----------------- # ----------------- Test-time ensembling -----------------
...@@ -227,7 +257,7 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -227,7 +257,7 @@ class MarigoldPipeline(DiffusionPipeline):
# Resize back to original resolution # Resize back to original resolution
if match_input_res: if match_input_res:
pred_img = Image.fromarray(depth_pred) pred_img = Image.fromarray(depth_pred)
pred_img = pred_img.resize(input_size) pred_img = pred_img.resize(input_size, resample=resample_method)
depth_pred = np.asarray(pred_img) depth_pred = np.asarray(pred_img)
# Clip output range # Clip output range
...@@ -243,12 +273,32 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -243,12 +273,32 @@ class MarigoldPipeline(DiffusionPipeline):
depth_colored_img = Image.fromarray(depth_colored_hwc) depth_colored_img = Image.fromarray(depth_colored_hwc)
else: else:
depth_colored_img = None depth_colored_img = None
return MarigoldDepthOutput( return MarigoldDepthOutput(
depth_np=depth_pred, depth_np=depth_pred,
depth_colored=depth_colored_img, depth_colored=depth_colored_img,
uncertainty=pred_uncert, uncertainty=pred_uncert,
) )
def _check_inference_step(self, n_step: int):
"""
Check if denoising step is reasonable
Args:
n_step (`int`): denoising steps
"""
assert n_step >= 1
if isinstance(self.scheduler, DDIMScheduler):
if n_step < 10:
logging.warning(
f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
)
elif isinstance(self.scheduler, LCMScheduler):
if not 1 <= n_step <= 4:
logging.warning(f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps.")
else:
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
def _encode_empty_text(self): def _encode_empty_text(self):
""" """
Encode text embedding for empty prompt. Encode text embedding for empty prompt.
...@@ -265,7 +315,13 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -265,7 +315,13 @@ class MarigoldPipeline(DiffusionPipeline):
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
@torch.no_grad() @torch.no_grad()
def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool) -> torch.Tensor: def single_infer(
self,
rgb_in: torch.Tensor,
num_inference_steps: int,
seed: Union[int, None],
show_pbar: bool,
) -> torch.Tensor:
""" """
Perform an individual depth prediction without ensembling. Perform an individual depth prediction without ensembling.
...@@ -286,10 +342,20 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -286,10 +342,20 @@ class MarigoldPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps # [T] timesteps = self.scheduler.timesteps # [T]
# Encode image # Encode image
rgb_latent = self._encode_rgb(rgb_in) rgb_latent = self.encode_rgb(rgb_in)
# Initial depth map (noise) # Initial depth map (noise)
depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype) # [B, 4, h, w] if seed is None:
rand_num_generator = None
else:
rand_num_generator = torch.Generator(device=device)
rand_num_generator.manual_seed(seed)
depth_latent = torch.randn(
rgb_latent.shape,
device=device,
dtype=self.dtype,
generator=rand_num_generator,
) # [B, 4, h, w]
# Batched empty text embedding # Batched empty text embedding
if self.empty_text_embed is None: if self.empty_text_embed is None:
...@@ -314,9 +380,9 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -314,9 +380,9 @@ class MarigoldPipeline(DiffusionPipeline):
noise_pred = self.unet(unet_input, t, encoder_hidden_states=batch_empty_text_embed).sample # [B, 4, h, w] noise_pred = self.unet(unet_input, t, encoder_hidden_states=batch_empty_text_embed).sample # [B, 4, h, w]
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample depth_latent = self.scheduler.step(noise_pred, t, depth_latent, generator=rand_num_generator).prev_sample
torch.cuda.empty_cache()
depth = self._decode_depth(depth_latent) depth = self.decode_depth(depth_latent)
# clip prediction # clip prediction
depth = torch.clip(depth, -1.0, 1.0) depth = torch.clip(depth, -1.0, 1.0)
...@@ -325,7 +391,7 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -325,7 +391,7 @@ class MarigoldPipeline(DiffusionPipeline):
return depth return depth
def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
""" """
Encode RGB image into latent. Encode RGB image into latent.
...@@ -344,7 +410,7 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -344,7 +410,7 @@ class MarigoldPipeline(DiffusionPipeline):
rgb_latent = mean * self.rgb_latent_scale_factor rgb_latent = mean * self.rgb_latent_scale_factor
return rgb_latent return rgb_latent
def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
""" """
Decode depth latent into depth map. Decode depth latent into depth map.
...@@ -365,7 +431,7 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -365,7 +431,7 @@ class MarigoldPipeline(DiffusionPipeline):
return depth_mean return depth_mean
@staticmethod @staticmethod
def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image: def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image:
""" """
Resize image to limit maximum edge length while keeping aspect ratio. Resize image to limit maximum edge length while keeping aspect ratio.
...@@ -374,6 +440,8 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -374,6 +440,8 @@ class MarigoldPipeline(DiffusionPipeline):
Image to be resized. Image to be resized.
max_edge_resolution (`int`): max_edge_resolution (`int`):
Maximum edge length (pixel). Maximum edge length (pixel).
resample_method (`PIL.Image.Resampling`):
Resampling method used to resize images.
Returns: Returns:
`Image.Image`: Resized image. `Image.Image`: Resized image.
...@@ -384,7 +452,7 @@ class MarigoldPipeline(DiffusionPipeline): ...@@ -384,7 +452,7 @@ class MarigoldPipeline(DiffusionPipeline):
new_width = int(original_width * downscale_factor) new_width = int(original_width * downscale_factor)
new_height = int(original_height * downscale_factor) new_height = int(original_height * downscale_factor)
resized_img = img.resize((new_width, new_height)) resized_img = img.resize((new_width, new_height), resample=resample_method)
return resized_img return resized_img
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment