"vscode:/vscode.git/clone" did not exist on "b9e2f886cd6e9182f1bf1bf7421c6363956f94c5"
Unverified Commit 008b608f authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[train_text2image] Fix EMA and make it compatible with deepspeed. (#813)

* fix ema

* style

* add comment about copy

* style

* quality
parent 5afc2b60
import argparse import argparse
import copy
import logging import logging
import math import math
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Iterable, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -234,25 +233,17 @@ dataset_name_mapping = { ...@@ -234,25 +233,17 @@ dataset_name_mapping = {
} }
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel: class EMAModel:
""" """
Exponential Moving Average of models weights Exponential Moving Average of models weights
""" """
def __init__( def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
self, parameters = list(parameters)
model, self.shadow_params = [p.clone().detach() for p in parameters]
decay=0.9999,
device=None,
):
self.averaged_model = copy.deepcopy(model).eval()
self.averaged_model.requires_grad_(False)
self.decay = decay self.decay = decay
if device is not None:
self.averaged_model = self.averaged_model.to(device=device)
self.optimization_step = 0 self.optimization_step = 0
def get_decay(self, optimization_step): def get_decay(self, optimization_step):
...@@ -263,34 +254,47 @@ class EMAModel: ...@@ -263,34 +254,47 @@ class EMAModel:
return 1 - min(self.decay, value) return 1 - min(self.decay, value)
@torch.no_grad() @torch.no_grad()
def step(self, new_model): def step(self, parameters):
ema_state_dict = self.averaged_model.state_dict() parameters = list(parameters)
self.optimization_step += 1 self.optimization_step += 1
self.decay = self.get_decay(self.optimization_step) self.decay = self.get_decay(self.optimization_step)
for key, param in new_model.named_parameters(): for s_param, param in zip(self.shadow_params, parameters):
if isinstance(param, dict):
continue
try:
ema_param = ema_state_dict[key]
except KeyError:
ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
ema_state_dict[key] = ema_param
param = param.clone().detach().to(ema_param.dtype).to(ema_param.device)
if param.requires_grad: if param.requires_grad:
ema_state_dict[key].sub_(self.decay * (ema_param - param)) tmp = self.decay * (s_param - param)
s_param.sub_(tmp)
else: else:
ema_state_dict[key].copy_(param) s_param.copy_(param)
for key, param in new_model.named_buffers():
ema_state_dict[key] = param
self.averaged_model.load_state_dict(ema_state_dict, strict=False)
torch.cuda.empty_cache() torch.cuda.empty_cache()
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Copy current averaged parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = list(parameters)
for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.data)
def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Args:
device: like `device` argument to `torch.Tensor.to`
"""
# .to() on the tensors handles None correctly
self.shadow_params = [
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
for p in self.shadow_params
]
def main(): def main():
args = parse_args() args = parse_args()
...@@ -336,9 +340,6 @@ def main(): ...@@ -336,9 +340,6 @@ def main():
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
if args.use_ema:
ema_unet = EMAModel(unet)
# Freeze vae and text_encoder # Freeze vae and text_encoder
vae.requires_grad_(False) vae.requires_grad_(False)
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
...@@ -510,8 +511,9 @@ def main(): ...@@ -510,8 +511,9 @@ def main():
text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
# Move the ema_unet to gpu. # Create EMA for the unet.
ema_unet.averaged_model.to(accelerator.device) if args.use_ema:
ema_unet = EMAModel(unet.parameters())
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
...@@ -583,7 +585,7 @@ def main(): ...@@ -583,7 +585,7 @@ def main():
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
if args.use_ema: if args.use_ema:
ema_unet.step(unet) ema_unet.step(unet.parameters())
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step) accelerator.log({"train_loss": train_loss}, step=global_step)
...@@ -598,10 +600,14 @@ def main(): ...@@ -598,10 +600,14 @@ def main():
# Create the pipeline using the trained modules and save it. # Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline( pipeline = StableDiffusionPipeline(
text_encoder=text_encoder, text_encoder=text_encoder,
vae=vae, vae=vae,
unet=accelerator.unwrap_model(ema_unet.averaged_model if args.use_ema else unet), unet=unet,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=PNDMScheduler( scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment