Unverified Commit cfc99adf authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add global pooling to controlnet (#3121)

parent 807f69b3
......@@ -119,6 +119,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
):
super().__init__()
......@@ -566,6 +567,12 @@ class ControlNetModel(ModelMixin, ConfigMixin):
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample *= conditioning_scale
if self.config.global_pool_conditions:
down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
]
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
if not return_dict:
return (down_block_res_samples, mid_block_res_sample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment