Unverified Commit 96220390 authored by linjiapro's avatar linjiapro Committed by GitHub
Browse files

Fix a bug for SD35 control net training and improve control net block index (#10065)



* wip

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 73dac0c4
...@@ -393,6 +393,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -393,6 +393,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
return custom_forward return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
if self.context_embedder is not None:
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), create_custom_forward(block),
hidden_states, hidden_states,
...@@ -400,6 +401,11 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -400,6 +401,11 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
temb, temb,
**ckpt_kwargs, **ckpt_kwargs,
) )
else:
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), hidden_states, temb, **ckpt_kwargs
)
else: else:
if self.context_embedder is not None: if self.context_embedder is not None:
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -424,8 +423,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi ...@@ -424,8 +423,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
# controlnet residual # controlnet residual
if block_controlnet_hidden_states is not None and block.context_pre_only is False: if block_controlnet_hidden_states is not None and block.context_pre_only is False:
interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states) interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
interval_control = int(np.ceil(interval_control)) hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment