Unverified Commit 116f70cb authored by Andy's avatar Andy Committed by GitHub
Browse files

Enabling gradient checkpointing for VAE (#2536)



* updated black format

* update black format

* make style format

* updated line endings

* update code formatting

* Update examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/vae.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/vae.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* added vae gradient checkpointing test

* make style

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>
parent a1695715
...@@ -412,6 +412,7 @@ def main(): ...@@ -412,6 +412,7 @@ def main():
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
vae.enable_gradient_checkpointing()
# Enable TF32 for faster training on Ampere GPUs, # Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
......
...@@ -65,6 +65,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -65,6 +65,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
""" """
_supports_gradient_checkpointing = True
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -121,6 +123,10 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -121,6 +123,10 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1))) self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25 self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value
def enable_tiling(self, use_tiling: bool = True): def enable_tiling(self, use_tiling: bool = True):
r""" r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
......
...@@ -24,9 +24,7 @@ from .attention_processor import ( # noqa: F401 ...@@ -24,9 +24,7 @@ from .attention_processor import ( # noqa: F401
SlicedAttnProcessor, SlicedAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from .attention_processor import ( # noqa: F401 from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401
AttnProcessor as AttnProcessorRename,
)
deprecate( deprecate(
......
...@@ -50,7 +50,13 @@ class Encoder(nn.Module): ...@@ -50,7 +50,13 @@ class Encoder(nn.Module):
super().__init__() super().__init__()
self.layers_per_block = layers_per_block self.layers_per_block = layers_per_block
self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) self.conv_in = torch.nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=3,
stride=1,
padding=1,
)
self.mid_block = None self.mid_block = None
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
...@@ -96,10 +102,28 @@ class Encoder(nn.Module): ...@@ -96,10 +102,28 @@ class Encoder(nn.Module):
conv_out_channels = 2 * out_channels if double_z else out_channels conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x): def forward(self, x):
sample = x sample = x
sample = self.conv_in(sample) sample = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# down
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
else:
# down # down
for down_block in self.down_blocks: for down_block in self.down_blocks:
sample = down_block(sample) sample = down_block(sample)
...@@ -129,7 +153,13 @@ class Decoder(nn.Module): ...@@ -129,7 +153,13 @@ class Decoder(nn.Module):
super().__init__() super().__init__()
self.layers_per_block = layers_per_block self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)
self.mid_block = None self.mid_block = None
self.up_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
...@@ -176,10 +206,27 @@ class Decoder(nn.Module): ...@@ -176,10 +206,27 @@ class Decoder(nn.Module):
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, z): def forward(self, z):
sample = z sample = z
sample = self.conv_in(sample) sample = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
else:
# middle # middle
sample = self.mid_block(sample) sample = self.mid_block(sample)
......
...@@ -68,6 +68,47 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -68,6 +68,47 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
def test_training(self): def test_training(self):
pass pass
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment