Unverified Commit 96220390 authored by linjiapro's avatar linjiapro Committed by GitHub
Browse files

Fix a bug for SD35 control net training and improve control net block index (#10065)



* wip

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 73dac0c4
......@@ -393,6 +393,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
if self.context_embedder is not None:
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
......@@ -400,6 +401,11 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
temb,
**ckpt_kwargs,
)
else:
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), hidden_states, temb, **ckpt_kwargs
)
else:
if self.context_embedder is not None:
......
......@@ -15,7 +15,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -424,8 +423,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
# controlnet residual
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment