Unverified Commit c4d4ac21 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Refactor gradient checkpointing (#10611)

* update

* remove unused fn

* apply suggestions based on review

* update + cleanup 🧹

* more cleanup 🧹

* make fix-copies

* update test
parent f295e2ee
...@@ -80,7 +80,6 @@ from diffusers.utils import ( ...@@ -80,7 +80,6 @@ from diffusers.utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
BaseOutput, BaseOutput,
deprecate, deprecate,
is_torch_version,
is_torch_xla_available, is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
...@@ -869,23 +868,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -869,23 +868,7 @@ class CrossAttnDownBlock2D(nn.Module):
for i, (resnet, attn) in enumerate(blocks): for i, (resnet, attn) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -1030,17 +1013,6 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -1030,17 +1013,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -1049,12 +1021,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -1049,12 +1021,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else: else:
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
...@@ -1192,23 +1159,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1192,23 +1159,7 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -1282,10 +1233,6 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin): ...@@ -1282,10 +1233,6 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
] ]
) )
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -1365,19 +1312,8 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin): ...@@ -1365,19 +1312,8 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
# Blocks # Blocks
for block in self.transformer_blocks: for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): block,
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -1385,7 +1321,6 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin): ...@@ -1385,7 +1321,6 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
timestep, timestep,
cross_attention_kwargs, cross_attention_kwargs,
class_labels, class_labels,
**ckpt_kwargs,
) )
else: else:
hidden_states = block( hidden_states = block(
...@@ -2724,10 +2659,6 @@ class MatryoshkaUNet2DConditionModel( ...@@ -2724,10 +2659,6 @@ class MatryoshkaUNet2DConditionModel(
for module in self.children(): for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
......
...@@ -8,7 +8,6 @@ from diffusers.models import PixArtTransformer2DModel ...@@ -8,7 +8,6 @@ from diffusers.models import PixArtTransformer2DModel
from diffusers.models.attention import BasicTransformerBlock from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils.torch_utils import is_torch_version
class PixArtControlNetAdapterBlock(nn.Module): class PixArtControlNetAdapterBlock(nn.Module):
...@@ -151,10 +150,6 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin): ...@@ -151,10 +150,6 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
self.transformer = transformer self.transformer = transformer
self.controlnet = controlnet self.controlnet = controlnet
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -220,18 +215,8 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin): ...@@ -220,18 +215,8 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.") print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1) exit(1)
def create_custom_forward(module, return_dict=None): hidden_states = self._gradient_checkpointing_func(
def custom_forward(*inputs): block,
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -239,7 +224,6 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin): ...@@ -239,7 +224,6 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
timestep, timestep,
cross_attention_kwargs, cross_attention_kwargs,
None, None,
**ckpt_kwargs,
) )
else: else:
# the control nets are only used for the blocks 1 to self.blocks_num # the control nets are only used for the blocks 1 to self.blocks_num
......
...@@ -138,10 +138,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter ...@@ -138,10 +138,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25 self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value
def enable_tiling(self, use_tiling: bool = True): def enable_tiling(self, use_tiling: bool = True):
r""" r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
......
...@@ -507,19 +507,12 @@ class AllegroEncoder3D(nn.Module): ...@@ -507,19 +507,12 @@ class AllegroEncoder3D(nn.Module):
sample = sample + residual sample = sample + residual
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# Down blocks # Down blocks
for down_block in self.down_blocks: for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) sample = self._gradient_checkpointing_func(down_block, sample)
# Mid block # Mid block
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) sample = self._gradient_checkpointing_func(self.mid_block, sample)
else: else:
# Down blocks # Down blocks
for down_block in self.down_blocks: for down_block in self.down_blocks:
...@@ -647,19 +640,12 @@ class AllegroDecoder3D(nn.Module): ...@@ -647,19 +640,12 @@ class AllegroDecoder3D(nn.Module):
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# Mid block # Mid block
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) sample = self._gradient_checkpointing_func(self.mid_block, sample)
# Up blocks # Up blocks
for up_block in self.up_blocks: for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) sample = self._gradient_checkpointing_func(up_block, sample)
else: else:
# Mid block # Mid block
...@@ -809,10 +795,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin): ...@@ -809,10 +795,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
sample_size - self.tile_overlap_w, sample_size - self.tile_overlap_w,
) )
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
module.gradient_checkpointing = value
def enable_tiling(self) -> None: def enable_tiling(self) -> None:
r""" r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
......
...@@ -421,15 +421,8 @@ class CogVideoXDownBlock3D(nn.Module): ...@@ -421,15 +421,8 @@ class CogVideoXDownBlock3D(nn.Module):
conv_cache_key = f"resnet_{i}" conv_cache_key = f"resnet_{i}"
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
def create_custom_forward(module): resnet,
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states, hidden_states,
temb, temb,
zq, zq,
...@@ -523,15 +516,8 @@ class CogVideoXMidBlock3D(nn.Module): ...@@ -523,15 +516,8 @@ class CogVideoXMidBlock3D(nn.Module):
conv_cache_key = f"resnet_{i}" conv_cache_key = f"resnet_{i}"
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
def create_custom_forward(module): resnet, hidden_states, temb, zq, conv_cache.get(conv_cache_key)
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
) )
else: else:
hidden_states, new_conv_cache[conv_cache_key] = resnet( hidden_states, new_conv_cache[conv_cache_key] = resnet(
...@@ -637,15 +623,8 @@ class CogVideoXUpBlock3D(nn.Module): ...@@ -637,15 +623,8 @@ class CogVideoXUpBlock3D(nn.Module):
conv_cache_key = f"resnet_{i}" conv_cache_key = f"resnet_{i}"
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
def create_custom_forward(module): resnet,
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states, hidden_states,
temb, temb,
zq, zq,
...@@ -774,18 +753,11 @@ class CogVideoXEncoder3D(nn.Module): ...@@ -774,18 +753,11 @@ class CogVideoXEncoder3D(nn.Module):
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 1. Down # 1. Down
for i, down_block in enumerate(self.down_blocks): for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}" conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
create_custom_forward(down_block), down_block,
hidden_states, hidden_states,
temb, temb,
None, None,
...@@ -793,8 +765,8 @@ class CogVideoXEncoder3D(nn.Module): ...@@ -793,8 +765,8 @@ class CogVideoXEncoder3D(nn.Module):
) )
# 2. Mid # 2. Mid
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint( hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
create_custom_forward(self.mid_block), self.mid_block,
hidden_states, hidden_states,
temb, temb,
None, None,
...@@ -940,16 +912,9 @@ class CogVideoXDecoder3D(nn.Module): ...@@ -940,16 +912,9 @@ class CogVideoXDecoder3D(nn.Module):
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 1. Mid # 1. Mid
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint( hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
create_custom_forward(self.mid_block), self.mid_block,
hidden_states, hidden_states,
temb, temb,
sample, sample,
...@@ -959,8 +924,8 @@ class CogVideoXDecoder3D(nn.Module): ...@@ -959,8 +924,8 @@ class CogVideoXDecoder3D(nn.Module):
# 2. Up # 2. Up
for i, up_block in enumerate(self.up_blocks): for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}" conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
create_custom_forward(up_block), up_block,
hidden_states, hidden_states,
temb, temb,
sample, sample,
...@@ -1122,10 +1087,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1122,10 +1087,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_overlap_factor_height = 1 / 6 self.tile_overlap_factor_height = 1 / 6
self.tile_overlap_factor_width = 1 / 5 self.tile_overlap_factor_width = 1 / 5
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
module.gradient_checkpointing = value
def enable_tiling( def enable_tiling(
self, self,
tile_sample_min_height: Optional[int] = None, tile_sample_min_height: Optional[int] = None,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -21,7 +21,7 @@ import torch.nn.functional as F ...@@ -21,7 +21,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation from ..activations import get_activation
from ..attention_processor import Attention from ..attention_processor import Attention
...@@ -252,21 +252,7 @@ class HunyuanVideoMidBlock3D(nn.Module): ...@@ -252,21 +252,7 @@ class HunyuanVideoMidBlock3D(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.resnets[0], hidden_states)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.resnets[0]), hidden_states, **ckpt_kwargs
)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None: if attn is not None:
...@@ -278,9 +264,7 @@ class HunyuanVideoMidBlock3D(nn.Module): ...@@ -278,9 +264,7 @@ class HunyuanVideoMidBlock3D(nn.Module):
hidden_states = attn(hidden_states, attention_mask=attention_mask) hidden_states = attn(hidden_states, attention_mask=attention_mask)
hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
)
else: else:
hidden_states = self.resnets[0](hidden_states) hidden_states = self.resnets[0](hidden_states)
...@@ -350,22 +334,8 @@ class HunyuanVideoDownBlock3D(nn.Module): ...@@ -350,22 +334,8 @@ class HunyuanVideoDownBlock3D(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for resnet in self.resnets: for resnet in self.resnets:
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
)
else: else:
for resnet in self.resnets: for resnet in self.resnets:
hidden_states = resnet(hidden_states) hidden_states = resnet(hidden_states)
...@@ -426,22 +396,8 @@ class HunyuanVideoUpBlock3D(nn.Module): ...@@ -426,22 +396,8 @@ class HunyuanVideoUpBlock3D(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for resnet in self.resnets: for resnet in self.resnets:
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
)
else: else:
for resnet in self.resnets: for resnet in self.resnets:
...@@ -545,26 +501,10 @@ class HunyuanVideoEncoder3D(nn.Module): ...@@ -545,26 +501,10 @@ class HunyuanVideoEncoder3D(nn.Module):
hidden_states = self.conv_in(hidden_states) hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for down_block in self.down_blocks: for down_block in self.down_blocks:
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
create_custom_forward(down_block), hidden_states, **ckpt_kwargs
)
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
)
else: else:
for down_block in self.down_blocks: for down_block in self.down_blocks:
hidden_states = down_block(hidden_states) hidden_states = down_block(hidden_states)
...@@ -667,26 +607,10 @@ class HunyuanVideoDecoder3D(nn.Module): ...@@ -667,26 +607,10 @@ class HunyuanVideoDecoder3D(nn.Module):
hidden_states = self.conv_in(hidden_states) hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
)
for up_block in self.up_blocks: for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
create_custom_forward(up_block), hidden_states, **ckpt_kwargs
)
else: else:
hidden_states = self.mid_block(hidden_states) hidden_states = self.mid_block(hidden_states)
...@@ -800,10 +724,6 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): ...@@ -800,10 +724,6 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
self.tile_sample_stride_width = 192 self.tile_sample_stride_width = 192
self.tile_sample_stride_num_frames = 12 self.tile_sample_stride_num_frames = 12
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)):
module.gradient_checkpointing = value
def enable_tiling( def enable_tiling(
self, self,
tile_sample_min_height: Optional[int] = None, tile_sample_min_height: Optional[int] = None,
......
...@@ -338,16 +338,7 @@ class LTXVideoDownBlock3D(nn.Module): ...@@ -338,16 +338,7 @@ class LTXVideoDownBlock3D(nn.Module):
for i, resnet in enumerate(self.resnets): for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, generator
)
else: else:
hidden_states = resnet(hidden_states, temb, generator) hidden_states = resnet(hidden_states, temb, generator)
...@@ -438,16 +429,7 @@ class LTXVideoMidBlock3d(nn.Module): ...@@ -438,16 +429,7 @@ class LTXVideoMidBlock3d(nn.Module):
for i, resnet in enumerate(self.resnets): for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, generator
)
else: else:
hidden_states = resnet(hidden_states, temb, generator) hidden_states = resnet(hidden_states, temb, generator)
...@@ -573,16 +555,7 @@ class LTXVideoUpBlock3d(nn.Module): ...@@ -573,16 +555,7 @@ class LTXVideoUpBlock3d(nn.Module):
for i, resnet in enumerate(self.resnets): for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, generator
)
else: else:
hidden_states = resnet(hidden_states, temb, generator) hidden_states = resnet(hidden_states, temb, generator)
...@@ -697,17 +670,10 @@ class LTXVideoEncoder3d(nn.Module): ...@@ -697,17 +670,10 @@ class LTXVideoEncoder3d(nn.Module):
hidden_states = self.conv_in(hidden_states) hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
for down_block in self.down_blocks: for down_block in self.down_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), hidden_states) hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states) hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
else: else:
for down_block in self.down_blocks: for down_block in self.down_blocks:
hidden_states = down_block(hidden_states) hidden_states = down_block(hidden_states)
...@@ -838,19 +804,10 @@ class LTXVideoDecoder3d(nn.Module): ...@@ -838,19 +804,10 @@ class LTXVideoDecoder3d(nn.Module):
hidden_states = self.conv_in(hidden_states) hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, temb
)
for up_block in self.up_blocks: for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb) hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb)
else: else:
hidden_states = self.mid_block(hidden_states, temb) hidden_states = self.mid_block(hidden_states, temb)
...@@ -1017,10 +974,6 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1017,10 +974,6 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_width = 448 self.tile_sample_stride_width = 448
self.tile_sample_stride_num_frames = 8 self.tile_sample_stride_num_frames = 8
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
module.gradient_checkpointing = value
def enable_tiling( def enable_tiling(
self, self,
tile_sample_min_height: Optional[int] = None, tile_sample_min_height: Optional[int] = None,
......
...@@ -207,15 +207,8 @@ class MochiDownBlock3D(nn.Module): ...@@ -207,15 +207,8 @@ class MochiDownBlock3D(nn.Module):
conv_cache_key = f"resnet_{i}" conv_cache_key = f"resnet_{i}"
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
def create_custom_forward(module): resnet,
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states, hidden_states,
conv_cache=conv_cache.get(conv_cache_key), conv_cache=conv_cache.get(conv_cache_key),
) )
...@@ -312,15 +305,8 @@ class MochiMidBlock3D(nn.Module): ...@@ -312,15 +305,8 @@ class MochiMidBlock3D(nn.Module):
conv_cache_key = f"resnet_{i}" conv_cache_key = f"resnet_{i}"
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
def create_custom_forward(module): resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, conv_cache=conv_cache.get(conv_cache_key)
) )
else: else:
hidden_states, new_conv_cache[conv_cache_key] = resnet( hidden_states, new_conv_cache[conv_cache_key] = resnet(
...@@ -393,15 +379,8 @@ class MochiUpBlock3D(nn.Module): ...@@ -393,15 +379,8 @@ class MochiUpBlock3D(nn.Module):
conv_cache_key = f"resnet_{i}" conv_cache_key = f"resnet_{i}"
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
def create_custom_forward(module): resnet,
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states, hidden_states,
conv_cache=conv_cache.get(conv_cache_key), conv_cache=conv_cache.get(conv_cache_key),
) )
...@@ -531,21 +510,14 @@ class MochiEncoder3D(nn.Module): ...@@ -531,21 +510,14 @@ class MochiEncoder3D(nn.Module):
hidden_states = hidden_states.permute(0, 4, 1, 2, 3) hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
def create_custom_forward(module): self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in")
) )
for i, down_block in enumerate(self.down_blocks): for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}" conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
create_custom_forward(down_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key) down_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
) )
else: else:
hidden_states, new_conv_cache["block_in"] = self.block_in( hidden_states, new_conv_cache["block_in"] = self.block_in(
...@@ -648,21 +620,14 @@ class MochiDecoder3D(nn.Module): ...@@ -648,21 +620,14 @@ class MochiDecoder3D(nn.Module):
# 1. Mid # 1. Mid
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
def create_custom_forward(module): self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in")
) )
for i, up_block in enumerate(self.up_blocks): for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}" conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
create_custom_forward(up_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key) up_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
) )
else: else:
hidden_states, new_conv_cache["block_in"] = self.block_in( hidden_states, new_conv_cache["block_in"] = self.block_in(
...@@ -819,10 +784,6 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin): ...@@ -819,10 +784,6 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin):
self.tile_sample_stride_height = 192 self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192 self.tile_sample_stride_width = 192
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (MochiEncoder3D, MochiDecoder3D)):
module.gradient_checkpointing = value
def enable_tiling( def enable_tiling(
self, self,
tile_sample_min_height: Optional[int] = None, tile_sample_min_height: Optional[int] = None,
......
...@@ -18,7 +18,6 @@ import torch ...@@ -18,7 +18,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version
from ...utils.accelerate_utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from ..modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
...@@ -97,35 +96,9 @@ class TemporalDecoder(nn.Module): ...@@ -97,35 +96,9 @@ class TemporalDecoder(nn.Module):
upscale_dtype = next(itertools.chain(self.up_blocks.parameters(), self.up_blocks.buffers())).dtype upscale_dtype = next(itertools.chain(self.up_blocks.parameters(), self.up_blocks.buffers())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
image_only_indicator,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
image_only_indicator,
use_reentrant=False,
)
else:
# middle # middle
sample = torch.utils.checkpoint.checkpoint( sample = self._gradient_checkpointing_func(
create_custom_forward(self.mid_block), self.mid_block,
sample, sample,
image_only_indicator, image_only_indicator,
) )
...@@ -133,8 +106,8 @@ class TemporalDecoder(nn.Module): ...@@ -133,8 +106,8 @@ class TemporalDecoder(nn.Module):
# up # up
for up_block in self.up_blocks: for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint( sample = self._gradient_checkpointing_func(
create_custom_forward(up_block), up_block,
sample, sample,
image_only_indicator, image_only_indicator,
) )
...@@ -229,10 +202,6 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin): ...@@ -229,10 +202,6 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, TemporalDecoder)):
module.gradient_checkpointing = value
@property @property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
......
...@@ -154,10 +154,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -154,10 +154,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
self.register_to_config(block_out_channels=decoder_block_out_channels) self.register_to_config(block_out_channels=decoder_block_out_channels)
self.register_to_config(force_upcast=False) self.register_to_config(force_upcast=False)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (EncoderTiny, DecoderTiny)):
module.gradient_checkpointing = value
def scale_latents(self, x: torch.Tensor) -> torch.Tensor: def scale_latents(self, x: torch.Tensor) -> torch.Tensor:
"""raw latents -> [0, 1]""" """raw latents -> [0, 1]"""
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1) return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
......
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...utils import BaseOutput, is_torch_version from ...utils import BaseOutput
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..activations import get_activation from ..activations import get_activation
from ..attention_processor import SpatialNorm from ..attention_processor import SpatialNorm
...@@ -156,28 +156,11 @@ class Encoder(nn.Module): ...@@ -156,28 +156,11 @@ class Encoder(nn.Module):
sample = self.conv_in(sample) sample = self.conv_in(sample)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# down # down
if is_torch_version(">=", "1.11.0"):
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), sample, use_reentrant=False
)
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, use_reentrant=False
)
else:
for down_block in self.down_blocks: for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) sample = self._gradient_checkpointing_func(down_block, sample)
# middle # middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) sample = self._gradient_checkpointing_func(self.mid_block, sample)
else: else:
# down # down
...@@ -305,41 +288,13 @@ class Decoder(nn.Module): ...@@ -305,41 +288,13 @@ class Decoder(nn.Module):
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle # middle
sample = torch.utils.checkpoint.checkpoint( sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype) sample = sample.to(upscale_dtype)
# up # up
for up_block in self.up_blocks: for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint( sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
else: else:
# middle # middle
sample = self.mid_block(sample, latent_embeds) sample = self.mid_block(sample, latent_embeds)
...@@ -558,59 +513,15 @@ class MaskConditionDecoder(nn.Module): ...@@ -558,59 +513,15 @@ class MaskConditionDecoder(nn.Module):
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# condition encoder
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.condition_encoder),
masked_image,
mask,
use_reentrant=False,
)
# up
for up_block in self.up_blocks:
if image is not None and mask is not None:
sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
sample = sample * mask_ + sample_ * (1 - mask_)
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
else:
# middle # middle
sample = torch.utils.checkpoint.checkpoint( sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
create_custom_forward(self.mid_block), sample, latent_embeds
)
sample = sample.to(upscale_dtype) sample = sample.to(upscale_dtype)
# condition encoder # condition encoder
if image is not None and mask is not None: if image is not None and mask is not None:
masked_image = (1 - mask) * image masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint( im_x = self._gradient_checkpointing_func(
create_custom_forward(self.condition_encoder), self.condition_encoder,
masked_image, masked_image,
mask, mask,
) )
...@@ -621,7 +532,7 @@ class MaskConditionDecoder(nn.Module): ...@@ -621,7 +532,7 @@ class MaskConditionDecoder(nn.Module):
sample_ = im_x[str(tuple(sample.shape))] sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
sample = sample * mask_ + sample_ * (1 - mask_) sample = sample * mask_ + sample_ * (1 - mask_)
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
if image is not None and mask is not None: if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
else: else:
...@@ -890,17 +801,7 @@ class EncoderTiny(nn.Module): ...@@ -890,17 +801,7 @@ class EncoderTiny(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `EncoderTiny` class.""" r"""The forward method of the `EncoderTiny` class."""
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
x = self._gradient_checkpointing_func(self.layers, x)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
else: else:
# scale image from [-1, 1] to [0, 1] to match TAESD convention # scale image from [-1, 1] to [0, 1] to match TAESD convention
...@@ -976,18 +877,7 @@ class DecoderTiny(nn.Module): ...@@ -976,18 +877,7 @@ class DecoderTiny(nn.Module):
x = torch.tanh(x / 3) * 3 x = torch.tanh(x / 3) * 3
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
x = self._gradient_checkpointing_func(self.layers, x)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
else: else:
x = self.layers(x) x = self.layers(x)
......
...@@ -31,8 +31,6 @@ from ..attention_processor import ( ...@@ -31,8 +31,6 @@ from ..attention_processor import (
from ..embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from ..embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..unets.unet_2d_blocks import ( from ..unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2D, UNetMidBlock2D,
UNetMidBlock2DCrossAttn, UNetMidBlock2DCrossAttn,
get_down_block, get_down_block,
...@@ -659,10 +657,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -659,10 +657,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for module in self.children(): for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
sample: torch.Tensor, sample: torch.Tensor,
......
...@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from ...models.attention_processor import AttentionProcessor from ...models.attention_processor import AttentionProcessor
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
...@@ -178,10 +178,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -178,10 +178,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@classmethod @classmethod
def from_transformer( def from_transformer(
cls, cls,
...@@ -330,24 +326,12 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -330,24 +326,12 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
block_samples = () block_samples = ()
for index_block, block in enumerate(self.transformer_blocks): for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): block,
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
temb, temb,
image_rotary_emb, image_rotary_emb,
**ckpt_kwargs,
) )
else: else:
...@@ -364,23 +348,11 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -364,23 +348,11 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
single_block_samples = () single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks): for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): block,
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
temb, temb,
image_rotary_emb, image_rotary_emb,
**ckpt_kwargs,
) )
else: else:
......
...@@ -21,7 +21,7 @@ import torch.nn as nn ...@@ -21,7 +21,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import JointTransformerBlock from ..attention import JointTransformerBlock
from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
...@@ -262,10 +262,6 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -262,10 +262,6 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors) self.set_attn_processor(self.original_attn_processors)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
# Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer # Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer
# we should have handled this in conversion script # we should have handled this in conversion script
def _get_pos_embed_from_transformer(self, transformer): def _get_pos_embed_from_transformer(self, transformer):
...@@ -382,30 +378,16 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -382,30 +378,16 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
for block in self.transformer_blocks: for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
if self.context_embedder is not None: if self.context_embedder is not None:
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
create_custom_forward(block), block,
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
temb, temb,
**ckpt_kwargs,
) )
else: else:
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states` # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb)
create_custom_forward(block), hidden_states, temb, **ckpt_kwargs
)
else: else:
if self.context_embedder is not None: if self.context_embedder is not None:
......
...@@ -590,10 +590,6 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -590,10 +590,6 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for module in self.children(): for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
sample: torch.Tensor, sample: torch.Tensor,
......
...@@ -29,8 +29,6 @@ from ..attention_processor import ( ...@@ -29,8 +29,6 @@ from ..attention_processor import (
from ..embeddings import TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from ..embeddings import TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..unets.unet_2d_blocks import ( from ..unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn, UNetMidBlock2DCrossAttn,
get_down_block, get_down_block,
) )
...@@ -599,10 +597,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -599,10 +597,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for module in self.children(): for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
sample: torch.Tensor, sample: torch.Tensor,
......
...@@ -20,7 +20,7 @@ import torch.utils.checkpoint ...@@ -20,7 +20,7 @@ import torch.utils.checkpoint
from torch import Tensor, nn from torch import Tensor, nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, is_torch_version, logging from ...utils import BaseOutput, logging
from ...utils.torch_utils import apply_freeu from ...utils.torch_utils import apply_freeu
from ..attention_processor import ( from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
...@@ -864,10 +864,6 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -864,10 +864,6 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
for u in self.up_blocks: for u in self.up_blocks:
u.freeze_base_params() u.freeze_base_params()
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@property @property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
...@@ -1450,15 +1446,6 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): ...@@ -1450,15 +1446,6 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
base_blocks = list(zip(self.base_resnets, self.base_attentions)) base_blocks = list(zip(self.base_resnets, self.base_attentions))
ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions)) ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions))
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip( for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip(
base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base
): ):
...@@ -1468,13 +1455,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): ...@@ -1468,13 +1455,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
# apply base subblock # apply base subblock
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} h_base = self._gradient_checkpointing_func(b_res, h_base, temb)
h_base = torch.utils.checkpoint.checkpoint(
create_custom_forward(b_res),
h_base,
temb,
**ckpt_kwargs,
)
else: else:
h_base = b_res(h_base, temb) h_base = b_res(h_base, temb)
...@@ -1491,13 +1472,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): ...@@ -1491,13 +1472,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
# apply ctrl subblock # apply ctrl subblock
if apply_control: if apply_control:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} h_ctrl = self._gradient_checkpointing_func(c_res, h_ctrl, temb)
h_ctrl = torch.utils.checkpoint.checkpoint(
create_custom_forward(c_res),
h_ctrl,
temb,
**ckpt_kwargs,
)
else: else:
h_ctrl = c_res(h_ctrl, temb) h_ctrl = c_res(h_ctrl, temb)
if c_attn is not None: if c_attn is not None:
...@@ -1862,15 +1837,6 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module): ...@@ -1862,15 +1837,6 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
and getattr(self, "b2", None) and getattr(self, "b2", None)
) )
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
# FreeU: Only operate on the first two stages # FreeU: Only operate on the first two stages
if is_freeu_enabled: if is_freeu_enabled:
...@@ -1900,13 +1866,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module): ...@@ -1900,13 +1866,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_h_base], dim=1) hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
......
...@@ -21,12 +21,13 @@ import json ...@@ -21,12 +21,13 @@ import json
import os import os
import re import re
from collections import OrderedDict from collections import OrderedDict
from functools import partial, wraps from functools import wraps
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import safetensors import safetensors
import torch import torch
import torch.utils.checkpoint
from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from torch import Tensor, nn from torch import Tensor, nn
...@@ -168,6 +169,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -168,6 +169,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._gradient_checkpointing_func = None
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
...@@ -193,14 +196,35 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -193,14 +196,35 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
""" """
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
def enable_gradient_checkpointing(self) -> None: def enable_gradient_checkpointing(self, gradient_checkpointing_func: Optional[Callable] = None) -> None:
""" """
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
*checkpoint activations* in other frameworks). *checkpoint activations* in other frameworks).
Args:
gradient_checkpointing_func (`Callable`, *optional*):
The function to use for gradient checkpointing. If `None`, the default PyTorch checkpointing function
is used (`torch.utils.checkpoint.checkpoint`).
""" """
if not self._supports_gradient_checkpointing: if not self._supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") raise ValueError(
self.apply(partial(self._set_gradient_checkpointing, value=True)) f"{self.__class__.__name__} does not support gradient checkpointing. Please make sure to set the boolean attribute "
f"`_supports_gradient_checkpointing` to `True` in the class definition."
)
if gradient_checkpointing_func is None:
def _gradient_checkpointing_func(module, *args):
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
return torch.utils.checkpoint.checkpoint(
module.__call__,
*args,
**ckpt_kwargs,
)
gradient_checkpointing_func = _gradient_checkpointing_func
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
def disable_gradient_checkpointing(self) -> None: def disable_gradient_checkpointing(self) -> None:
""" """
...@@ -208,7 +232,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -208,7 +232,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
*checkpoint activations* in other frameworks). *checkpoint activations* in other frameworks).
""" """
if self._supports_gradient_checkpointing: if self._supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False)) self._set_gradient_checkpointing(enable=False)
def set_use_npu_flash_attention(self, valid: bool) -> None: def set_use_npu_flash_attention(self, valid: bool) -> None:
r""" r"""
...@@ -1452,6 +1476,24 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1452,6 +1476,24 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mem = mem + mem_bufs mem = mem + mem_bufs
return mem return mem
def _set_gradient_checkpointing(
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
) -> None:
is_gradient_checkpointing_set = False
for name, module in self.named_modules():
if hasattr(module, "gradient_checkpointing"):
logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'")
module._gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = enable
is_gradient_checkpointing_set = True
if not is_gradient_checkpointing_set:
raise ValueError(
f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to "
f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`."
)
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
deprecated_attention_block_paths = [] deprecated_attention_block_paths = []
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Union from typing import Dict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -21,7 +21,7 @@ import torch.nn.functional as F ...@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin from ...loaders import FromOriginalModelMixin
from ...utils import is_torch_version, logging from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import ( from ..attention_processor import (
Attention, Attention,
...@@ -444,10 +444,6 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin ...@@ -444,10 +444,6 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors) self.set_attn_processor(self.original_attn_processors)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
...@@ -469,23 +465,11 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin ...@@ -469,23 +465,11 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
# MMDiT blocks. # MMDiT blocks.
for index_block, block in enumerate(self.joint_transformer_blocks): for index_block, block in enumerate(self.joint_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): block,
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
temb, temb,
**ckpt_kwargs,
) )
else: else:
...@@ -500,22 +484,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin ...@@ -500,22 +484,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
for index_block, block in enumerate(self.single_transformer_blocks): for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
combined_hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): block,
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
combined_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
combined_hidden_states, combined_hidden_states,
temb, temb,
**ckpt_kwargs,
) )
else: else:
......
...@@ -20,7 +20,7 @@ from torch import nn ...@@ -20,7 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
...@@ -331,9 +331,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac ...@@ -331,9 +331,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
self.gradient_checkpointing = False self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@property @property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
...@@ -489,22 +486,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac ...@@ -489,22 +486,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
# 3. Transformer blocks # 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module): block,
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
emb, emb,
image_rotary_emb, image_rotary_emb,
attention_kwargs, attention_kwargs,
**ckpt_kwargs,
) )
else: else:
hidden_states, encoder_hidden_states = block( hidden_states, encoder_hidden_states = block(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment