Unverified Commit 5b972fbd authored by Michael Tkachuk's avatar Michael Tkachuk Committed by GitHub
Browse files

Enabling gradient checkpointing in eval() mode (#9878)

* refactored
parent 0be52c07
...@@ -868,7 +868,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -868,7 +868,7 @@ class CrossAttnDownBlock2D(nn.Module):
blocks = list(zip(self.resnets, self.attentions)) blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks): for i, (resnet, attn) in enumerate(blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1029,7 +1029,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -1029,7 +1029,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1191,7 +1191,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1191,7 +1191,7 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1364,7 +1364,7 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin): ...@@ -1364,7 +1364,7 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
# Blocks # Blocks
for block in self.transformer_blocks: for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -215,7 +215,7 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin): ...@@ -215,7 +215,7 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
# 2. Blocks # 2. Blocks
for block_index, block in enumerate(self.transformer.transformer_blocks): for block_index, block in enumerate(self.transformer.transformer_blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
# rc todo: for training and gradient checkpointing # rc todo: for training and gradient checkpointing
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.") print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1) exit(1)
......
...@@ -506,7 +506,7 @@ class AllegroEncoder3D(nn.Module): ...@@ -506,7 +506,7 @@ class AllegroEncoder3D(nn.Module):
sample = self.temp_conv_in(sample) sample = self.temp_conv_in(sample)
sample = sample + residual sample = sample + residual
if self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -646,7 +646,7 @@ class AllegroDecoder3D(nn.Module): ...@@ -646,7 +646,7 @@ class AllegroDecoder3D(nn.Module):
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -420,7 +420,7 @@ class CogVideoXDownBlock3D(nn.Module): ...@@ -420,7 +420,7 @@ class CogVideoXDownBlock3D(nn.Module):
for i, resnet in enumerate(self.resnets): for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}" conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def create_forward(*inputs): def create_forward(*inputs):
...@@ -522,7 +522,7 @@ class CogVideoXMidBlock3D(nn.Module): ...@@ -522,7 +522,7 @@ class CogVideoXMidBlock3D(nn.Module):
for i, resnet in enumerate(self.resnets): for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}" conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def create_forward(*inputs): def create_forward(*inputs):
...@@ -636,7 +636,7 @@ class CogVideoXUpBlock3D(nn.Module): ...@@ -636,7 +636,7 @@ class CogVideoXUpBlock3D(nn.Module):
for i, resnet in enumerate(self.resnets): for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}" conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def create_forward(*inputs): def create_forward(*inputs):
...@@ -773,7 +773,7 @@ class CogVideoXEncoder3D(nn.Module): ...@@ -773,7 +773,7 @@ class CogVideoXEncoder3D(nn.Module):
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -939,7 +939,7 @@ class CogVideoXDecoder3D(nn.Module): ...@@ -939,7 +939,7 @@ class CogVideoXDecoder3D(nn.Module):
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -206,7 +206,7 @@ class MochiDownBlock3D(nn.Module): ...@@ -206,7 +206,7 @@ class MochiDownBlock3D(nn.Module):
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)): for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
conv_cache_key = f"resnet_{i}" conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def create_forward(*inputs): def create_forward(*inputs):
...@@ -311,7 +311,7 @@ class MochiMidBlock3D(nn.Module): ...@@ -311,7 +311,7 @@ class MochiMidBlock3D(nn.Module):
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)): for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
conv_cache_key = f"resnet_{i}" conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def create_forward(*inputs): def create_forward(*inputs):
...@@ -392,7 +392,7 @@ class MochiUpBlock3D(nn.Module): ...@@ -392,7 +392,7 @@ class MochiUpBlock3D(nn.Module):
for i, resnet in enumerate(self.resnets): for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}" conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def create_forward(*inputs): def create_forward(*inputs):
...@@ -529,7 +529,7 @@ class MochiEncoder3D(nn.Module): ...@@ -529,7 +529,7 @@ class MochiEncoder3D(nn.Module):
hidden_states = self.proj_in(hidden_states) hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 4, 1, 2, 3) hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def create_forward(*inputs): def create_forward(*inputs):
...@@ -646,7 +646,7 @@ class MochiDecoder3D(nn.Module): ...@@ -646,7 +646,7 @@ class MochiDecoder3D(nn.Module):
hidden_states = self.conv_in(hidden_states) hidden_states = self.conv_in(hidden_states)
# 1. Mid # 1. Mid
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def create_forward(*inputs): def create_forward(*inputs):
......
...@@ -95,7 +95,7 @@ class TemporalDecoder(nn.Module): ...@@ -95,7 +95,7 @@ class TemporalDecoder(nn.Module):
sample = self.conv_in(sample) sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -142,7 +142,7 @@ class Encoder(nn.Module): ...@@ -142,7 +142,7 @@ class Encoder(nn.Module):
sample = self.conv_in(sample) sample = self.conv_in(sample)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -291,7 +291,7 @@ class Decoder(nn.Module): ...@@ -291,7 +291,7 @@ class Decoder(nn.Module):
sample = self.conv_in(sample) sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -544,7 +544,7 @@ class MaskConditionDecoder(nn.Module): ...@@ -544,7 +544,7 @@ class MaskConditionDecoder(nn.Module):
sample = self.conv_in(sample) sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -876,7 +876,7 @@ class EncoderTiny(nn.Module): ...@@ -876,7 +876,7 @@ class EncoderTiny(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `EncoderTiny` class.""" r"""The forward method of the `EncoderTiny` class."""
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -962,7 +962,7 @@ class DecoderTiny(nn.Module): ...@@ -962,7 +962,7 @@ class DecoderTiny(nn.Module):
# Clamp. # Clamp.
x = torch.tanh(x / 3) * 3 x = torch.tanh(x / 3) * 3
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -329,7 +329,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -329,7 +329,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
block_samples = () block_samples = ()
for index_block, block in enumerate(self.transformer_blocks): for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -363,7 +363,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -363,7 +363,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
single_block_samples = () single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks): for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -324,7 +324,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -324,7 +324,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
block_res_samples = () block_res_samples = ()
for block in self.transformer_blocks: for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -1466,7 +1466,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): ...@@ -1466,7 +1466,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)
# apply base subblock # apply base subblock
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
h_base = torch.utils.checkpoint.checkpoint( h_base = torch.utils.checkpoint.checkpoint(
create_custom_forward(b_res), create_custom_forward(b_res),
...@@ -1489,7 +1489,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): ...@@ -1489,7 +1489,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
# apply ctrl subblock # apply ctrl subblock
if apply_control: if apply_control:
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
h_ctrl = torch.utils.checkpoint.checkpoint( h_ctrl = torch.utils.checkpoint.checkpoint(
create_custom_forward(c_res), create_custom_forward(c_res),
...@@ -1898,7 +1898,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module): ...@@ -1898,7 +1898,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base) hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base)
hidden_states = torch.cat([hidden_states, res_h_base], dim=1) hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), create_custom_forward(resnet),
......
...@@ -466,7 +466,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -466,7 +466,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
# MMDiT blocks. # MMDiT blocks.
for index_block, block in enumerate(self.joint_transformer_blocks): for index_block, block in enumerate(self.joint_transformer_blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -497,7 +497,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -497,7 +497,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks): for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -452,7 +452,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -452,7 +452,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 3. Transformer blocks # 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -184,7 +184,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -184,7 +184,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
# 2. Blocks # 2. Blocks
for block in self.transformer_blocks: for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -238,7 +238,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -238,7 +238,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
for i, (spatial_block, temp_block) in enumerate( for i, (spatial_block, temp_block) in enumerate(
zip(self.transformer_blocks, self.temporal_transformer_blocks) zip(self.transformer_blocks, self.temporal_transformer_blocks)
): ):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
spatial_block, spatial_block,
hidden_states, hidden_states,
...@@ -271,7 +271,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -271,7 +271,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
if i == 0 and num_frame > 1: if i == 0 and num_frame > 1:
hidden_states = hidden_states + self.temp_pos_embed hidden_states = hidden_states + self.temp_pos_embed
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
temp_block, temp_block,
hidden_states, hidden_states,
......
...@@ -386,7 +386,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -386,7 +386,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
# 2. Blocks # 2. Blocks
for block in self.transformer_blocks: for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -414,7 +414,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin): ...@@ -414,7 +414,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1) attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1)
for block in self.transformer_blocks: for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -415,7 +415,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): ...@@ -415,7 +415,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
# 2. Blocks # 2. Blocks
for block in self.transformer_blocks: for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -371,7 +371,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -371,7 +371,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
# 3. Transformer blocks # 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
# TODO(aryan): Implement gradient checkpointing # TODO(aryan): Implement gradient checkpointing
if self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -341,7 +341,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -341,7 +341,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
hidden_states = hidden_states[:, text_seq_length:] hidden_states = hidden_states[:, text_seq_length:]
for index_block, block in enumerate(self.transformer_blocks): for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -480,7 +480,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -480,7 +480,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
image_rotary_emb = self.pos_embed(ids) image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks): for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -525,7 +525,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -525,7 +525,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks): for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment