from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from ...models.controlnet import ControlNetModel, ControlNetOutput from ...models.modeling_utils import ModelMixin class MultiControlNetModel(ModelMixin): r""" Multiple `ControlNetModel` wrapper class for Multi-ControlNet This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be compatible with `ControlNetModel`. Args: controlnets (`List[ControlNetModel]`): Provides additional conditioning to the unet during the denoising process. You must set multiple `ControlNetModel` as a list. """ def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): super().__init__() self.nets = nn.ModuleList(controlnets) def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: List[torch.tensor], conditioning_scale: List[float], class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, ) -> Union[ControlNetOutput, Tuple]: for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): down_samples, mid_sample = controlnet( sample, timestep, encoder_hidden_states, image, scale, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, guess_mode, return_dict, ) # merge samples if i == 0: down_block_res_samples, mid_block_res_sample = down_samples, mid_sample else: down_block_res_samples = [ samples_prev + samples_curr for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) ] mid_block_res_sample += mid_sample return down_block_res_samples, mid_block_res_sample