Unverified Commit c4d4ac21 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Refactor gradient checkpointing (#10611)

* update

* remove unused fn

* apply suggestions based on review

* update + cleanup 🧹

* more cleanup 🧹

* make fix-copies

* update test
parent f295e2ee
...@@ -20,7 +20,7 @@ from torch import nn ...@@ -20,7 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0 from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
...@@ -595,9 +595,6 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -595,9 +595,6 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def _init_face_inputs(self): def _init_face_inputs(self):
self.local_facial_extractor = LocalFacialExtractor( self.local_facial_extractor = LocalFacialExtractor(
id_dim=self.LFE_id_dim, id_dim=self.LFE_id_dim,
...@@ -745,22 +742,13 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -745,22 +742,13 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 3. Transformer blocks # 3. Transformer blocks
ca_idx = 0 ca_idx = 0
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module): block,
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
emb, emb,
image_rotary_emb, image_rotary_emb,
**ckpt_kwargs,
) )
else: else:
hidden_states, encoder_hidden_states = block( hidden_states, encoder_hidden_states = block(
......
...@@ -18,7 +18,7 @@ import torch.nn.functional as F ...@@ -18,7 +18,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging from ...utils import logging
from ..attention import BasicTransformerBlock from ..attention import BasicTransformerBlock
from ..embeddings import PatchEmbed from ..embeddings import PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
...@@ -144,10 +144,6 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -144,10 +144,6 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
) )
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -186,19 +182,8 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -186,19 +182,8 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
# 2. Blocks # 2. Blocks
for block in self.transformer_blocks: for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): block,
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
None, None,
None, None,
...@@ -206,7 +191,6 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -206,7 +191,6 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
timestep, timestep,
cross_attention_kwargs, cross_attention_kwargs,
class_labels, class_labels,
**ckpt_kwargs,
) )
else: else:
hidden_states = block( hidden_states = block(
......
...@@ -166,9 +166,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): ...@@ -166,9 +166,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -243,7 +240,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): ...@@ -243,7 +240,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
zip(self.transformer_blocks, self.temporal_transformer_blocks) zip(self.transformer_blocks, self.temporal_transformer_blocks)
): ):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(
spatial_block, spatial_block,
hidden_states, hidden_states,
None, # attention_mask None, # attention_mask
...@@ -252,7 +249,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): ...@@ -252,7 +249,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
timestep_spatial, timestep_spatial,
None, # cross_attention_kwargs None, # cross_attention_kwargs
None, # class_labels None, # class_labels
use_reentrant=False,
) )
else: else:
hidden_states = spatial_block( hidden_states = spatial_block(
...@@ -276,7 +272,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): ...@@ -276,7 +272,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
hidden_states = hidden_states + self.temp_pos_embed hidden_states = hidden_states + self.temp_pos_embed
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(
temp_block, temp_block,
hidden_states, hidden_states,
None, # attention_mask None, # attention_mask
...@@ -285,7 +281,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): ...@@ -285,7 +281,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
timestep_temp, timestep_temp,
None, # cross_attention_kwargs None, # cross_attention_kwargs
None, # class_labels None, # class_labels
use_reentrant=False,
) )
else: else:
hidden_states = temp_block( hidden_states = temp_block(
......
...@@ -17,7 +17,7 @@ import torch ...@@ -17,7 +17,7 @@ import torch
from torch import nn from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging from ...utils import logging
from ..attention import BasicTransformerBlock from ..attention import BasicTransformerBlock
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0 from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
...@@ -184,10 +184,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -184,10 +184,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
in_features=self.config.caption_channels, hidden_size=self.inner_dim in_features=self.config.caption_channels, hidden_size=self.inner_dim
) )
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@property @property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
...@@ -388,19 +384,8 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -388,19 +384,8 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
# 2. Blocks # 2. Blocks
for block in self.transformer_blocks: for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): block,
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -408,7 +393,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -408,7 +393,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
timestep, timestep,
cross_attention_kwargs, cross_attention_kwargs,
None, None,
**ckpt_kwargs,
) )
else: else:
hidden_states = block( hidden_states = block(
......
...@@ -19,7 +19,7 @@ from torch import nn ...@@ -19,7 +19,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention_processor import ( from ..attention_processor import (
Attention, Attention,
AttentionProcessor, AttentionProcessor,
...@@ -308,10 +308,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -308,10 +308,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@property @property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
...@@ -438,21 +434,9 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -438,21 +434,9 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 2. Transformer blocks # 2. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for block in self.transformer_blocks: for block in self.transformer_blocks:
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(
create_custom_forward(block), block,
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -460,7 +444,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -460,7 +444,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
timestep, timestep,
post_patch_height, post_patch_height,
post_patch_width, post_patch_width,
**ckpt_kwargs,
) )
else: else:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional, Union from typing import Dict, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -29,7 +29,7 @@ from ...models.attention_processor import ( ...@@ -29,7 +29,7 @@ from ...models.attention_processor import (
) )
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_2d import Transformer2DModelOutput from ...models.transformers.transformer_2d import Transformer2DModelOutput
from ...utils import is_torch_version, logging from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
...@@ -346,10 +346,6 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin): ...@@ -346,10 +346,6 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
""" """
self.set_attn_processor(StableAudioAttnProcessor2_0()) self.set_attn_processor(StableAudioAttnProcessor2_0())
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
...@@ -416,25 +412,13 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin): ...@@ -416,25 +412,13 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
for block in self.transformer_blocks: for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): block,
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
attention_mask, attention_mask,
cross_attention_hidden_states, cross_attention_hidden_states,
encoder_attention_mask, encoder_attention_mask,
rotary_embedding, rotary_embedding,
**ckpt_kwargs,
) )
else: else:
......
...@@ -18,7 +18,7 @@ import torch.nn.functional as F ...@@ -18,7 +18,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from ...configuration_utils import LegacyConfigMixin, register_to_config from ...configuration_utils import LegacyConfigMixin, register_to_config
from ...utils import deprecate, is_torch_version, logging from ...utils import deprecate, logging
from ..attention import BasicTransformerBlock from ..attention import BasicTransformerBlock
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
...@@ -321,10 +321,6 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): ...@@ -321,10 +321,6 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
in_features=self.caption_channels, hidden_size=self.inner_dim in_features=self.caption_channels, hidden_size=self.inner_dim
) )
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -417,19 +413,8 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): ...@@ -417,19 +413,8 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
# 2. Blocks # 2. Blocks
for block in self.transformer_blocks: for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): block,
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -437,7 +422,6 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): ...@@ -437,7 +422,6 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
timestep, timestep,
cross_attention_kwargs, cross_attention_kwargs,
class_labels, class_labels,
**ckpt_kwargs,
) )
else: else:
hidden_states = block( hidden_states = block(
......
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward from ..attention import FeedForward
from ..attention_processor import AllegroAttnProcessor2_0, Attention from ..attention_processor import AllegroAttnProcessor2_0, Attention
...@@ -304,9 +304,6 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): ...@@ -304,9 +304,6 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -376,23 +373,14 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): ...@@ -376,23 +373,14 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
# TODO(aryan): Implement gradient checkpointing # TODO(aryan): Implement gradient checkpointing
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module): block,
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
timestep, timestep,
attention_mask, attention_mask,
encoder_attention_mask, encoder_attention_mask,
image_rotary_emb, image_rotary_emb,
**ckpt_kwargs,
) )
else: else:
hidden_states = block( hidden_states = block(
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Union from typing import Dict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -27,7 +27,7 @@ from ...models.attention_processor import ( ...@@ -27,7 +27,7 @@ from ...models.attention_processor import (
) )
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous from ...models.normalization import AdaLayerNormContinuous
from ...utils import is_torch_version, logging from ...utils import logging
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
...@@ -289,10 +289,6 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -289,10 +289,6 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -344,20 +340,11 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -344,20 +340,11 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
for index_block, block in enumerate(self.transformer_blocks): for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module): block,
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
emb, emb,
**ckpt_kwargs,
) )
else: else:
hidden_states, encoder_hidden_states = block( hidden_states, encoder_hidden_states = block(
......
...@@ -32,7 +32,7 @@ from ...models.attention_processor import ( ...@@ -32,7 +32,7 @@ from ...models.attention_processor import (
) )
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
...@@ -423,10 +423,6 @@ class FluxTransformer2DModel( ...@@ -423,10 +423,6 @@ class FluxTransformer2DModel(
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors) self.set_attn_processor(self.original_attn_processors)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -521,24 +517,12 @@ class FluxTransformer2DModel( ...@@ -521,24 +517,12 @@ class FluxTransformer2DModel(
for index_block, block in enumerate(self.transformer_blocks): for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): block,
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
temb, temb,
image_rotary_emb, image_rotary_emb,
**ckpt_kwargs,
) )
else: else:
...@@ -565,23 +549,11 @@ class FluxTransformer2DModel( ...@@ -565,23 +549,11 @@ class FluxTransformer2DModel(
for index_block, block in enumerate(self.single_transformer_blocks): for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): block,
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
temb, temb,
image_rotary_emb, image_rotary_emb,
**ckpt_kwargs,
) )
else: else:
......
...@@ -22,7 +22,7 @@ from diffusers.loaders import FromOriginalModelMixin ...@@ -22,7 +22,7 @@ from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor from ..attention_processor import Attention, AttentionProcessor
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
...@@ -672,10 +672,6 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, ...@@ -672,10 +672,6 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -734,38 +730,24 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, ...@@ -734,38 +730,24 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
# 4. Transformer blocks # 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for block in self.transformer_blocks: for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
create_custom_forward(block), block,
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
temb, temb,
attention_mask, attention_mask,
image_rotary_emb, image_rotary_emb,
**ckpt_kwargs,
) )
for block in self.single_transformer_blocks: for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
create_custom_forward(block), block,
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
temb, temb,
attention_mask, attention_mask,
image_rotary_emb, image_rotary_emb,
**ckpt_kwargs,
) )
else: else:
......
...@@ -22,7 +22,7 @@ import torch.nn.functional as F ...@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward from ..attention import FeedForward
from ..attention_processor import Attention from ..attention_processor import Attention
...@@ -361,10 +361,6 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin ...@@ -361,10 +361,6 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
self.gradient_checkpointing = False self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -417,25 +413,13 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin ...@@ -417,25 +413,13 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
for block in self.transformer_blocks: for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): block,
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
temb, temb,
image_rotary_emb, image_rotary_emb,
encoder_attention_mask, encoder_attention_mask,
**ckpt_kwargs,
) )
else: else:
hidden_states = block( hidden_states = block(
......
...@@ -21,7 +21,7 @@ import torch.nn as nn ...@@ -21,7 +21,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward from ..attention import FeedForward
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
...@@ -404,10 +404,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri ...@@ -404,10 +404,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
self.gradient_checkpointing = False self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -460,22 +456,13 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri ...@@ -460,22 +456,13 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module): block,
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
temb, temb,
encoder_attention_mask, encoder_attention_mask,
image_rotary_emb, image_rotary_emb,
**ckpt_kwargs,
) )
else: else:
hidden_states, encoder_hidden_states = block( hidden_states, encoder_hidden_states = block(
......
...@@ -28,7 +28,7 @@ from ...models.attention_processor import ( ...@@ -28,7 +28,7 @@ from ...models.attention_processor import (
) )
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
...@@ -329,10 +329,6 @@ class SD3Transformer2DModel( ...@@ -329,10 +329,6 @@ class SD3Transformer2DModel(
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors) self.set_attn_processor(self.original_attn_processors)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
...@@ -404,24 +400,12 @@ class SD3Transformer2DModel( ...@@ -404,24 +400,12 @@ class SD3Transformer2DModel(
is_skip = True if skip_layers is not None and index_block in skip_layers else False is_skip = True if skip_layers is not None and index_block in skip_layers else False
if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip: if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): block,
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
temb, temb,
joint_attention_kwargs, joint_attention_kwargs,
**ckpt_kwargs,
) )
elif not is_skip: elif not is_skip:
encoder_hidden_states, hidden_states = block( encoder_hidden_states, hidden_states = block(
......
...@@ -343,19 +343,11 @@ class TransformerSpatioTemporalModel(nn.Module): ...@@ -343,19 +343,11 @@ class TransformerSpatioTemporalModel(nn.Module):
# 2. Blocks # 2. Blocks
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(
block, block, hidden_states, None, encoder_hidden_states, None
hidden_states,
None,
encoder_hidden_states,
None,
use_reentrant=False,
) )
else: else:
hidden_states = block( hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states_mix = hidden_states hidden_states_mix = hidden_states
hidden_states_mix = hidden_states_mix + emb hidden_states_mix = hidden_states_mix + emb
......
...@@ -248,10 +248,6 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -248,10 +248,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
sample: torch.Tensor, sample: torch.Tensor,
......
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ...utils import deprecate, is_torch_version, logging from ...utils import deprecate, logging
from ...utils.torch_utils import apply_freeu from ...utils.torch_utils import apply_freeu
from ..activations import get_activation from ..activations import get_activation
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
...@@ -737,25 +737,9 @@ class UNetMidBlock2D(nn.Module): ...@@ -737,25 +737,9 @@ class UNetMidBlock2D(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
if attn is not None: if attn is not None:
hidden_states = attn(hidden_states, temb=temb) hidden_states = attn(hidden_states, temb=temb)
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else: else:
if attn is not None: if attn is not None:
hidden_states = attn(hidden_states, temb=temb) hidden_states = attn(hidden_states, temb=temb)
...@@ -883,17 +867,6 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -883,17 +867,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -902,12 +875,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -902,12 +875,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else: else:
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
...@@ -1156,23 +1124,7 @@ class AttnDownBlock2D(nn.Module): ...@@ -1156,23 +1124,7 @@ class AttnDownBlock2D(nn.Module):
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(hidden_states, **cross_attention_kwargs) hidden_states = attn(hidden_states, **cross_attention_kwargs)
output_states = output_states + (hidden_states,) output_states = output_states + (hidden_states,)
else: else:
...@@ -1304,23 +1256,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -1304,23 +1256,7 @@ class CrossAttnDownBlock2D(nn.Module):
for i, (resnet, attn) in enumerate(blocks): for i, (resnet, attn) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -1418,21 +1354,7 @@ class DownBlock2D(nn.Module): ...@@ -1418,21 +1354,7 @@ class DownBlock2D(nn.Module):
for resnet in self.resnets: for resnet in self.resnets:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1906,21 +1828,7 @@ class ResnetDownsampleBlock2D(nn.Module): ...@@ -1906,21 +1828,7 @@ class ResnetDownsampleBlock2D(nn.Module):
for resnet in self.resnets: for resnet in self.resnets:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -2058,17 +1966,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -2058,17 +1966,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -2153,21 +2051,7 @@ class KDownBlock2D(nn.Module): ...@@ -2153,21 +2051,7 @@ class KDownBlock2D(nn.Module):
for resnet in self.resnets: for resnet in self.resnets:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -2262,22 +2146,10 @@ class KCrossAttnDownBlock2D(nn.Module): ...@@ -2262,22 +2146,10 @@ class KCrossAttnDownBlock2D(nn.Module):
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): resnet,
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states, hidden_states,
temb, temb,
**ckpt_kwargs,
) )
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
...@@ -2423,23 +2295,7 @@ class AttnUpBlock2D(nn.Module): ...@@ -2423,23 +2295,7 @@ class AttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(hidden_states) hidden_states = attn(hidden_states)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -2588,23 +2444,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -2588,23 +2444,7 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -2721,21 +2561,7 @@ class UpBlock2D(nn.Module): ...@@ -2721,21 +2561,7 @@ class UpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -3251,21 +3077,7 @@ class ResnetUpsampleBlock2D(nn.Module): ...@@ -3251,21 +3077,7 @@ class ResnetUpsampleBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -3409,17 +3221,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -3409,17 +3221,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -3512,21 +3314,7 @@ class KUpBlock2D(nn.Module): ...@@ -3512,21 +3314,7 @@ class KUpBlock2D(nn.Module):
for resnet in self.resnets: for resnet in self.resnets:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -3640,22 +3428,10 @@ class KCrossAttnUpBlock2D(nn.Module): ...@@ -3640,22 +3428,10 @@ class KCrossAttnUpBlock2D(nn.Module):
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): resnet,
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states, hidden_states,
temb, temb,
**ckpt_kwargs,
) )
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
......
...@@ -834,10 +834,6 @@ class UNet2DConditionModel( ...@@ -834,10 +834,6 @@ class UNet2DConditionModel(
for module in self.children(): for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
......
...@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple, Union ...@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from ...utils import deprecate, is_torch_version, logging from ...utils import deprecate, logging
from ...utils.torch_utils import apply_freeu from ...utils.torch_utils import apply_freeu
from ..attention import Attention from ..attention import Attention
from ..resnet import ( from ..resnet import (
...@@ -1078,31 +1078,14 @@ class UNetMidBlockSpatioTemporal(nn.Module): ...@@ -1078,31 +1078,14 @@ class UNetMidBlockSpatioTemporal(nn.Module):
) )
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator, image_only_indicator=image_only_indicator,
return_dict=False, return_dict=False,
)[0] )[0]
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
else: else:
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
...@@ -1110,11 +1093,7 @@ class UNetMidBlockSpatioTemporal(nn.Module): ...@@ -1110,11 +1093,7 @@ class UNetMidBlockSpatioTemporal(nn.Module):
image_only_indicator=image_only_indicator, image_only_indicator=image_only_indicator,
return_dict=False, return_dict=False,
)[0] )[0]
hidden_states = resnet( hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
return hidden_states return hidden_states
...@@ -1169,34 +1148,9 @@ class DownBlockSpatioTemporal(nn.Module): ...@@ -1169,34 +1148,9 @@ class DownBlockSpatioTemporal(nn.Module):
output_states = () output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
)
else: else:
hidden_states = resnet( hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
output_states = output_states + (hidden_states,) output_states = output_states + (hidden_states,)
...@@ -1281,25 +1235,8 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module): ...@@ -1281,25 +1235,8 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
blocks = list(zip(self.resnets, self.attentions)) blocks = list(zip(self.resnets, self.attentions))
for resnet, attn in blocks: for resnet, attn in blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
...@@ -1308,11 +1245,7 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module): ...@@ -1308,11 +1245,7 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
return_dict=False, return_dict=False,
)[0] )[0]
else: else:
hidden_states = resnet( hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -1385,34 +1318,9 @@ class UpBlockSpatioTemporal(nn.Module): ...@@ -1385,34 +1318,9 @@ class UpBlockSpatioTemporal(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
)
else: else:
hidden_states = resnet( hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
...@@ -1495,25 +1403,8 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module): ...@@ -1495,25 +1403,8 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -1521,11 +1412,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module): ...@@ -1521,11 +1412,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
return_dict=False, return_dict=False,
)[0] )[0]
else: else:
hidden_states = resnet( hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
......
...@@ -37,11 +37,7 @@ from ..embeddings import TimestepEmbedding, Timesteps ...@@ -37,11 +37,7 @@ from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..transformers.transformer_temporal import TransformerTemporalModel from ..transformers.transformer_temporal import TransformerTemporalModel
from .unet_3d_blocks import ( from .unet_3d_blocks import (
CrossAttnDownBlock3D,
CrossAttnUpBlock3D,
DownBlock3D,
UNetMidBlock3DCrossAttn, UNetMidBlock3DCrossAttn,
UpBlock3D,
get_down_block, get_down_block,
get_up_block, get_up_block,
) )
...@@ -472,10 +468,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -472,10 +468,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.set_attn_processor(processor) self.set_attn_processor(processor)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1, s2, b1, b2): def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment