import logging from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple, Union from functools import partial import math import torch from einops import rearrange, repeat from ...util import append_dims, default, instantiate_from_config class Guider(ABC): @abstractmethod def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: pass def prepare_inputs( self, x: torch.Tensor, s: float, c: Dict, uc: Dict ) -> Tuple[torch.Tensor, float, Dict]: pass class VanillaCFG: """ implements parallelized CFG """ def __init__(self, scale, dyn_thresh_config=None): scale_schedule = lambda scale, sigma: scale # independent of step self.scale_schedule = partial(scale_schedule, scale) self.dyn_thresh = instantiate_from_config( default( dyn_thresh_config, { "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" }, ) ) def __call__(self, x, sigma, step = None, num_steps = None, **kwargs): x_u, x_c = x.chunk(2) scale_value = self.scale_schedule(sigma) x_pred = self.dyn_thresh(x_u, x_c, scale_value, step=step, num_steps=num_steps) return x_pred def prepare_inputs(self, x, s, c, uc): c_out = dict() for k in c: if k in ["vector", "crossattn", "concat"]: c_out[k] = torch.cat((uc[k], c[k]), 0) else: assert c[k] == uc[k] c_out[k] = c[k] return torch.cat([x] * 2), torch.cat([s] * 2), c_out class IdentityGuider: def __call__(self, x, sigma, **kwargs): return x def prepare_inputs(self, x, s, c, uc): c_out = dict() for k in c: c_out[k] = c[k] return x, s, c_out class LinearPredictionGuider(Guider): def __init__( self, max_scale: float, num_frames: int, min_scale: float = 1.0, additional_cond_keys: Optional[Union[List[str], str]] = None, ): self.min_scale = min_scale self.max_scale = max_scale self.num_frames = num_frames self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0) additional_cond_keys = default(additional_cond_keys, []) if isinstance(additional_cond_keys, str): additional_cond_keys = [additional_cond_keys] self.additional_cond_keys = additional_cond_keys def __call__(self, x: torch.Tensor, sigma: torch.Tensor, **kwargs) -> torch.Tensor: x_u, x_c = x.chunk(2) x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) scale = append_dims(scale, x_u.ndim).to(x_u.device) scale = scale.to(x_u.dtype) return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") def prepare_inputs( self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict ) -> Tuple[torch.Tensor, torch.Tensor, dict]: c_out = dict() for k in c: if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: c_out[k] = torch.cat((uc[k], c[k]), 0) else: assert c[k] == uc[k] c_out[k] = c[k] return torch.cat([x] * 2), torch.cat([s] * 2), c_out