"test/vscode:/vscode.git/clone" did not exist on "81da16f6d3bf4d8e24684e6e9a7be61b113de2bb"
Commit e2364931 authored by mashun1's avatar mashun1
Browse files

pixart-alpha

parents
Pipeline #861 canceled with stages
import os
from pathlib import Path
import sys
current_file_path = Path(__file__).resolve()
sys.path.insert(0, str(current_file_path.parent.parent))
from PIL import Image
import torch
from torchvision import transforms as T
import numpy as np
import json
from tqdm import tqdm
import argparse
import threading
from queue import Queue
from pathlib import Path
from torch.utils.data import DataLoader, RandomSampler
from accelerate import Accelerator
from torchvision.transforms.functional import InterpolationMode
from torchvision.datasets.folder import default_loader
from diffusion.model.t5 import T5Embedder
from diffusers.models import AutoencoderKL
from diffusion.data.datasets.InternalData import InternalData
from diffusion.utils.misc import SimpleTimer
from diffusion.utils.data_sampler import AspectRatioBatchSampler
from diffusion.data.builder import DATASETS
from diffusion.data import ASPECT_RATIO_512, ASPECT_RATIO_1024
def get_closest_ratio(height: float, width: float, ratios: dict):
aspect_ratio = height / width
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
return ratios[closest_ratio], float(closest_ratio)
@DATASETS.register_module()
class DatasetMS(InternalData):
def __init__(self, root, image_list_json=None, transform=None, resolution=1024, load_vae_feat=False, aspect_ratio_type=None, start_index=0, end_index=100000000, **kwargs):
if image_list_json is None:
image_list_json = ['data_info.json']
assert os.path.isabs(root), 'root must be a absolute path'
self.root = root
self.img_dir_name = 'InternalImgs' # need to change to according to your data structure
self.json_dir_name = 'InternalData' # need to change to according to your data structure
self.transform = transform
self.load_vae_feat = load_vae_feat
self.resolution = resolution
self.meta_data_clean = []
self.img_samples = []
self.txt_feat_samples = []
self.aspect_ratio = aspect_ratio_type
assert self.aspect_ratio in [ASPECT_RATIO_1024, ASPECT_RATIO_512]
self.ratio_index = {}
self.ratio_nums = {}
for k, v in self.aspect_ratio.items():
self.ratio_index[float(k)] = [] # used for self.getitem
self.ratio_nums[float(k)] = 0 # used for batch-sampler
image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json]
for json_file in image_list_json:
meta_data = self.load_json(os.path.join(self.root, 'partition', json_file))
meta_data_clean = [item for item in meta_data if item['ratio'] <= 4]
self.meta_data_clean.extend(meta_data_clean)
self.img_samples.extend([os.path.join(self.root.replace(self.json_dir_name, self.img_dir_name), item['path']) for item in meta_data_clean])
self.img_samples = self.img_samples[start_index: end_index]
# scan the dataset for ratio static
for i, info in enumerate(self.meta_data_clean[:len(self.meta_data_clean)//3]):
ori_h, ori_w = info['height'], info['width']
closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
self.ratio_nums[closest_ratio] += 1
if len(self.ratio_index[closest_ratio]) == 0:
self.ratio_index[closest_ratio].append(i)
# Set loader and extensions
if self.load_vae_feat:
raise ValueError("No VAE loader here")
self.loader = default_loader
def __getitem__(self, idx):
data_info = {}
for _ in range(20):
try:
img_path = self.img_samples[idx]
img = self.loader(img_path)
if self.transform:
img = self.transform(img)
# Calculate closest aspect ratio and resize & crop image[w, h]
if isinstance(img, Image.Image):
h, w = (img.size[1], img.size[0])
assert h, w == (self.meta_data_clean[idx]['height'], self.meta_data_clean[idx]['width'])
closest_size, closest_ratio = get_closest_ratio(h, w, self.aspect_ratio)
closest_size = list(map(lambda x: int(x), closest_size))
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB')),
T.Resize(closest_size, interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC
T.CenterCrop(closest_size),
T.ToTensor(),
T.Normalize([.5], [.5]),
])
img = transform(img)
data_info['img_hw'] = torch.tensor([h, w], dtype=torch.float32)
data_info['aspect_ratio'] = closest_ratio
# change the path according to your data structure
return img, '_'.join(self.img_samples[idx].rsplit('/', 2)[-2:]) # change from 'serial-number-of-dir/serial-number-of-image.png' ---> 'serial-number-of-dir_serial-number-of-image.png'
except Exception as e:
print(f"Error details: {str(e)}")
idx = np.random.randint(len(self))
raise RuntimeError('Too many bad data.')
def get_data_info(self, idx):
data_info = self.meta_data_clean[idx]
return {'height': data_info['height'], 'width': data_info['width']}
def extract_caption_t5_do(q):
while not q.empty():
item = q.get()
extract_caption_t5_job(item)
q.task_done()
def extract_caption_t5_job(item):
global mutex
global t5
global t5_save_dir
with torch.no_grad():
caption = item['prompt'].strip()
if isinstance(caption, str):
caption = [caption]
save_path = os.path.join(t5_save_dir, Path(item['path']).stem)
if os.path.exists(f"{save_path}.npz"):
return
try:
mutex.acquire()
caption_emb, emb_mask = t5.get_text_embeddings(caption)
mutex.release()
emb_dict = {
'caption_feature': caption_emb.float().cpu().data.numpy(),
'attention_mask': emb_mask.cpu().data.numpy(),
}
np.savez_compressed(save_path, **emb_dict)
except Exception as e:
print(e)
def extract_caption_t5():
global t5
global t5_save_dir
# global images_extension
t5 = T5Embedder(device="cuda", local_cache=True, cache_dir=f'{args.pretrained_models_dir}/t5_ckpts', model_max_length=120)
t5_save_dir = args.t5_save_root
os.makedirs(t5_save_dir, exist_ok=True)
train_data_json = json.load(open(args.json_path, 'r'))
train_data = train_data_json[args.start_index: args.end_index]
global mutex
mutex = threading.Lock()
jobs = Queue()
for item in tqdm(train_data):
jobs.put(item)
for _ in range(20):
worker = threading.Thread(target=extract_caption_t5_do, args=(jobs,))
worker.start()
jobs.join()
def extract_img_vae_do(q):
while not q.empty():
item = q.get()
extract_img_vae_job(item)
q.task_done()
def extract_img_vae_job(item):
return
def extract_img_vae():
vae = AutoencoderKL.from_pretrained(f'{args.pretrained_models_dir}/sd-vae-ft-ema').to(device)
train_data_json = json.load(open(args.json_path, 'r'))
image_names = set()
vae_save_root = f'{args.vae_save_root}/{image_resize}resolution'
os.umask(0o000) # file permission: 666; dir permission: 777
os.makedirs(vae_save_root, exist_ok=True)
vae_save_dir = os.path.join(vae_save_root, 'noflip')
os.makedirs(vae_save_dir, exist_ok=True)
for item in train_data_json:
image_name = item['path']
if image_name in image_names:
continue
image_names.add(image_name)
lines = sorted(image_names)
lines = lines[args.start_index: args.end_index]
_, images_extension = os.path.splitext(lines[0])
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB')),
T.Resize(image_resize), # Image.BICUBIC
T.CenterCrop(image_resize),
T.ToTensor(),
T.Normalize([.5], [.5]),
])
os.umask(0o000) # file permission: 666; dir permission: 777
for image_name in tqdm(lines):
save_path = os.path.join(vae_save_dir, Path(image_name).stem)
if os.path.exists(f"{save_path}.npy"):
continue
try:
img = Image.open(f'{args.dataset_root}/{image_name}')
img = transform(img).to(device)[None]
with torch.no_grad():
posterior = vae.encode(img).latent_dist
z = torch.cat([posterior.mean, posterior.std], dim=1).detach().cpu().numpy().squeeze()
np.save(save_path, z)
except Exception as e:
print(e)
print(image_name)
def save_results(results, paths, signature, work_dir):
timer = SimpleTimer(len(results), log_interval=100, desc="Saving Results")
# save to npy
new_paths = []
os.umask(0o000) # file permission: 666; dir permission: 777
for res, p in zip(results, paths):
file_name = p.split('.')[0] + '.npy'
new_folder = signature
save_folder = os.path.join(work_dir, new_folder)
if os.path.exists(save_folder):
raise FileExistsError(f"{save_folder} exists. BE careful not to overwrite your files. Comment this error raising for overwriting!!")
os.makedirs(save_folder, exist_ok=True)
new_paths.append(os.path.join(new_folder, file_name))
np.save(os.path.join(save_folder, file_name), res)
timer.log()
# save paths
with open(os.path.join(work_dir, f"VAE-{signature}.txt"), 'w') as f:
f.write('\n'.join(new_paths))
def inference(vae, dataloader, signature, work_dir):
timer = SimpleTimer(len(dataloader), log_interval=100, desc="VAE-Inference")
for batch in dataloader:
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=True):
posterior = vae.encode(batch[0]).latent_dist
results = torch.cat([posterior.mean, posterior.std], dim=1).detach().cpu().numpy()
path = batch[1]
save_results(results, path, signature=signature, work_dir=work_dir)
timer.log()
def extract_img_vae_multiscale(bs=1):
assert image_resize in [512, 1024]
work_dir = os.path.abspath(args.vae_save_root)
os.umask(0o000) # file permission: 666; dir permission: 777
os.makedirs(work_dir, exist_ok=True)
accelerator = Accelerator(mixed_precision='fp16')
vae = AutoencoderKL.from_pretrained(f'{args.pretrained_models_dir}/sd-vae-ft-ema').to(device)
signature = 'ms'
aspect_ratio_type = ASPECT_RATIO_1024 if image_resize == 1024 else ASPECT_RATIO_512
dataset = DatasetMS(args.dataset_root, image_list_json=[args.json_file], transform=None, sample_subset=None,
aspect_ratio_type=aspect_ratio_type, start_index=args.start_index, end_index=args.end_index)
# create AspectRatioBatchSampler
sampler = AspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset, batch_size=bs, aspect_ratios=dataset.aspect_ratio, ratio_nums=dataset.ratio_nums)
# create DataLoader
dataloader = DataLoader(dataset, batch_sampler=sampler, num_workers=13, pin_memory=True)
dataloader = accelerator.prepare(dataloader, )
inference(vae, dataloader, signature=signature, work_dir=work_dir)
accelerator.wait_for_everyone()
print('done')
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--multi_scale", action='store_true', default=False, help="multi-scale feature extraction")
parser.add_argument("--img_size", default=512, type=int, help="image scale for multi-scale feature extraction")
parser.add_argument('--start_index', default=0, type=int)
parser.add_argument('--end_index', default=1000000, type=int)
parser.add_argument('--json_path', type=str)
parser.add_argument('--t5_save_root', default='data/data_toy/caption_feature_wmask', type=str)
parser.add_argument('--vae_save_root', default='data/data_toy/img_vae_features', type=str)
parser.add_argument('--dataset_root', default='data/data_toy', type=str)
parser.add_argument('--pretrained_models_dir', default='output/pretrained_models', type=str)
### for multi-scale(ms) vae feauture extraction
parser.add_argument('--json_file', type=str)
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
image_resize = args.img_size
# prepare extracted caption t5 features for training
extract_caption_t5()
# prepare extracted image vae features for training
if args.multi_scale:
print(f'Extracting Multi-scale Image Resolution based on {image_resize}')
extract_img_vae_multiscale(bs=1) # recommend bs = 1 for AspectRatioBatchSampler
else:
print(f'Extracting Single Image Resolution {image_resize}')
extract_img_vae()
\ No newline at end of file
CUDA_VISIBLE_DEVICES=5,6,7 python -m torch.distributed.launch --nproc_per_node=3 \
--master_port=26662 train_scripts/train_controlnet.py \
configs/pixart_app_config/PixArt_xl2_img1024_controlHed_Half.py \
--work-dir output/debug
\ No newline at end of file
import os
import sys
import types
from pathlib import Path
current_file_path = Path(__file__).resolve()
sys.path.insert(0, str(current_file_path.parent.parent))
import argparse
import datetime
import time
import warnings
warnings.filterwarnings("ignore") # ignore warning
import torch
import torch.nn as nn
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.utils import DistributedType
from diffusers.models import AutoencoderKL
from torch.utils.data import RandomSampler
from mmcv.runner import LogBuffer
from copy import deepcopy
from PIL import Image
import numpy as np
from diffusion import IDDPM
from diffusion.utils.checkpoint import save_checkpoint, load_checkpoint
from diffusion.utils.dist_utils import synchronize, get_world_size, clip_grad_norm_
from diffusion.data.builder import build_dataset, build_dataloader, set_data_root
from diffusion.model.builder import build_model
from diffusion.utils.logger import get_root_logger
from diffusion.utils.misc import set_random_seed, read_config, init_random_seed, DebugUnderflowOverflow
from diffusion.utils.optimizer import build_optimizer, auto_scale_lr
from diffusion.utils.lr_scheduler import build_lr_scheduler
from diffusion.utils.data_sampler import AspectRatioBatchSampler, BalancedAspectRatioBatchSampler
def set_fsdp_env():
os.environ["ACCELERATE_USE_FSDP"] = 'true'
os.environ["FSDP_AUTO_WRAP_POLICY"] = 'TRANSFORMER_BASED_WRAP'
os.environ["FSDP_BACKWARD_PREFETCH"] = 'BACKWARD_PRE'
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = 'PixArtBlock'
def ema_update(model_dest: nn.Module, model_src: nn.Module, rate):
param_dict_src = dict(model_src.named_parameters())
for p_name, p_dest in model_dest.named_parameters():
p_src = param_dict_src[p_name]
assert p_src is not p_dest
p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
def train():
if config.get('debug_nan', False):
DebugUnderflowOverflow(model)
logger.info('NaN debugger registered. Start to detect overflow during training.')
time_start, last_tic = time.time(), time.time()
log_buffer = LogBuffer()
start_step = start_epoch * len(train_dataloader)
global_step = 0
total_steps = len(train_dataloader) * config.num_epochs
# load_vae_feat = getattr(train_dataloader.dataset, 'load_vae_feat', False)
# Now you train the model
for epoch in range(start_epoch + 1, config.num_epochs + 1):
data_time_start= time.time()
data_time_all = 0
for step, batch in enumerate(train_dataloader):
data_time_all += time.time() - data_time_start
# if load_vae_feat:
z = batch[0]
# else:
# with torch.no_grad():
# with torch.cuda.amp.autocast(enabled=config.mixed_precision == 'fp16'):
# posterior = vae.encode(batch[0]).latent_dist
# if config.sample_posterior:
# z = posterior.sample()
# else:
# z = posterior.mode()
clean_images = z * config.scale_factor
y = batch[1]
y_mask = batch[2]
data_info = batch[3]
# Sample a random timestep for each image
bs = clean_images.shape[0]
timesteps = torch.randint(0, config.train_sampling_steps, (bs,), device=clean_images.device).long()
grad_norm = None
with accelerator.accumulate(model):
# Predict the noise residual
optimizer.zero_grad()
loss_term = train_diffusion.training_losses(model, clean_images, timesteps, model_kwargs=dict(y=y, mask=y_mask, data_info=data_info))
loss = loss_term['loss'].mean()
accelerator.backward(loss)
if accelerator.sync_gradients:
grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.gradient_clip)
optimizer.step()
lr_scheduler.step()
if accelerator.sync_gradients:
ema_update(model_ema, model, config.ema_rate)
lr = lr_scheduler.get_last_lr()[0]
logs = {args.loss_report_name: accelerator.gather(loss).mean().item()}
if grad_norm is not None:
logs.update(grad_norm=accelerator.gather(grad_norm).mean().item())
log_buffer.update(logs)
# logging on terminal
if (step + 1) % config.log_interval == 0 or (step + 1) == 1:
t = (time.time() - last_tic) / config.log_interval
t_d = data_time_all / config.log_interval
avg_time = (time.time() - time_start) / (global_step + 1)
eta = str(datetime.timedelta(seconds=int(avg_time * (total_steps - start_step - global_step - 1))))
eta_epoch = str(datetime.timedelta(seconds=int(avg_time * (len(train_dataloader) - step - 1))))
# avg_loss = sum(loss_buffer) / len(loss_buffer)
log_buffer.average()
info = f"Step/Epoch [{(epoch-1)*len(train_dataloader)+step+1}/{epoch}][{step + 1}/{len(train_dataloader)}]:total_eta: {eta}, " \
f"epoch_eta:{eta_epoch}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:({model.module.h}, {model.module.w}), "
info += ', '.join([f"{k}:{v:.4f}" for k, v in log_buffer.output.items()])
logger.info(info)
last_tic = time.time()
log_buffer.clear()
data_time_all = 0
logs.update(lr=lr)
accelerator.log(logs, step=global_step + start_step)
global_step += 1
data_time_start= time.time()
synchronize()
if accelerator.is_main_process:
if ((epoch - 1) * len(train_dataloader) + step + 1) % config.save_model_steps == 0:
os.umask(0o000)
save_checkpoint(os.path.join(config.work_dir, 'checkpoints'),
epoch=epoch,
step=(epoch - 1) * len(train_dataloader) + step + 1,
model=accelerator.unwrap_model(model),
model_ema=accelerator.unwrap_model(model_ema),
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
synchronize()
synchronize()
if accelerator.is_main_process:
if epoch % config.save_model_epochs == 0 or epoch == config.num_epochs:
os.umask(0o000)
save_checkpoint(os.path.join(config.output_dir, 'checkpoints'),
epoch=epoch,
step=(epoch - 1) * len(train_dataloader) + step + 1,
model=accelerator.unwrap_model(model),
model_ema=accelerator.unwrap_model(model_ema),
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
########### EVAL ###################
if epoch % config.save_image_epochs == 0 or epoch == config.num_epochs:
if config.validation_prompts is not None:
logger.info("Running inference for collecting generated images...")
assert config.eval_sampler in ['iddpm', 'dpm-solver', 'sa-solver']
sample_steps_dict = {'iddpm': 100, 'dpm-solver': 20, 'sa-solver': 25}
sample_steps = config.eval_steps if config.eval_steps != -1 else sample_steps_dict[config.eval_sampler]
# base_ratios = eval(f'ASPECT_RATIO_{config.image_size}_TEST')
eval_dir = os.path.join(config.output_dir, 'eval')
os.makedirs(eval_dir, exist_ok=True)
save_path = os.path.join(eval_dir, f'{epoch}_{global_step}.png')
model.eval()
images = []
# device = t5.device
for ip, prompt in enumerate(config.validation_prompts):
prompts = [prompt]
# prompts = []
# prompt_clean, _, hw, ar, custom_hw = prepare_prompt_ar(prompt, base_ratios, device=device, show=False) # ar for aspect ratio
# if config.image_size == 1024:
# latent_size_h, latent_size_w = int(hw[0, 0] // 8), int(hw[0, 1] // 8)
# else:
# hw = torch.tensor([[config.image_size, config.image_size]], dtype=torch.float, device=device).repeat(bs, 1)
# ar = torch.tensor([[1.]], device=device).repeat(bs, 1)
# latent_size_h, latent_size_w = latent_size, latent_size
# prompts.append(prompt_clean.strip())
null_y = model.module.y_embedder.y_embedding[None].repeat(len(prompts), 1, 1)[:, None]
with torch.no_grad():
caption_embs, emb_masks, len_prompts = val_txt_embs[ip]
# caption_embs, emb_masks = t5.get_text_embeddings(prompts)
# caption_embs = caption_embs.float()[:, None]
print(f'finish embedding')
n = len_prompts
if config.eval_sampler == 'iddpm':
# Create sampling noise:
z = torch.randn(n, 4, latent_size_h, latent_size_w, device=device).repeat(2, 1, 1, 1)
model_kwargs = dict(y=torch.cat([caption_embs, null_y]),
cfg_scale=config.cfg_scale, data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks)
diffusion = IDDPM(str(sample_steps))
# Sample images:
samples = diffusion.p_sample_loop(
model.module.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True,
device=device
)
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
elif config.eval_sampler == 'dpm-solver':
# Create sampling noise:
z = torch.randn(n, 4, latent_size_h, latent_size_w, device=device)
model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks)
dpm_solver = DPMS(model.module.forward_with_dpmsolver,
condition=caption_embs,
uncondition=null_y,
cfg_scale=config.cfg_scale,
model_kwargs=model_kwargs)
samples = dpm_solver.sample(
z,
steps=sample_steps,
order=2,
skip_type="time_uniform",
method="multistep",
)
elif config.eval_sampler == 'sa-solver':
# Create sampling noise:
model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks)
sa_solver = SASolverSampler(model.module.forward_with_dpmsolver, device=device)
samples = sa_solver.sample(
S=25,
batch_size=n,
shape=(4, latent_size_h, latent_size_w),
eta=1,
conditioning=caption_embs,
unconditional_conditioning=null_y,
unconditional_guidance_scale=config.cfg_scale,
model_kwargs=model_kwargs,
)[0]
samples = vae.decode(samples / 0.18215).sample
# decode image
image = make_grid(samples, nrow=1, normalize=True, value_range=(-1, 1))
image = image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
image = Image.fromarray(image)
images.append(image)
image_grid = make_image_grid(images, 2, len(images)//2)
image_grid.save(save_path)
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
elif tracker.name == "comet_ml":
logger.info('Logging validation images')
tracker.writer.log_image(image_grid, name=f"{epoch}", step=global_step)
else:
logger.warn(f"image logging not implemented for {tracker.name}")
del images, image, samples, image_grid
torch.cuda.empty_cache()
model.train()
synchronize()
def parse_args():
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument("config", type=str, help="config")
parser.add_argument("--cloud", action='store_true', default=False, help="cloud or local machine")
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument('--resume-from', help='the dir to resume the training')
parser.add_argument('--load-from', default=None, help='the dir to load a ckpt for training')
parser.add_argument('--local-rank', type=int, default=-1)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--debug', action='store_true')
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="text2image-fine-tune",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
parser.add_argument("--loss_report_name", type=str, default="loss")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
config = read_config(args.config)
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
config.work_dir = args.work_dir
if args.cloud:
config.data_root = '/data/data'
if args.resume_from is not None:
config.load_from = None
config.resume_from = dict(
checkpoint=args.resume_from,
load_ema=False,
resume_optimizer=True,
resume_lr_scheduler=True)
if args.debug:
config.log_interval = 1
config.train_batch_size = 8
config.valid_num = 100
os.umask(0o000)
config.output_dir = os.path.join(config.work_dir,
f"""{config.model}_{config.dataset_alias}_{config.image_size}_batch{config.train_batch_size}_{config.lr_schedule}_lr{config.optimizer['lr']}_warmup{config.lr_schedule_args['num_warmup_steps']}_gas{config.gradient_accumulation_steps}""")
os.makedirs(config.output_dir, exist_ok=True)
init_handler = InitProcessGroupKwargs()
init_handler.timeout = datetime.timedelta(seconds=5400) # change timeout to avoid a strange NCCL bug
# Initialize accelerator and tensorboard logging
if config.use_fsdp:
init_train = 'FSDP'
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
set_fsdp_env()
fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),)
else:
init_train = 'DDP'
fsdp_plugin = None
even_batches = True
if config.multi_scale:
even_batches=False,
if args.report_to == "comet_ml":
import comet_ml
comet_ml.init(
project_name=args.tracker_project_name,
)
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
log_with=args.report_to,
project_dir=os.path.join(config.output_dir, "logs"),
fsdp_plugin=fsdp_plugin,
even_batches=even_batches,
kwargs_handlers=[init_handler]
)
logger = get_root_logger(os.path.join(config.output_dir, 'train_log.log'))
config.seed = init_random_seed(config.get('seed', None))
set_random_seed(config.seed)
if accelerator.is_main_process:
config.dump(os.path.join(config.output_dir, 'config.py'))
logger.info(f"Config: \n{config.pretty_text}")
logger.info(f"World_size: {get_world_size()}, seed: {config.seed}")
logger.info(f"Initializing: {init_train} for training")
image_size = config.image_size # @param [256, 512]
latent_size = int(image_size) // 8
pred_sigma = getattr(config, 'pred_sigma', True)
learn_sigma = getattr(config, 'learn_sigma', True) and pred_sigma
model_kwargs={"window_block_indexes": config.window_block_indexes, "window_size": config.window_size,
"use_rel_pos": config.use_rel_pos, "lewei_scale": config.lewei_scale, 'config':config,
'model_max_length': config.model_max_length}
if config.validation_prompts is not None:
logger.info('Precompute validation prompt embeddings')
from diffusion.model.utils import prepare_prompt_ar
from diffusion import IDDPM, DPMS, SASolverSampler
from diffusion.model.t5 import T5Embedder
from diffusion.data.datasets import ASPECT_RATIO_256_TEST, ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST
from diffusers.utils import make_image_grid
from torchvision.utils import make_grid
t5 = T5Embedder(device="cuda", local_cache=True, cache_dir='output/pretrained_models/t5_ckpts', torch_dtype=torch.float)
device = t5.device
base_ratios = eval(f'ASPECT_RATIO_{config.image_size}_TEST')
pbs = 1
val_txt_embs = []
for prompt in config.validation_prompts:
prompts = []
prompt_clean, _, hw, ar, custom_hw = prepare_prompt_ar(prompt, base_ratios, device=device, show=False) # ar for aspect ratio
if config.image_size == 1024:
latent_size_h, latent_size_w = int(hw[0, 0] // 8), int(hw[0, 1] // 8)
else:
hw = torch.tensor([[config.image_size, config.image_size]], dtype=torch.float, device=device).repeat(pbs, 1)
ar = torch.tensor([[1.]], device=device).repeat(pbs, 1)
latent_size_h, latent_size_w = latent_size, latent_size
prompts.append(prompt_clean.strip())
with torch.no_grad():
caption_embs, emb_masks = t5.get_text_embeddings(prompts)
caption_embs = caption_embs.float()[:, None]
val_txt_embs.append([caption_embs, emb_masks, len(prompts)])
del t5
import gc # garbage collect library
gc.collect()
torch.cuda.empty_cache()
logger.info('[ DONE ]')
# build models
train_diffusion = IDDPM(str(config.train_sampling_steps), learn_sigma=learn_sigma, pred_sigma=pred_sigma, snr=config.snr_loss)
model = build_model(config.model,
config.grad_checkpointing,
config.get('fp32_attention', False),
input_size=latent_size,
learn_sigma=learn_sigma,
pred_sigma=pred_sigma,
**model_kwargs).train()
logger.info(f"{model.__class__.__name__} Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
logger.info(f"T5 max token length: {config.model_max_length}")
model_ema = deepcopy(model).eval()
if config.load_from is not None:
if args.load_from is not None:
config.load_from = args.load_from
missing, unexpected = load_checkpoint(config.load_from, model, load_ema=config.get('load_ema', False))
logger.warning(f'Missing keys: {missing}')
logger.warning(f'Unexpected keys: {unexpected}')
ema_update(model_ema, model, 0.)
if not config.data.load_vae_feat:
vae = AutoencoderKL.from_pretrained(config.vae_pretrained).cuda()
# prepare for FSDP clip grad norm calculation
if accelerator.distributed_type == DistributedType.FSDP:
for m in accelerator._models:
m.clip_grad_norm_ = types.MethodType(clip_grad_norm_, m)
# build dataloader
set_data_root(config.data_root)
dataset = build_dataset(config.data, resolution=image_size, aspect_ratio_type=config.aspect_ratio_type)
if config.multi_scale:
batch_sampler = AspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset,
batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio, drop_last=True,
ratio_nums=dataset.ratio_nums, config=config, valid_num=config.valid_num)
# used for balanced sampling
# batch_sampler = BalancedAspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset,
# batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio,
# ratio_nums=dataset.ratio_nums)
train_dataloader = build_dataloader(dataset, batch_sampler=batch_sampler, num_workers=config.num_workers)
else:
logger.info(f'Batch size {config.train_batch_size}')
train_dataloader = build_dataloader(dataset, num_workers=config.num_workers, batch_size=config.train_batch_size, shuffle=True)
# build optimizer and lr scheduler
lr_scale_ratio = 1
if config.get('auto_lr', None):
lr_scale_ratio = auto_scale_lr(config.train_batch_size * get_world_size() * config.gradient_accumulation_steps,
config.optimizer, **config.auto_lr)
optimizer = build_optimizer(model, config.optimizer)
lr_scheduler = build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio)
timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())
if accelerator.is_main_process:
tracker_config = dict(vars(config))
accelerator.init_trackers(args.tracker_project_name, tracker_config)
accelerator.get_tracker("comet_ml").writer.add_tags([config.model,
config.dataset_alias,
config.image_size,
config.lr_schedule,
f'bs{config.train_batch_size}',
f'gs{config.gradient_accumulation_steps}'
])
start_epoch = 0
if config.resume_from is not None and config.resume_from['checkpoint'] is not None:
start_epoch, missing, unexpected = load_checkpoint(**config.resume_from,
model=model,
model_ema=model_ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
logger.warning(f'Missing keys: {missing}')
logger.warning(f'Unexpected keys: {unexpected}')
# Prepare everything
# There is no specific order to remember, you just need to unpack the
# objects in the same order you gave them to the prepare method.
model, model_ema = accelerator.prepare(model, model_ema)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
train()
import argparse
import datetime
import os
import sys
import time
import types
import warnings
import xformers
from copy import deepcopy
from pathlib import Path
warnings.filterwarnings("ignore") # ignore warning
current_file_path = Path(__file__).resolve()
sys.path.insert(0, str(current_file_path.parent.parent))
import torch
import torch.nn as nn
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.utils import DistributedType
from diffusers.models import AutoencoderKL
from mmcv.runner import LogBuffer
from torch.utils.data import RandomSampler
from diffusion import IDDPM
from diffusion.data.builder import build_dataset, build_dataloader, set_data_root
from diffusion.model.builder import build_model
from diffusion.utils.checkpoint import save_checkpoint, load_checkpoint
from diffusion.utils.data_sampler import AspectRatioBatchSampler, BalancedAspectRatioBatchSampler
from diffusion.utils.dist_utils import synchronize, get_world_size, clip_grad_norm_
from diffusion.utils.logger import get_root_logger
from diffusion.utils.lr_scheduler import build_lr_scheduler
from diffusion.utils.misc import set_random_seed, read_config, init_random_seed, DebugUnderflowOverflow
from diffusion.utils.optimizer import build_optimizer, auto_scale_lr
def set_fsdp_env():
os.environ["ACCELERATE_USE_FSDP"] = 'true'
os.environ["FSDP_AUTO_WRAP_POLICY"] = 'TRANSFORMER_BASED_WRAP'
os.environ["FSDP_BACKWARD_PREFETCH"] = 'BACKWARD_PRE'
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = 'PixArtBlock'
def ema_update(model_dest: nn.Module, model_src: nn.Module, rate):
param_dict_src = dict(model_src.named_parameters())
for p_name, p_dest in model_dest.named_parameters():
p_src = param_dict_src[p_name]
assert p_src is not p_dest
p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
def train():
if config.get('debug_nan', False):
DebugUnderflowOverflow(model)
logger.info('NaN debugger registered. Start to detect overflow during training.')
time_start, last_tic = time.time(), time.time()
log_buffer = LogBuffer()
start_step = start_epoch * len(train_dataloader)
global_step = 0
total_steps = len(train_dataloader) * config.num_epochs
load_vae_feat = getattr(train_dataloader.dataset, 'load_vae_feat', False)
# Now you train the model
for epoch in range(start_epoch + 1, config.num_epochs + 1):
data_time_start= time.time()
data_time_all = 0
for step, batch in enumerate(train_dataloader):
data_time_all += time.time() - data_time_start
if load_vae_feat:
z = batch[0]
else:
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=config.mixed_precision == 'fp16'):
posterior = vae.encode(batch[0]).latent_dist
if config.sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
clean_images = z * config.scale_factor
y = batch[1]
y_mask = batch[2]
data_info = batch[3]
# Sample a random timestep for each image
bs = clean_images.shape[0]
timesteps = torch.randint(0, config.train_sampling_steps, (bs,), device=clean_images.device).long()
grad_norm = None
with accelerator.accumulate(model):
# Predict the noise residual
optimizer.zero_grad()
loss_term = train_diffusion.training_losses(model, clean_images, timesteps, model_kwargs=dict(y=y, mask=y_mask, data_info=data_info))
loss = loss_term['loss'].mean()
accelerator.backward(loss)
if accelerator.sync_gradients:
grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.gradient_clip)
optimizer.step()
lr_scheduler.step()
if accelerator.sync_gradients:
ema_update(model_ema, model, config.ema_rate)
lr = lr_scheduler.get_last_lr()[0]
logs = {args.loss_report_name: accelerator.gather(loss).mean().item()}
if grad_norm is not None:
logs.update(grad_norm=accelerator.gather(grad_norm).mean().item())
log_buffer.update(logs)
if (step + 1) % config.log_interval == 0 or (step + 1) == 1:
t = (time.time() - last_tic) / config.log_interval
t_d = data_time_all / config.log_interval
avg_time = (time.time() - time_start) / (global_step + 1)
eta = str(datetime.timedelta(seconds=int(avg_time * (total_steps - start_step - global_step - 1))))
eta_epoch = str(datetime.timedelta(seconds=int(avg_time * (len(train_dataloader) - step - 1))))
# avg_loss = sum(loss_buffer) / len(loss_buffer)
log_buffer.average()
info = f"Step/Epoch [{(epoch-1)*len(train_dataloader)+step+1}/{epoch}][{step + 1}/{len(train_dataloader)}]:total_eta: {eta}, " \
f"epoch_eta:{eta_epoch}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:({model.module.h}, {model.module.w}), "
info += ', '.join([f"{k}:{v:.4f}" for k, v in log_buffer.output.items()])
logger.info(info)
last_tic = time.time()
log_buffer.clear()
data_time_all = 0
logs.update(lr=lr)
accelerator.log(logs, step=global_step + start_step)
global_step += 1
data_time_start= time.time()
synchronize()
if accelerator.is_main_process:
if ((epoch - 1) * len(train_dataloader) + step + 1) % config.save_model_steps == 0:
os.umask(0o000)
save_checkpoint(os.path.join(config.work_dir, 'checkpoints'),
epoch=epoch,
step=(epoch - 1) * len(train_dataloader) + step + 1,
model=accelerator.unwrap_model(model),
model_ema=accelerator.unwrap_model(model_ema),
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
synchronize()
synchronize()
if accelerator.is_main_process:
if epoch % config.save_model_epochs == 0 or epoch == config.num_epochs:
os.umask(0o000)
save_checkpoint(os.path.join(config.work_dir, 'checkpoints'),
epoch=epoch,
step=(epoch - 1) * len(train_dataloader) + step + 1,
model=accelerator.unwrap_model(model),
model_ema=accelerator.unwrap_model(model_ema),
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
synchronize()
def parse_args():
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument("config", type=str, help="config")
parser.add_argument("--cloud", action='store_true', default=False, help="cloud or local machine")
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument('--resume-from', help='the dir to resume the training')
parser.add_argument('--load-from', default=None, help='the dir to load a ckpt for training')
parser.add_argument('--local-rank', type=int, default=-1)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--debug', action='store_true')
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="text2image-fine-tune",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
parser.add_argument("--loss_report_name", type=str, default="loss")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
config = read_config(args.config)
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
config.work_dir = args.work_dir
if args.cloud:
config.data_root = '/data/data'
if args.resume_from is not None:
config.load_from = None
config.resume_from = dict(
checkpoint=args.resume_from,
load_ema=False,
resume_optimizer=True,
resume_lr_scheduler=True)
if args.debug:
config.log_interval = 1
config.train_batch_size = 8
config.valid_num = 100
os.umask(0o000)
os.makedirs(config.work_dir, exist_ok=True)
init_handler = InitProcessGroupKwargs()
init_handler.timeout = datetime.timedelta(seconds=5400) # change timeout to avoid a strange NCCL bug
# Initialize accelerator and tensorboard logging
if config.use_fsdp:
init_train = 'FSDP'
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
set_fsdp_env()
fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),)
else:
init_train = 'DDP'
fsdp_plugin = None
even_batches = True
if config.multi_scale:
even_batches=False,
# config.mixed_precision = 'no'
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
log_with=args.report_to,
project_dir=os.path.join(config.work_dir, "logs"),
fsdp_plugin=fsdp_plugin,
even_batches=even_batches,
kwargs_handlers=[init_handler]
)
logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
config.seed = init_random_seed(config.get('seed', None))
set_random_seed(config.seed)
if accelerator.is_main_process:
config.dump(os.path.join(config.work_dir, 'config.py'))
logger.info(f"Config: \n{config.pretty_text}")
logger.info(f"World_size: {get_world_size()}, seed: {config.seed}")
logger.info(f"Initializing: {init_train} for training")
image_size = config.image_size # @param [256, 512, 1024]
latent_size = int(image_size) // 8
pred_sigma = getattr(config, 'pred_sigma', True)
learn_sigma = getattr(config, 'learn_sigma', True) and pred_sigma
model_kwargs={"window_block_indexes": config.window_block_indexes, "window_size": config.window_size,
"use_rel_pos": config.use_rel_pos, "lewei_scale": config.lewei_scale, 'config':config,
'model_max_length': config.model_max_length}
# build models
train_diffusion = IDDPM(str(config.train_sampling_steps), learn_sigma=learn_sigma, pred_sigma=pred_sigma, snr=config.snr_loss)
model = build_model(config.model,
config.grad_checkpointing,
config.get('fp32_attention', False),
input_size=latent_size,
learn_sigma=learn_sigma,
pred_sigma=pred_sigma,
**model_kwargs).train()
logger.info(f"{model.__class__.__name__} Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
model_ema = deepcopy(model).eval()
if config.load_from is not None:
if args.load_from is not None:
config.load_from = args.load_from
missing, unexpected = load_checkpoint(config.load_from, model, load_ema=config.get('load_ema', False))
logger.warning(f'Missing keys: {missing}')
logger.warning(f'Unexpected keys: {unexpected}')
ema_update(model_ema, model, 0.)
if not config.data.load_vae_feat:
vae = AutoencoderKL.from_pretrained(config.vae_pretrained).cuda()
# prepare for FSDP clip grad norm calculation
if accelerator.distributed_type == DistributedType.FSDP:
for m in accelerator._models:
m.clip_grad_norm_ = types.MethodType(clip_grad_norm_, m)
# build dataloader
set_data_root(config.data_root)
dataset = build_dataset(config.data, resolution=image_size, aspect_ratio_type=config.aspect_ratio_type)
if config.multi_scale:
batch_sampler = AspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset,
batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio, drop_last=True,
ratio_nums=dataset.ratio_nums, config=config, valid_num=config.valid_num)
# used for balanced sampling
# batch_sampler = BalancedAspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset,
# batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio,
# ratio_nums=dataset.ratio_nums)
train_dataloader = build_dataloader(dataset, batch_sampler=batch_sampler, num_workers=config.num_workers)
else:
train_dataloader = build_dataloader(dataset, num_workers=config.num_workers, batch_size=config.train_batch_size, shuffle=True)
# build optimizer and lr scheduler
lr_scale_ratio = 1
if config.get('auto_lr', None):
lr_scale_ratio = auto_scale_lr(config.train_batch_size * get_world_size() * config.gradient_accumulation_steps,
config.optimizer, **config.auto_lr)
optimizer = build_optimizer(model, config.optimizer)
lr_scheduler = build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio)
timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())
if accelerator.is_main_process:
tracker_config = dict(vars(config))
try:
accelerator.init_trackers(args.tracker_project_name, tracker_config)
except:
accelerator.init_trackers(f"tb_{timestamp}")
start_epoch = 0
if config.resume_from is not None and config.resume_from['checkpoint'] is not None:
start_epoch, missing, unexpected = load_checkpoint(**config.resume_from,
model=model,
model_ema=model_ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
logger.warning(f'Missing keys: {missing}')
logger.warning(f'Unexpected keys: {unexpected}')
# Prepare everything
# There is no specific order to remember, you just need to unpack the
# objects in the same order you gave them to the prepare method.
model, model_ema = accelerator.prepare(model, model_ema)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
train()
import argparse
import datetime
import os
import sys
import time
import types
import warnings
from pathlib import Path
current_file_path = Path(__file__).resolve()
sys.path.insert(0, str(current_file_path.parent.parent))
import torch
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.utils import DistributedType
from mmcv.runner import LogBuffer
from torch.utils.data import RandomSampler
from diffusion import IDDPM
from diffusion.data.builder import build_dataset, build_dataloader, set_data_root
from diffusion.model.builder import build_model
from diffusion.model.nets import PixArtMS, ControlPixArtHalf, ControlPixArtMSHalf
from diffusion.utils.checkpoint import save_checkpoint, load_checkpoint
from diffusion.utils.data_sampler import AspectRatioBatchSampler, BalancedAspectRatioBatchSampler
from diffusion.utils.dist_utils import synchronize, get_world_size, clip_grad_norm_
from diffusion.utils.logger import get_root_logger
from diffusion.utils.lr_scheduler import build_lr_scheduler
from diffusion.utils.misc import set_random_seed, read_config, init_random_seed, DebugUnderflowOverflow
from diffusion.utils.optimizer import build_optimizer, auto_scale_lr
warnings.filterwarnings("ignore") # ignore warning
def set_fsdp_env():
os.environ["ACCELERATE_USE_FSDP"] = 'true'
os.environ["FSDP_AUTO_WRAP_POLICY"] = 'TRANSFORMER_BASED_WRAP'
os.environ["FSDP_BACKWARD_PREFETCH"] = 'BACKWARD_PRE'
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = 'PixArtBlock'
def train():
if config.get('debug_nan', False):
DebugUnderflowOverflow(model)
logger.info('NaN debugger registered. Start to detect overflow during training.')
time_start, last_tic = time.time(), time.time()
log_buffer = LogBuffer()
start_step = start_epoch * len(train_dataloader)
global_step = 0
total_steps = len(train_dataloader) * config.num_epochs
load_vae_feat = getattr(train_dataloader.dataset, 'load_vae_feat', False)
if not load_vae_feat:
raise ValueError("Only support load vae features for now.")
# Now you train the model
for epoch in range(start_epoch + 1, config.num_epochs + 1):
data_time_start = time.time()
data_time_all = 0
for step, batch in enumerate(train_dataloader):
data_time_all += time.time() - data_time_start
z = batch[0] # 4 x 4 x 128 x 128 z:vae output, 3x1024x1024->vae->4x128x128
clean_images = z * config.scale_factor # vae needed scale factor
y = batch[1] # 4 x 1 x 120 x 4096 # T5 extracted feature of caption, 120 token, 4096
y_mask = batch[2] # 4 x 1 x 1 x 120 # caption indicate whether valid
data_info = batch[3]
# Sample a random timestep for each image
bs = clean_images.shape[0]
timesteps = torch.randint(0, config.train_sampling_steps, (bs,), device=clean_images.device).long()
grad_norm = None
with accelerator.accumulate(model):
# Predict the noise residual
optimizer.zero_grad()
loss_term = train_diffusion.training_losses(model, clean_images, timesteps, model_kwargs=dict(y=y, mask=y_mask, data_info=data_info, c=data_info['condition'] * config.scale_factor))
loss = loss_term['loss'].mean()
accelerator.backward(loss)
if accelerator.sync_gradients:
grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.gradient_clip)
optimizer.step()
lr_scheduler.step()
lr = lr_scheduler.get_last_lr()[0]
logs = {"loss": accelerator.gather(loss).mean().item()}
if grad_norm is not None:
logs.update(grad_norm=accelerator.gather(grad_norm).mean().item())
log_buffer.update(logs)
if (step + 1) % config.log_interval == 0 or (step + 1) == 1:
t = (time.time() - last_tic) / config.log_interval
t_d = data_time_all / config.log_interval
avg_time = (time.time() - time_start) / (global_step + 1)
eta = str(datetime.timedelta(seconds=int(avg_time * (total_steps - start_step - global_step - 1))))
eta_epoch = str(datetime.timedelta(seconds=int(avg_time * (len(train_dataloader) - step - 1))))
# avg_loss = sum(loss_buffer) / len(loss_buffer)
log_buffer.average()
info = f"Step/Epoch [{(epoch - 1) * len(train_dataloader) + step + 1}/{epoch}][{step + 1}/{len(train_dataloader)}]:total_eta: {eta}, " \
f"epoch_eta:{eta_epoch}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:({data_info['img_hw'][0][0].item()}, {data_info['img_hw'][0][1].item()}), "
info += ', '.join([f"{k}:{v:.4f}" for k, v in log_buffer.output.items()])
logger.info(info)
last_tic = time.time()
log_buffer.clear()
data_time_all = 0
logs.update(lr=lr)
accelerator.log(logs, step=global_step + start_step)
if (global_step + 1) % 1000 == 0 and config.s3_work_dir is not None:
logger.info(f"s3_work_dir: {config.s3_work_dir}")
global_step += 1
data_time_start = time.time()
synchronize()
if accelerator.is_main_process:
if ((epoch - 1) * len(train_dataloader) + step + 1) % config.save_model_steps == 0:
os.umask(0o000) # file permission: 666; dir permission: 777
save_checkpoint(os.path.join(config.work_dir, 'checkpoints'),
epoch=epoch,
step=(epoch - 1) * len(train_dataloader) + step + 1,
model=accelerator.unwrap_model(model),
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
synchronize()
synchronize()
# After each epoch you optionally sample some demo images with evaluate() and save the model
if accelerator.is_main_process:
if epoch % config.save_model_epochs == 0 or epoch == config.num_epochs:
os.umask(0o000) # file permission: 666; dir permission: 777
save_checkpoint(os.path.join(config.work_dir, 'checkpoints'),
epoch=epoch,
step=(epoch - 1) * len(train_dataloader) + step + 1,
model=accelerator.unwrap_model(model),
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
synchronize()
def parse_args():
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument("config", type=str, help="config")
parser.add_argument("--cloud", action='store_true', default=False, help="cloud or local machine")
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument('--resume_from', help='the dir to save logs and models')
parser.add_argument('--local-rank', type=int, default=-1)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--debug', action='store_true')
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="text2image-fine-tune",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--data_root', type=str, default=None)
parser.add_argument('--resume_optimizer', action='store_true')
parser.add_argument('--resume_lr_scheduler', action='store_true')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
config = read_config(args.config)
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
config.work_dir = args.work_dir
if args.cloud:
config.data_root = '/data/data'
if args.data_root:
config.data_root = args.data_root
if args.resume_from is not None:
config.load_from = None
config.resume_from = dict(
checkpoint=args.resume_from,
load_ema=False,
resume_optimizer=args.resume_optimizer,
resume_lr_scheduler=args.resume_lr_scheduler)
if args.debug:
config.log_interval = 1
config.train_batch_size = 6
config.optimizer.update({'lr': args.lr})
os.umask(0o000) # file permission: 666; dir permission: 777
os.makedirs(config.work_dir, exist_ok=True)
init_handler = InitProcessGroupKwargs()
init_handler.timeout = datetime.timedelta(seconds=9600) # change timeout to avoid a strange NCCL bug
# Initialize accelerator and tensorboard logging
if config.use_fsdp:
init_train = 'FSDP'
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
set_fsdp_env()
fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),)
else:
init_train = 'DDP'
fsdp_plugin = None
even_batches = True
if config.multi_scale:
even_batches=False,
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
log_with=args.report_to,
project_dir=os.path.join(config.work_dir, "logs"),
fsdp_plugin=fsdp_plugin,
even_batches=even_batches,
kwargs_handlers=[init_handler]
)
logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
config.seed = init_random_seed(config.get('seed', None))
set_random_seed(config.seed)
if accelerator.is_main_process:
config.dump(os.path.join(config.work_dir, 'config.py'))
logger.info(f"Config: \n{config.pretty_text}")
logger.info(f"World_size: {get_world_size()}, seed: {config.seed}")
logger.info(f"Initializing: {init_train} for training")
image_size = config.image_size # @param [512, 1024]
latent_size = int(image_size) // 8
pred_sigma = getattr(config, 'pred_sigma', True)
learn_sigma = getattr(config, 'learn_sigma', True) and pred_sigma
model_kwargs={"window_block_indexes": config.window_block_indexes, "window_size": config.window_size,
"use_rel_pos": config.use_rel_pos, "lewei_scale": config.lewei_scale, 'config':config,
'model_max_length': config.model_max_length}
# build models
train_diffusion = IDDPM(str(config.train_sampling_steps))
model: PixArtMS = build_model(config.model,
config.grad_checkpointing,
config.get('fp32_attention', False),
input_size=latent_size,
learn_sigma=learn_sigma,
pred_sigma=pred_sigma,
**model_kwargs)
if config.load_from is not None and args.resume_from is None:
# load from PixArt model
missing, unexpected = load_checkpoint(config.load_from, model)
logger.warning(f'Missing keys: {missing}')
logger.warning(f'Unexpected keys: {unexpected}')
if image_size == 1024:
model: ControlPixArtMSHalf = ControlPixArtMSHalf(model, copy_blocks_num=config.copy_blocks_num).train()
else:
model: ControlPixArtHalf = ControlPixArtHalf(model, copy_blocks_num=config.copy_blocks_num).train()
logger.info(f"{model.__class__.__name__} Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
logger.info(f"T5 max token length: {config.model_max_length}")
# if args.local_rank == 0:
# for name, params in model.named_parameters():
# if params.requires_grad == False: logger.info(f"freeze param: {name}")
#
# for name, params in model.named_parameters():
# if params.requires_grad == True: logger.info(f"trainable param: {name}")
# prepare for FSDP clip grad norm calculation
if accelerator.distributed_type == DistributedType.FSDP:
for m in accelerator._models:
m.clip_grad_norm_ = types.MethodType(clip_grad_norm_, m)
# build dataloader
set_data_root(config.data_root)
dataset = build_dataset(config.data, resolution=image_size, aspect_ratio_type=config.aspect_ratio_type, train_ratio=config.train_ratio)
if config.multi_scale:
batch_sampler = AspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset,
batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio, drop_last=True,
ratio_nums=dataset.ratio_nums, config=config, valid_num=1)
# batch_sampler = BalancedAspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset,
# batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio,
# ratio_nums=dataset.ratio_nums)
train_dataloader = build_dataloader(dataset, batch_sampler=batch_sampler, num_workers=config.num_workers)
else:
train_dataloader = build_dataloader(dataset, num_workers=config.num_workers, batch_size=config.train_batch_size, shuffle=True)
# build optimizer and lr scheduler
lr_scale_ratio = 1
if config.get('auto_lr', None):
lr_scale_ratio = auto_scale_lr(config.train_batch_size * get_world_size() * config.gradient_accumulation_steps,
config.optimizer, **config.auto_lr)
optimizer = build_optimizer(model.controlnet, config.optimizer)
lr_scheduler = build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio)
timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())
if accelerator.is_main_process:
tracker_config = dict(vars(config))
try:
accelerator.init_trackers(args.tracker_project_name, tracker_config)
except:
accelerator.init_trackers(f"tb_{timestamp}")
start_epoch = 0
if config.resume_from is not None and config.resume_from['checkpoint'] is not None:
if args.resume_optimizer == False or args.resume_lr_scheduler == False:
missing, unexpected = load_checkpoint(args.resume_from, model)
else:
start_epoch, missing, unexpected = load_checkpoint(**config.resume_from,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
logger.warning(f'Missing keys: {missing}')
logger.warning(f'Unexpected keys: {unexpected}')
# Prepare everything
# There is no specific order to remember, you just need to unpack the
# objects in the same order you gave them to the prepare method.
model = accelerator.prepare(model,)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
train()
import argparse
import datetime
import os
import sys
import time
import types
import warnings
from pathlib import Path
current_file_path = Path(__file__).resolve()
sys.path.insert(0, str(current_file_path.parent.parent))
import accelerate
import gc
import numpy as np
import torch
import torch.nn as nn
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.utils import DistributedType
from copy import deepcopy
from diffusers import AutoencoderKL, Transformer2DModel, PixArtAlphaPipeline, DPMSolverMultistepScheduler
from mmcv.runner import LogBuffer
from packaging import version
from torch.utils.data import RandomSampler
from transformers import T5Tokenizer, T5EncoderModel
from diffusion import IDDPM
from diffusion.data.builder import build_dataset, build_dataloader, set_data_root
from diffusion.utils.data_sampler import AspectRatioBatchSampler, BalancedAspectRatioBatchSampler
from diffusion.utils.dist_utils import get_world_size, clip_grad_norm_, flush
from diffusion.utils.logger import get_root_logger, rename_file_with_creation_time
from diffusion.utils.lr_scheduler import build_lr_scheduler
from diffusion.utils.misc import set_random_seed, read_config, init_random_seed, DebugUnderflowOverflow
from diffusion.utils.optimizer import build_optimizer, auto_scale_lr
warnings.filterwarnings("ignore") # ignore warning
def set_fsdp_env():
os.environ["ACCELERATE_USE_FSDP"] = 'true'
os.environ["FSDP_AUTO_WRAP_POLICY"] = 'TRANSFORMER_BASED_WRAP'
os.environ["FSDP_BACKWARD_PREFETCH"] = 'BACKWARD_PRE'
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = 'Transformer2DModel'
def ema_update(model_dest: nn.Module, model_src: nn.Module, rate):
param_dict_src = dict(model_src.named_parameters())
for p_name, p_dest in model_dest.named_parameters():
p_src = param_dict_src[p_name]
assert p_src is not p_dest
p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
def token_drop(y, y_mask, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(y.shape[0]).cuda() < config.class_dropout_prob
else:
drop_ids = force_drop_ids == 1
y = torch.where(drop_ids[:, None, None], uncond_prompt_embeds, y)
y_mask = torch.where(drop_ids[:, None], uncond_prompt_attention_mask, y_mask)
return y, y_mask
def get_null_embed(npz_file, max_length=120):
if os.path.exists(npz_file) and (npz_file.endswith('.npz') or npz_file.endswith('.pth')):
data = torch.load(npz_file)
uncond_prompt_embeds = data['uncond_prompt_embeds'].to(accelerator.device)
uncond_prompt_attention_mask = data['uncond_prompt_attention_mask'].to(accelerator.device)
else:
tokenizer = T5Tokenizer.from_pretrained(args.pipeline_load_from, subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(args.pipeline_load_from, subfolder="text_encoder")
uncond = tokenizer("", max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
uncond_prompt_embeds = text_encoder(uncond.input_ids, attention_mask=uncond.attention_mask)[0]
torch.save({
'uncond_prompt_embeds': uncond_prompt_embeds.cpu(),
'uncond_prompt_attention_mask': uncond.attention_mask.cpu()
}, npz_file)
uncond_prompt_embeds = uncond_prompt_embeds.to(accelerator.device)
uncond_prompt_attention_mask = uncond.attention_mask.to(accelerator.device)
return uncond_prompt_embeds, uncond_prompt_attention_mask
def prepare_vis():
if accelerator.is_main_process:
# preparing embeddings for visualization. We put it here for saving GPU memory
validation_prompts = [
"dog",
"portrait photo of a girl, photograph, highly detailed face, depth of field",
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
]
logger.info("Preparing Visualization prompt embeddings...")
logger.info(f"Loading text encoder and tokenizer from {args.pipeline_load_from} ...")
skip = True
for prompt in validation_prompts:
if not os.path.exists(f'output/tmp/{prompt}_{max_length}token.pth'):
skip = False
break
if accelerator.is_main_process and not skip:
print(f"Saving visualizate prompt text embedding at output/tmp/")
tokenizer = T5Tokenizer.from_pretrained(args.pipeline_load_from, subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(args.pipeline_load_from, subfolder="text_encoder").to(accelerator.device)
for prompt in validation_prompts:
caption_token = tokenizer(prompt, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt").to(accelerator.device)
caption_emb = text_encoder(caption_token.input_ids, attention_mask=caption_token.attention_mask)[0]
torch.save({'caption_embeds': caption_emb, 'emb_mask': caption_token.attention_mask}, f'output/tmp/{prompt}_{max_length}token.pth')
flush()
@torch.inference_mode()
def log_validation(model, accelerator, weight_dtype, step):
logger.info("Running validation... ")
model = accelerator.unwrap_model(model)
pipeline = PixArtAlphaPipeline.from_pretrained(
args.pipeline_load_from,
transformer=model,
tokenizer=None,
text_encoder=None,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
generator = torch.Generator(device=accelerator.device).manual_seed(0)
validation_prompts = [
"dog",
"portrait photo of a girl, photograph, highly detailed face, depth of field",
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
]
image_logs = []
images = []
latents = []
for _, prompt in enumerate(validation_prompts):
embed = torch.load(f'output/tmp/{prompt}_{max_length}token.pth', map_location='cpu')
caption_embs, emb_masks = embed['caption_embeds'].to(accelerator.device), embed['emb_mask'].to(accelerator.device)
latents.append(pipeline(
num_inference_steps=14,
num_images_per_prompt=1,
generator=generator,
guidance_scale=4.5,
prompt_embeds=caption_embs,
prompt_attention_mask=emb_masks,
negative_prompt=None,
negative_prompt_embeds=uncond_prompt_embeds,
negative_prompt_attention_mask=uncond_prompt_attention_mask,
output_type="latent",
).images)
flush()
for latent in latents:
images.append(pipeline.vae.decode(latent.to(weight_dtype) / pipeline.vae.config.scaling_factor, return_dict=False)[0])
for prompt, image in zip(validation_prompts, images):
image = pipeline.image_processor.postprocess(image, output_type="pil")
image_logs.append({"validation_prompt": prompt, "images": image})
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
for log in image_logs:
images = log["images"]
validation_prompt = log["validation_prompt"]
formatted_images = []
for image in images:
formatted_images.append(np.asarray(image))
formatted_images = np.stack(formatted_images)
tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
elif tracker.name == "wandb":
import wandb
formatted_images = []
for log in image_logs:
images = log["images"]
validation_prompt = log["validation_prompt"]
for image in images:
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
tracker.log({"validation": formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
torch.cuda.empty_cache()
return image_logs
def train(model):
if config.get('debug_nan', False):
DebugUnderflowOverflow(model)
logger.info('NaN debugger registered. Start to detect overflow during training.')
time_start, last_tic = time.time(), time.time()
log_buffer = LogBuffer()
global_step = start_step + 1
load_vae_feat = getattr(train_dataloader.dataset, 'load_vae_feat', False)
# Now you train the model
for epoch in range(start_epoch + 1, config.num_epochs + 1):
data_time_start= time.time()
data_time_all = 0
for step, batch in enumerate(train_dataloader):
data_time_all += time.time() - data_time_start
if load_vae_feat:
z = batch[0]
else:
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=config.mixed_precision == 'fp16'):
posterior = vae.encode(batch[0]).latent_dist
if config.sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
latents = (z * config.scale_factor).to(weight_dtype)
y = batch[1].squeeze(1).to(weight_dtype)
y_mask = batch[2].squeeze(1).squeeze(1).to(weight_dtype)
y, y_mask = token_drop(y, y_mask) # classifier-free guidance
data_info = {'resolution': batch[3]['img_hw'].to(weight_dtype), 'aspect_ratio': batch[3]['aspect_ratio'].to(weight_dtype),}
# Sample a random timestep for each image
bs = latents.shape[0]
timesteps = torch.randint(0, config.train_sampling_steps, (bs,), device=latents.device).long()
grad_norm = None
with accelerator.accumulate(model):
# Predict the noise residual
optimizer.zero_grad()
loss_term = train_diffusion.training_losses_diffusers(
model, latents, timesteps,
model_kwargs = dict(encoder_hidden_states=y, encoder_attention_mask=y_mask, added_cond_kwargs=data_info),
)
loss = loss_term['loss'].mean()
accelerator.backward(loss)
if accelerator.sync_gradients:
grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.gradient_clip)
optimizer.step()
lr_scheduler.step()
# if accelerator.sync_gradients:
# ema_update(model_ema, accelerator.unwrap_model(model), config.ema_rate)
lr = lr_scheduler.get_last_lr()[0]
logs = {args.loss_report_name: accelerator.gather(loss).mean().item()}
if grad_norm is not None:
logs.update(grad_norm=accelerator.gather(grad_norm).mean().item())
log_buffer.update(logs)
if (step + 1) % config.log_interval == 0 or (step + 1) == 1:
t = (time.time() - last_tic) / config.log_interval
t_d = data_time_all / config.log_interval
avg_time = (time.time() - time_start) / (global_step - start_step)
eta = str(datetime.timedelta(seconds=int(avg_time * (total_steps - global_step - 1))))
eta_epoch = str(datetime.timedelta(seconds=int(avg_time * (len(train_dataloader) - step - 1))))
# avg_loss = sum(loss_buffer) / len(loss_buffer)
log_buffer.average()
info = f"Step/Epoch [{global_step}/{epoch}][{step + 1}/{len(train_dataloader)}]:total_eta: {eta}, " \
f"epoch_eta:{eta_epoch}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}," \
f"s:({data_info['resolution'][0][0].item()}, {data_info['resolution'][0][1].item()}), "
# f"s:({data_info['resolution'][0][0].item() * relative_to_1024 // 8}, {data_info['resolution'][0][1].item() * relative_to_1024 // 8}), "
info += ', '.join([f"{k}:{v:.4f}" for k, v in log_buffer.output.items()])
logger.info(info)
last_tic = time.time()
log_buffer.clear()
data_time_all = 0
logs.update(lr=lr)
accelerator.log(logs, step=global_step)
global_step += 1
data_time_start= time.time()
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if global_step % config.save_model_steps == 0:
save_path = os.path.join(os.path.join(config.work_dir, 'checkpoints'), f"checkpoint-{global_step}")
os.umask(0o000)
logger.info(f"Start to save state to {save_path}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if global_step % config.eval_sampling_steps == 0 or (step + 1) == 1:
log_validation(model, accelerator, weight_dtype, global_step)
accelerator.wait_for_everyone()
if epoch % config.save_model_epochs == 0 or epoch == config.num_epochs:
os.umask(0o000)
save_path = os.path.join(os.path.join(config.work_dir, 'checkpoints'), f"checkpoint-{global_step}")
logger.info(f"Start to save state to {save_path}")
model = accelerator.unwrap_model(model)
model.save_pretrained(save_path)
logger.info(f"Saved state to {save_path}")
def parse_args():
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument("config", type=str, help="config")
parser.add_argument("--cloud", action='store_true', default=False, help="cloud or local machine")
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument('--resume-from', help='the dir to resume the training')
parser.add_argument('--load-from', default=None, help='the dir to load a ckpt for training')
parser.add_argument('--local-rank', type=int, default=-1)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--debug', action='store_true')
parser.add_argument("--pipeline_load_from", default='output/pretrained_models/pixart_omega_sdxl_256px_diffusers_from512', type=str, help="path for loading text_encoder, tokenizer and vae")
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="text2image-pixart-omega",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
parser.add_argument("--loss_report_name", type=str, default="loss")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
config = read_config(args.config)
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
config.work_dir = args.work_dir
if args.cloud:
config.data_root = '/data/data'
if args.resume_from is not None:
config.resume_from = args.resume_from
if args.debug:
config.log_interval = 1
config.train_batch_size = 32
config.valid_num = 100
os.umask(0o000)
os.makedirs(config.work_dir, exist_ok=True)
init_handler = InitProcessGroupKwargs()
init_handler.timeout = datetime.timedelta(seconds=5400) # change timeout to avoid a strange NCCL bug
# Initialize accelerator and tensorboard logging
if config.use_fsdp:
init_train = 'FSDP'
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
set_fsdp_env()
fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),)
else:
init_train = 'DDP'
fsdp_plugin = None
even_batches = True
if config.multi_scale:
even_batches=False,
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
log_with=args.report_to,
project_dir=os.path.join(config.work_dir, "logs"),
fsdp_plugin=fsdp_plugin,
even_batches=even_batches,
kwargs_handlers=[init_handler]
)
log_name = 'train_log.log'
if accelerator.is_main_process:
if os.path.exists(os.path.join(config.work_dir, log_name)):
rename_file_with_creation_time(os.path.join(config.work_dir, log_name))
logger = get_root_logger(os.path.join(config.work_dir, log_name))
logger.info(accelerator.state)
config.seed = init_random_seed(config.get('seed', None))
set_random_seed(config.seed)
if accelerator.is_main_process:
config.dump(os.path.join(config.work_dir, 'config.py'))
logger.info(f"Config: \n{config.pretty_text}")
logger.info(f"World_size: {get_world_size()}, seed: {config.seed}")
logger.info(f"Initializing: {init_train} for training")
image_size = config.image_size # @param [256, 512, 1024]
latent_size = int(image_size) // 8
relative_to_1024 = float(image_size / 1024)
pred_sigma = getattr(config, 'pred_sigma', True)
learn_sigma = getattr(config, 'learn_sigma', True) and pred_sigma
# Create for unconditional prompt embedding for classifier free guidance
logger.info("Embedding for classifier free guidance")
max_length = config.model_max_length
uncond_prompt_embeds, uncond_prompt_attention_mask = get_null_embed(
f'output/pretrained_models/null_embed_diffusers_{max_length}token.pth', max_length=max_length
)
# preparing embeddings for visualization. We put it here for saving GPU memory
prepare_vis()
# build models
train_diffusion = IDDPM(str(config.train_sampling_steps), learn_sigma=learn_sigma, pred_sigma=pred_sigma, snr=config.snr_loss)
model = Transformer2DModel.from_pretrained(config.load_from, subfolder="transformer").train()
logger.info(f"{model.__class__.__name__} Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
logger.info(f"lewei scale: {model.pos_embed.interpolation_scale} base size: {model.pos_embed.base_size}")
# model_ema = deepcopy(model).eval()
# 9. Handle mixed precision and device placement
# For mixed precision training we cast all non-trainable weigths to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# 11. Enable optimizations
# model.enable_xformers_memory_efficient_attention() # not available for now
# for name, params in model.named_parameters():
# if params.requires_grad == False: logger.info(f"freeze param: {name}")
#
# for name, params in model.named_parameters():
# if params.requires_grad == True: logger.info(f"trainable param: {name}")
# 10. Handle saving and loading of checkpoints
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_ = accelerator.unwrap_model(models[0])
# save weights in peft format to be able to load them back
transformer_.save_pretrained(output_dir)
for _, model in enumerate(models):
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()
# load diffusers style into model
load_model = Transformer2DModel.from_pretrained(input_dir)
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if config.grad_checkpointing:
model.enable_gradient_checkpointing()
if not config.data.load_vae_feat:
vae = AutoencoderKL.from_pretrained(config.vae_pretrained).cuda()
# prepare for FSDP clip grad norm calculation
if accelerator.distributed_type == DistributedType.FSDP:
for m in accelerator._models:
m.clip_grad_norm_ = types.MethodType(clip_grad_norm_, m)
# build dataloader
set_data_root(config.data_root)
logger.info(f"ratio of real user prompt: {config.real_prompt_ratio}")
dataset = build_dataset(
config.data, resolution=image_size, aspect_ratio_type=config.aspect_ratio_type,
real_prompt_ratio=config.real_prompt_ratio, max_length=max_length, config=config,
)
if config.multi_scale:
batch_sampler = AspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset,
batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio, drop_last=True,
ratio_nums=dataset.ratio_nums, config=config, valid_num=config.valid_num)
# used for balanced sampling
# batch_sampler = BalancedAspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset,
# batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio,
# ratio_nums=dataset.ratio_nums)
train_dataloader = build_dataloader(dataset, batch_sampler=batch_sampler, num_workers=config.num_workers)
else:
train_dataloader = build_dataloader(dataset, num_workers=config.num_workers, batch_size=config.train_batch_size, shuffle=True)
# build optimizer and lr scheduler
lr_scale_ratio = 1
if config.get('auto_lr', None):
lr_scale_ratio = auto_scale_lr(config.train_batch_size * get_world_size() * config.gradient_accumulation_steps,
config.optimizer, **config.auto_lr)
optimizer = build_optimizer(model, config.optimizer)
lr_scheduler = build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio)
timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())
if accelerator.is_main_process:
tracker_config = dict(vars(config))
accelerator.init_trackers(f"tb_{timestamp}_{args.tracker_project_name}")
logger.info(f"Training tracker at tb_{timestamp}_{args.tracker_project_name}")
start_epoch = 0
start_step = 0
total_steps = len(train_dataloader) * config.num_epochs
# Prepare everything
# There is no specific order to remember, you just need to unpack the
# objects in the same order you gave them to the prepare method.
# model, model_ema = accelerator.prepare(model, model_ema)
model = accelerator.prepare(model)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if config.resume_from is not None:
if config.resume_from != "latest":
path = os.path.basename(config.resume_from)
else:
# Get the most recent checkpoint
dirs = os.listdir(os.path.join(config.work_dir, 'checkpoints'))
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(f"Checkpoint '{config.resume_from}' does not exist. Starting a new training run.")
config.resume_from = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(config.work_dir, 'checkpoints', path))
start_step = int(path.split("-")[1])
start_epoch = start_step // len(train_dataloader)
train(model)
\ No newline at end of file
import os
import sys
import types
from pathlib import Path
current_file_path = Path(__file__).resolve()
sys.path.insert(0, str(current_file_path.parent.parent))
import argparse
import datetime
import time
import warnings
warnings.filterwarnings("ignore") # ignore warning
from mmcv.runner import LogBuffer
from copy import deepcopy
from diffusion.utils.checkpoint import save_checkpoint, load_checkpoint
import torch
import torch.nn as nn
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.utils import DistributedType
from torch.utils.data import RandomSampler
from diffusion import IDDPM
from diffusion.utils.dist_utils import synchronize, get_world_size, clip_grad_norm_
from diffusion.data.builder import build_dataset, build_dataloader, set_data_root
from diffusion.model.builder import build_model
from diffusion.utils.logger import get_root_logger
from diffusion.utils.misc import set_random_seed, read_config, init_random_seed, DebugUnderflowOverflow
from diffusion.utils.optimizer import build_optimizer, auto_scale_lr
from diffusion.utils.lr_scheduler import build_lr_scheduler
from diffusion.model.t5 import T5Embedder
from diffusion.utils.data_sampler import AspectRatioBatchSampler
def set_fsdp_env():
os.environ["ACCELERATE_USE_FSDP"] = 'true'
os.environ["FSDP_AUTO_WRAP_POLICY"] = 'TRANSFORMER_BASED_WRAP'
os.environ["FSDP_BACKWARD_PREFETCH"] = 'BACKWARD_PRE'
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = 'PixArtBlock'
def ema_update(model_dest: nn.Module, model_src: nn.Module, rate):
param_dict_src = dict(model_src.named_parameters())
for p_name, p_dest in model_dest.named_parameters():
p_src = param_dict_src[p_name]
assert p_src is not p_dest
p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
def train():
if config.get('debug_nan', False):
DebugUnderflowOverflow(model)
logger.info('NaN debugger registered. Start to detect overflow during training.')
time_start, last_tic = time.time(), time.time()
log_buffer = LogBuffer()
start_step = start_epoch * len(train_dataloader)
global_step = 0
total_steps = len(train_dataloader) * config.num_epochs
# txt related
prompt = config.data.prompt if isinstance(config.data.prompt, list) else [config.data.prompt]
llm_embed_model = T5Embedder(device="cpu", local_cache=True, cache_dir='output/pretrained_models/t5_ckpts', torch_dtype=torch.float)
prompt_embs, attention_mask = llm_embed_model.get_text_embeddings(prompt)
prompt_embs, attention_mask = prompt_embs[None].cuda(), attention_mask[None].cuda()
del llm_embed_model
# Now you train the model
for epoch in range(start_epoch + 1, config.num_epochs + 1):
data_time_start= time.time()
data_time_all = 0
for step, batch in enumerate(train_dataloader):
data_time_all += time.time() - data_time_start
z = batch[0]
clean_images = z * config.scale_factor
y = prompt_embs
y_mask = attention_mask
data_info = batch[1]
# Sample a random timestep for each image
bs = clean_images.shape[0]
timesteps = torch.randint(0, config.train_sampling_steps, (bs,), device=clean_images.device).long()
grad_norm = None
with accelerator.accumulate(model):
# Predict the noise residual
optimizer.zero_grad()
loss_term = train_diffusion.training_losses(model, clean_images, timesteps, model_kwargs=dict(y=y, mask=y_mask, data_info=data_info))
loss = loss_term['loss'].mean()
accelerator.backward(loss)
if accelerator.sync_gradients:
grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.gradient_clip)
optimizer.step()
lr_scheduler.step()
if accelerator.sync_gradients:
ema_update(model_ema, model, config.ema_rate)
lr = lr_scheduler.get_last_lr()[0]
logs = {"loss": accelerator.gather(loss).mean().item()}
if grad_norm is not None:
logs.update(grad_norm=accelerator.gather(grad_norm).mean().item())
log_buffer.update(logs)
if (step + 1) % config.log_interval == 0:
t = (time.time() - last_tic) / config.log_interval
t_d = data_time_all / config.log_interval
avg_time = (time.time() - time_start) / (global_step + 1)
eta = str(datetime.timedelta(seconds=int(avg_time * (total_steps - start_step - global_step - 1))))
eta_epoch = str(datetime.timedelta(seconds=int(avg_time * (len(train_dataloader) - step - 1))))
# avg_loss = sum(loss_buffer) / len(loss_buffer)
log_buffer.average()
info = f"Steps [{(epoch-1)*len(train_dataloader)+step+1}][{step + 1}/{len(train_dataloader)}]:total_eta: {eta}, " \
f"epoch_eta:{eta_epoch}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:({model.module.h}, {model.module.w}), "
info += ', '.join([f"{k}:{v:.4f}" for k, v in log_buffer.output.items()])
logger.info(info)
last_tic = time.time()
log_buffer.clear()
data_time_all = 0
logs.update(lr=lr)
accelerator.log(logs, step=global_step + start_step)
global_step += 1
data_time_start= time.time()
synchronize()
if accelerator.is_main_process:
if ((epoch - 1) * len(train_dataloader) + step + 1) % config.save_model_steps == 0:
os.umask(0o000)
save_checkpoint(os.path.join(config.work_dir, 'checkpoints'),
epoch=epoch,
step=(epoch - 1) * len(train_dataloader) + step + 1,
model=accelerator.unwrap_model(model),
model_ema=accelerator.unwrap_model(model_ema),
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
synchronize()
synchronize()
if accelerator.is_main_process:
if epoch % config.save_model_epochs == 0 or epoch == config.num_epochs:
os.umask(0o000)
save_checkpoint(os.path.join(config.work_dir, 'checkpoints'),
epoch=epoch,
step=(epoch - 1) * len(train_dataloader) + step + 1,
model=accelerator.unwrap_model(model),
model_ema=accelerator.unwrap_model(model_ema),
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
synchronize()
def parse_args():
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument("config", type=str, help="config")
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument('--resume-from', help='the dir to resume the training')
parser.add_argument('--load-from', default=None, help='the dir to load a ckpt for training')
parser.add_argument('--local-rank', type=int, default=-1)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--debug', action='store_true')
parser.add_argument('--save_step', type=int, default=100)
parser.add_argument('--lr', type=float, default=5e-6)
parser.add_argument('--train_class', type=str)
parser.add_argument('--prompt', type=str, default='a photo of sks dog')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
config = read_config(args.config)
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
config.work_dir = args.work_dir
if args.resume_from is not None:
config.resume_from = dict(
checkpoint=args.resume_from,
load_ema=False,
resume_optimizer=True,
resume_lr_scheduler=True)
if args.debug:
config.log_interval = 1
config.train_batch_size = 1
config.save_model_steps=args.save_step
config.data.update({'prompt': [args.prompt], 'root': args.train_class})
config.optimizer.update({'lr': args.lr})
os.umask(0o000)
os.makedirs(config.work_dir, exist_ok=True)
init_handler = InitProcessGroupKwargs()
init_handler.timeout = datetime.timedelta(seconds=5400) # change timeout to avoid a strange NCCL bug
# Initialize accelerator and tensorboard logging
if config.use_fsdp:
init_train = 'FSDP'
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
set_fsdp_env()
fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),)
else:
init_train = 'DDP'
fsdp_plugin = None
even_batches = True
if config.multi_scale:
even_batches=False,
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
log_with="tensorboard",
project_dir=os.path.join(config.work_dir, "logs"),
fsdp_plugin=fsdp_plugin,
even_batches=even_batches,
kwargs_handlers=[init_handler]
)
logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
config.seed = init_random_seed(config.get('seed', None))
set_random_seed(config.seed)
if accelerator.is_main_process:
config.dump(os.path.join(config.work_dir, 'config.py'))
logger.info(f"Config: \n{config.pretty_text}")
logger.info(f"World_size: {get_world_size()}, seed: {config.seed}")
logger.info(f"Initializing: {init_train} for training")
image_size = config.image_size # @param [256, 512]
latent_size = int(image_size) // 8
pred_sigma = getattr(config, 'pred_sigma', True)
learn_sigma = getattr(config, 'learn_sigma', True) and pred_sigma
model_kwargs={"window_block_indexes": config.window_block_indexes, "window_size": config.window_size,
"use_rel_pos": config.use_rel_pos, "lewei_scale": config.lewei_scale, 'config':config,
'model_max_length': config.model_max_length}
# build models
train_diffusion = IDDPM(str(config.train_sampling_steps))
eval_diffusion = IDDPM(str(config.eval_sampling_steps))
model = build_model(config.model,
config.grad_checkpointing,
config.get('fp32_attention', False),
input_size=latent_size,
learn_sigma=learn_sigma,
pred_sigma=pred_sigma,
**model_kwargs).train()
logger.info(f"{config.model} Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
model_ema = deepcopy(model).eval()
if config.load_from is not None:
if args.load_from is not None:
config.load_from = args.load_from
missing, unexpected = load_checkpoint(config.load_from, model, load_ema=config.get('load_ema', False))
# model.reparametrize()
if accelerator.is_main_process:
print('Warning Missing keys: ', missing)
print('Warning Unexpected keys', unexpected)
ema_update(model_ema, model, 0.)
# prepare for FSDP clip grad norm calculation
if accelerator.distributed_type == DistributedType.FSDP:
for m in accelerator._models:
m.clip_grad_norm_ = types.MethodType(clip_grad_norm_, m)
# build dataloader
logger.warning(f"Training prompt: {config.data['prompt']}, Training data class: {config.data['root']}")
set_data_root(config.data_root)
dataset = build_dataset(config.data, resolution=image_size, aspect_ratio_type=config.aspect_ratio_type)
if config.multi_scale:
batch_sampler = AspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset,
batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio, drop_last=True,
ratio_nums=dataset.ratio_nums, config=config, valid_num=1)
# batch_sampler = BalancedAspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset,
# batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio,
# ratio_nums=dataset.ratio_nums)
train_dataloader = build_dataloader(dataset, batch_sampler=batch_sampler, num_workers=config.num_workers)
else:
train_dataloader = build_dataloader(dataset, num_workers=config.num_workers, batch_size=config.train_batch_size, shuffle=True)
# build optimizer and lr scheduler
lr_scale_ratio = 1
if config.get('auto_lr', None):
lr_scale_ratio = auto_scale_lr(config.train_batch_size * get_world_size() * config.gradient_accumulation_steps,
config.optimizer,
**config.auto_lr)
optimizer = build_optimizer(model, config.optimizer)
lr_scheduler = build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio)
timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())
if accelerator.is_main_process:
accelerator.init_trackers(f"tb_{timestamp}")
start_epoch = 0
if config.resume_from is not None and config.resume_from['checkpoint'] is not None:
start_epoch, missing, unexpected = load_checkpoint(**config.resume_from,
model=model,
model_ema=model_ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
if accelerator.is_main_process:
print('Warning Missing keys: ', missing)
print('Warning Unexpected keys', unexpected)
# Prepare everything
# There is no specific order to remember, you just need to unpack the
# objects in the same order you gave them to the prepare method.
model, model_ema = accelerator.prepare(model, model_ema)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
train()
\ No newline at end of file
import os
import sys
import types
from pathlib import Path
current_file_path = Path(__file__).resolve()
sys.path.insert(0, str(current_file_path.parent.parent))
import argparse
import datetime
import time
import warnings
warnings.filterwarnings("ignore") # ignore warning
import torch
import torch.nn as nn
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.utils import DistributedType
from diffusers.models import AutoencoderKL
from torch.utils.data import RandomSampler
from mmcv.runner import LogBuffer
from copy import deepcopy
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from diffusion import IDDPM
from diffusion.utils.checkpoint import save_checkpoint, load_checkpoint
from diffusion.utils.dist_utils import synchronize, get_world_size, clip_grad_norm_
from diffusion.data.builder import build_dataset, build_dataloader, set_data_root
from diffusion.model.builder import build_model
from diffusion.utils.logger import get_root_logger
from diffusion.utils.misc import set_random_seed, read_config, init_random_seed, DebugUnderflowOverflow
from diffusion.utils.optimizer import build_optimizer, auto_scale_lr
from diffusion.utils.lr_scheduler import build_lr_scheduler
from diffusion.utils.data_sampler import AspectRatioBatchSampler, BalancedAspectRatioBatchSampler
from diffusion.lcm_scheduler import LCMScheduler
from torchvision.utils import save_image
def set_fsdp_env():
os.environ["ACCELERATE_USE_FSDP"] = 'true'
os.environ["FSDP_AUTO_WRAP_POLICY"] = 'TRANSFORMER_BASED_WRAP'
os.environ["FSDP_BACKWARD_PREFETCH"] = 'BACKWARD_PRE'
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = 'PixArtBlock'
def ema_update(model_dest: nn.Module, model_src: nn.Module, rate):
param_dict_src = dict(model_src.named_parameters())
for p_name, p_dest in model_dest.named_parameters():
p_src = param_dict_src[p_name]
assert p_src is not p_dest
p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
return x[(...,) + (None,) * dims_to_append]
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
return c_skip, c_out
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
class DDIMSolver:
def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
# DDIM sampling parameters
step_ratio = timesteps // ddim_timesteps
self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
self.ddim_alpha_cumprods_prev = np.asarray(
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
)
# convert to torch tensors
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
def to(self, device):
self.ddim_timesteps = self.ddim_timesteps.to(device)
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
return self
def ddim_step(self, pred_x0, pred_noise, timestep_index):
alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
return x_prev
@torch.no_grad()
def log_validation(model, step, device):
if hasattr(model, 'module'):
model = model.module
scheduler = LCMScheduler(beta_start=0.0001, beta_end=0.02, beta_schedule="linear", prediction_type="epsilon")
scheduler.set_timesteps(4, 50)
infer_timesteps = scheduler.timesteps
dog_embed = torch.load('data/tmp/dog.pth', map_location='cpu')
caption_embs, emb_masks = dog_embed['dog_text'].to(device), dog_embed['dog_mask'].to(device)
hw = torch.tensor([[1024, 1024]], dtype=torch.float, device=device).repeat(1, 1)
ar = torch.tensor([[1.]], device=device).repeat(1, 1)
# Create sampling noise:
infer_latents = torch.randn(1, 4, 1024, 1024, device=device)
model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks)
logger.info("Running validation... ")
# 7. LCM MultiStep Sampling Loop:
for i, t in tqdm(list(enumerate(infer_timesteps))):
ts = torch.full((1,), t, device=device, dtype=torch.long)
# model prediction (v-prediction, eps, x)
model_pred = model(infer_latents, ts, caption_embs, **model_kwargs)[:, :4]
# compute the previous noisy sample x_t -> x_t-1
infer_latents, denoised = scheduler.step(model_pred, i, t, infer_latents, return_dict=False)
samples = vae.decode(denoised / 0.18215).sample
torch.cuda.empty_cache()
save_image(samples[0], f'output_cv/vis/{step}.jpg', nrow=1, normalize=True, value_range=(-1, 1))
def train():
if config.get('debug_nan', False):
DebugUnderflowOverflow(model)
logger.info('NaN debugger registered. Start to detect overflow during training.')
time_start, last_tic = time.time(), time.time()
log_buffer = LogBuffer()
start_step = start_epoch * len(train_dataloader)
global_step = 0
total_steps = len(train_dataloader) * config.num_epochs
load_vae_feat = getattr(train_dataloader.dataset, 'load_vae_feat', False)
# Create uncond embeds for classifier free guidance
uncond_prompt_embeds = model.module.y_embedder.y_embedding.repeat(config.train_batch_size, 1, 1, 1)
# Now you train the model
for epoch in range(start_epoch + 1, config.num_epochs + 1):
data_time_start= time.time()
data_time_all = 0
for step, batch in enumerate(train_dataloader):
data_time_all += time.time() - data_time_start
if load_vae_feat:
z = batch[0]
else:
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=config.mixed_precision == 'fp16'):
posterior = vae.encode(batch[0]).latent_dist
if config.sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
latents = z * config.scale_factor
y = batch[1]
y_mask = batch[2]
data_info = batch[3]
# Sample a random timestep for each image
grad_norm = None
with accelerator.accumulate(model):
# Predict the noise residual
optimizer.zero_grad()
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
topk = config.train_sampling_steps // config.num_ddim_timesteps
index = torch.randint(0, config.num_ddim_timesteps, (bsz,), device=latents.device).long()
start_timesteps = solver.ddim_timesteps[index]
timesteps = start_timesteps - topk
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# Sample a random guidance scale w from U[w_min, w_max] and embed it
# w = (config.w_max - config.w_min) * torch.rand((bsz,)) + config.w_min
w = config.cfg_scale * torch.ones((bsz,))
w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype)
# Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
_, pred_x_0, noisy_model_input = train_diffusion.training_losses(model, latents, start_timesteps, model_kwargs=dict(y=y, mask=y_mask, data_info=data_info), noise=noise)
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
# Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
# Get teacher model prediction on noisy_latents and conditional embedding
with torch.no_grad():
with torch.autocast("cuda"):
cond_teacher_output, cond_pred_x0, _ = train_diffusion.training_losses(model_teacher, latents, start_timesteps, model_kwargs=dict(y=y, mask=y_mask, data_info=data_info), noise=noise)
# Get teacher model prediction on noisy_latents and unconditional embedding
uncond_teacher_output, uncond_pred_x0, _ = train_diffusion.training_losses(model_teacher, latents, start_timesteps, model_kwargs=dict(y=uncond_prompt_embeds, mask=y_mask, data_info=data_info), noise=noise)
# Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
# Get target LCM prediction on x_prev, w, c, t_n
with torch.no_grad():
with torch.autocast("cuda", enabled=True):
_, pred_x_0, _ = train_diffusion.training_losses(model_ema, x_prev.float(), timesteps, model_kwargs=dict(y=y, mask=y_mask, data_info=data_info), skip_noise=True)
target = c_skip * x_prev + c_out * pred_x_0
# Calculate loss
if config.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif config.loss_type == "huber":
loss = torch.mean(torch.sqrt((model_pred.float() - target.float()) ** 2 + config.huber_c**2) - config.huber_c)
# Backpropagation on the online student model (`model`)
accelerator.backward(loss)
if accelerator.sync_gradients:
grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.gradient_clip)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if accelerator.sync_gradients:
ema_update(model_ema, model, config.ema_decay)
lr = lr_scheduler.get_last_lr()[0]
logs = {"loss": accelerator.gather(loss).mean().item()}
if grad_norm is not None:
logs.update(grad_norm=accelerator.gather(grad_norm).mean().item())
log_buffer.update(logs)
if (step + 1) % config.log_interval == 0 or (step + 1) == 1:
t = (time.time() - last_tic) / config.log_interval
t_d = data_time_all / config.log_interval
avg_time = (time.time() - time_start) / (global_step + 1)
eta = str(datetime.timedelta(seconds=int(avg_time * (total_steps - start_step - global_step - 1))))
eta_epoch = str(datetime.timedelta(seconds=int(avg_time * (len(train_dataloader) - step - 1))))
# avg_loss = sum(loss_buffer) / len(loss_buffer)
log_buffer.average()
info = f"Step/Epoch [{(epoch-1)*len(train_dataloader)+step+1}/{epoch}][{step + 1}/{len(train_dataloader)}]:total_eta: {eta}, " \
f"epoch_eta:{eta_epoch}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:({data_info['resolution'][0][0].item()}, {data_info['resolution'][0][1].item()}), "
info += ', '.join([f"{k}:{v:.4f}" for k, v in log_buffer.output.items()])
logger.info(info)
last_tic = time.time()
log_buffer.clear()
data_time_all = 0
logs.update(lr=lr)
accelerator.log(logs, step=global_step + start_step)
global_step += 1
data_time_start= time.time()
synchronize()
torch.cuda.empty_cache()
if accelerator.is_main_process:
# log_validation(model_ema, step, model.device)
if ((epoch - 1) * len(train_dataloader) + step + 1) % config.save_model_steps == 0:
os.umask(0o000)
save_checkpoint(os.path.join(config.work_dir, 'checkpoints'),
epoch=epoch,
step=(epoch - 1) * len(train_dataloader) + step + 1,
model=accelerator.unwrap_model(model),
model_ema=accelerator.unwrap_model(model_ema),
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
synchronize()
synchronize()
if accelerator.is_main_process:
if epoch % config.save_model_epochs == 0 or epoch == config.num_epochs:
os.umask(0o000)
save_checkpoint(os.path.join(config.work_dir, 'checkpoints'),
epoch=epoch,
step=(epoch - 1) * len(train_dataloader) + step + 1,
model=accelerator.unwrap_model(model),
model_ema=accelerator.unwrap_model(model_ema),
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
synchronize()
def parse_args():
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument("config", type=str, help="config")
parser.add_argument("--cloud", action='store_true', default=False, help="cloud or local machine")
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument('--resume-from', help='the dir to resume the training')
parser.add_argument('--load-from', default=None, help='the dir to load a ckpt for training')
parser.add_argument('--local-rank', type=int, default=-1)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
config = read_config(args.config)
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
config.work_dir = args.work_dir
if args.cloud:
config.data_root = '/data/data'
if args.resume_from is not None:
config.load_from = None
config.resume_from = dict(
checkpoint=args.resume_from,
load_ema=False,
resume_optimizer=True,
resume_lr_scheduler=True)
if args.debug:
config.log_interval = 1
config.train_batch_size = 11
config.valid_num = 100
config.load_from = None
os.umask(0o000)
os.makedirs(config.work_dir, exist_ok=True)
init_handler = InitProcessGroupKwargs()
init_handler.timeout = datetime.timedelta(seconds=5400) # change timeout to avoid a strange NCCL bug
# Initialize accelerator and tensorboard logging
if config.use_fsdp:
init_train = 'FSDP'
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
set_fsdp_env()
fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),)
else:
init_train = 'DDP'
fsdp_plugin = None
even_batches = True
if config.multi_scale:
even_batches=False,
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
log_with="tensorboard",
project_dir=os.path.join(config.work_dir, "logs"),
fsdp_plugin=fsdp_plugin,
even_batches=even_batches,
kwargs_handlers=[init_handler]
)
logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
config.seed = init_random_seed(config.get('seed', None))
set_random_seed(config.seed)
if accelerator.is_main_process:
config.dump(os.path.join(config.work_dir, 'config.py'))
logger.info(f"Config: \n{config.pretty_text}")
logger.info(f"World_size: {get_world_size()}, seed: {config.seed}")
logger.info(f"Initializing: {init_train} for training")
image_size = config.image_size # @param [256, 512]
latent_size = int(image_size) // 8
pred_sigma = getattr(config, 'pred_sigma', True)
learn_sigma = getattr(config, 'learn_sigma', True) and pred_sigma
model_kwargs={"window_block_indexes": config.window_block_indexes, "window_size": config.window_size,
"use_rel_pos": config.use_rel_pos, "lewei_scale": config.lewei_scale, 'config':config,
'model_max_length': config.model_max_length}
# build models
train_diffusion = IDDPM(str(config.train_sampling_steps), learn_sigma=learn_sigma, pred_sigma=pred_sigma,
snr=config.snr_loss, return_startx=True)
model = build_model(config.model,
config.grad_checkpointing,
config.get('fp32_attention', False),
input_size=latent_size,
learn_sigma=learn_sigma,
pred_sigma=pred_sigma,
**model_kwargs).train()
logger.info(f"{model.__class__.__name__} Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
if config.load_from is not None:
if args.load_from is not None:
config.load_from = args.load_from
missing, unexpected = load_checkpoint(config.load_from, model, load_ema=config.get('load_ema', False))
logger.warning(f'Missing keys: {missing}')
logger.warning(f'Unexpected keys: {unexpected}')
model_ema = deepcopy(model).eval()
model_teacher = deepcopy(model).eval()
if not config.data.load_vae_feat:
vae = AutoencoderKL.from_pretrained(config.vae_pretrained).cuda()
# prepare for FSDP clip grad norm calculation
if accelerator.distributed_type == DistributedType.FSDP:
for m in accelerator._models:
m.clip_grad_norm_ = types.MethodType(clip_grad_norm_, m)
# build dataloader
set_data_root(config.data_root)
dataset = build_dataset(config.data, resolution=image_size, aspect_ratio_type=config.aspect_ratio_type)
if config.multi_scale:
batch_sampler = AspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset,
batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio, drop_last=True,
ratio_nums=dataset.ratio_nums, config=config, valid_num=config.valid_num)
# used for balanced sampling
# batch_sampler = BalancedAspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset,
# batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio,
# ratio_nums=dataset.ratio_nums)
train_dataloader = build_dataloader(dataset, batch_sampler=batch_sampler, num_workers=config.num_workers)
else:
train_dataloader = build_dataloader(dataset, num_workers=config.num_workers, batch_size=config.train_batch_size, shuffle=True)
# build optimizer and lr scheduler
lr_scale_ratio = 1
if config.get('auto_lr', None):
lr_scale_ratio = auto_scale_lr(config.train_batch_size * get_world_size() * config.gradient_accumulation_steps,
config.optimizer,
**config.auto_lr)
optimizer = build_optimizer(model, config.optimizer)
lr_scheduler = build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio)
timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())
if accelerator.is_main_process:
accelerator.init_trackers(f"tb_{timestamp}")
start_epoch = 0
if config.resume_from is not None and config.resume_from['checkpoint'] is not None:
start_epoch, missing, unexpected = load_checkpoint(**config.resume_from,
model=model,
model_ema=model_ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
logger.warning(f'Missing keys: {missing}')
logger.warning(f'Unexpected keys: {unexpected}')
solver = DDIMSolver(train_diffusion.alphas_cumprod, timesteps=config.train_sampling_steps, ddim_timesteps=config.num_ddim_timesteps)
solver.to(accelerator.device)
# Prepare everything
# There is no specific order to remember, you just need to unpack the
# objects in the same order you gave them to the prepare method.
model, model_ema, model_teacher = accelerator.prepare(model, model_ema, model_teacher)
# model, model_ema = accelerator.prepare(model, model_ema)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
train()
import os
import sys
import types
from pathlib import Path
current_file_path = Path(__file__).resolve()
sys.path.insert(0, str(current_file_path.parent.parent))
import argparse
import datetime
import time
import warnings
warnings.filterwarnings("ignore") # ignore warning
import torch
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.utils import DistributedType
from torch.utils.data import RandomSampler
from mmcv.runner import LogBuffer
import torch.nn.functional as F
import numpy as np
import re
from packaging import version
import accelerate
from diffusion import IDDPM
from diffusion.utils.dist_utils import get_world_size, clip_grad_norm_
from diffusion.data.builder import build_dataset, build_dataloader, set_data_root
from diffusion.utils.logger import get_root_logger
from diffusion.utils.misc import set_random_seed, read_config, init_random_seed, DebugUnderflowOverflow
from diffusion.utils.optimizer import build_optimizer, auto_scale_lr
from diffusion.utils.lr_scheduler import build_lr_scheduler
from diffusion.utils.data_sampler import AspectRatioBatchSampler, BalancedAspectRatioBatchSampler
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
from diffusers import AutoencoderKL, Transformer2DModel, StableDiffusionPipeline, PixArtAlphaPipeline
def set_fsdp_env():
os.environ["ACCELERATE_USE_FSDP"] = 'true'
os.environ["FSDP_AUTO_WRAP_POLICY"] = 'TRANSFORMER_BASED_WRAP'
os.environ["FSDP_BACKWARD_PREFETCH"] = 'BACKWARD_PRE'
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = 'PixArtBlock'
def filter_keys(key_set):
def _f(dictionary):
return {k: v for k, v in dictionary.items() if k in key_set}
return _f
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
return x[(...,) + (None,) * dims_to_append]
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
return c_skip, c_out
# Compare LCMScheduler.step, Step 4
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
if prediction_type == "epsilon":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "v_prediction":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = alphas * sample - sigmas * model_output
else:
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
return pred_x_0
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
class DDIMSolver:
def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
# DDIM sampling parameters
step_ratio = timesteps // ddim_timesteps
self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
self.ddim_alpha_cumprods_prev = np.asarray(
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
)
# convert to torch tensors
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
def to(self, device):
self.ddim_timesteps = self.ddim_timesteps.to(device)
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
return self
def ddim_step(self, pred_x0, pred_noise, timestep_index):
alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
return x_prev
def train(model):
if config.get('debug_nan', False):
DebugUnderflowOverflow(model)
logger.info('NaN debugger registered. Start to detect overflow during training.')
time_start, last_tic = time.time(), time.time()
log_buffer = LogBuffer()
global_step = start_step
load_vae_feat = getattr(train_dataloader.dataset, 'load_vae_feat', False)
# Create uncond embeds for classifier free guidance
uncond_prompt_embeds = torch.load('output/pretrained_models/null_embed.pth', map_location='cpu').to(accelerator.device).repeat(config.train_batch_size, 1, 1, 1)
# Now you train the model
for epoch in range(start_epoch + 1, config.num_epochs + 1):
data_time_start= time.time()
data_time_all = 0
for step, batch in enumerate(train_dataloader):
data_time_all += time.time() - data_time_start
if load_vae_feat:
z = batch[0]
else:
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=config.mixed_precision == 'fp16'):
posterior = vae.encode(batch[0]).latent_dist
if config.sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
latents = (z * config.scale_factor).to(weight_dtype)
y = batch[1].squeeze(1).to(weight_dtype)
y_mask = batch[2].squeeze(1).squeeze(1).to(weight_dtype)
data_info = {'resolution': batch[3]['img_hw'].to(weight_dtype), 'aspect_ratio': batch[3]['aspect_ratio'].to(weight_dtype),}
# Sample a random timestep for each image
grad_norm = None
with accelerator.accumulate(model):
# Predict the noise residual
optimizer.zero_grad()
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
topk = config.train_sampling_steps // config.num_ddim_timesteps
index = torch.randint(0, config.num_ddim_timesteps, (bsz,), device=latents.device).long()
start_timesteps = solver.ddim_timesteps[index]
timesteps = start_timesteps - topk
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# Sample a random guidance scale w from U[w_min, w_max] and embed it
# w = (config.w_max - config.w_min) * torch.rand((bsz,)) + config.w_min
w = config.cfg_scale * torch.ones((bsz,))
w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype)
# Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
_, pred_x_0, noisy_model_input = train_diffusion.training_losses_diffusers(
model, latents, start_timesteps,
model_kwargs=dict(encoder_hidden_states=y, encoder_attention_mask=y_mask, added_cond_kwargs=data_info),
noise=noise
)
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
with torch.no_grad():
with torch.autocast("cuda"):
cond_teacher_output, cond_pred_x0, _ = train_diffusion.training_losses_diffusers(
model_teacher, latents, start_timesteps,
model_kwargs=dict(encoder_hidden_states=y, encoder_attention_mask=y_mask, added_cond_kwargs=data_info),
noise=noise
)
# Get teacher model prediction on noisy_latents and unconditional embedding
uncond_teacher_output, uncond_pred_x0, _ = train_diffusion.training_losses_diffusers(
model_teacher, latents, start_timesteps,
model_kwargs=dict(encoder_hidden_states=uncond_prompt_embeds, encoder_attention_mask=y_mask, added_cond_kwargs=data_info),
noise=noise
)
# Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
# Get target LCM prediction on x_prev, w, c, t_n
with torch.no_grad():
with torch.autocast("cuda", enabled=True):
_, pred_x_0, _ = train_diffusion.training_losses_diffusers(
model, x_prev.float(), timesteps,
model_kwargs=dict(encoder_hidden_states=y, encoder_attention_mask=y_mask, added_cond_kwargs=data_info),
skip_noise=True
)
target = c_skip * x_prev + c_out * pred_x_0
# Calculate loss
if config.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif config.loss_type == "huber":
loss = torch.mean(torch.sqrt((model_pred.float() - target.float()) ** 2 + config.huber_c**2) - config.huber_c)
accelerator.backward(loss)
if accelerator.sync_gradients:
grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.gradient_clip)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
lr = lr_scheduler.get_last_lr()[0]
logs = {"loss": accelerator.gather(loss).mean().item()}
if grad_norm is not None:
logs.update(grad_norm=accelerator.gather(grad_norm).mean().item())
log_buffer.update(logs)
if (step + 1) % config.log_interval == 0 or (step + 1) == 1:
t = (time.time() - last_tic) / config.log_interval
t_d = data_time_all / config.log_interval
avg_time = (time.time() - time_start) / (global_step + 1)
eta = str(datetime.timedelta(seconds=int(avg_time * (total_steps - start_step - global_step - 1))))
eta_epoch = str(datetime.timedelta(seconds=int(avg_time * (len(train_dataloader) - step - 1))))
# avg_loss = sum(loss_buffer) / len(loss_buffer)
log_buffer.average()
info = f"Step/Epoch [{(epoch-1)*len(train_dataloader)+step+1}/{epoch}][{step + 1}/{len(train_dataloader)}]:total_eta: {eta}, " \
f"epoch_eta:{eta_epoch}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:({data_info['resolution'][0][0].item()}, {data_info['resolution'][0][1].item()}), "
info += ', '.join([f"{k}:{v:.4f}" for k, v in log_buffer.output.items()])
logger.info(info)
last_tic = time.time()
log_buffer.clear()
data_time_all = 0
logs.update(lr=lr)
accelerator.log(logs, step=global_step + start_step)
global_step += 1
data_time_start= time.time()
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if ((epoch - 1) * len(train_dataloader) + step + 1) % config.save_model_steps == 0:
save_path = os.path.join(os.path.join(config.work_dir, 'checkpoints'), f"checkpoint-{(epoch - 1) * len(train_dataloader) + step + 1}")
os.umask(0o000)
logger.info(f"Start to save state to {save_path}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
accelerator.wait_for_everyone()
if epoch % config.save_model_epochs == 0 or epoch == config.num_epochs:
os.umask(0o000)
save_path = os.path.join(os.path.join(config.work_dir, 'checkpoints'), f"checkpoint-{(epoch - 1) * len(train_dataloader) + step + 1}")
logger.info(f"Start to save state to {save_path}")
model = accelerator.unwrap_model(model)
model.save_pretrained(save_path)
lora_state_dict = get_peft_model_state_dict(model, adapter_name="default")
StableDiffusionPipeline.save_lora_weights(os.path.join(save_path, "transformer_lora"), lora_state_dict)
logger.info(f"Saved state to {save_path}")
def parse_args():
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument("config", type=str, help="config")
parser.add_argument("--cloud", action='store_true', default=False, help="cloud or local machine")
parser.add_argument("--work-dir", default='output', help='the dir to save logs and models')
parser.add_argument("--resume-from", help='the dir to save logs and models')
parser.add_argument("--local-rank", type=int, default=-1)
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--debug", action='store_true')
parser.add_argument("--lora_rank", type=int, default=64, help="The rank of the LoRA projection matrix.", )
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
config = read_config(args.config)
config.resume_from = None
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
config.work_dir = args.work_dir
if args.cloud:
config.data_root = '/data/data'
if args.resume_from is not None:
config.resume_from = args.resume_from
if args.debug:
config.log_interval = 1
config.train_batch_size = 4
config.valid_num = 10
config.save_model_steps = 10
os.umask(0o000)
os.makedirs(config.work_dir, exist_ok=True)
init_handler = InitProcessGroupKwargs()
init_handler.timeout = datetime.timedelta(seconds=5400) # change timeout to avoid a strange NCCL bug
# Initialize accelerator and tensorboard logging
if config.use_fsdp:
init_train = 'FSDP'
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
set_fsdp_env()
fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),)
else:
init_train = 'DDP'
fsdp_plugin = None
even_batches = True
if config.multi_scale:
even_batches=False,
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
log_with="tensorboard",
project_dir=os.path.join(config.work_dir, "logs"),
fsdp_plugin=fsdp_plugin,
even_batches=even_batches,
kwargs_handlers=[init_handler]
)
logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
logger.info(accelerator.state)
config.seed = init_random_seed(config.get('seed', None))
set_random_seed(config.seed)
if accelerator.is_main_process:
config.dump(os.path.join(config.work_dir, 'config.py'))
logger.info(f"Config: \n{config.pretty_text}")
logger.info(f"World_size: {get_world_size()}, seed: {config.seed}")
logger.info(f"Initializing: {init_train} for training")
image_size = config.image_size # @param [256, 512]
latent_size = int(image_size) // 8
pred_sigma = getattr(config, 'pred_sigma', True)
learn_sigma = getattr(config, 'learn_sigma', True) and pred_sigma
# prepare null_embedding for training
if not os.path.exists('output/pretrained_models/null_embed.pth'):
logger.info(f"Creating output/pretrained_models/null_embed.pth")
os.makedirs('output/pretrained_models/', exist_ok=True)
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16, use_safetensors=True,).to("cuda")
torch.save(pipe.encode_prompt(""), 'output/pretrained_models/null_embed.pth')
del pipe
torch.cuda.empty_cache()
# build models
train_diffusion = IDDPM(str(config.train_sampling_steps), learn_sigma=learn_sigma, pred_sigma=pred_sigma, return_startx=True)
model_teacher = Transformer2DModel.from_pretrained(config.load_from, subfolder="transformer")
model_teacher.requires_grad_(False)
model = Transformer2DModel.from_pretrained(config.load_from, subfolder="transformer").train()
logger.info(f"{model.__class__.__name__} Model Parameters: {sum(p.numel() for p in model.parameters()):}")
lora_config = LoraConfig(
r=config.lora_rank,
target_modules=[
"to_q",
"to_k",
"to_v",
"to_out.0",
"proj_in",
"proj_out",
"ff.net.0.proj",
"ff.net.2",
"proj",
"linear",
"linear_1",
"linear_2",
# "scale_shift_table", # not available due to the implementation in huggingface/peft, working on it.
],
)
print(lora_config)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# 9. Handle mixed precision and device placement
# For mixed precision training we cast all non-trainable weigths to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# 11. Enable optimizations
# model.enable_xformers_memory_efficient_attention()
# model_teacher.enable_xformers_memory_efficient_attention()
lora_layers = filter(lambda p: p.requires_grad, model.parameters())
# for name, params in model.named_parameters():
# if params.requires_grad == False: logger.info(f"freeze param: {name}")
#
# for name, params in model.named_parameters():
# if params.requires_grad == True: logger.info(f"trainable param: {name}")
# 10. Handle saving and loading of checkpoints
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_ = accelerator.unwrap_model(models[0])
lora_state_dict = get_peft_model_state_dict(transformer_, adapter_name="default")
StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "transformer_lora"), lora_state_dict)
# save weights in peft format to be able to load them back
transformer_.save_pretrained(output_dir)
for _, model in enumerate(models):
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
# load the LoRA into the model
transformer_ = accelerator.unwrap_model(models[0])
transformer_.load_adapter(input_dir, "default", is_trainable=True)
for _ in range(len(models)):
# pop models so that they are not loaded again
models.pop()
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if config.grad_checkpointing:
model.enable_gradient_checkpointing()
if not config.data.load_vae_feat:
vae = AutoencoderKL.from_pretrained(config.vae_pretrained).cuda()
# prepare for FSDP clip grad norm calculation
if accelerator.distributed_type == DistributedType.FSDP:
for m in accelerator._models:
m.clip_grad_norm_ = types.MethodType(clip_grad_norm_, m)
# build dataloader
set_data_root(config.data_root)
dataset = build_dataset(config.data, resolution=image_size, aspect_ratio_type=config.aspect_ratio_type)
if config.multi_scale:
batch_sampler = AspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset,
batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio, drop_last=True,
ratio_nums=dataset.ratio_nums, config=config, valid_num=config.valid_num)
# used for balanced sampling
# batch_sampler = BalancedAspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset,
# batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio,
# ratio_nums=dataset.ratio_nums)
train_dataloader = build_dataloader(dataset, batch_sampler=batch_sampler, num_workers=config.num_workers)
else:
train_dataloader = build_dataloader(dataset, num_workers=config.num_workers, batch_size=config.train_batch_size, shuffle=True)
# build optimizer and lr scheduler
lr_scale_ratio = 1
if config.get('auto_lr', None):
lr_scale_ratio = auto_scale_lr(config.train_batch_size * get_world_size() * config.gradient_accumulation_steps,
config.optimizer,
**config.auto_lr)
optimizer = build_optimizer(model, config.optimizer)
lr_scheduler = build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio)
timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())
if accelerator.is_main_process:
accelerator.init_trackers(f"tb_{timestamp}")
start_epoch = 0
start_step = 0
total_steps = len(train_dataloader) * config.num_epochs
solver = DDIMSolver(train_diffusion.alphas_cumprod, timesteps=config.train_sampling_steps, ddim_timesteps=config.num_ddim_timesteps)
solver.to(accelerator.device)
# Prepare everything
# There is no specific order to remember, you just need to unpack the
# objects in the same order you gave them to the prepare method.
model, model_teacher = accelerator.prepare(model, model_teacher)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if config.resume_from is not None:
if config.resume_from != "latest":
path = os.path.basename(config.resume_from)
else:
# Get the most recent checkpoint
dirs = os.listdir(os.path.join(config.work_dir, 'checkpoints'))
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(f"Checkpoint '{config.resume_from}' does not exist. Starting a new training run.")
config.resume_from = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(config.work_dir, 'checkpoints', path))
start_step = int(path.split("-")[1])
start_epoch = start_step // len(train_dataloader)
train(model)
\ No newline at end of file
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
import argparse
import logging
import math
import os
import random
import shutil
from pathlib import Path
import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
import accelerate
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig, get_peft_model_state_dict, get_peft_model, PeftModel
from torchvision import transforms
from tqdm.auto import tqdm
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, PixArtAlphaPipeline, Transformer2DModel
from transformers import T5EncoderModel, T5Tokenizer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- lora
inference: true
---
"""
model_card = f"""
# LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
{img_str}
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--train_data_dir",
type=str,
default=None,
help=(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
),
)
parser.add_argument(
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument(
"--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
)
parser.add_argument(
"--num_validation_images",
type=int,
default=4,
help="Number of images that should be generated during validation with `validation_prompt`.",
)
parser.add_argument(
"--validation_epochs",
type=int,
default=1,
help=(
"Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`."
),
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="sd-model-finetuned-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-6,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
# ----Diffusion Training Arguments----
parser.add_argument(
"--proportion_empty_prompts",
type=float,
default=0,
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
)
parser.add_argument(
"--prediction_type",
type=str,
default=None,
help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
)
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
parser.add_argument(
"--rank",
type=int,
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--local-rank", type=int, default=-1)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
# Sanity checks
if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Need either a dataset name or a training folder.")
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
return args
DATASET_NAME_MAPPING = {"lambdalabs/pokemon-blip-captions": ("image", "text"),}
def main():
args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token).repo_id
# See Section 3.1. of the paper.
max_length = 120
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision)
text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant)
transformer = Transformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="transformer", torch_dtype=torch.float16)
# freeze parameters of models to save more memory
transformer.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Freeze the transformer parameters before adding adapters
for param in transformer.parameters():
param.requires_grad_(False)
lora_config = LoraConfig(
r=args.rank,
init_lora_weights="gaussian",
target_modules=[
"to_k",
"to_q",
"to_v",
"to_out.0",
"proj_in",
"proj_out",
"ff.net.0.proj",
"ff.net.2",
"proj",
"linear",
"linear_1",
"linear_2",
# "scale_shift_table", # not available due to the implementation in huggingface/peft, working on it.
]
)
# Move transformer, vae and text_encoder to device and cast to weight_dtype
transformer.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
transformer = get_peft_model(transformer, lora_config)
transformer.print_trainable_parameters()
# 10. Handle saving and loading of checkpoints
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_ = accelerator.unwrap_model(transformer)
lora_state_dict = get_peft_model_state_dict(transformer_, adapter_name="default")
StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "transformer_lora"), lora_state_dict)
# save weights in peft format to be able to load them back
transformer_.save_pretrained(output_dir)
for _, model in enumerate(models):
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
# load the LoRA into the model
transformer_ = accelerator.unwrap_model(transformer)
transformer_.load_adapter(input_dir, "default", is_trainable=True)
for _ in range(len(models)):
# pop models so that they are not loaded again
models.pop()
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
transformer.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
lora_layers = filter(lambda p: p.requires_grad, transformer.parameters())
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
if args.scale_lr:
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
# Initialize the optimizer
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`")
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
lora_layers,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
data_dir=args.train_data_dir,
)
else:
data_files = {}
if args.train_data_dir is not None:
data_files["train"] = os.path.join(args.train_data_dir, "**")
dataset = load_dataset(
"imagefolder",
data_files=data_files,
cache_dir=args.cache_dir,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
if args.image_column is None:
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
)
if args.caption_column is None:
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
caption_column = args.caption_column
if caption_column not in column_names:
raise ValueError(
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
)
# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
def tokenize_captions(examples, is_train=True, proportion_empty_prompts=0., max_length=120):
captions = []
for caption in examples[caption_column]:
if random.random() < proportion_empty_prompts:
captions.append("")
elif isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
inputs = tokenizer(captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
return inputs.input_ids, inputs.attention_mask
# Preprocessing the datasets.
train_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
examples["pixel_values"] = [train_transforms(image) for image in images]
examples["input_ids"], examples['prompt_attention_mask'] = tokenize_captions(examples, proportion_empty_prompts=args.proportion_empty_prompts, max_length=max_length)
return examples
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = torch.stack([example["input_ids"] for example in examples])
prompt_attention_mask = torch.stack([example["prompt_attention_mask"] for example in examples])
return {"pixel_values": pixel_values, "input_ids": input_ids, 'prompt_attention_mask': prompt_attention_mask}
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
# Prepare everything with our `accelerator`.
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(transformer, optimizer, train_dataloader, lr_scheduler)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(transformer):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
prompt_embeds = text_encoder(batch["input_ids"], attention_mask=batch['prompt_attention_mask'])[0]
prompt_attention_mask = batch['prompt_attention_mask']
# Get the target for loss depending on the prediction type
if args.prediction_type is not None:
# set prediction_type of scheduler if defined
noise_scheduler.register_to_config(prediction_type=args.prediction_type)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
if getattr(transformer, 'module', transformer).config.sample_size == 128:
resolution = torch.tensor([args.resolution, args.resolution]).repeat(bsz, 1)
aspect_ratio = torch.tensor([float(args.resolution / args.resolution)]).repeat(bsz, 1)
resolution = resolution.to(dtype=weight_dtype, device=latents.device)
aspect_ratio = aspect_ratio.to(dtype=weight_dtype, device=latents.device)
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
# Predict the noise residual and compute loss
model_pred = transformer(noisy_latents,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=timesteps,
added_cond_kwargs=added_cond_kwargs).sample.chunk(2, 1)[0]
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before we divide by them.
snr = snr + 1
mse_loss_weights = (torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = lora_layers
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints")
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
transformer_lora_state_dict = get_peft_model_state_dict(transformer)
StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
unet_lora_layers=transformer_lora_state_dict,
safe_serialization=True,
)
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False),
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device)
if args.seed is not None:
generator = generator.manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
images.append(pipeline(args.validation_prompt, num_inference_steps=20, generator=generator).images[0])
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)]
}
)
del pipeline
torch.cuda.empty_cache()
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = accelerator.unwrap_model(transformer, keep_fp32_wrapper=False)
transformer.save_pretrained(args.output_dir)
lora_state_dict = get_peft_model_state_dict(transformer)
StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "transformer_lora"), lora_state_dict)
if args.push_to_hub:
save_model_card(
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
dataset_name=args.dataset_name,
repo_folder=args.output_dir,
)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
# Final inference
# Load previous transformer
transformer = Transformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='transformer', torch_dtype=weight_dtype)
# load lora weight
transformer = PeftModel.from_pretrained(transformer, args.output_dir)
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, transformer=transformer, text_encoder=text_encoder, vae=vae, torch_dtype=weight_dtype,)
pipeline = pipeline.to(accelerator.device)
del transformer
torch.cuda.empty_cache()
# run inference
generator = torch.Generator(device=accelerator.device)
if args.seed is not None:
generator = generator.manual_seed(args.seed)
images = []
with torch.autocast("cuda", dtype=weight_dtype):
for _ in range(args.num_validation_images):
images.append(pipeline(args.validation_prompt, num_inference_steps=20, generator=generator).images[0])
if accelerator.is_main_process:
for tracker in accelerator.trackers:
if len(images) != 0:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
accelerator.end_training()
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment