Unverified Commit c4d4ac21 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Refactor gradient checkpointing (#10611)

* update

* remove unused fn

* apply suggestions based on review

* update + cleanup 🧹

* more cleanup 🧹

* make fix-copies

* update test
parent f295e2ee
......@@ -35,11 +35,7 @@ from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..transformers.transformer_temporal import TransformerTemporalModel
from .unet_3d_blocks import (
CrossAttnDownBlock3D,
CrossAttnUpBlock3D,
DownBlock3D,
UNetMidBlock3DCrossAttn,
UpBlock3D,
get_down_block,
get_up_block,
)
......@@ -436,11 +432,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
self.set_attn_processor(processor)
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel._set_gradient_checkpointing
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
......
......@@ -205,10 +205,6 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
"""
self.set_attn_processor(AttnProcessor())
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
......
......@@ -22,7 +22,7 @@ import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, is_torch_version, logging
from ...utils import BaseOutput, deprecate, logging
from ...utils.torch_utils import apply_freeu
from ..attention import BasicTransformerBlock
from ..attention_processor import (
......@@ -324,25 +324,7 @@ class DownBlockMotion(nn.Module):
blocks = zip(self.resnets, self.motion_modules)
for resnet, motion_module in blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
......@@ -514,23 +496,7 @@ class CrossAttnDownBlockMotion(nn.Module):
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
for i, (resnet, attn, motion_module) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
......@@ -543,10 +509,7 @@ class CrossAttnDownBlockMotion(nn.Module):
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
hidden_states = motion_module(hidden_states, num_frames=num_frames)
# apply additional residuals to the output of the last pair of resnet and attention blocks
if i == len(blocks) - 1 and additional_residuals is not None:
......@@ -733,23 +696,7 @@ class CrossAttnUpBlockMotion(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
......@@ -762,10 +709,7 @@ class CrossAttnUpBlockMotion(nn.Module):
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
hidden_states = motion_module(hidden_states, num_frames=num_frames)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
......@@ -896,24 +840,7 @@ class UpBlockMotion(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
......@@ -1080,34 +1007,12 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
)[0]
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
hidden_states = self._gradient_checkpointing_func(
motion_module, hidden_states, None, None, None, num_frames, None
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
hidden_states = motion_module(hidden_states, None, None, None, num_frames, None)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
return hidden_states
......@@ -1966,10 +1871,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
self.set_attn_processor(processor)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
module.gradient_checkpointing = value
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
......
......@@ -320,10 +320,6 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
self.set_attn_processor(processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
......
......@@ -387,9 +387,6 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, value=False):
self.gradient_checkpointing = value
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
torch.nn.init.xavier_uniform_(m.weight)
......@@ -456,29 +453,18 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for down_block, downscaler, repmap in block_group:
x = downscaler(x)
for i in range(len(repmap) + 1):
for block in down_block:
if isinstance(block, SDCascadeResBlock):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
x = self._gradient_checkpointing_func(block, x)
elif isinstance(block, SDCascadeAttnBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, clip, use_reentrant=False
)
x = self._gradient_checkpointing_func(block, x, clip)
elif isinstance(block, SDCascadeTimestepBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, r_embed, use_reentrant=False
)
x = self._gradient_checkpointing_func(block, x, r_embed)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), use_reentrant=False)
x = self._gradient_checkpointing_func(block)
if i < len(repmap):
x = repmap[i](x)
level_outputs.insert(0, x)
......@@ -505,13 +491,6 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for i, (up_block, upscaler, repmap) in enumerate(block_group):
for j in range(len(repmap) + 1):
for k, block in enumerate(up_block):
......@@ -523,19 +502,13 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
)
x = x.to(orig_type)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, skip, use_reentrant=False
)
x = self._gradient_checkpointing_func(block, x, skip)
elif isinstance(block, SDCascadeAttnBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, clip, use_reentrant=False
)
x = self._gradient_checkpointing_func(block, x, clip)
elif isinstance(block, SDCascadeTimestepBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, r_embed, use_reentrant=False
)
x = self._gradient_checkpointing_func(block, x, r_embed)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
x = self._gradient_checkpointing_func(block, x)
if j < len(repmap):
x = repmap[j](x)
x = upscaler(x)
......
......@@ -148,9 +148,6 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
pass
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
......
......@@ -38,7 +38,7 @@ from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from ...models.transformers.transformer_2d import Transformer2DModel
from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
from ...models.unets.unet_2d_condition import UNet2DConditionOutput
from ...utils import BaseOutput, is_torch_version, logging
from ...utils import BaseOutput, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -673,11 +673,6 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.Tensor,
......@@ -1114,23 +1109,7 @@ class CrossAttnDownBlock2D(nn.Module):
for i in range(num_layers):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.resnets[i]),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1:
forward_encoder_hidden_states = encoder_hidden_states
......@@ -1141,8 +1120,8 @@ class CrossAttnDownBlock2D(nn.Module):
else:
forward_encoder_hidden_states = None
forward_encoder_attention_mask = None
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
hidden_states = self._gradient_checkpointing_func(
self.attentions[i * num_attention_per_layer + idx],
hidden_states,
forward_encoder_hidden_states,
None, # timestep
......@@ -1150,7 +1129,6 @@ class CrossAttnDownBlock2D(nn.Module):
cross_attention_kwargs,
attention_mask,
forward_encoder_attention_mask,
**ckpt_kwargs,
)[0]
else:
hidden_states = self.resnets[i](hidden_states, temb)
......@@ -1292,17 +1270,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
for i in range(len(self.resnets[1:])):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1:
forward_encoder_hidden_states = encoder_hidden_states
......@@ -1313,8 +1280,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
else:
forward_encoder_hidden_states = None
forward_encoder_attention_mask = None
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
hidden_states = self._gradient_checkpointing_func(
self.attentions[i * num_attention_per_layer + idx],
hidden_states,
forward_encoder_hidden_states,
None, # timestep
......@@ -1322,14 +1289,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
cross_attention_kwargs,
attention_mask,
forward_encoder_attention_mask,
**ckpt_kwargs,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.resnets[i + 1]),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(self.resnets[i + 1], hidden_states, temb)
else:
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1:
......@@ -1466,23 +1427,7 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.resnets[i]),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1:
forward_encoder_hidden_states = encoder_hidden_states
......@@ -1493,8 +1438,8 @@ class CrossAttnUpBlock2D(nn.Module):
else:
forward_encoder_hidden_states = None
forward_encoder_attention_mask = None
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
hidden_states = self._gradient_checkpointing_func(
self.attentions[i * num_attention_per_layer + idx],
hidden_states,
forward_encoder_hidden_states,
None, # timestep
......@@ -1502,7 +1447,6 @@ class CrossAttnUpBlock2D(nn.Module):
cross_attention_kwargs,
attention_mask,
forward_encoder_attention_mask,
**ckpt_kwargs,
)[0]
else:
hidden_states = self.resnets[i](hidden_states, temb)
......
......@@ -174,19 +174,16 @@ class Blip2QFormerEncoder(nn.Module):
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions, query_length)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self._gradient_checkpointing_func(
layer_module,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
else:
layer_outputs = layer_module(
......
......@@ -34,7 +34,7 @@ from ....models.resnet import ResnetBlockCondNorm2D
from ....models.transformers.dual_transformer_2d import DualTransformer2DModel
from ....models.transformers.transformer_2d import Transformer2DModel
from ....models.unets.unet_2d_condition import UNet2DConditionOutput
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ....utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ....utils.torch_utils import apply_freeu
......@@ -963,10 +963,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
......@@ -1597,21 +1593,7 @@ class DownBlockFlat(nn.Module):
for resnet in self.resnets:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
......@@ -1734,23 +1716,7 @@ class CrossAttnDownBlockFlat(nn.Module):
for i, (resnet, attn) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
......@@ -1876,21 +1842,7 @@ class UpBlockFlat(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
......@@ -2035,23 +1987,7 @@ class CrossAttnUpBlockFlat(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
......@@ -2230,25 +2166,9 @@ class UNetMidBlockFlat(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
......@@ -2377,17 +2297,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
......@@ -2396,12 +2305,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = attn(
hidden_states,
......
......@@ -605,7 +605,7 @@ class GLMTransformer(torch.nn.Module):
layer = self._get_layer(index)
if torch.is_grad_enabled() and self.gradient_checkpointing:
layer_ret = torch.utils.checkpoint.checkpoint(
layer_ret = self._gradient_checkpointing_func(
layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache
)
else:
......@@ -666,10 +666,6 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
return position_ids
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GLMTransformer):
module.gradient_checkpointing = value
def default_init(cls, *args, **kwargs):
return cls(*args, **kwargs)
......
......@@ -544,10 +544,6 @@ class LDMBertPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LDMBertEncoder,)):
module.gradient_checkpointing = value
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
......@@ -688,15 +684,8 @@ class LDMBertEncoder(LDMBertPreTrainedModel):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self._gradient_checkpointing_func(
encoder_layer,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
......
......@@ -29,7 +29,6 @@ from ...models.attention_processor import (
AttnProcessor,
)
from ...models.modeling_utils import ModelMixin
from ...utils import is_torch_version
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
......@@ -138,9 +137,6 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
self.set_attn_processor(processor)
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
......@@ -159,33 +155,13 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
r_embed = self.gen_r_embedding(r)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
for block in self.blocks:
if isinstance(block, AttnBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, c_embed, use_reentrant=False
)
elif isinstance(block, TimestepBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, r_embed, use_reentrant=False
)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
else:
for block in self.blocks:
if isinstance(block, AttnBlock):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed)
elif isinstance(block, TimestepBlock):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x)
for block in self.blocks:
if isinstance(block, AttnBlock):
x = self._gradient_checkpointing_func(block, x, c_embed)
elif isinstance(block, TimestepBlock):
x = self._gradient_checkpointing_func(block, x, r_embed)
else:
x = self._gradient_checkpointing_func(block, x)
else:
for block in self.blocks:
if isinstance(block, AttnBlock):
......
......@@ -953,24 +953,15 @@ class ModelTesterMixin:
init_dict["block_out_channels"] = block_out_channels
model_class_copy = copy.copy(self.model_class)
modules_with_gc_enabled = {}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
modules_with_gc_enabled = {}
for submodule in model.modules():
if hasattr(submodule, "gradient_checkpointing"):
self.assertTrue(submodule.gradient_checkpointing)
modules_with_gc_enabled[submodule.__class__.__name__] = True
assert set(modules_with_gc_enabled.keys()) == expected_set
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment