"vscode:/vscode.git/clone" did not exist on "130d84a2defe043b3e8dcfa1675820257f92a2e5"
Unverified Commit 3376252d authored by Carolinabanana's avatar Carolinabanana Committed by GitHub
Browse files

Fix gradient checkpointing issue for Stable Diffusion 3 (#8542)


Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 16170c69
...@@ -306,7 +306,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi ...@@ -306,7 +306,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
return custom_forward return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint( encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), create_custom_forward(block),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment