Unverified Commit 00f5b418 authored by pibbo88's avatar pibbo88 Committed by GitHub
Browse files

Fix the bug of sd3 controlnet training when using gradient checkpointing. (#9498)

Fix the bug of sd3 controlnet training when using gradient_checkpointing. Refer to issue #9496
parent 14f6464b
...@@ -336,7 +336,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -336,7 +336,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
return custom_forward return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint( encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), create_custom_forward(block),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment