Unverified Commit c4d4ac21 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Refactor gradient checkpointing (#10611)

* update

* remove unused fn

* apply suggestions based on review

* update + cleanup 🧹

* more cleanup 🧹

* make fix-copies

* update test
parent f295e2ee
...@@ -35,11 +35,7 @@ from ..embeddings import TimestepEmbedding, Timesteps ...@@ -35,11 +35,7 @@ from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..transformers.transformer_temporal import TransformerTemporalModel from ..transformers.transformer_temporal import TransformerTemporalModel
from .unet_3d_blocks import ( from .unet_3d_blocks import (
CrossAttnDownBlock3D,
CrossAttnUpBlock3D,
DownBlock3D,
UNetMidBlock3DCrossAttn, UNetMidBlock3DCrossAttn,
UpBlock3D,
get_down_block, get_down_block,
get_up_block, get_up_block,
) )
...@@ -436,11 +432,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -436,11 +432,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
self.set_attn_processor(processor) self.set_attn_processor(processor)
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel._set_gradient_checkpointing
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1, s2, b1, b2): def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
......
...@@ -205,10 +205,6 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin): ...@@ -205,10 +205,6 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
""" """
self.set_attn_processor(AttnProcessor()) self.set_attn_processor(AttnProcessor())
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True): def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
if encoder_attention_mask is not None: if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
......
...@@ -22,7 +22,7 @@ import torch.utils.checkpoint ...@@ -22,7 +22,7 @@ import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, is_torch_version, logging from ...utils import BaseOutput, deprecate, logging
from ...utils.torch_utils import apply_freeu from ...utils.torch_utils import apply_freeu
from ..attention import BasicTransformerBlock from ..attention import BasicTransformerBlock
from ..attention_processor import ( from ..attention_processor import (
...@@ -324,25 +324,7 @@ class DownBlockMotion(nn.Module): ...@@ -324,25 +324,7 @@ class DownBlockMotion(nn.Module):
blocks = zip(self.resnets, self.motion_modules) blocks = zip(self.resnets, self.motion_modules)
for resnet, motion_module in blocks: for resnet, motion_module in blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb) hidden_states = resnet(input_tensor=hidden_states, temb=temb)
...@@ -514,23 +496,7 @@ class CrossAttnDownBlockMotion(nn.Module): ...@@ -514,23 +496,7 @@ class CrossAttnDownBlockMotion(nn.Module):
blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
for i, (resnet, attn, motion_module) in enumerate(blocks): for i, (resnet, attn, motion_module) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else: else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb) hidden_states = resnet(input_tensor=hidden_states, temb=temb)
...@@ -543,10 +509,7 @@ class CrossAttnDownBlockMotion(nn.Module): ...@@ -543,10 +509,7 @@ class CrossAttnDownBlockMotion(nn.Module):
return_dict=False, return_dict=False,
)[0] )[0]
hidden_states = motion_module( hidden_states = motion_module(hidden_states, num_frames=num_frames)
hidden_states,
num_frames=num_frames,
)
# apply additional residuals to the output of the last pair of resnet and attention blocks # apply additional residuals to the output of the last pair of resnet and attention blocks
if i == len(blocks) - 1 and additional_residuals is not None: if i == len(blocks) - 1 and additional_residuals is not None:
...@@ -733,23 +696,7 @@ class CrossAttnUpBlockMotion(nn.Module): ...@@ -733,23 +696,7 @@ class CrossAttnUpBlockMotion(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else: else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb) hidden_states = resnet(input_tensor=hidden_states, temb=temb)
...@@ -762,10 +709,7 @@ class CrossAttnUpBlockMotion(nn.Module): ...@@ -762,10 +709,7 @@ class CrossAttnUpBlockMotion(nn.Module):
return_dict=False, return_dict=False,
)[0] )[0]
hidden_states = motion_module( hidden_states = motion_module(hidden_states, num_frames=num_frames)
hidden_states,
num_frames=num_frames,
)
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
...@@ -896,24 +840,7 @@ class UpBlockMotion(nn.Module): ...@@ -896,24 +840,7 @@ class UpBlockMotion(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb) hidden_states = resnet(input_tensor=hidden_states, temb=temb)
...@@ -1080,34 +1007,12 @@ class UNetMidBlockCrossAttnMotion(nn.Module): ...@@ -1080,34 +1007,12 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
)[0] )[0]
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
def create_custom_forward(module, return_dict=None): motion_module, hidden_states, None, None, None, num_frames, None
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
) )
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else: else:
hidden_states = motion_module( hidden_states = motion_module(hidden_states, None, None, None, num_frames, None)
hidden_states,
num_frames=num_frames,
)
hidden_states = resnet(input_tensor=hidden_states, temb=temb) hidden_states = resnet(input_tensor=hidden_states, temb=temb)
return hidden_states return hidden_states
...@@ -1966,10 +1871,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -1966,10 +1871,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
self.set_attn_processor(processor) self.set_attn_processor(processor)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
module.gradient_checkpointing = value
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None: def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
......
...@@ -320,10 +320,6 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL ...@@ -320,10 +320,6 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
self.set_attn_processor(processor) self.set_attn_processor(processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
""" """
......
...@@ -387,9 +387,6 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -387,9 +387,6 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, value=False):
self.gradient_checkpointing = value
def _init_weights(self, m): def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)): if isinstance(m, (nn.Conv2d, nn.Linear)):
torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.xavier_uniform_(m.weight)
...@@ -456,29 +453,18 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -456,29 +453,18 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for down_block, downscaler, repmap in block_group: for down_block, downscaler, repmap in block_group:
x = downscaler(x) x = downscaler(x)
for i in range(len(repmap) + 1): for i in range(len(repmap) + 1):
for block in down_block: for block in down_block:
if isinstance(block, SDCascadeResBlock): if isinstance(block, SDCascadeResBlock):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) x = self._gradient_checkpointing_func(block, x)
elif isinstance(block, SDCascadeAttnBlock): elif isinstance(block, SDCascadeAttnBlock):
x = torch.utils.checkpoint.checkpoint( x = self._gradient_checkpointing_func(block, x, clip)
create_custom_forward(block), x, clip, use_reentrant=False
)
elif isinstance(block, SDCascadeTimestepBlock): elif isinstance(block, SDCascadeTimestepBlock):
x = torch.utils.checkpoint.checkpoint( x = self._gradient_checkpointing_func(block, x, r_embed)
create_custom_forward(block), x, r_embed, use_reentrant=False
)
else: else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), use_reentrant=False) x = self._gradient_checkpointing_func(block)
if i < len(repmap): if i < len(repmap):
x = repmap[i](x) x = repmap[i](x)
level_outputs.insert(0, x) level_outputs.insert(0, x)
...@@ -505,13 +491,6 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -505,13 +491,6 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for i, (up_block, upscaler, repmap) in enumerate(block_group): for i, (up_block, upscaler, repmap) in enumerate(block_group):
for j in range(len(repmap) + 1): for j in range(len(repmap) + 1):
for k, block in enumerate(up_block): for k, block in enumerate(up_block):
...@@ -523,19 +502,13 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -523,19 +502,13 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
) )
x = x.to(orig_type) x = x.to(orig_type)
x = torch.utils.checkpoint.checkpoint( x = self._gradient_checkpointing_func(block, x, skip)
create_custom_forward(block), x, skip, use_reentrant=False
)
elif isinstance(block, SDCascadeAttnBlock): elif isinstance(block, SDCascadeAttnBlock):
x = torch.utils.checkpoint.checkpoint( x = self._gradient_checkpointing_func(block, x, clip)
create_custom_forward(block), x, clip, use_reentrant=False
)
elif isinstance(block, SDCascadeTimestepBlock): elif isinstance(block, SDCascadeTimestepBlock):
x = torch.utils.checkpoint.checkpoint( x = self._gradient_checkpointing_func(block, x, r_embed)
create_custom_forward(block), x, r_embed, use_reentrant=False
)
else: else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) x = self._gradient_checkpointing_func(block, x)
if j < len(repmap): if j < len(repmap):
x = repmap[j](x) x = repmap[j](x)
x = upscaler(x) x = upscaler(x)
......
...@@ -148,9 +148,6 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -148,9 +148,6 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
pass
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None): def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
encoder_hidden_states = self.encoder_proj(encoder_hidden_states) encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
......
...@@ -38,7 +38,7 @@ from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D ...@@ -38,7 +38,7 @@ from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from ...models.transformers.transformer_2d import Transformer2DModel from ...models.transformers.transformer_2d import Transformer2DModel
from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
from ...models.unets.unet_2d_condition import UNet2DConditionOutput from ...models.unets.unet_2d_condition import UNet2DConditionOutput
from ...utils import BaseOutput, is_torch_version, logging from ...utils import BaseOutput, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -673,11 +673,6 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -673,11 +673,6 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
for module in self.children(): for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward( def forward(
self, self,
sample: torch.Tensor, sample: torch.Tensor,
...@@ -1114,23 +1109,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -1114,23 +1109,7 @@ class CrossAttnDownBlock2D(nn.Module):
for i in range(num_layers): for i in range(num_layers):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.resnets[i]),
hidden_states,
temb,
**ckpt_kwargs,
)
for idx, cross_attention_dim in enumerate(self.cross_attention_dim): for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1: if cross_attention_dim is not None and idx <= 1:
forward_encoder_hidden_states = encoder_hidden_states forward_encoder_hidden_states = encoder_hidden_states
...@@ -1141,8 +1120,8 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -1141,8 +1120,8 @@ class CrossAttnDownBlock2D(nn.Module):
else: else:
forward_encoder_hidden_states = None forward_encoder_hidden_states = None
forward_encoder_attention_mask = None forward_encoder_attention_mask = None
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(
create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False), self.attentions[i * num_attention_per_layer + idx],
hidden_states, hidden_states,
forward_encoder_hidden_states, forward_encoder_hidden_states,
None, # timestep None, # timestep
...@@ -1150,7 +1129,6 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -1150,7 +1129,6 @@ class CrossAttnDownBlock2D(nn.Module):
cross_attention_kwargs, cross_attention_kwargs,
attention_mask, attention_mask,
forward_encoder_attention_mask, forward_encoder_attention_mask,
**ckpt_kwargs,
)[0] )[0]
else: else:
hidden_states = self.resnets[i](hidden_states, temb) hidden_states = self.resnets[i](hidden_states, temb)
...@@ -1292,17 +1270,6 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -1292,17 +1270,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
for i in range(len(self.resnets[1:])): for i in range(len(self.resnets[1:])):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for idx, cross_attention_dim in enumerate(self.cross_attention_dim): for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1: if cross_attention_dim is not None and idx <= 1:
forward_encoder_hidden_states = encoder_hidden_states forward_encoder_hidden_states = encoder_hidden_states
...@@ -1313,8 +1280,8 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -1313,8 +1280,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
else: else:
forward_encoder_hidden_states = None forward_encoder_hidden_states = None
forward_encoder_attention_mask = None forward_encoder_attention_mask = None
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(
create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False), self.attentions[i * num_attention_per_layer + idx],
hidden_states, hidden_states,
forward_encoder_hidden_states, forward_encoder_hidden_states,
None, # timestep None, # timestep
...@@ -1322,14 +1289,8 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -1322,14 +1289,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
cross_attention_kwargs, cross_attention_kwargs,
attention_mask, attention_mask,
forward_encoder_attention_mask, forward_encoder_attention_mask,
**ckpt_kwargs,
)[0] )[0]
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(self.resnets[i + 1], hidden_states, temb)
create_custom_forward(self.resnets[i + 1]),
hidden_states,
temb,
**ckpt_kwargs,
)
else: else:
for idx, cross_attention_dim in enumerate(self.cross_attention_dim): for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1: if cross_attention_dim is not None and idx <= 1:
...@@ -1466,23 +1427,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1466,23 +1427,7 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.resnets[i]),
hidden_states,
temb,
**ckpt_kwargs,
)
for idx, cross_attention_dim in enumerate(self.cross_attention_dim): for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1: if cross_attention_dim is not None and idx <= 1:
forward_encoder_hidden_states = encoder_hidden_states forward_encoder_hidden_states = encoder_hidden_states
...@@ -1493,8 +1438,8 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1493,8 +1438,8 @@ class CrossAttnUpBlock2D(nn.Module):
else: else:
forward_encoder_hidden_states = None forward_encoder_hidden_states = None
forward_encoder_attention_mask = None forward_encoder_attention_mask = None
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(
create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False), self.attentions[i * num_attention_per_layer + idx],
hidden_states, hidden_states,
forward_encoder_hidden_states, forward_encoder_hidden_states,
None, # timestep None, # timestep
...@@ -1502,7 +1447,6 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1502,7 +1447,6 @@ class CrossAttnUpBlock2D(nn.Module):
cross_attention_kwargs, cross_attention_kwargs,
attention_mask, attention_mask,
forward_encoder_attention_mask, forward_encoder_attention_mask,
**ckpt_kwargs,
)[0] )[0]
else: else:
hidden_states = self.resnets[i](hidden_states, temb) hidden_states = self.resnets[i](hidden_states, temb)
......
...@@ -174,19 +174,16 @@ class Blip2QFormerEncoder(nn.Module): ...@@ -174,19 +174,16 @@ class Blip2QFormerEncoder(nn.Module):
) )
use_cache = False use_cache = False
def create_custom_forward(module): layer_outputs = self._gradient_checkpointing_func(
def custom_forward(*inputs): layer_module,
return module(*inputs, past_key_value, output_attentions, query_length)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
......
...@@ -34,7 +34,7 @@ from ....models.resnet import ResnetBlockCondNorm2D ...@@ -34,7 +34,7 @@ from ....models.resnet import ResnetBlockCondNorm2D
from ....models.transformers.dual_transformer_2d import DualTransformer2DModel from ....models.transformers.dual_transformer_2d import DualTransformer2DModel
from ....models.transformers.transformer_2d import Transformer2DModel from ....models.transformers.transformer_2d import Transformer2DModel
from ....models.unets.unet_2d_condition import UNet2DConditionOutput from ....models.unets.unet_2d_condition import UNet2DConditionOutput
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ....utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ....utils.torch_utils import apply_freeu from ....utils.torch_utils import apply_freeu
...@@ -963,10 +963,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -963,10 +963,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
for module in self.children(): for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def enable_freeu(self, s1, s2, b1, b2): def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
...@@ -1597,21 +1593,7 @@ class DownBlockFlat(nn.Module): ...@@ -1597,21 +1593,7 @@ class DownBlockFlat(nn.Module):
for resnet in self.resnets: for resnet in self.resnets:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1734,23 +1716,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1734,23 +1716,7 @@ class CrossAttnDownBlockFlat(nn.Module):
for i, (resnet, attn) in enumerate(blocks): for i, (resnet, attn) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -1876,21 +1842,7 @@ class UpBlockFlat(nn.Module): ...@@ -1876,21 +1842,7 @@ class UpBlockFlat(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -2035,23 +1987,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -2035,23 +1987,7 @@ class CrossAttnUpBlockFlat(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -2230,25 +2166,9 @@ class UNetMidBlockFlat(nn.Module): ...@@ -2230,25 +2166,9 @@ class UNetMidBlockFlat(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
if attn is not None: if attn is not None:
hidden_states = attn(hidden_states, temb=temb) hidden_states = attn(hidden_states, temb=temb)
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else: else:
if attn is not None: if attn is not None:
hidden_states = attn(hidden_states, temb=temb) hidden_states = attn(hidden_states, temb=temb)
...@@ -2377,17 +2297,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -2377,17 +2297,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -2396,12 +2305,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -2396,12 +2305,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else: else:
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
......
...@@ -605,7 +605,7 @@ class GLMTransformer(torch.nn.Module): ...@@ -605,7 +605,7 @@ class GLMTransformer(torch.nn.Module):
layer = self._get_layer(index) layer = self._get_layer(index)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
layer_ret = torch.utils.checkpoint.checkpoint( layer_ret = self._gradient_checkpointing_func(
layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache
) )
else: else:
...@@ -666,10 +666,6 @@ class ChatGLMPreTrainedModel(PreTrainedModel): ...@@ -666,10 +666,6 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
return position_ids return position_ids
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GLMTransformer):
module.gradient_checkpointing = value
def default_init(cls, *args, **kwargs): def default_init(cls, *args, **kwargs):
return cls(*args, **kwargs) return cls(*args, **kwargs)
......
...@@ -544,10 +544,6 @@ class LDMBertPreTrainedModel(PreTrainedModel): ...@@ -544,10 +544,6 @@ class LDMBertPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LDMBertEncoder,)):
module.gradient_checkpointing = value
@property @property
def dummy_inputs(self): def dummy_inputs(self):
pad_token = self.config.pad_token_id pad_token = self.config.pad_token_id
...@@ -688,15 +684,8 @@ class LDMBertEncoder(LDMBertPreTrainedModel): ...@@ -688,15 +684,8 @@ class LDMBertEncoder(LDMBertPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
layer_outputs = self._gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
......
...@@ -29,7 +29,6 @@ from ...models.attention_processor import ( ...@@ -29,7 +29,6 @@ from ...models.attention_processor import (
AttnProcessor, AttnProcessor,
) )
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...utils import is_torch_version
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
...@@ -138,9 +137,6 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -138,9 +137,6 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
self.set_attn_processor(processor) self.set_attn_processor(processor)
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def gen_r_embedding(self, r, max_positions=10000): def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions r = r * max_positions
half_dim = self.c_r // 2 half_dim = self.c_r // 2
...@@ -159,33 +155,13 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -159,33 +155,13 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
r_embed = self.gen_r_embedding(r) r_embed = self.gen_r_embedding(r)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
for block in self.blocks:
if isinstance(block, AttnBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, c_embed, use_reentrant=False
)
elif isinstance(block, TimestepBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, r_embed, use_reentrant=False
)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
else:
for block in self.blocks: for block in self.blocks:
if isinstance(block, AttnBlock): if isinstance(block, AttnBlock):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed) x = self._gradient_checkpointing_func(block, x, c_embed)
elif isinstance(block, TimestepBlock): elif isinstance(block, TimestepBlock):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed) x = self._gradient_checkpointing_func(block, x, r_embed)
else: else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x) x = self._gradient_checkpointing_func(block, x)
else: else:
for block in self.blocks: for block in self.blocks:
if isinstance(block, AttnBlock): if isinstance(block, AttnBlock):
......
...@@ -953,24 +953,15 @@ class ModelTesterMixin: ...@@ -953,24 +953,15 @@ class ModelTesterMixin:
init_dict["block_out_channels"] = block_out_channels init_dict["block_out_channels"] = block_out_channels
model_class_copy = copy.copy(self.model_class) model_class_copy = copy.copy(self.model_class)
modules_with_gc_enabled = {}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
model = model_class_copy(**init_dict) model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing() model.enable_gradient_checkpointing()
modules_with_gc_enabled = {}
for submodule in model.modules():
if hasattr(submodule, "gradient_checkpointing"):
self.assertTrue(submodule.gradient_checkpointing)
modules_with_gc_enabled[submodule.__class__.__name__] = True
assert set(modules_with_gc_enabled.keys()) == expected_set assert set(modules_with_gc_enabled.keys()) == expected_set
assert all(modules_with_gc_enabled.values()), "All modules should be enabled" assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment