Unverified Commit 15f1bab1 authored by 7eu7d7's avatar 7eu7d7 Committed by GitHub
Browse files

Fix gradient checkpointing bugs in freezing part of models (requires_grad=False) (#3404)



* gradient checkpointing bug fix

* bug fix; changes for reviews

* reformat

* reformat

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 415c6167
...@@ -18,6 +18,7 @@ import torch ...@@ -18,6 +18,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils import is_torch_version
from .attention import AdaGroupNorm from .attention import AdaGroupNorm
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
from .dual_transformer_2d import DualTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel
...@@ -866,13 +867,27 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -866,13 +867,27 @@ class CrossAttnDownBlock2D(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
hidden_states, )
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
cross_attention_kwargs, create_custom_forward(attn, return_dict=False),
)[0] hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
use_reentrant=False,
)[0]
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
)[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
hidden_states = attn( hidden_states = attn(
...@@ -957,7 +972,14 @@ class DownBlock2D(nn.Module): ...@@ -957,7 +972,14 @@ class DownBlock2D(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1361,7 +1383,14 @@ class ResnetDownsampleBlock2D(nn.Module): ...@@ -1361,7 +1383,14 @@ class ResnetDownsampleBlock2D(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1558,7 +1587,14 @@ class KDownBlock2D(nn.Module): ...@@ -1558,7 +1587,14 @@ class KDownBlock2D(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1653,14 +1689,29 @@ class KCrossAttnDownBlock2D(nn.Module): ...@@ -1653,14 +1689,29 @@ class KCrossAttnDownBlock2D(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
hidden_states, )
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
attention_mask, create_custom_forward(attn, return_dict=False),
cross_attention_kwargs, hidden_states,
) encoder_hidden_states,
attention_mask,
cross_attention_kwargs,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
attention_mask,
cross_attention_kwargs,
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
hidden_states = attn( hidden_states = attn(
...@@ -1874,13 +1925,27 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1874,13 +1925,27 @@ class CrossAttnUpBlock2D(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
hidden_states, )
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
cross_attention_kwargs, create_custom_forward(attn, return_dict=False),
)[0] hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
use_reentrant=False,
)[0]
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
)[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
hidden_states = attn( hidden_states = attn(
...@@ -1960,7 +2025,14 @@ class UpBlock2D(nn.Module): ...@@ -1960,7 +2025,14 @@ class UpBlock2D(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -2388,7 +2460,14 @@ class ResnetUpsampleBlock2D(nn.Module): ...@@ -2388,7 +2460,14 @@ class ResnetUpsampleBlock2D(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -2593,7 +2672,14 @@ class KUpBlock2D(nn.Module): ...@@ -2593,7 +2672,14 @@ class KUpBlock2D(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -2714,14 +2800,29 @@ class KCrossAttnUpBlock2D(nn.Module): ...@@ -2714,14 +2800,29 @@ class KCrossAttnUpBlock2D(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
hidden_states, )
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
attention_mask, create_custom_forward(attn, return_dict=False),
cross_attention_kwargs, hidden_states,
)[0] encoder_hidden_states,
attention_mask,
cross_attention_kwargs,
use_reentrant=False,
)[0]
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
attention_mask,
cross_attention_kwargs,
)[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
hidden_states = attn( hidden_states = attn(
......
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..utils import BaseOutput, randn_tensor from ..utils import BaseOutput, is_torch_version, randn_tensor
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
...@@ -117,11 +117,20 @@ class Encoder(nn.Module): ...@@ -117,11 +117,20 @@ class Encoder(nn.Module):
return custom_forward return custom_forward
# down # down
for down_block in self.down_blocks: if is_torch_version(">=", "1.11.0"):
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(
# middle create_custom_forward(down_block), sample, use_reentrant=False
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) )
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, use_reentrant=False
)
else:
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
else: else:
# down # down
...@@ -221,13 +230,26 @@ class Decoder(nn.Module): ...@@ -221,13 +230,26 @@ class Decoder(nn.Module):
return custom_forward return custom_forward
# middle if is_torch_version(">=", "1.11.0"):
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) # middle
sample = sample.to(upscale_dtype) sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, use_reentrant=False
)
sample = sample.to(upscale_dtype)
# up # up
for up_block in self.up_blocks: for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, use_reentrant=False
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
else: else:
# middle # middle
sample = self.mid_block(sample) sample = self.mid_block(sample)
......
...@@ -18,7 +18,7 @@ from ...models.dual_transformer_2d import DualTransformer2DModel ...@@ -18,7 +18,7 @@ from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps from ...models.embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import logging from ...utils import is_torch_version, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1077,7 +1077,14 @@ class DownBlockFlat(nn.Module): ...@@ -1077,7 +1077,14 @@ class DownBlockFlat(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1198,13 +1205,27 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1198,13 +1205,27 @@ class CrossAttnDownBlockFlat(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
hidden_states, )
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
cross_attention_kwargs, create_custom_forward(attn, return_dict=False),
)[0] hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
use_reentrant=False,
)[0]
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
)[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
hidden_states = attn( hidden_states = attn(
...@@ -1289,7 +1310,14 @@ class UpBlockFlat(nn.Module): ...@@ -1289,7 +1310,14 @@ class UpBlockFlat(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1412,13 +1440,27 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1412,13 +1440,27 @@ class CrossAttnUpBlockFlat(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
hidden_states, )
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
cross_attention_kwargs, create_custom_forward(attn, return_dict=False),
)[0] hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
use_reentrant=False,
)[0]
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
)[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
hidden_states = attn( hidden_states = attn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment