Commit e532679c authored by oahzxl's avatar oahzxl
Browse files

Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk

parents c1492e50 7d5640b9
"""Utils for monoDepth."""
import sys
import re
import numpy as np
import cv2
import torch
def read_pfm(path):
"""Read pfm file.
Args:
path (str): path to file
Returns:
tuple: (data, scale)
"""
with open(path, "rb") as file:
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header.decode("ascii") == "PF":
color = True
elif header.decode("ascii") == "Pf":
color = False
else:
raise Exception("Not a PFM file: " + path)
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
if dim_match:
width, height = list(map(int, dim_match.groups()))
else:
raise Exception("Malformed PFM header.")
scale = float(file.readline().decode("ascii").rstrip())
if scale < 0:
# little-endian
endian = "<"
scale = -scale
else:
# big-endian
endian = ">"
data = np.fromfile(file, endian + "f")
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data, scale
def write_pfm(path, image, scale=1):
"""Write pfm file.
Args:
path (str): pathto file
image (array): data
scale (int, optional): Scale. Defaults to 1.
"""
with open(path, "wb") as file:
color = None
if image.dtype.name != "float32":
raise Exception("Image dtype must be float32.")
image = np.flipud(image)
if len(image.shape) == 3 and image.shape[2] == 3: # color image
color = True
elif (
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
): # greyscale
color = False
else:
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
file.write("PF\n" if color else "Pf\n".encode())
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
endian = image.dtype.byteorder
if endian == "<" or endian == "=" and sys.byteorder == "little":
scale = -scale
file.write("%f\n".encode() % scale)
image.tofile(file)
def read_image(path):
"""Read image and output RGB image (0-1).
Args:
path (str): path to file
Returns:
array: RGB image (0-1)
"""
img = cv2.imread(path)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
return img
def resize_image(img):
"""Resize image and make it fit for network.
Args:
img (array): image
Returns:
tensor: data ready for network
"""
height_orig = img.shape[0]
width_orig = img.shape[1]
if width_orig > height_orig:
scale = width_orig / 384
else:
scale = height_orig / 384
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
img_resized = (
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
)
img_resized = img_resized.unsqueeze(0)
return img_resized
def resize_depth(depth, width, height):
"""Resize depth map and bring to CPU (numpy).
Args:
depth (tensor): depth
width (int): image width
height (int): image height
Returns:
array: processed depth
"""
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
depth_resized = cv2.resize(
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
)
return depth_resized
def write_depth(path, depth, bits=1):
"""Write depth map to pfm and png file.
Args:
path (str): filepath without extension
depth (array): depth
"""
write_pfm(path + ".pfm", depth.astype(np.float32))
depth_min = depth.min()
depth_max = depth.max()
max_val = (2**(8*bits))-1
if depth_max - depth_min > np.finfo("float").eps:
out = max_val * (depth - depth_min) / (depth_max - depth_min)
else:
out = np.zeros(depth.shape, dtype=depth.type)
if bits == 1:
cv2.imwrite(path + ".png", out.astype("uint8"))
elif bits == 2:
cv2.imwrite(path + ".png", out.astype("uint16"))
return
import importlib
import torch
from torch import optim
import numpy as np
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
nc = int(40 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x,torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
return total_params
def instantiate_from_config(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
class AdamWwithEMAandWings(optim.Optimizer):
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
ema_power=1., param_names=()):
"""AdamW that saves EMA versions of the parameters."""
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= ema_decay <= 1.0:
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
ema_power=ema_power, param_names=param_names)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
ema_params_with_grad = []
state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
beta1, beta2 = group['betas']
ema_decay = group['ema_decay']
ema_power = group['ema_power']
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of parameter values
state['param_exp_avg'] = p.detach().float().clone()
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
ema_params_with_grad.append(state['param_exp_avg'])
if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'])
optim._functional.adamw(params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
maximize=False)
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
return loss
\ No newline at end of file
import argparse
import csv
import datetime
import glob
import importlib
import os
import sys
import time
import numpy as np
import torch
import torchvision
try:
import lightning.pytorch as pl
except:
import pytorch_lightning as pl
from functools import partial
from omegaconf import OmegaConf
from packaging import version
from PIL import Image
from prefetch_generator import BackgroundGenerator
from torch.utils.data import DataLoader, Dataset, Subset, random_split
try:
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
LIGHTNING_PACK_NAME = "lightning.pytorch."
except:
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
LIGHTNING_PACK_NAME = "pytorch_lightning."
from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
# from ldm.modules.attention import enable_flash_attentions
class DataLoaderX(DataLoader):
def __iter__(self):
return BackgroundGenerator(super().__iter__())
def get_parser(**parser_kwargs):
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument(
"-n",
"--name",
type=str,
const=True,
default="",
nargs="?",
help="postfix for logdir",
)
parser.add_argument(
"-r",
"--resume",
type=str,
const=True,
default="",
nargs="?",
help="resume from logdir or checkpoint in logdir",
)
parser.add_argument(
"-b",
"--base",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default=list(),
)
parser.add_argument(
"-t",
"--train",
type=str2bool,
const=True,
default=False,
nargs="?",
help="train",
)
parser.add_argument(
"--no-test",
type=str2bool,
const=True,
default=False,
nargs="?",
help="disable test",
)
parser.add_argument("-p", "--project", help="name of new or path to existing project")
parser.add_argument(
"-d",
"--debug",
type=str2bool,
nargs="?",
const=True,
default=False,
help="enable post-mortem debugging",
)
parser.add_argument(
"-s",
"--seed",
type=int,
default=23,
help="seed for seed_everything",
)
parser.add_argument(
"-f",
"--postfix",
type=str,
default="",
help="post-postfix for default name",
)
parser.add_argument(
"-l",
"--logdir",
type=str,
default="logs",
help="directory for logging dat shit",
)
parser.add_argument(
"--scale_lr",
type=str2bool,
nargs="?",
const=True,
default=True,
help="scale base-lr by ngpu * batch_size * n_accumulate",
)
parser.add_argument(
"--use_fp16",
type=str2bool,
nargs="?",
const=True,
default=True,
help="whether to use fp16",
)
parser.add_argument(
"--flash",
type=str2bool,
const=True,
default=False,
nargs="?",
help="whether to use flash attention",
)
return parser
def nondefault_trainer_args(opt):
parser = argparse.ArgumentParser()
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args([])
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
if isinstance(dataset, Txt2ImgIterableBaseDataset):
split_size = dataset.num_records // worker_info.num_workers
# reset num_records to the true number to retain reliable length information
dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
else:
return np.random.seed(np.random.get_state()[1][0] + worker_id)
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(self,
batch_size,
train=None,
validation=None,
test=None,
predict=None,
wrap=False,
num_workers=None,
shuffle_test_loader=False,
use_worker_init_fn=False,
shuffle_val_dataloader=False):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = dict()
self.num_workers = num_workers if num_workers is not None else batch_size * 2
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
if predict is not None:
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
self.wrap = wrap
def prepare_data(self):
for data_cfg in self.dataset_configs.values():
instantiate_from_config(data_cfg)
def setup(self, stage=None):
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
if self.wrap:
for k in self.datasets:
self.datasets[k] = WrappedDataset(self.datasets[k])
def _train_dataloader(self):
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoaderX(self.datasets["train"],
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False if is_iterable_dataset else True,
worker_init_fn=init_fn)
def _val_dataloader(self, shuffle=False):
if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoaderX(self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle)
def _test_dataloader(self, shuffle=False):
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
# do not shuffle dataloader for iterable dataset
shuffle = shuffle and (not is_iterable_dataset)
return DataLoaderX(self.datasets["test"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle)
def _predict_dataloader(self, shuffle=False):
if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoaderX(self.datasets["predict"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn)
class SetupCallback(Callback):
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
super().__init__()
self.resume = resume
self.now = now
self.logdir = logdir
self.ckptdir = ckptdir
self.cfgdir = cfgdir
self.config = config
self.lightning_config = lightning_config
def on_keyboard_interrupt(self, trainer, pl_module):
if trainer.global_rank == 0:
print("Summoning checkpoint.")
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
# def on_pretrain_routine_start(self, trainer, pl_module):
def on_fit_start(self, trainer, pl_module):
if trainer.global_rank == 0:
# Create logdirs and save configs
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
if "callbacks" in self.lightning_config:
if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
print("Project config")
print(OmegaConf.to_yaml(self.config))
OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
print("Lightning config")
print(OmegaConf.to_yaml(self.lightning_config))
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
else:
# ModelCheckpoint callback created log directory --- remove it
if not self.resume and os.path.exists(self.logdir):
dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, "child_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
try:
os.rename(self.logdir, dst)
except FileNotFoundError:
pass
class ImageLogger(Callback):
def __init__(self,
batch_frequency,
max_images,
clamp=True,
increase_log_steps=True,
rescale=True,
disabled=False,
log_on_batch_idx=False,
log_first_step=False,
log_images_kwargs=None):
super().__init__()
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = {
pl.loggers.CSVLogger: self._testtube,
}
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
self.disabled = disabled
self.log_on_batch_idx = log_on_batch_idx
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step
@rank_zero_only
def _testtube(self, pl_module, images, batch_idx, split):
for k in images:
grid = torchvision.utils.make_grid(images[k])
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
tag = f"{split}/{k}"
pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step)
@rank_zero_only
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "images", split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
if self.rescale:
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split="train"):
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
for k in images:
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
if self.clamp:
images[k] = torch.clamp(images[k], -1., 1.)
self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch,
batch_idx)
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
logger_log_images(pl_module, images, pl_module.global_step, split)
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
if ((check_idx % self.batch_freq) == 0 or
(check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step):
try:
self.log_steps.pop(0)
except IndexError as e:
print(e)
pass
return True
return False
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
# self.log_img(pl_module, batch, batch_idx, split="train")
pass
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split="val")
if hasattr(pl_module, 'calibrate_grad_norm'):
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
class CUDACallback(Callback):
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
def on_train_start(self, trainer, pl_module):
rank_zero_info("Training is starting")
def on_train_end(self, trainer, pl_module):
rank_zero_info("Training is ending")
def on_train_epoch_start(self, trainer, pl_module):
# Reset the memory use counter
torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
torch.cuda.synchronize(trainer.strategy.root_device.index)
self.start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module):
torch.cuda.synchronize(trainer.strategy.root_device.index)
max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20
epoch_time = time.time() - self.start_time
try:
max_memory = trainer.strategy.reduce(max_memory)
epoch_time = trainer.strategy.reduce(epoch_time)
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
except AttributeError:
pass
if __name__ == "__main__":
# custom parser to specify config files, train, test and debug mode,
# postfix, resume.
# `--key value` arguments are interpreted as arguments to the trainer.
# `nested.key=value` arguments are interpreted as config parameters.
# configs are merged from left-to-right followed by command line parameters.
# model:
# base_learning_rate: float
# target: path to lightning module
# params:
# key: value
# data:
# target: main.DataModuleFromConfig
# params:
# batch_size: int
# wrap: bool
# train:
# target: path to train dataset
# params:
# key: value
# validation:
# target: path to validation dataset
# params:
# key: value
# test:
# target: path to test dataset
# params:
# key: value
# lightning: (optional, has sane defaults and can be specified on cmdline)
# trainer:
# additional arguments to trainer
# logger:
# logger to instantiate
# modelcheckpoint:
# modelcheckpoint to instantiate
# callbacks:
# callback1:
# target: importpath
# params:
# key: value
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
# add cwd for convenience and to make classes in this file available when
# running as `python main.py`
# (in particular `main.DataModuleFromConfig`)
sys.path.append(os.getcwd())
parser = get_parser()
parser = Trainer.add_argparse_args(parser)
opt, unknown = parser.parse_known_args()
if opt.name and opt.resume:
raise ValueError("-n/--name and -r/--resume cannot be specified both."
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint")
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split("/")
# idx = len(paths)-paths[::-1].index("logs")+1
# logdir = "/".join(paths[:idx])
logdir = "/".join(paths[:-2])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
opt.resume_from_checkpoint = ckpt
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
opt.base = base_configs + opt.base
_tmp = logdir.split("/")
nowname = _tmp[-1]
else:
if opt.name:
name = "_" + opt.name
elif opt.base:
cfg_fname = os.path.split(opt.base[0])[-1]
cfg_name = os.path.splitext(cfg_fname)[0]
name = "_" + cfg_name
else:
name = ""
nowname = now + name + opt.postfix
logdir = os.path.join(opt.logdir, nowname)
ckptdir = os.path.join(logdir, "checkpoints")
cfgdir = os.path.join(logdir, "configs")
seed_everything(opt.seed)
try:
# init and save configs
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
lightning_config = config.pop("lightning", OmegaConf.create())
# merge trainer cli with config
trainer_config = lightning_config.get("trainer", OmegaConf.create())
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
print(trainer_config)
if not trainer_config["accelerator"] == "gpu":
del trainer_config["accelerator"]
cpu = True
print("Running on CPU")
else:
cpu = False
print("Running on GPU")
trainer_opt = argparse.Namespace(**trainer_config)
lightning_config.trainer = trainer_config
# model
use_fp16 = trainer_config.get("precision", 32) == 16
if use_fp16:
config.model["params"].update({"use_fp16": True})
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
else:
config.model["params"].update({"use_fp16": False})
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
model = instantiate_from_config(config.model)
# trainer and callbacks
trainer_kwargs = dict()
# config the logger
# default logger configs
default_logger_cfgs = {
"wandb": {
"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger",
"params": {
"name": nowname,
"save_dir": logdir,
"offline": opt.debug,
"id": nowname,
}
},
"tensorboard": {
"target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger",
"params": {
"save_dir": logdir,
"name": "diff_tb",
"log_graph": True
}
}
}
default_logger_cfg = default_logger_cfgs["tensorboard"]
if "logger" in lightning_config:
logger_cfg = lightning_config.logger
else:
logger_cfg = default_logger_cfg
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
# config the strategy, defualt is ddp
if "strategy" in trainer_config:
strategy_cfg = trainer_config["strategy"]
print("Using strategy: {}".format(strategy_cfg["target"]))
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
else:
strategy_cfg = {
"target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy",
"params": {
"find_unused_parameters": False
}
}
print("Using strategy: DDPStrategy")
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg = {
"target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
}
}
if hasattr(model, "monitor"):
print(f"Monitoring {model.monitor} as checkpoint metric.")
default_modelckpt_cfg["params"]["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 3
if "modelcheckpoint" in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint
else:
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
if version.parse(pl.__version__) < version.parse('1.4.0'):
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
# add callback which sets up log directory
default_callbacks_cfg = {
"setup_callback": {
"target": "main.SetupCallback",
"params": {
"resume": opt.resume,
"now": now,
"logdir": logdir,
"ckptdir": ckptdir,
"cfgdir": cfgdir,
"config": config,
"lightning_config": lightning_config,
}
},
"image_logger": {
"target": "main.ImageLogger",
"params": {
"batch_frequency": 750,
"max_images": 4,
"clamp": True
}
},
"learning_rate_logger": {
"target": "main.LearningRateMonitor",
"params": {
"logging_interval": "step",
# "log_momentum": True
}
},
"cuda_callback": {
"target": "main.CUDACallback"
},
}
if version.parse(pl.__version__) >= version.parse('1.4.0'):
default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
else:
callbacks_cfg = OmegaConf.create()
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
print(
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
default_metrics_over_trainsteps_ckpt_dict = {
'metrics_over_trainsteps_checkpoint': {
"target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint',
'params': {
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
"filename": "{epoch:06}-{step:09}",
"verbose": True,
'save_top_k': -1,
'every_n_train_steps': 10000,
'save_weights_only': True
}
}
}
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
elif 'ignore_keys_callback' in callbacks_cfg:
del callbacks_cfg['ignore_keys_callback']
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir ###
# data
data = instantiate_from_config(config.data)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
data.prepare_data()
data.setup()
print("#### Data #####")
for k in data.datasets:
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
# configure learning rate
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
if not cpu:
ngpu = trainer_config["devices"]
else:
ngpu = 1
if 'accumulate_grad_batches' in lightning_config.trainer:
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
else:
accumulate_grad_batches = 1
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
if opt.scale_lr:
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
print(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
.format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
else:
model.learning_rate = base_lr
print("++++ NOT USING LR SCALING ++++")
print(f"Setting learning rate to {model.learning_rate:.2e}")
# allow checkpointing via USR1
def melk(*args, **kwargs):
# run all checkpoint hooks
if trainer.global_rank == 0:
print("Summoning checkpoint.")
ckpt_path = os.path.join(ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def divein(*args, **kwargs):
if trainer.global_rank == 0:
import pudb
pudb.set_trace()
import signal
signal.signal(signal.SIGUSR1, melk)
signal.signal(signal.SIGUSR2, divein)
# run
if opt.train:
try:
trainer.fit(model, data)
except Exception:
melk()
raise
# if not opt.no_test and not trainer.interrupted:
# trainer.test(model, data)
except Exception:
if opt.debug and trainer.global_rank == 0:
try:
import pudb as debugger
except ImportError:
import pdb as debugger
debugger.post_mortem()
raise
finally:
# move newly created debug project to debug_runs
if opt.debug and not opt.resume and trainer.global_rank == 0:
dst, name = os.path.split(logdir)
dst = os.path.join(dst, "debug_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
os.rename(logdir, dst)
if trainer.global_rank == 0:
print(trainer.profiler.summary())
albumentations==1.3.0
opencv-python==4.6.0
pudb==2019.2
prefetch_generator
imageio==2.9.0
imageio-ffmpeg==0.4.2
torchmetrics==0.6
omegaconf==2.1.1
test-tube>=0.7.5
streamlit>=0.73.1
einops==0.3.0
transformers==4.19.2
webdataset==0.2.5
open-clip-torch==2.7.0
gradio==3.11
datasets
colossalai
-e .
#!/bin/bash
wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip
wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip
wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
cd models/first_stage_models/kl-f4
unzip -o model.zip
cd ../kl-f8
unzip -o model.zip
cd ../kl-f16
unzip -o model.zip
cd ../kl-f32
unzip -o model.zip
cd ../vq-f4
unzip -o model.zip
cd ../vq-f4-noattn
unzip -o model.zip
cd ../vq-f8
unzip -o model.zip
cd ../vq-f8-n256
unzip -o model.zip
cd ../vq-f16
unzip -o model.zip
cd ../..
\ No newline at end of file
#!/bin/bash
wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip
wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip
wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip
wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip
wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
cd models/ldm/celeba256
unzip -o celeba-256.zip
cd ../ffhq256
unzip -o ffhq-256.zip
cd ../lsun_churches256
unzip -o lsun_churches-256.zip
cd ../lsun_beds256
unzip -o lsun_beds-256.zip
cd ../text2img256
unzip -o model.zip
cd ../cin256
unzip -o model.zip
cd ../semantic_synthesis512
unzip -o model.zip
cd ../semantic_synthesis256
unzip -o model.zip
cd ../bsr_sr
unzip -o model.zip
cd ../layout2img-openimages256
unzip -o model.zip
cd ../inpainting_big
unzip -o model.zip
cd ../..
"""make variations of input image"""
import argparse, os
import PIL
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
from torchvision.utils import make_grid
from torch import autocast
from contextlib import nullcontext
try:
from lightning.pytorch import seed_everything
except:
from pytorch_lightning import seed_everything
from imwatermark import WatermarkEncoder
from scripts.txt2img import put_watermark
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from utils import replace_module, getModelSize
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.eval()
return model
def load_img(path):
image = Image.open(path).convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2. * image - 1.
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
)
parser.add_argument(
"--init-img",
type=str,
nargs="?",
help="path to the input image"
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/img2img-samples"
)
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="number of ddim sampling steps",
)
parser.add_argument(
"--fixed_code",
action='store_true',
help="if enabled, uses the same starting code across all samples ",
)
parser.add_argument(
"--ddim_eta",
type=float,
default=0.0,
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
)
parser.add_argument(
"--n_iter",
type=int,
default=1,
help="sample this often",
)
parser.add_argument(
"--C",
type=int,
default=4,
help="latent channels",
)
parser.add_argument(
"--f",
type=int,
default=8,
help="downsampling factor, most often 8 or 16",
)
parser.add_argument(
"--n_samples",
type=int,
default=2,
help="how many samples to produce for each given prompt. A.k.a batch size",
)
parser.add_argument(
"--n_rows",
type=int,
default=0,
help="rows in the grid (default: n_samples)",
)
parser.add_argument(
"--scale",
type=float,
default=9.0,
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
)
parser.add_argument(
"--strength",
type=float,
default=0.8,
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
)
parser.add_argument(
"--from-file",
type=str,
help="if specified, load prompts from this file",
)
parser.add_argument(
"--config",
type=str,
default="configs/stable-diffusion/v2-inference.yaml",
help="path to config which constructs model",
)
parser.add_argument(
"--ckpt",
type=str,
help="path to checkpoint of model",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="the seed (for reproducible sampling)",
)
parser.add_argument(
"--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
default="autocast"
)
parser.add_argument(
"--use_int8",
type=bool,
default=False,
help="use int8 for inference",
)
opt = parser.parse_args()
seed_everything(opt.seed)
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
# quantize model
if opt.use_int8:
model = replace_module(model)
# # to compute the model size
# getModelSize(model)
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
wm = "SDV2"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file:
prompt = opt.prompt
assert prompt is not None
data = [batch_size * [prompt]]
else:
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
data = list(chunk(data, batch_size))
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
assert os.path.isfile(opt.init_img)
init_image = load_img(opt.init_img).to(device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(opt.strength * opt.ddim_steps)
print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if opt.precision == "autocast" else nullcontext
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device))
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc, )
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
img = Image.fromarray(x_sample.astype(np.uint8))
img = put_watermark(img, wm_encoder)
img.save(os.path.join(sample_path, f"{base_count:05}.png"))
base_count += 1
all_samples.append(x_samples)
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
grid = Image.fromarray(grid.astype(np.uint8))
grid = put_watermark(grid, wm_encoder)
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
if __name__ == "__main__":
main()
# # to compute the mem allocated
# print(torch.cuda.max_memory_allocated() / 1024 / 1024)
import argparse, os, sys, glob
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
from main import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
def make_batch(image, mask, device):
image = np.array(Image.open(image).convert("RGB"))
image = image.astype(np.float32)/255.0
image = image[None].transpose(0,3,1,2)
image = torch.from_numpy(image)
mask = np.array(Image.open(mask).convert("L"))
mask = mask.astype(np.float32)/255.0
mask = mask[None,None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = (1-mask)*image
batch = {"image": image, "mask": mask, "masked_image": masked_image}
for k in batch:
batch[k] = batch[k].to(device=device)
batch[k] = batch[k]*2.0-1.0
return batch
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--indir",
type=str,
nargs="?",
help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
)
parser.add_argument(
"--steps",
type=int,
default=50,
help="number of ddim sampling steps",
)
opt = parser.parse_args()
masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
images = [x.replace("_mask.png", ".png") for x in masks]
print(f"Found {len(masks)} inputs.")
config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
with torch.no_grad():
with model.ema_scope():
for image, mask in tqdm(zip(images, masks)):
outpath = os.path.join(opt.outdir, os.path.split(image)[1])
batch = make_batch(image, mask, device=device)
# encode masked image and concat downsampled mask
c = model.cond_stage_model.encode(batch["masked_image"])
cc = torch.nn.functional.interpolate(batch["mask"],
size=c.shape[-2:])
c = torch.cat((c, cc), dim=1)
shape = (c.shape[1]-1,)+c.shape[2:]
samples_ddim, _ = sampler.sample(S=opt.steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False)
x_samples_ddim = model.decode_first_stage(samples_ddim)
image = torch.clamp((batch["image"]+1.0)/2.0,
min=0.0, max=1.0)
mask = torch.clamp((batch["mask"]+1.0)/2.0,
min=0.0, max=1.0)
predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
min=0.0, max=1.0)
inpainted = (1-mask)*image+mask*predicted_image
inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
import argparse, os, sys, glob
import clip
import torch
import torch.nn as nn
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
from torchvision.utils import make_grid
import scann
import time
from multiprocessing import cpu_count
from ldm.util import instantiate_from_config, parallel_data_prefetch
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
DATABASES = [
"openimages",
"artbench-art_nouveau",
"artbench-baroque",
"artbench-expressionism",
"artbench-impressionism",
"artbench-post_impressionism",
"artbench-realism",
"artbench-romanticism",
"artbench-renaissance",
"artbench-surrealism",
"artbench-ukiyo_e",
]
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return model
class Searcher(object):
def __init__(self, database, retriever_version='ViT-L/14'):
assert database in DATABASES
# self.database = self.load_database(database)
self.database_name = database
self.searcher_savedir = f'data/rdm/searchers/{self.database_name}'
self.database_path = f'data/rdm/retrieval_databases/{self.database_name}'
self.retriever = self.load_retriever(version=retriever_version)
self.database = {'embedding': [],
'img_id': [],
'patch_coords': []}
self.load_database()
self.load_searcher()
def train_searcher(self, k,
metric='dot_product',
searcher_savedir=None):
print('Start training searcher')
searcher = scann.scann_ops_pybind.builder(self.database['embedding'] /
np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis],
k, metric)
self.searcher = searcher.score_brute_force().build()
print('Finish training searcher')
if searcher_savedir is not None:
print(f'Save trained searcher under "{searcher_savedir}"')
os.makedirs(searcher_savedir, exist_ok=True)
self.searcher.serialize(searcher_savedir)
def load_single_file(self, saved_embeddings):
compressed = np.load(saved_embeddings)
self.database = {key: compressed[key] for key in compressed.files}
print('Finished loading of clip embeddings.')
def load_multi_files(self, data_archive):
out_data = {key: [] for key in self.database}
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
for key in d.files:
out_data[key].append(d[key])
return out_data
def load_database(self):
print(f'Load saved patch embedding from "{self.database_path}"')
file_content = glob.glob(os.path.join(self.database_path, '*.npz'))
if len(file_content) == 1:
self.load_single_file(file_content[0])
elif len(file_content) > 1:
data = [np.load(f) for f in file_content]
prefetched_data = parallel_data_prefetch(self.load_multi_files, data,
n_proc=min(len(data), cpu_count()), target_data_type='dict')
self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in
self.database}
else:
raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
def load_retriever(self, version='ViT-L/14', ):
model = FrozenClipImageEmbedder(model=version)
if torch.cuda.is_available():
model.cuda()
model.eval()
return model
def load_searcher(self):
print(f'load searcher for database {self.database_name} from {self.searcher_savedir}')
self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
print('Finished loading searcher.')
def search(self, x, k):
if self.searcher is None and self.database['embedding'].shape[0] < 2e4:
self.train_searcher(k) # quickly fit searcher on the fly for small databases
assert self.searcher is not None, 'Cannot search with uninitialized searcher'
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
if len(x.shape) == 3:
x = x[:, 0]
query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis]
start = time.time()
nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
end = time.time()
out_embeddings = self.database['embedding'][nns]
out_img_ids = self.database['img_id'][nns]
out_pc = self.database['patch_coords'][nns]
out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
'img_ids': out_img_ids,
'patch_coords': out_pc,
'queries': x,
'exec_time': end - start,
'nns': nns,
'q_embeddings': query_embeddings}
return out
def __call__(self, x, n):
return self.search(x, n)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
# TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2img-samples"
)
parser.add_argument(
"--skip_grid",
action='store_true',
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="number of ddim sampling steps",
)
parser.add_argument(
"--n_repeat",
type=int,
default=1,
help="number of repeats in CLIP latent space",
)
parser.add_argument(
"--plms",
action='store_true',
help="use plms sampling",
)
parser.add_argument(
"--ddim_eta",
type=float,
default=0.0,
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
)
parser.add_argument(
"--n_iter",
type=int,
default=1,
help="sample this often",
)
parser.add_argument(
"--H",
type=int,
default=768,
help="image height, in pixel space",
)
parser.add_argument(
"--W",
type=int,
default=768,
help="image width, in pixel space",
)
parser.add_argument(
"--n_samples",
type=int,
default=3,
help="how many samples to produce for each given prompt. A.k.a batch size",
)
parser.add_argument(
"--n_rows",
type=int,
default=0,
help="rows in the grid (default: n_samples)",
)
parser.add_argument(
"--scale",
type=float,
default=5.0,
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
)
parser.add_argument(
"--from-file",
type=str,
help="if specified, load prompts from this file",
)
parser.add_argument(
"--config",
type=str,
default="configs/retrieval-augmented-diffusion/768x768.yaml",
help="path to config which constructs model",
)
parser.add_argument(
"--ckpt",
type=str,
default="models/rdm/rdm768x768/model.ckpt",
help="path to checkpoint of model",
)
parser.add_argument(
"--clip_type",
type=str,
default="ViT-L/14",
help="which CLIP model to use for retrieval and NN encoding",
)
parser.add_argument(
"--database",
type=str,
default='artbench-surrealism',
choices=DATABASES,
help="The database used for the search, only applied when --use_neighbors=True",
)
parser.add_argument(
"--use_neighbors",
default=False,
action='store_true',
help="Include neighbors in addition to text prompt for conditioning",
)
parser.add_argument(
"--knn",
default=10,
type=int,
help="The number of included neighbors, only applied when --use_neighbors=True",
)
opt = parser.parse_args()
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device)
if opt.plms:
sampler = PLMSSampler(model)
else:
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir
batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file:
prompt = opt.prompt
assert prompt is not None
data = [batch_size * [prompt]]
else:
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
data = list(chunk(data, batch_size))
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
print(f"sampling scale for cfg is {opt.scale:.2f}")
searcher = None
if opt.use_neighbors:
searcher = Searcher(opt.database)
with torch.no_grad():
with model.ema_scope():
for n in trange(opt.n_iter, desc="Sampling"):
all_samples = list()
for prompts in tqdm(data, desc="data"):
print("sampling prompts:", prompts)
if isinstance(prompts, tuple):
prompts = list(prompts)
c = clip_text_encoder.encode(prompts)
uc = None
if searcher is not None:
nn_dict = searcher(c, opt.knn)
c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
if opt.scale != 1.0:
uc = torch.zeros_like(c)
if isinstance(prompts, tuple):
prompts = list(prompts)
shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png"))
base_count += 1
all_samples.append(x_samples_ddim)
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
import argparse, os, sys, glob, datetime, yaml
import torch
import time
import numpy as np
from tqdm import trange
from omegaconf import OmegaConf
from PIL import Image
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
rescale = lambda x: (x + 1.) / 2.
def custom_to_pil(x):
x = x.detach().cpu()
x = torch.clamp(x, -1., 1.)
x = (x + 1.) / 2.
x = x.permute(1, 2, 0).numpy()
x = (255 * x).astype(np.uint8)
x = Image.fromarray(x)
if not x.mode == "RGB":
x = x.convert("RGB")
return x
def custom_to_np(x):
# saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
sample = x.detach().cpu()
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
sample = sample.permute(0, 2, 3, 1)
sample = sample.contiguous()
return sample
def logs2pil(logs, keys=["sample"]):
imgs = dict()
for k in logs:
try:
if len(logs[k].shape) == 4:
img = custom_to_pil(logs[k][0, ...])
elif len(logs[k].shape) == 3:
img = custom_to_pil(logs[k])
else:
print(f"Unknown format for key {k}. ")
img = None
except:
img = None
imgs[k] = img
return imgs
@torch.no_grad()
def convsample(model, shape, return_intermediates=True,
verbose=True,
make_prog_row=False):
if not make_prog_row:
return model.p_sample_loop(None, shape,
return_intermediates=return_intermediates, verbose=verbose)
else:
return model.progressive_denoising(
None, shape, verbose=True
)
@torch.no_grad()
def convsample_ddim(model, steps, shape, eta=1.0
):
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,):
log = dict()
shape = [batch_size,
model.model.diffusion_model.in_channels,
model.model.diffusion_model.image_size,
model.model.diffusion_model.image_size]
with model.ema_scope("Plotting"):
t0 = time.time()
if vanilla:
sample, progrow = convsample(model, shape,
make_prog_row=True)
else:
sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape,
eta=eta)
t1 = time.time()
x_sample = model.decode_first_stage(sample)
log["sample"] = x_sample
log["time"] = t1 - t0
log['throughput'] = sample.shape[0] / (t1 - t0)
print(f'Throughput for this batch: {log["throughput"]}')
return log
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
if vanilla:
print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')
else:
print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')
tstart = time.time()
n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1
# path = logdir
if model.cond_stage_model is None:
all_images = []
print(f"Running unconditional sampling for {n_samples} samples")
for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
logs = make_convolutional_sample(model, batch_size=batch_size,
vanilla=vanilla, custom_steps=custom_steps,
eta=eta)
n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
all_images.extend([custom_to_np(logs["sample"])])
if n_saved >= n_samples:
print(f'Finish after generating {n_saved} samples')
break
all_img = np.concatenate(all_images, axis=0)
all_img = all_img[:n_samples]
shape_str = "x".join([str(x) for x in all_img.shape])
nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
np.savez(nppath, all_img)
else:
raise NotImplementedError('Currently only sampling for unconditional models supported.')
print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
for k in logs:
if k == key:
batch = logs[key]
if np_path is None:
for x in batch:
img = custom_to_pil(x)
imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
img.save(imgpath)
n_saved += 1
else:
npbatch = custom_to_np(batch)
shape_str = "x".join([str(x) for x in npbatch.shape])
nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
np.savez(nppath, npbatch)
n_saved += npbatch.shape[0]
return n_saved
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-r",
"--resume",
type=str,
nargs="?",
help="load from logdir or checkpoint in logdir",
)
parser.add_argument(
"-n",
"--n_samples",
type=int,
nargs="?",
help="number of samples to draw",
default=50000
)
parser.add_argument(
"-e",
"--eta",
type=float,
nargs="?",
help="eta for ddim sampling (0.0 yields deterministic sampling)",
default=1.0
)
parser.add_argument(
"-v",
"--vanilla_sample",
default=False,
action='store_true',
help="vanilla sampling (default option is DDIM sampling)?",
)
parser.add_argument(
"-l",
"--logdir",
type=str,
nargs="?",
help="extra logdir",
default="none"
)
parser.add_argument(
"-c",
"--custom_steps",
type=int,
nargs="?",
help="number of steps for ddim and fastdpm sampling",
default=50
)
parser.add_argument(
"--batch_size",
type=int,
nargs="?",
help="the bs",
default=10
)
return parser
def load_model_from_config(config, sd):
model = instantiate_from_config(config)
model.load_state_dict(sd,strict=False)
model.cuda()
model.eval()
return model
def load_model(config, ckpt, gpu, eval_mode):
if ckpt:
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
else:
pl_sd = {"state_dict": None}
global_step = None
model = load_model_from_config(config.model,
pl_sd["state_dict"])
return model, global_step
if __name__ == "__main__":
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
sys.path.append(os.getcwd())
command = " ".join(sys.argv)
parser = get_parser()
opt, unknown = parser.parse_known_args()
ckpt = None
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
# paths = opt.resume.split("/")
try:
logdir = '/'.join(opt.resume.split('/')[:-1])
# idx = len(paths)-paths[::-1].index("logs")+1
print(f'Logdir is {logdir}')
except ValueError:
paths = opt.resume.split("/")
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "model.ckpt")
base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
opt.base = base_configs
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
gpu = True
eval_mode = True
if opt.logdir != "none":
locallog = logdir.split(os.sep)[-1]
if locallog == "": locallog = logdir.split(os.sep)[-2]
print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
logdir = os.path.join(opt.logdir, locallog)
print(config)
model, global_step = load_model(config, ckpt, gpu, eval_mode)
print(f"global step: {global_step}")
print(75 * "=")
print("logging to:")
logdir = os.path.join(logdir, "samples", f"{global_step:08}", now)
imglogdir = os.path.join(logdir, "img")
numpylogdir = os.path.join(logdir, "numpy")
os.makedirs(imglogdir)
os.makedirs(numpylogdir)
print(logdir)
print(75 * "=")
# write config out
sampling_file = os.path.join(logdir, "sampling_config.yaml")
sampling_conf = vars(opt)
with open(sampling_file, 'w') as f:
yaml.dump(sampling_conf, f, default_flow_style=False)
print(sampling_conf)
run(model, imglogdir, eta=opt.eta,
vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps,
batch_size=opt.batch_size, nplog=numpylogdir)
print("done.")
import os
import sys
from copy import deepcopy
import yaml
from datetime import datetime
from diffusers import StableDiffusionPipeline
import torch
from ldm.util import instantiate_from_config
from main import get_parser
if __name__ == "__main__":
with torch.no_grad():
yaml_path = "../../train_colossalai.yaml"
with open(yaml_path, 'r', encoding='utf-8') as f:
config = f.read()
base_config = yaml.load(config, Loader=yaml.FullLoader)
unet_config = base_config['model']['params']['unet_config']
diffusion_model = instantiate_from_config(unet_config).to("cuda:0")
pipe = StableDiffusionPipeline.from_pretrained(
"/data/scratch/diffuser/stable-diffusion-v1-4"
).to("cuda:0")
dif_model_2 = pipe.unet
random_input_ = torch.rand((4, 4, 32, 32)).to("cuda:0")
random_input_2 = torch.clone(random_input_).to("cuda:0")
time_stamp = torch.randint(20, (4,)).to("cuda:0")
time_stamp2 = torch.clone(time_stamp).to("cuda:0")
context_ = torch.rand((4, 77, 768)).to("cuda:0")
context_2 = torch.clone(context_).to("cuda:0")
out_1 = diffusion_model(random_input_, time_stamp, context_)
out_2 = dif_model_2(random_input_2, time_stamp2, context_2)
print(out_1.shape)
print(out_2['sample'].shape)
\ No newline at end of file
import cv2
import fire
from imwatermark import WatermarkDecoder
def testit(img_path):
bgr = cv2.imread(img_path)
decoder = WatermarkDecoder('bytes', 136)
watermark = decoder.decode(bgr, 'dwtDct')
try:
dec = watermark.decode('utf-8')
except:
dec = "null"
print(dec)
if __name__ == "__main__":
fire.Fire(testit)
\ No newline at end of file
import os, sys
import numpy as np
import scann
import argparse
import glob
from multiprocessing import cpu_count
from tqdm import tqdm
from ldm.util import parallel_data_prefetch
def search_bruteforce(searcher):
return searcher.score_brute_force().build()
def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
partioning_trainsize, num_leaves, num_leaves_to_search):
return searcher.tree(num_leaves=num_leaves,
num_leaves_to_search=num_leaves_to_search,
training_sample_size=partioning_trainsize). \
score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
reorder_k).build()
def load_datapool(dpath):
def load_single_file(saved_embeddings):
compressed = np.load(saved_embeddings)
database = {key: compressed[key] for key in compressed.files}
return database
def load_multi_files(data_archive):
database = {key: [] for key in data_archive[0].files}
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
for key in d.files:
database[key].append(d[key])
return database
print(f'Load saved patch embedding from "{dpath}"')
file_content = glob.glob(os.path.join(dpath, '*.npz'))
if len(file_content) == 1:
data_pool = load_single_file(file_content[0])
elif len(file_content) > 1:
data = [np.load(f) for f in file_content]
prefetched_data = parallel_data_prefetch(load_multi_files, data,
n_proc=min(len(data), cpu_count()), target_data_type='dict')
data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
else:
raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
return data_pool
def train_searcher(opt,
metric='dot_product',
partioning_trainsize=None,
reorder_k=None,
# todo tune
aiq_thld=0.2,
dims_per_block=2,
num_leaves=None,
num_leaves_to_search=None,):
data_pool = load_datapool(opt.database)
k = opt.knn
if not reorder_k:
reorder_k = 2 * k
# normalize
# embeddings =
searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
pool_size = data_pool['embedding'].shape[0]
print(*(['#'] * 100))
print('Initializing scaNN searcher with the following values:')
print(f'k: {k}')
print(f'metric: {metric}')
print(f'reorder_k: {reorder_k}')
print(f'anisotropic_quantization_threshold: {aiq_thld}')
print(f'dims_per_block: {dims_per_block}')
print(*(['#'] * 100))
print('Start training searcher....')
print(f'N samples in pool is {pool_size}')
# this reflects the recommended design choices proposed at
# https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
if pool_size < 2e4:
print('Using brute force search.')
searcher = search_bruteforce(searcher)
elif 2e4 <= pool_size and pool_size < 1e5:
print('Using asymmetric hashing search and reordering.')
searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
else:
print('Using using partioning, asymmetric hashing search and reordering.')
if not partioning_trainsize:
partioning_trainsize = data_pool['embedding'].shape[0] // 10
if not num_leaves:
num_leaves = int(np.sqrt(pool_size))
if not num_leaves_to_search:
num_leaves_to_search = max(num_leaves // 20, 1)
print('Partitioning params:')
print(f'num_leaves: {num_leaves}')
print(f'num_leaves_to_search: {num_leaves_to_search}')
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
partioning_trainsize, num_leaves, num_leaves_to_search)
print('Finish training searcher')
searcher_savedir = opt.target_path
os.makedirs(searcher_savedir, exist_ok=True)
searcher.serialize(searcher_savedir)
print(f'Saved trained searcher under "{searcher_savedir}"')
if __name__ == '__main__':
sys.path.append(os.getcwd())
parser = argparse.ArgumentParser()
parser.add_argument('--database',
'-d',
default='data/rdm/retrieval_databases/openimages',
type=str,
help='path to folder containing the clip feature of the database')
parser.add_argument('--target_path',
'-t',
default='data/rdm/searchers/openimages',
type=str,
help='path to the target folder where the searcher shall be stored.')
parser.add_argument('--knn',
'-k',
default=20,
type=int,
help='number of nearest neighbors, for which the searcher shall be optimized')
opt, _ = parser.parse_known_args()
train_searcher(opt,)
\ No newline at end of file
import argparse, os
import cv2
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
try:
from lightning.pytorch import seed_everything
except:
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import nullcontext
from imwatermark import WatermarkEncoder
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
from utils import replace_module, getModelSize
torch.set_grad_enabled(False)
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.eval()
return model
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a professional photograph of an astronaut riding a triceratops",
help="the prompt to render"
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2img-samples"
)
parser.add_argument(
"--steps",
type=int,
default=50,
help="number of ddim sampling steps",
)
parser.add_argument(
"--plms",
action='store_true',
help="use plms sampling",
)
parser.add_argument(
"--dpm",
action='store_true',
help="use DPM (2) sampler",
)
parser.add_argument(
"--fixed_code",
action='store_true',
help="if enabled, uses the same starting code across all samples ",
)
parser.add_argument(
"--ddim_eta",
type=float,
default=0.0,
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
)
parser.add_argument(
"--n_iter",
type=int,
default=3,
help="sample this often",
)
parser.add_argument(
"--H",
type=int,
default=512,
help="image height, in pixel space",
)
parser.add_argument(
"--W",
type=int,
default=512,
help="image width, in pixel space",
)
parser.add_argument(
"--C",
type=int,
default=4,
help="latent channels",
)
parser.add_argument(
"--f",
type=int,
default=8,
help="downsampling factor, most often 8 or 16",
)
parser.add_argument(
"--n_samples",
type=int,
default=3,
help="how many samples to produce for each given prompt. A.k.a batch size",
)
parser.add_argument(
"--n_rows",
type=int,
default=0,
help="rows in the grid (default: n_samples)",
)
parser.add_argument(
"--scale",
type=float,
default=9.0,
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
)
parser.add_argument(
"--from-file",
type=str,
help="if specified, load prompts from this file, separated by newlines",
)
parser.add_argument(
"--config",
type=str,
default="configs/stable-diffusion/v2-inference.yaml",
help="path to config which constructs model",
)
parser.add_argument(
"--ckpt",
type=str,
help="path to checkpoint of model",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="the seed (for reproducible sampling)",
)
parser.add_argument(
"--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
default="autocast"
)
parser.add_argument(
"--repeat",
type=int,
default=1,
help="repeat each prompt in file this often",
)
parser.add_argument(
"--use_int8",
type=bool,
default=False,
help="use int8 for inference",
)
opt = parser.parse_args()
return opt
def put_watermark(img, wm_encoder=None):
if wm_encoder is not None:
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
img = wm_encoder.encode(img, 'dwtDct')
img = Image.fromarray(img[:, :, ::-1])
return img
def main(opt):
seed_everything(opt.seed)
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
# quantize model
if opt.use_int8:
model = replace_module(model)
# # to compute the model size
# getModelSize(model)
if opt.plms:
sampler = PLMSSampler(model)
elif opt.dpm:
sampler = DPMSolverSampler(model)
else:
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
wm = "SDV2"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file:
prompt = opt.prompt
assert prompt is not None
data = [batch_size * [prompt]]
else:
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
data = [p for p in data for i in range(opt.repeat)]
data = list(chunk(data, batch_size))
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
sample_count = 0
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
start_code = None
if opt.fixed_code:
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
precision_scope = autocast if opt.precision == "autocast" else nullcontext
with torch.no_grad(), \
precision_scope("cuda"), \
model.ema_scope():
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
samples, _ = sampler.sample(S=opt.steps,
conditioning=c,
batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code)
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
img = Image.fromarray(x_sample.astype(np.uint8))
img = put_watermark(img, wm_encoder)
img.save(os.path.join(sample_path, f"{base_count:05}.png"))
base_count += 1
sample_count += 1
all_samples.append(x_samples)
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
grid = Image.fromarray(grid.astype(np.uint8))
grid = put_watermark(grid, wm_encoder)
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f" \nEnjoy.")
if __name__ == "__main__":
opt = parse_args()
main(opt)
# # to compute the mem allocated
# print(torch.cuda.max_memory_allocated() / 1024 / 1024)
python scripts/txt2img.py --prompt "Teyvat, Name:Layla, Element: Cryo, Weapon:Sword, Region:Sumeru, Model type:Medium Female, Description:a woman in a blue outfit holding a sword" --plms \
--outdir ./output \
--config /home/lcmql/data2/Genshin/2022-11-18T16-38-46_train_colossalai_teyvattest/checkpoints/last.ckpt \
--ckpt /home/lcmql/data2/Genshin/2022-11-18T16-38-46_train_colossalai_teyvattest/configs/2022-11-18T16-38-46-project.yaml \
--n_samples 4
import bitsandbytes as bnb
import torch.nn as nn
import torch
class Linear8bit(nn.Linear):
def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=False,
memory_efficient_backward=False,
threshold=6.0,
weight_data=None,
bias_data=None
):
super(Linear8bit, self).__init__(
input_features, output_features, bias
)
self.state = bnb.MatmulLtState()
self.bias = bias_data
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.register_parameter("SCB", nn.Parameter(torch.empty(0), requires_grad=False))
self.weight = weight_data
self.quant()
def quant(self):
weight = self.weight.data.contiguous().half().cuda()
CB, _, SCB, _, _ = bnb.functional.double_quant(weight)
delattr(self, "weight")
setattr(self, "weight", nn.Parameter(CB, requires_grad=False))
delattr(self, "SCB")
setattr(self, "SCB", nn.Parameter(SCB, requires_grad=False))
del weight
def forward(self, x):
self.state.is_training = self.training
if self.bias is not None and self.bias.dtype != torch.float16:
self.bias.data = self.bias.data.half()
self.state.CB = self.weight.data
self.state.SCB = self.SCB.data
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
del self.state.CxB
return out
def replace_module(model):
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_module(module)
if isinstance(module, nn.Linear) and "out_proj" not in name:
model._modules[name] = Linear8bit(
input_features=module.in_features,
output_features=module.out_features,
threshold=6.0,
weight_data=module.weight,
bias_data=module.bias,
)
return model
def getModelSize(model):
param_size = 0
param_sum = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
param_sum += param.nelement()
buffer_size = 0
buffer_sum = 0
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
buffer_sum += buffer.nelement()
all_size = (param_size + buffer_size) / 1024 / 1024
print('Model Size: {:.3f}MB'.format(all_size))
return (param_size, param_sum, buffer_size, buffer_sum, all_size)
from setuptools import setup, find_packages
setup(
name='latent-diffusion',
version='0.0.1',
description='',
packages=find_packages(),
install_requires=[
'torch',
'numpy',
'tqdm',
],
)
\ No newline at end of file
HF_DATASETS_OFFLINE=1
TRANSFORMERS_OFFLINE=1
DIFFUSERS_OFFLINE=1
python main.py --logdir /tmp -t -b /configs/train_colossalai.yaml
HF_DATASETS_OFFLINE=1
TRANSFORMERS_OFFLINE=1
DIFFUSERS_OFFLINE=1
python main.py --logdir /tmp -t -b /configs/train_ddp.yaml
# [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) by [colossalai](https://github.com/hpcaitech/ColossalAI.git)
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
The `train_dreambooth_colossalai.py` script shows how to implement the training procedure and adapt it for stable diffusion.
By accommodating model data in CPU and GPU and moving the data to the computing device when necessary, [Gemini](https://www.colossalai.org/docs/advanced_tutorials/meet_gemini), the Heterogeneous Memory Manager of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) can breakthrough the GPU memory wall by using GPU and CPU memory (composed of CPU DRAM or nvme SSD memory) together at the same time. Moreover, the model scale can be further improved by combining heterogeneous training with the other parallel approaches, such as data parallel, tensor parallel and pipeline parallel.
## Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
```bash
pip install -r requirements_colossalai.txt
```
### Install [colossalai](https://github.com/hpcaitech/ColossalAI.git)
```bash
pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org
```
**From source**
```bash
git clone https://github.com/hpcaitech/ColossalAI.git
python setup.py install
```
## Dataset for Teyvat BLIP captions
Dataset used to train [Teyvat characters text to image model](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion).
BLIP generated captions for characters images from [genshin-impact fandom wiki](https://genshin-impact.fandom.com/wiki/Character#Playable_Characters)and [biligame wiki for genshin impact](https://wiki.biligame.com/ys/%E8%A7%92%E8%89%B2).
For each row the dataset contains `image` and `text` keys. `image` is a varying size PIL png, and `text` is the accompanying text caption. Only a train split is provided.
The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Model type`, and `Description`, the `Description` is captioned with the [pre-trained BLIP model](https://github.com/salesforce/BLIP).
## Training
The arguement `placement` can be `cpu`, `auto`, `cuda`, with `cpu` the GPU RAM required can be minimized to 4GB but will deceleration, with `cuda` you can also reduce GPU memory by half but accelerated training, with `auto` a more balanced solution for speed and memory can be obtained。
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export OUTPUT_DIR="path-to-save-model"
torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=400 \
--placement="cuda"
```
### Training with prior-preservation loss
Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time.
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$OUTPUT_DIR \
--with_prior_preservation --prior_loss_weight=1.0 \
--instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \
--resolution=512 \
--train_batch_size=1 \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=800 \
--placement="cuda"
```
## Inference
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
```python
from diffusers import StableDiffusionPipeline
import torch
model_id = "path-to-save-model"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
prompt = "A photo of sks dog in a bucket"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("dog-bucket.png")
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment