Unverified Commit f5ccffec authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Use `accelerate` save & loading hooks to have better checkpoint structure (#2048)



* better accelerated saving

* up

* finish

* finish

* uP

* up

* up

* fix

* Apply suggestions from code review

* correct ema

* Remove @

* up

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/training/dreambooth.mdx
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent e619db24
...@@ -127,7 +127,30 @@ This would be a good opportunity to tweak some of your hyperparameters if you wi ...@@ -127,7 +127,30 @@ This would be a good opportunity to tweak some of your hyperparameters if you wi
Saved checkpoints are stored in a format suitable for resuming training. They not only include the model weights, but also the state of the optimizer, data loaders and learning rate. Saved checkpoints are stored in a format suitable for resuming training. They not only include the model weights, but also the state of the optimizer, data loaders and learning rate.
You can use a checkpoint for inference, but first you need to convert it to an inference pipeline. This is how you could do it: **Note**: If you have installed `"accelerate>=0.16.0"` you can use the following code to run
inference from an intermediate checkpoint.
```python
from diffusers import DiffusionPipeline, UNet2DConditionModel
from transformers import CLIPTextModel
import torch
# Load the pipeline with the same arguments (model, revision) that were used for training
model_id = "CompVis/stable-diffusion-v1-4"
unet = UNet2DConditionModel.from_pretrained("/sddata/dreambooth/daruma-v2-1/checkpoint-100/unet")
# if you have trained with `--args.train_text_encoder` make sure to also load the text encoder
text_encoder = CLIPTextModel.from_pretrained("/sddata/dreambooth/daruma-v2-1/checkpoint-100/text_encoder")
pipeline = DiffusionPipeline.from_pretrained(model_id, unet=unet, text_encoder=text_encoder, dtype=torch.float16)
pipeline.to("cuda")
# Perform inference, or save, or push to the hub
pipeline.save_pretrained("dreambooth-pipeline")
```
If you have installed `"accelerate<0.16.0"` you need to first convert it to an inference pipeline. This is how you could do it:
```python ```python
from accelerate import Accelerator from accelerate import Accelerator
...@@ -271,6 +294,10 @@ accelerate launch train_dreambooth.py \ ...@@ -271,6 +294,10 @@ accelerate launch train_dreambooth.py \
Once you have trained a model, inference can be done using the `StableDiffusionPipeline`, by simply indicating the path where the model was saved. Make sure that your prompts include the special `identifier` used during training (`sks` in the previous examples). Once you have trained a model, inference can be done using the `StableDiffusionPipeline`, by simply indicating the path where the model was saved. Make sure that your prompts include the special `identifier` used during training (`sks` in the previous examples).
**Note**: If you have installed `"accelerate>=0.16.0"` you can use the following code to run
inference from an intermediate checkpoint.
```python ```python
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
import torch import torch
...@@ -284,4 +311,4 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] ...@@ -284,4 +311,4 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("dog-bucket.png") image.save("dog-bucket.png")
``` ```
You may also run inference from [any of the saved training checkpoints](#performing-inference-using-a-saved-checkpoint). You may also run inference from [any of the saved training checkpoints](#performing-inference-using-a-saved-checkpoint).
\ No newline at end of file
...@@ -28,6 +28,7 @@ import torch.nn.functional as F ...@@ -28,6 +28,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.utils.data import Dataset from torch.utils.data import Dataset
import accelerate
import diffusers import diffusers
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
...@@ -38,6 +39,7 @@ from diffusers.optimization import get_scheduler ...@@ -38,6 +39,7 @@ from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from packaging import version
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -606,6 +608,37 @@ def main(args): ...@@ -606,6 +608,37 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
) )
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
for model in models:
sub_dir = "unet" if type(model) == type(unet) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
while len(models) > 0:
# pop models so that they are not loaded again
model = models.pop()
if type(model) == type(text_encoder):
# load transformers style into model
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
model.config = load_model.config
else:
# load diffusers style into model
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
vae.requires_grad_(False) vae.requires_grad_(False)
if not args.train_text_encoder: if not args.train_text_encoder:
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
......
...@@ -26,6 +26,7 @@ import torch ...@@ -26,6 +26,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
import accelerate
import datasets import datasets
import diffusers import diffusers
import transformers import transformers
...@@ -36,9 +37,10 @@ from datasets import load_dataset ...@@ -36,9 +37,10 @@ from datasets import load_dataset
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version from diffusers.utils import check_min_version, deprecate
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from packaging import version
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
...@@ -319,6 +321,16 @@ dataset_name_mapping = { ...@@ -319,6 +321,16 @@ dataset_name_mapping = {
def main(): def main():
args = parse_args() args = parse_args()
if args.non_ema_revision is not None:
deprecate(
"non_ema_revision!=None",
"0.15.0",
message=(
"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
" use `--variant=non_ema` instead."
),
)
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator = Accelerator( accelerator = Accelerator(
...@@ -396,6 +408,39 @@ def main(): ...@@ -396,6 +408,39 @@ def main():
else: else:
raise ValueError("xformers is not available. Make sure it is installed correctly") raise ValueError("xformers is not available. Make sure it is installed correctly")
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
if args.use_ema:
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
ema_unet.load_state_dict(load_model.state_dict())
del load_model
for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()
# load diffusers style into model
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
...@@ -552,8 +597,9 @@ def main(): ...@@ -552,8 +597,9 @@ def main():
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler unet, optimizer, train_dataloader, lr_scheduler
) )
if args.use_ema: if args.use_ema:
accelerator.register_for_checkpointing(ema_unet) ema_unet.to(accelerator.device)
# For mixed precision training we cast the text_encoder and vae weights to half-precision # For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required. # as these models are only used for inference, keeping weights in full precision is not required.
...@@ -566,8 +612,6 @@ def main(): ...@@ -566,8 +612,6 @@ def main():
# Move text_encode and vae to gpu and cast to weight_dtype # Move text_encode and vae to gpu and cast to weight_dtype
text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
if args.use_ema:
ema_unet.to(accelerator.device)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
......
...@@ -9,6 +9,7 @@ from typing import Optional ...@@ -9,6 +9,7 @@ from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import accelerate
import datasets import datasets
import diffusers import diffusers
from accelerate import Accelerator from accelerate import Accelerator
...@@ -19,6 +20,7 @@ from diffusers.optimization import get_scheduler ...@@ -19,6 +20,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from packaging import version
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -271,6 +273,40 @@ def main(args): ...@@ -271,6 +273,40 @@ def main(args):
logging_dir=logging_dir, logging_dir=logging_dir,
) )
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if args.use_ema:
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
if args.use_ema:
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel)
ema_model.load_state_dict(load_model.state_dict())
ema_model.to(accelerator.device)
del load_model
for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()
# load diffusers style into model
load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
...@@ -336,6 +372,8 @@ def main(args): ...@@ -336,6 +372,8 @@ def main(args):
use_ema_warmup=True, use_ema_warmup=True,
inv_gamma=args.ema_inv_gamma, inv_gamma=args.ema_inv_gamma,
power=args.ema_power, power=args.ema_power,
model_cls=UNet2DModel,
model_config=model.config,
) )
# Initialize the scheduler # Initialize the scheduler
...@@ -411,7 +449,6 @@ def main(args): ...@@ -411,7 +449,6 @@ def main(args):
) )
if args.use_ema: if args.use_ema:
accelerator.register_for_checkpointing(ema_model)
ema_model.to(accelerator.device) ema_model.to(accelerator.device)
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
......
import copy import copy
import os import os
import random import random
from typing import Iterable, Union from typing import Any, Dict, Iterable, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -57,6 +57,8 @@ class EMAModel: ...@@ -57,6 +57,8 @@ class EMAModel:
use_ema_warmup: bool = False, use_ema_warmup: bool = False,
inv_gamma: Union[float, int] = 1.0, inv_gamma: Union[float, int] = 1.0,
power: Union[float, int] = 2 / 3, power: Union[float, int] = 2 / 3,
model_cls: Optional[Any] = None,
model_config: Dict[str, Any] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -123,6 +125,35 @@ class EMAModel: ...@@ -123,6 +125,35 @@ class EMAModel:
self.power = power self.power = power
self.optimization_step = 0 self.optimization_step = 0
self.model_cls = model_cls
self.model_config = model_config
@classmethod
def from_pretrained(cls, path, model_cls) -> "EMAModel":
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
model = model_cls.from_pretrained(path)
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
ema_model.load_state_dict(ema_kwargs)
return ema_model
def save_pretrained(self, path):
if self.model_cls is None:
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
if self.model_config is None:
raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
model = self.model_cls.from_config(self.model_config)
state_dict = self.state_dict()
state_dict.pop("shadow_params", None)
state_dict.pop("collected_params", None)
model.register_to_config(**state_dict)
self.copy_to(model.parameters())
model.save_pretrained(path)
def get_decay(self, optimization_step: int) -> float: def get_decay(self, optimization_step: int) -> float:
""" """
Compute the decay factor for the exponential moving average. Compute the decay factor for the exponential moving average.
...@@ -184,7 +215,7 @@ class EMAModel: ...@@ -184,7 +215,7 @@ class EMAModel:
""" """
parameters = list(parameters) parameters = list(parameters)
for s_param, param in zip(self.shadow_params, parameters): for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.data) param.data.copy_(s_param.to(param.device).data)
def to(self, device=None, dtype=None) -> None: def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`. r"""Move internal buffers of the ExponentialMovingAverage to `device`.
...@@ -257,13 +288,15 @@ class EMAModel: ...@@ -257,13 +288,15 @@ class EMAModel:
if not isinstance(self.power, (float, int)): if not isinstance(self.power, (float, int)):
raise ValueError("Invalid power") raise ValueError("Invalid power")
self.shadow_params = state_dict["shadow_params"] shadow_params = state_dict.get("shadow_params", None)
if not isinstance(self.shadow_params, list): if shadow_params is not None:
raise ValueError("shadow_params must be a list") self.shadow_params = shadow_params
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): if not isinstance(self.shadow_params, list):
raise ValueError("shadow_params must all be Tensors") raise ValueError("shadow_params must be a list")
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
raise ValueError("shadow_params must all be Tensors")
self.collected_params = state_dict["collected_params"] self.collected_params = state_dict.get("collected_params", None)
if self.collected_params is not None: if self.collected_params is not None:
if not isinstance(self.collected_params, list): if not isinstance(self.collected_params, list):
raise ValueError("collected_params must be a list") raise ValueError("collected_params must be a list")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment