Unverified Commit 5b972fbd authored by Michael Tkachuk's avatar Michael Tkachuk Committed by GitHub
Browse files

Enabling gradient checkpointing in eval() mode (#9878)

* refactored
parent 0be52c07
......@@ -868,7 +868,7 @@ class CrossAttnDownBlock2D(nn.Module):
blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
......@@ -1029,7 +1029,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
......@@ -1191,7 +1191,7 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
......@@ -1364,7 +1364,7 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
# Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
......
......@@ -215,7 +215,7 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
# 2. Blocks
for block_index, block in enumerate(self.transformer.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
# rc todo: for training and gradient checkpointing
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1)
......
......@@ -506,7 +506,7 @@ class AllegroEncoder3D(nn.Module):
sample = self.temp_conv_in(sample)
sample = sample + residual
if self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -646,7 +646,7 @@ class AllegroDecoder3D(nn.Module):
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -420,7 +420,7 @@ class CogVideoXDownBlock3D(nn.Module):
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
......@@ -522,7 +522,7 @@ class CogVideoXMidBlock3D(nn.Module):
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
......@@ -636,7 +636,7 @@ class CogVideoXUpBlock3D(nn.Module):
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
......@@ -773,7 +773,7 @@ class CogVideoXEncoder3D(nn.Module):
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -939,7 +939,7 @@ class CogVideoXDecoder3D(nn.Module):
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -206,7 +206,7 @@ class MochiDownBlock3D(nn.Module):
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
......@@ -311,7 +311,7 @@ class MochiMidBlock3D(nn.Module):
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
......@@ -392,7 +392,7 @@ class MochiUpBlock3D(nn.Module):
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
......@@ -529,7 +529,7 @@ class MochiEncoder3D(nn.Module):
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
......@@ -646,7 +646,7 @@ class MochiDecoder3D(nn.Module):
hidden_states = self.conv_in(hidden_states)
# 1. Mid
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
......
......@@ -95,7 +95,7 @@ class TemporalDecoder(nn.Module):
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -142,7 +142,7 @@ class Encoder(nn.Module):
sample = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -291,7 +291,7 @@ class Decoder(nn.Module):
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -544,7 +544,7 @@ class MaskConditionDecoder(nn.Module):
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -876,7 +876,7 @@ class EncoderTiny(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `EncoderTiny` class."""
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -962,7 +962,7 @@ class DecoderTiny(nn.Module):
# Clamp.
x = torch.tanh(x / 3) * 3
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -329,7 +329,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
block_samples = ()
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
......@@ -363,7 +363,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
......
......@@ -324,7 +324,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
block_res_samples = ()
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
......
......@@ -1466,7 +1466,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)
# apply base subblock
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
h_base = torch.utils.checkpoint.checkpoint(
create_custom_forward(b_res),
......@@ -1489,7 +1489,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
# apply ctrl subblock
if apply_control:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
h_ctrl = torch.utils.checkpoint.checkpoint(
create_custom_forward(c_res),
......@@ -1898,7 +1898,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base)
hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
......
......@@ -466,7 +466,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
# MMDiT blocks.
for index_block, block in enumerate(self.joint_transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
......@@ -497,7 +497,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
......
......@@ -452,7 +452,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -184,7 +184,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
# 2. Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
......
......@@ -238,7 +238,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
for i, (spatial_block, temp_block) in enumerate(
zip(self.transformer_blocks, self.temporal_transformer_blocks)
):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
spatial_block,
hidden_states,
......@@ -271,7 +271,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
if i == 0 and num_frame > 1:
hidden_states = hidden_states + self.temp_pos_embed
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states,
......
......@@ -386,7 +386,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
# 2. Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
......
......@@ -414,7 +414,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1)
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
......
......@@ -415,7 +415,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
# 2. Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
......
......@@ -371,7 +371,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
# TODO(aryan): Implement gradient checkpointing
if self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -341,7 +341,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
hidden_states = hidden_states[:, text_seq_length:]
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -480,7 +480,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
......@@ -525,7 +525,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment