Commit 76b9024b authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3145 failed with stages
in 0 seconds
import sys
import contextlib
from functools import lru_cache
import torch
#from modules import errors
if sys.platform == "darwin":
from modules import mac_specific
def has_mps() -> bool:
if sys.platform != "darwin":
return False
else:
return mac_specific.has_mps
def get_cuda_device_string():
return "cuda"
def get_optimal_device_name():
if torch.cuda.is_available():
return get_cuda_device_string()
if has_mps():
return "mps"
return "cpu"
def get_optimal_device():
return torch.device(get_optimal_device_name())
def get_device_for(task):
return get_optimal_device()
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if has_mps():
mac_specific.torch_mps_gc()
def enable_tf32():
if torch.cuda.is_available():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
enable_tf32()
#errors.run(enable_tf32, "Enabling TF32")
cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda")
dtype = torch.float16
dtype_vae = torch.float16
dtype_unet = torch.float16
unet_needs_upcast = False
def cond_cast_unet(input):
return input.to(dtype_unet) if unet_needs_upcast else input
def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
def randn(seed, shape):
torch.manual_seed(seed)
return torch.randn(shape, device=device)
def randn_without_seed(shape):
return torch.randn(shape, device=device)
def autocast(disable=False):
if disable:
return contextlib.nullcontext()
return torch.autocast("cuda")
def without_autocast(disable=False):
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
class NansException(Exception):
pass
def test_for_nans(x, where):
if not torch.all(torch.isnan(x)).item():
return
if where == "unet":
message = "A tensor with all NaNs was produced in Unet."
elif where == "vae":
message = "A tensor with all NaNs was produced in VAE."
else:
message = "A tensor with all NaNs was produced."
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message)
@lru_cache
def first_time_calculation():
"""
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
spends about 2.7 seconds doing that, at least wih NVidia.
"""
x = torch.zeros((1, 1)).to(device, dtype)
linear = torch.nn.Linear(1, 1).to(device, dtype)
linear(x)
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
import torch
import torchvision
import torch.nn.functional as F
def attn_cosine_sim(x, eps=1e-08):
x = x[0] # TEMP: getting rid of redundant dimension, TBF
norm1 = x.norm(dim=2, keepdim=True)
factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps)
sim_matrix = (x @ x.permute(0, 2, 1)) / factor
return sim_matrix
class VitExtractor:
BLOCK_KEY = 'block'
ATTN_KEY = 'attn'
PATCH_IMD_KEY = 'patch_imd'
QKV_KEY = 'qkv'
KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY]
def __init__(self, model_name, device):
# pdb.set_trace()
self.model = torch.hub.load('facebookresearch/dino:main', model_name).to(device)
self.model.eval()
self.model_name = model_name
self.hook_handlers = []
self.layers_dict = {}
self.outputs_dict = {}
for key in VitExtractor.KEY_LIST:
self.layers_dict[key] = []
self.outputs_dict[key] = []
self._init_hooks_data()
def _init_hooks_data(self):
self.layers_dict[VitExtractor.BLOCK_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
self.layers_dict[VitExtractor.ATTN_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
self.layers_dict[VitExtractor.QKV_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
for key in VitExtractor.KEY_LIST:
# self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else []
self.outputs_dict[key] = []
def _register_hooks(self, **kwargs):
for block_idx, block in enumerate(self.model.blocks):
if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]:
self.hook_handlers.append(block.register_forward_hook(self._get_block_hook()))
if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]:
self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook()))
if block_idx in self.layers_dict[VitExtractor.QKV_KEY]:
self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook()))
if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]:
self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook()))
def _clear_hooks(self):
for handler in self.hook_handlers:
handler.remove()
self.hook_handlers = []
def _get_block_hook(self):
def _get_block_output(model, input, output):
self.outputs_dict[VitExtractor.BLOCK_KEY].append(output)
return _get_block_output
def _get_attn_hook(self):
def _get_attn_output(model, inp, output):
self.outputs_dict[VitExtractor.ATTN_KEY].append(output)
return _get_attn_output
def _get_qkv_hook(self):
def _get_qkv_output(model, inp, output):
self.outputs_dict[VitExtractor.QKV_KEY].append(output)
return _get_qkv_output
# TODO: CHECK ATTN OUTPUT TUPLE
def _get_patch_imd_hook(self):
def _get_attn_output(model, inp, output):
self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0])
return _get_attn_output
def get_feature_from_input(self, input_img): # List([B, N, D])
self._register_hooks()
self.model(input_img)
feature = self.outputs_dict[VitExtractor.BLOCK_KEY]
self._clear_hooks()
self._init_hooks_data()
return feature
def get_qkv_feature_from_input(self, input_img):
self._register_hooks()
self.model(input_img)
feature = self.outputs_dict[VitExtractor.QKV_KEY]
self._clear_hooks()
self._init_hooks_data()
return feature
def get_attn_feature_from_input(self, input_img):
self._register_hooks()
self.model(input_img)
feature = self.outputs_dict[VitExtractor.ATTN_KEY]
self._clear_hooks()
self._init_hooks_data()
return feature
def get_patch_size(self):
return 8 if "8" in self.model_name else 16
def get_width_patch_num(self, input_img_shape):
b, c, h, w = input_img_shape
patch_size = self.get_patch_size()
return w // patch_size
def get_height_patch_num(self, input_img_shape):
b, c, h, w = input_img_shape
patch_size = self.get_patch_size()
return h // patch_size
def get_patch_num(self, input_img_shape):
patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape))
return patch_num
def get_head_num(self):
if "dino" in self.model_name:
return 6 if "s" in self.model_name else 12
return 6 if "small" in self.model_name else 12
def get_embedding_dim(self):
if "dino" in self.model_name:
return 384 if "s" in self.model_name else 768
return 384 if "small" in self.model_name else 768
def get_queries_from_qkv(self, qkv, input_img_shape):
patch_num = self.get_patch_num(input_img_shape)
head_num = self.get_head_num()
embedding_dim = self.get_embedding_dim()
q = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[0]
return q
def get_keys_from_qkv(self, qkv, input_img_shape):
patch_num = self.get_patch_num(input_img_shape)
head_num = self.get_head_num()
embedding_dim = self.get_embedding_dim()
k = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[1]
return k
def get_values_from_qkv(self, qkv, input_img_shape):
patch_num = self.get_patch_num(input_img_shape)
head_num = self.get_head_num()
embedding_dim = self.get_embedding_dim()
v = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[2]
return v
def get_keys_from_input(self, input_img, layer_num):
qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num]
keys = self.get_keys_from_qkv(qkv_features, input_img.shape)
return keys
def get_keys_self_sim_from_input(self, input_img, layer_num):
keys = self.get_keys_from_input(input_img, layer_num=layer_num)
h, t, d = keys.shape
concatenated_keys = keys.transpose(0, 1).reshape(t, h * d)
ssim_map = attn_cosine_sim(concatenated_keys[None, None, ...])
return ssim_map
class DinoStructureLoss:
def __init__(self, ):
self.extractor = VitExtractor(model_name="dino_vitb8", device="cuda")
self.preprocess = torchvision.transforms.Compose([
torchvision.transforms.Resize(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
def calculate_global_ssim_loss(self, outputs, inputs):
loss = 0.0
for a, b in zip(inputs, outputs): # avoid memory limitations
with torch.no_grad():
target_keys_self_sim = self.extractor.get_keys_self_sim_from_input(a.unsqueeze(0), layer_num=11)
keys_ssim = self.extractor.get_keys_self_sim_from_input(b.unsqueeze(0), layer_num=11)
loss += F.mse_loss(keys_ssim, target_keys_self_sim)
return loss
import argparse
import json
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F
from glob import glob
import cv2
import math
import numpy as np
import os
import os.path as osp
import random
import time
import torch
from pathlib import Path
from torch.utils import data as data
from basicsr.utils import DiffJPEG, USMSharp
from basicsr.utils.img_process_util import filter2D
from basicsr.data.transforms import paired_random_crop, triplet_random_crop
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, random_add_speckle_noise_pt, random_add_saltpepper_noise_pt, bivariate_Gaussian
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
def parse_args_paired_testing(input_args=None):
"""
Parses command-line arguments used for configuring an paired session (pix2pix-Turbo).
This function sets up an argument parser to handle various training options.
Returns:
argparse.Namespace: The parsed command-line arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--ref_path", type=str, default=None,)
parser.add_argument("--base_config", default="./configs/sr_test.yaml", type=str)
parser.add_argument("--tracker_project_name", type=str, default="train_pix2pix_turbo", help="The name of the wandb project to log to.")
# details about the model architecture
parser.add_argument("--sd_path")
parser.add_argument("--de_net_path")
parser.add_argument("--pretrained_path", type=str, default=None,)
parser.add_argument("--revision", type=str, default=None,)
parser.add_argument("--variant", type=str, default=None,)
parser.add_argument("--tokenizer_name", type=str, default=None)
parser.add_argument("--lora_rank_unet", default=32, type=int)
parser.add_argument("--lora_rank_vae", default=16, type=int)
parser.add_argument("--scale", type=int, default=4, help="Scale factor for SR.")
parser.add_argument("--chop_size", type=int, default=128, choices=[512, 256, 128], help="Chopping forward.")
parser.add_argument("--chop_stride", type=int, default=96, help="Chopping stride.")
parser.add_argument("--padding_offset", type=int, default=32, help="padding offset.")
parser.add_argument("--vae_decoder_tiled_size", type=int, default=224)
parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024)
parser.add_argument("--latent_tiled_size", type=int, default=96)
parser.add_argument("--latent_tiled_overlap", type=int, default=32)
parser.add_argument("--align_method", type=str, default="wavelet")
parser.add_argument("--pos_prompt", type=str, default="A high-resolution, 8K, ultra-realistic image with sharp focus, vibrant colors, and natural lighting.")
parser.add_argument("--neg_prompt", type=str, default="oil painting, cartoon, blur, dirty, messy, low quality, deformation, low resolution, oversmooth")
# training details
parser.add_argument("--output_dir", type=str, default='output/')
parser.add_argument("--cache_dir", default=None,)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument("--resolution", type=int, default=512,)
parser.add_argument("--checkpointing_steps", type=int, default=500,)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",)
parser.add_argument("--gradient_checkpointing", action="store_true",)
parser.add_argument("--dataloader_num_workers", type=int, default=0,)
parser.add_argument("--allow_tf32", action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument("--report_to", type=str, default="wandb",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],)
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
parser.add_argument("--set_grads_to_none", action="store_true",)
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
return args
class PlainDataset(data.Dataset):
"""Modified dataset based on the dataset used for Real-ESRGAN model:
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It loads gt (Ground-Truth) images, and augments them.
It also generates blur kernels and sinc kernels for generating low-quality images.
Note that the low-quality images are processed in tensors on GPUS for faster processing.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
meta_info (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
Please see more options in the codes.
"""
def __init__(self, opt):
super(PlainDataset, self).__init__()
self.opt = opt
self.file_client = None
self.io_backend_opt = opt['io_backend']
if 'image_type' not in opt:
opt['image_type'] = 'png'
# support multiple type of data: file path and meta data, remove support of lmdb
self.lr_paths = []
if 'lr_path' in opt:
if isinstance(opt['lr_path'], str):
self.lr_paths.extend(sorted(
[str(x) for x in Path(opt['lr_path']).glob('*.png')] +
[str(x) for x in Path(opt['lr_path']).glob('*.jpg')] +
[str(x) for x in Path(opt['lr_path']).glob('*.jpeg')]
))
else:
self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][0]).glob('*.'+opt['image_type'])]))
if len(opt['lr_path']) > 1:
for i in range(len(opt['lr_path'])-1):
self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][i+1]).glob('*.'+opt['image_type'])]))
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# -------------------------------- Load gt images -------------------------------- #
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
lr_path = self.lr_paths[index]
# avoid errors caused by high latency in reading files
retry = 3
while retry > 0:
try:
lr_img_bytes = self.file_client.get(lr_path, 'gt')
except (IOError, OSError) as e:
# logger = get_root_logger()
# logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
# change another file to read
index = random.randint(0, self.__len__()-1)
lr_path = self.lr_paths[index]
time.sleep(1) # sleep 1s for occasional server congestion
else:
break
finally:
retry -= 1
img_lr = imfrombytes(lr_img_bytes, float32=True)
# BGR to RGB, HWC to CHW, numpy to tensor
img_lr = img2tensor([img_lr], bgr2rgb=True, float32=True)[0]
return_d = {'lr': img_lr, 'lr_path': lr_path}
return return_d
def __len__(self):
return len(self.lr_paths)
def lr_proc(config, batch, device):
im_lr = batch['lr'].cuda()
im_lr = im_lr.to(memory_format=torch.contiguous_format).float()
ori_lr = im_lr
im_lr = F.interpolate(
im_lr,
size=(im_lr.size(-2) * config.sf,
im_lr.size(-1) * config.sf),
mode='bicubic',
)
im_lr = im_lr.contiguous()
im_lr = im_lr * 2 - 1.0
im_lr = torch.clamp(im_lr, -1.0, 1.0)
ori_h, ori_w = im_lr.size(-2), im_lr.size(-1)
pad_h = (math.ceil(ori_h / 64)) * 64 - ori_h
pad_w = (math.ceil(ori_w / 64)) * 64 - ori_w
im_lr = F.pad(im_lr, pad=(0, pad_w, 0, pad_h), mode='reflect')
return im_lr.to(device), ori_lr.to(device), (ori_h, ori_w)
import argparse
import json
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F
from glob import glob
import cv2
import math
import numpy as np
import os
import os.path as osp
import random
import time
import torch
from pathlib import Path
from torch.utils import data as data
from basicsr.utils import DiffJPEG, USMSharp
from basicsr.utils.img_process_util import filter2D
from basicsr.data.transforms import paired_random_crop, triplet_random_crop
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, random_add_speckle_noise_pt, random_add_saltpepper_noise_pt, bivariate_Gaussian
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
def parse_args_paired_training(input_args=None):
"""
Parses command-line arguments used for configuring an paired session (pix2pix-Turbo).
This function sets up an argument parser to handle various training options.
Returns:
argparse.Namespace: The parsed command-line arguments.
"""
parser = argparse.ArgumentParser()
# args for the loss function
parser.add_argument("--gan_disc_type", default="vagan")
parser.add_argument("--gan_loss_type", default="multilevel_sigmoid_s")
parser.add_argument("--lambda_gan", default=0.5, type=float)
parser.add_argument("--lambda_lpips", default=5.0, type=float)
parser.add_argument("--lambda_l2", default=2.0, type=float)
parser.add_argument("--base_config", default="./configs/sr.yaml", type=str)
# validation eval args
parser.add_argument("--eval_freq", default=100, type=int)
parser.add_argument("--save_val", default=True, action="store_false")
parser.add_argument("--num_samples_eval", type=int, default=100, help="Number of samples to use for all evaluation")
parser.add_argument("--viz_freq", type=int, default=100, help="Frequency of visualizing the outputs.")
# details about the model architecture
parser.add_argument("--sd_path")
parser.add_argument("--pretrained_path", type=str, default=None,)
parser.add_argument("--de_net_path")
parser.add_argument("--revision", type=str, default=None,)
parser.add_argument("--variant", type=str, default=None,)
parser.add_argument("--tokenizer_name", type=str, default=None)
parser.add_argument("--lora_rank_unet", default=32, type=int)
parser.add_argument("--lora_rank_vae", default=16, type=int)
parser.add_argument("--neg_prob", default=0.05, type=float)
parser.add_argument("--pos_prompt", type=str, default="A high-resolution, 8K, ultra-realistic image with sharp focus, vibrant colors, and natural lighting.")
parser.add_argument("--neg_prompt", type=str, default="oil painting, cartoon, blur, dirty, messy, low quality, deformation, low resolution, oversmooth")
# training details
parser.add_argument("--output_dir", required=True)
parser.add_argument("--cache_dir", default=None,)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument("--resolution", type=int, default=512,)
parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.")
parser.add_argument("--num_training_epochs", type=int, default=50)
parser.add_argument("--max_train_steps", type=int, default=50000,)
parser.add_argument("--checkpointing_steps", type=int, default=500,)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4, help="Number of updates steps to accumulate before performing a backward/update pass.",)
parser.add_argument("--gradient_checkpointing", action="store_true",)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--lr_scheduler", type=str, default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "piecewise_constant", "constant_with_warmup"]'
),
)
parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument("--lr_num_cycles", type=int, default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=0.1, help="Power factor of the polynomial scheduler.")
parser.add_argument("--dataloader_num_workers", type=int, default=0,)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--allow_tf32", action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument("--report_to", type=str, default="wandb",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],)
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
parser.add_argument("--set_grads_to_none", action="store_true",)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
return args
# @DATASET_REGISTRY.register(suffix='basicsr')
class PairedDataset(data.Dataset):
"""Modified dataset based on the dataset used for Real-ESRGAN model:
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It loads gt (Ground-Truth) images, and augments them.
It also generates blur kernels and sinc kernels for generating low-quality images.
Note that the low-quality images are processed in tensors on GPUS for faster processing.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
meta_info (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
Please see more options in the codes.
"""
def __init__(self, opt):
super(PairedDataset, self).__init__()
self.opt = opt
self.file_client = None
self.io_backend_opt = opt['io_backend']
if 'crop_size' in opt:
self.crop_size = opt['crop_size']
else:
self.crop_size = 512
if 'image_type' not in opt:
opt['image_type'] = 'png'
# support multiple type of data: file path and meta data, remove support of lmdb
self.paths = []
if 'meta_info' in opt:
with open(self.opt['meta_info']) as fin:
paths = [line.strip().split(' ')[0] for line in fin]
self.paths = [v for v in paths]
if 'meta_num' in opt:
self.paths = sorted(self.paths)[:opt['meta_num']]
if 'gt_path' in opt:
if isinstance(opt['gt_path'], str):
# Use rglob to recursively search for images
self.paths.extend(sorted([str(x) for x in Path(opt['gt_path']).rglob('*.' + opt['image_type'])]))
else:
for path in opt['gt_path']:
self.paths.extend(sorted([str(x) for x in Path(path).rglob('*.' + opt['image_type'])]))
# if 'gt_path' in opt:
# if isinstance(opt['gt_path'], str):
# self.paths.extend(sorted([str(x) for x in Path(opt['gt_path']).glob('*.'+opt['image_type'])]))
# else:
# self.paths.extend(sorted([str(x) for x in Path(opt['gt_path'][0]).glob('*.'+opt['image_type'])]))
# if len(opt['gt_path']) > 1:
# for i in range(len(opt['gt_path'])-1):
# self.paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]).glob('*.'+opt['image_type'])]))
if 'imagenet_path' in opt:
class_list = os.listdir(opt['imagenet_path'])
for class_file in class_list:
self.paths.extend(sorted([str(x) for x in Path(os.path.join(opt['imagenet_path'], class_file)).glob('*.'+'JPEG')]))
if 'face_gt_path' in opt:
if isinstance(opt['face_gt_path'], str):
face_list = sorted([str(x) for x in Path(opt['face_gt_path']).glob('*.'+opt['image_type'])])
self.paths.extend(face_list[:opt['num_face']])
else:
face_list = sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])])
self.paths.extend(face_list[:opt['num_face']])
if len(opt['face_gt_path']) > 1:
for i in range(len(opt['face_gt_path'])-1):
self.paths.extend(sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])])[:opt['num_face']])
# limit number of pictures for test
if 'num_pic' in opt:
if 'val' or 'test' in opt:
random.shuffle(self.paths)
self.paths = self.paths[:opt['num_pic']]
else:
self.paths = self.paths[:opt['num_pic']]
if 'mul_num' in opt:
self.paths = self.paths * opt['mul_num']
# print('>>>>>>>>>>>>>>>>>>>>>')
# print(self.paths)
# blur settings for the first degradation
self.blur_kernel_size = opt['blur_kernel_size']
self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
self.blur_sigma = opt['blur_sigma']
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
# blur settings for the second degradation
self.blur_kernel_size2 = opt['blur_kernel_size2']
self.kernel_list2 = opt['kernel_list2']
self.kernel_prob2 = opt['kernel_prob2']
self.blur_sigma2 = opt['blur_sigma2']
self.betag_range2 = opt['betag_range2']
self.betap_range2 = opt['betap_range2']
self.sinc_prob2 = opt['sinc_prob2']
# a final sinc filter
self.final_sinc_prob = opt['final_sinc_prob']
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
# TODO: kernel range is now hard-coded, should be in the configure file
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
self.pulse_tensor[10, 10] = 1
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# -------------------------------- Load gt images -------------------------------- #
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
gt_path = self.paths[index]
# avoid errors caused by high latency in reading files
retry = 3
while retry > 0:
try:
img_bytes = self.file_client.get(gt_path, 'gt')
except (IOError, OSError) as e:
# logger = get_root_logger()
# logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
# change another file to read
index = random.randint(0, self.__len__()-1)
gt_path = self.paths[index]
time.sleep(1) # sleep 1s for occasional server congestion
else:
break
finally:
retry -= 1
img_gt = imfrombytes(img_bytes, float32=True)
# filter the dataset and remove images with too low quality
img_size = os.path.getsize(gt_path)
img_size = img_size / 1024
while img_gt.shape[0] * img_gt.shape[1] < 384*384 or img_size<100:
index = random.randint(0, self.__len__()-1)
gt_path = self.paths[index]
time.sleep(0.1) # sleep 1s for occasional server congestion
img_bytes = self.file_client.get(gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
img_size = os.path.getsize(gt_path)
img_size = img_size / 1024
# -------------------- Do augmentation for training: flip, rotation -------------------- #
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
# crop or pad to 400
# TODO: 400 is hard-coded. You may change it accordingly
h, w = img_gt.shape[0:2]
crop_pad_size = self.crop_size
# pad
if h < crop_pad_size or w < crop_pad_size:
pad_h = max(0, crop_pad_size - h)
pad_w = max(0, crop_pad_size - w)
img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
# crop
if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
h, w = img_gt.shape[0:2]
# randomly choose top and left coordinates
top = random.randint(0, h - crop_pad_size)
left = random.randint(0, w - crop_pad_size)
# top = (h - crop_pad_size) // 2 -1
# left = (w - crop_pad_size) // 2 -1
img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.opt['sinc_prob']:
# this sinc filter setting is for kernels ranging from [7, 21]
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel = random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
self.betag_range,
self.betap_range,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.opt['sinc_prob2']:
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel2 = random_mixed_kernels(
self.kernel_list2,
self.kernel_prob2,
kernel_size,
self.blur_sigma2,
self.blur_sigma2, [-math.pi, math.pi],
self.betag_range2,
self.betap_range2,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------------------- the final sinc kernel ------------------------------------- #
if np.random.uniform() < self.opt['final_sinc_prob']:
kernel_size = random.choice(self.kernel_range)
omega_c = np.random.uniform(np.pi / 3, np.pi)
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
sinc_kernel = torch.FloatTensor(sinc_kernel)
else:
sinc_kernel = self.pulse_tensor
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
kernel = torch.FloatTensor(kernel)
kernel2 = torch.FloatTensor(kernel2)
return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
return return_d
def __len__(self):
return len(self.paths)
def randn_cropinput(lq, gt, base_size=[64, 128, 256, 512]):
cur_size_h = random.choice(base_size)
cur_size_w = random.choice(base_size)
init_h = lq.size(-2)//2
init_w = lq.size(-1)//2
lq = lq[:, :, init_h-cur_size_h//2:init_h+cur_size_h//2, init_w-cur_size_w//2:init_w+cur_size_w//2]
gt = gt[:, :, init_h-cur_size_h//2:init_h+cur_size_h//2, init_w-cur_size_w//2:init_w+cur_size_w//2]
assert lq.size(-1)>=64
assert lq.size(-2)>=64
return [lq, gt]
def degradation_proc(configs, batch, device, val=False, use_usm=False, resize_lq=True, random_size=False):
"""Degradation pipeline, modified from Real-ESRGAN:
https://github.com/xinntao/Real-ESRGAN
"""
jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
usm_sharpener = USMSharp().cuda() # do usm sharpening
im_gt = batch['gt'].cuda()
if use_usm:
im_gt = usm_sharpener(im_gt)
im_gt = im_gt.to(memory_format=torch.contiguous_format).float()
kernel1 = batch['kernel1'].cuda()
kernel2 = batch['kernel2'].cuda()
sinc_kernel = batch['sinc_kernel'].cuda()
ori_h, ori_w = im_gt.size()[2:4]
# ----------------------- The first degradation process ----------------------- #
# blur
out = filter2D(im_gt, kernel1)
# random resize
updown_type = random.choices(
['up', 'down', 'keep'],
configs.degradation['resize_prob'],
)[0]
if updown_type == 'up':
scale = random.uniform(1, configs.degradation['resize_range'][1])
elif updown_type == 'down':
scale = random.uniform(configs.degradation['resize_range'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, scale_factor=scale, mode=mode)
# add noise
gray_noise_prob = configs.degradation['gray_noise_prob']
if random.random() < configs.degradation['gaussian_noise_prob']:
out = random_add_gaussian_noise_pt(
out,
sigma_range=configs.degradation['noise_range'],
clip=True,
rounds=False,
gray_prob=gray_noise_prob,
)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=configs.degradation['poisson_scale_range'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*configs.degradation['jpeg_range'])
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out = jpeger(out, quality=jpeg_p)
# ----------------------- The second degradation process ----------------------- #
# blur
if random.random() < configs.degradation['second_blur_prob']:
out = filter2D(out, kernel2)
# random resize
updown_type = random.choices(
['up', 'down', 'keep'],
configs.degradation['resize_prob2'],
)[0]
if updown_type == 'up':
scale = random.uniform(1, configs.degradation['resize_range2'][1])
elif updown_type == 'down':
scale = random.uniform(configs.degradation['resize_range2'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out,
size=(int(ori_h / configs.sf * scale),
int(ori_w / configs.sf * scale)),
mode=mode,
)
# add noise
gray_noise_prob = configs.degradation['gray_noise_prob2']
if random.random() < configs.degradation['gaussian_noise_prob2']:
out = random_add_gaussian_noise_pt(
out,
sigma_range=configs.degradation['noise_range2'],
clip=True,
rounds=False,
gray_prob=gray_noise_prob,
)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=configs.degradation['poisson_scale_range2'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False,
)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if random.random() < 0.5:
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out,
size=(ori_h // configs.sf,
ori_w // configs.sf),
mode=mode,
)
out = filter2D(out, sinc_kernel)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*configs.degradation['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = jpeger(out, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*configs.degradation['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = jpeger(out, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out,
size=(ori_h // configs.sf,
ori_w // configs.sf),
mode=mode,
)
out = filter2D(out, sinc_kernel)
# clamp and round
im_lq = torch.clamp(out, 0, 1.0)
# random crop
gt_size = configs.degradation['gt_size']
im_gt, im_lq = paired_random_crop(im_gt, im_lq, gt_size, configs.sf)
lq, gt = im_lq, im_gt
ori_lq = im_lq
if resize_lq:
lq = F.interpolate(
lq,
size=(gt.size(-2),
gt.size(-1)),
mode='bicubic',
)
if random.random() < configs.degradation['no_degradation_prob'] or torch.isnan(lq).any():
lq = gt
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
lq = lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
lq = lq * 2 - 1.0 # TODO 0~1?
gt = gt * 2 - 1.0
if random_size:
lq, gt = randn_cropinput(lq, gt)
lq = torch.clamp(lq, -1.0, 1.0)
return lq.to(device), gt.to(device), ori_lq.to(device)
import importlib
import torch
import numpy as np
from collections import abc
from einops import rearrange
from functools import partial
import multiprocessing as mp
from threading import Thread
from queue import Queue
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
nc = int(40 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params
def instantiate_from_config(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def instantiate_from_config_sr(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(config.get("params", dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
# create dummy dataset instance
# run prefetching
if idx_to_fn:
res = func(data, worker_id=idx)
else:
res = func(data)
Q.put([idx, res])
Q.put("Done")
def parallel_data_prefetch(
func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
):
# if target_data_type not in ["ndarray", "list"]:
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if isinstance(data, np.ndarray) and target_data_type == "list":
raise ValueError("list expected but function got ndarray.")
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
print(
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == "ndarray":
data = np.asarray(data)
else:
data = list(data)
else:
raise TypeError(
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
)
if cpu_intensive:
Q = mp.Queue(1000)
proc = mp.Process
else:
Q = Queue(1000)
proc = Thread
# spawn processes
if target_data_type == "ndarray":
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc))
]
else:
step = (
int(len(data) / n_proc + 1)
if len(data) % n_proc != 0
else int(len(data) / n_proc)
)
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(
[data[i: i + step] for i in range(0, len(data), step)]
)
]
processes = []
for i in range(n_proc):
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
processes += [p]
# start processes
print(f"Start prefetching...")
import time
start = time.time()
gather_res = [[] for _ in range(n_proc)]
try:
for p in processes:
p.start()
k = 0
while k < n_proc:
# get result
res = Q.get()
if res == "Done":
k += 1
else:
gather_res[res[0]] = res[1]
except Exception as e:
print("Exception: ", e)
for p in processes:
p.terminate()
raise e
finally:
for p in processes:
p.join()
print(f"Prefetching complete. [{time.time() - start} sec.]")
if target_data_type == 'ndarray':
if not isinstance(gather_res[0], np.ndarray):
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
# order outputs
return np.concatenate(gather_res, axis=0)
elif target_data_type == 'list':
out = []
for r in gather_res:
out.extend(r)
return out
else:
return gather_res
# ------------------------------------------------------------------------
#
# Ultimate VAE Tile Optimization
#
# Introducing a revolutionary new optimization designed to make
# the VAE work with giant images on limited VRAM!
# Say goodbye to the frustration of OOM and hello to seamless output!
#
# ------------------------------------------------------------------------
#
# This script is a wild hack that splits the image into tiles,
# encodes each tile separately, and merges the result back together.
#
# Advantages:
# - The VAE can now work with giant images on limited VRAM
# (~10 GB for 8K images!)
# - The merged output is completely seamless without any post-processing.
#
# Drawbacks:
# - Giant RAM needed. To store the intermediate results for a 4096x4096
# images, you need 32 GB RAM it consumes ~20GB); for 8192x8192
# you need 128 GB RAM machine (it consumes ~100 GB)
# - NaNs always appear in for 8k images when you use fp16 (half) VAE
# You must use --no-half-vae to disable half VAE for that giant image.
# - Slow speed. With default tile size, it takes around 50/200 seconds
# to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode
# a 8192x8192 image. (The speed is limited by both the GPU and the CPU.)
# - The gradient calculation is not compatible with this hack. It
# will break any backward() or torch.autograd.grad() that passes VAE.
# (But you can still use the VAE to generate training data.)
#
# How it works:
# 1) The image is split into tiles.
# - To ensure perfect results, each tile is padded with 32 pixels
# on each side.
# - Then the conv2d/silu/upsample/downsample can produce identical
# results to the original image without splitting.
# 2) The original forward is decomposed into a task queue and a task worker.
# - The task queue is a list of functions that will be executed in order.
# - The task worker is a loop that executes the tasks in the queue.
# 3) The task queue is executed for each tile.
# - Current tile is sent to GPU.
# - local operations are directly executed.
# - Group norm calculation is temporarily suspended until the mean
# and var of all tiles are calculated.
# - The residual is pre-calculated and stored and addded back later.
# - When need to go to the next tile, the current tile is send to cpu.
# 4) After all tiles are processed, tiles are merged on cpu and return.
#
# Enjoy!
#
# @author: LI YI @ Nanyang Technological University - Singapore
# @date: 2023-03-02
# @license: MIT License
#
# Please give me a star if you like this project!
#
# -------------------------------------------------------------------------
import gc
from time import time
import math
from tqdm import tqdm
import torch
import torch.version
import torch.nn.functional as F
from einops import rearrange
import os
import sys
sys.path.append(os.getcwd())
import my_utils.devices as devices
try:
import xformers
import xformers.ops
except ImportError:
pass
sd_flag = False
def get_recommend_encoder_tile_size():
if torch.cuda.is_available():
total_memory = torch.cuda.get_device_properties(
devices.device).total_memory // 2**20
if total_memory > 16*1000:
ENCODER_TILE_SIZE = 3072
elif total_memory > 12*1000:
ENCODER_TILE_SIZE = 2048
elif total_memory > 8*1000:
ENCODER_TILE_SIZE = 1536
else:
ENCODER_TILE_SIZE = 960
else:
ENCODER_TILE_SIZE = 512
return ENCODER_TILE_SIZE
def get_recommend_decoder_tile_size():
if torch.cuda.is_available():
total_memory = torch.cuda.get_device_properties(
devices.device).total_memory // 2**20
if total_memory > 30*1000:
DECODER_TILE_SIZE = 256
elif total_memory > 16*1000:
DECODER_TILE_SIZE = 192
elif total_memory > 12*1000:
DECODER_TILE_SIZE = 128
elif total_memory > 8*1000:
DECODER_TILE_SIZE = 96
else:
DECODER_TILE_SIZE = 64
else:
DECODER_TILE_SIZE = 64
return DECODER_TILE_SIZE
if 'global const':
DEFAULT_ENABLED = False
DEFAULT_MOVE_TO_GPU = False
DEFAULT_FAST_ENCODER = True
DEFAULT_FAST_DECODER = True
DEFAULT_COLOR_FIX = 0
DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()
DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()
# inplace version of silu
def inplace_nonlinearity(x):
# Test: fix for Nans
return F.silu(x, inplace=True)
# extracted from ldm.modules.diffusionmodules.model
# from diffusers lib
def attn_forward_new(self, h_):
batch_size, channel, height, width = h_.shape
hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2)
attention_mask = None
encoder_hidden_states = None
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = self.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif self.norm_cross:
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
query = self.head_to_batch_dim(query)
key = self.head_to_batch_dim(key)
value = self.head_to_batch_dim(value)
attention_probs = self.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = self.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states
def attn_forward(self, h_):
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h*w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h*w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h*w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return h_
def xformer_attn_forward(self, h_):
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op)
out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
out = self.proj_out(out)
return out
def attn2task(task_queue, net):
if False: #isinstance(net, AttnBlock):
task_queue.append(('store_res', lambda x: x))
task_queue.append(('pre_norm', net.norm))
task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
task_queue.append(['add_res', None])
elif False: #isinstance(net, MemoryEfficientAttnBlock):
task_queue.append(('store_res', lambda x: x))
task_queue.append(('pre_norm', net.norm))
task_queue.append(
('attn', lambda x, net=net: xformer_attn_forward(net, x)))
task_queue.append(['add_res', None])
else:
task_queue.append(('store_res', lambda x: x))
task_queue.append(('pre_norm', net.group_norm))
task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))
task_queue.append(['add_res', None])
def resblock2task(queue, block):
"""
Turn a ResNetBlock into a sequence of tasks and append to the task queue
@param queue: the target task queue
@param block: ResNetBlock
"""
if block.in_channels != block.out_channels:
if sd_flag:
if block.use_conv_shortcut:
queue.append(('store_res', block.conv_shortcut))
else:
queue.append(('store_res', block.nin_shortcut))
else:
if block.use_in_shortcut:
queue.append(('store_res', block.conv_shortcut))
else:
queue.append(('store_res', block.nin_shortcut))
else:
queue.append(('store_res', lambda x: x))
queue.append(('pre_norm', block.norm1))
queue.append(('silu', inplace_nonlinearity))
queue.append(('conv1', block.conv1))
queue.append(('pre_norm', block.norm2))
queue.append(('silu', inplace_nonlinearity))
queue.append(('conv2', block.conv2))
queue.append(['add_res', None])
def build_sampling(task_queue, net, is_decoder):
"""
Build the sampling part of a task queue
@param task_queue: the target task queue
@param net: the network
@param is_decoder: currently building decoder or encoder
"""
if is_decoder:
# resblock2task(task_queue, net.mid.block_1)
# attn2task(task_queue, net.mid.attn_1)
# resblock2task(task_queue, net.mid.block_2)
# resolution_iter = reversed(range(net.num_resolutions))
# block_ids = net.num_res_blocks + 1
# condition = 0
# module = net.up
# func_name = 'upsample'
resblock2task(task_queue, net.mid_block.resnets[0])
attn2task(task_queue, net.mid_block.attentions[0])
resblock2task(task_queue, net.mid_block.resnets[1])
resolution_iter = (range(len(net.up_blocks))) # range(0,4)
block_ids = 2 + 1
condition = len(net.up_blocks) - 1
module = net.up_blocks
func_name = 'upsamplers'
else:
# resolution_iter = range(net.num_resolutions)
# block_ids = net.num_res_blocks
# condition = net.num_resolutions - 1
# module = net.down
# func_name = 'downsample'
resolution_iter = (range(len(net.down_blocks))) # range(0,4)
block_ids = 2
condition = len(net.down_blocks) - 1
module = net.down_blocks
func_name = 'downsamplers'
for i_level in resolution_iter:
for i_block in range(block_ids):
resblock2task(task_queue, module[i_level].resnets[i_block])
if i_level != condition:
if is_decoder:
task_queue.append((func_name, module[i_level].upsamplers[0]))
else:
task_queue.append((func_name, module[i_level].downsamplers[0]))
if not is_decoder:
resblock2task(task_queue, net.mid_block.resnets[0])
attn2task(task_queue, net.mid_block.attentions[0])
resblock2task(task_queue, net.mid_block.resnets[1])
def build_task_queue(net, is_decoder):
"""
Build a single task queue for the encoder or decoder
@param net: the VAE decoder or encoder network
@param is_decoder: currently building decoder or encoder
@return: the task queue
"""
task_queue = []
task_queue.append(('conv_in', net.conv_in))
# construct the sampling part of the task queue
# because encoder and decoder share the same architecture, we extract the sampling part
build_sampling(task_queue, net, is_decoder)
if is_decoder and not sd_flag:
net.give_pre_end = False
net.tanh_out = False
if not is_decoder or not net.give_pre_end:
if sd_flag:
task_queue.append(('pre_norm', net.norm_out))
else:
task_queue.append(('pre_norm', net.conv_norm_out))
task_queue.append(('silu', inplace_nonlinearity))
task_queue.append(('conv_out', net.conv_out))
if is_decoder and net.tanh_out:
task_queue.append(('tanh', torch.tanh))
return task_queue
def clone_task_queue(task_queue):
"""
Clone a task queue
@param task_queue: the task queue to be cloned
@return: the cloned task queue
"""
return [[item for item in task] for task in task_queue]
def get_var_mean(input, num_groups, eps=1e-6):
"""
Get mean and var for group norm
"""
b, c = input.size(0), input.size(1)
channel_in_group = int(c/num_groups)
input_reshaped = input.contiguous().view(
1, int(b * num_groups), channel_in_group, *input.size()[2:])
var, mean = torch.var_mean(
input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
return var, mean
def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
"""
Custom group norm with fixed mean and var
@param input: input tensor
@param num_groups: number of groups. by default, num_groups = 32
@param mean: mean, must be pre-calculated by get_var_mean
@param var: var, must be pre-calculated by get_var_mean
@param weight: weight, should be fetched from the original group norm
@param bias: bias, should be fetched from the original group norm
@param eps: epsilon, by default, eps = 1e-6 to match the original group norm
@return: normalized tensor
"""
b, c = input.size(0), input.size(1)
channel_in_group = int(c/num_groups)
input_reshaped = input.contiguous().view(
1, int(b * num_groups), channel_in_group, *input.size()[2:])
out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None,
training=False, momentum=0, eps=eps)
out = out.view(b, c, *input.size()[2:])
# post affine transform
if weight is not None:
out *= weight.view(1, -1, 1, 1)
if bias is not None:
out += bias.view(1, -1, 1, 1)
return out
def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
"""
Crop the valid region from the tile
@param x: input tile
@param input_bbox: original input bounding box
@param target_bbox: output bounding box
@param scale: scale factor
@return: cropped tile
"""
padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
def perfcount(fn):
def wrapper(*args, **kwargs):
ts = time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(devices.device)
devices.torch_gc()
gc.collect()
ret = fn(*args, **kwargs)
devices.torch_gc()
gc.collect()
if torch.cuda.is_available():
vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
torch.cuda.reset_peak_memory_stats(devices.device)
print(
f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
else:
print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
return ret
return wrapper
# copy end :)
class GroupNormParam:
def __init__(self):
self.var_list = []
self.mean_list = []
self.pixel_list = []
self.weight = None
self.bias = None
def add_tile(self, tile, layer):
var, mean = get_var_mean(tile, 32)
# For giant images, the variance can be larger than max float16
# In this case we create a copy to float32
if var.dtype == torch.float16 and var.isinf().any():
fp32_tile = tile.float()
var, mean = get_var_mean(fp32_tile, 32)
# ============= DEBUG: test for infinite =============
# if torch.isinf(var).any():
# print('var: ', var)
# ====================================================
self.var_list.append(var)
self.mean_list.append(mean)
self.pixel_list.append(
tile.shape[2]*tile.shape[3])
if hasattr(layer, 'weight'):
self.weight = layer.weight
self.bias = layer.bias
else:
self.weight = None
self.bias = None
def summary(self):
"""
summarize the mean and var and return a function
that apply group norm on each tile
"""
if len(self.var_list) == 0:
return None
var = torch.vstack(self.var_list)
mean = torch.vstack(self.mean_list)
max_value = max(self.pixel_list)
pixels = torch.tensor(
self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
sum_pixels = torch.sum(pixels)
pixels = pixels.unsqueeze(
1) / sum_pixels
var = torch.sum(
var * pixels, dim=0)
mean = torch.sum(
mean * pixels, dim=0)
return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
@staticmethod
def from_tile(tile, norm):
"""
create a function from a single tile without summary
"""
var, mean = get_var_mean(tile, 32)
if var.dtype == torch.float16 and var.isinf().any():
fp32_tile = tile.float()
var, mean = get_var_mean(fp32_tile, 32)
# if it is a macbook, we need to convert back to float16
if var.device.type == 'mps':
# clamp to avoid overflow
var = torch.clamp(var, 0, 60000)
var = var.half()
mean = mean.half()
if hasattr(norm, 'weight'):
weight = norm.weight
bias = norm.bias
else:
weight = None
bias = None
def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
return group_norm_func
class VAEHook:
def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False):
self.net = net # encoder | decoder
self.tile_size = tile_size
self.is_decoder = is_decoder
self.fast_mode = (fast_encoder and not is_decoder) or (
fast_decoder and is_decoder)
self.color_fix = color_fix and not is_decoder
self.to_gpu = to_gpu
self.pad = 11 if is_decoder else 32
def __call__(self, x):
B, C, H, W = x.shape
original_device = next(self.net.parameters()).device
try:
if self.to_gpu:
self.net.to(devices.get_optimal_device())
if max(H, W) <= self.pad * 2 + self.tile_size:
print("[Tiled VAE]: the input size is tiny and unnecessary to tile.")
return self.net.original_forward(x)
else:
return self.vae_tile_forward(x)
finally:
self.net.to(original_device)
def get_best_tile_size(self, lowerbound, upperbound):
"""
Get the best tile size for GPU memory
"""
divider = 32
while divider >= 2:
remainer = lowerbound % divider
if remainer == 0:
return lowerbound
candidate = lowerbound - remainer + divider
if candidate <= upperbound:
return candidate
divider //= 2
return lowerbound
def split_tiles(self, h, w):
"""
Tool function to split the image into tiles
@param h: height of the image
@param w: width of the image
@return: tile_input_bboxes, tile_output_bboxes
"""
tile_input_bboxes, tile_output_bboxes = [], []
tile_size = self.tile_size
pad = self.pad
num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
# If any of the numbers are 0, we let it be 1
# This is to deal with long and thin images
num_height_tiles = max(num_height_tiles, 1)
num_width_tiles = max(num_width_tiles, 1)
# Suggestions from https://github.com/Kahsolt: auto shrink the tile size
real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
for i in range(num_height_tiles):
for j in range(num_width_tiles):
# bbox: [x1, x2, y1, y2]
# the padding is is unnessary for image borders. So we directly start from (32, 32)
input_bbox = [
pad + j * real_tile_width,
min(pad + (j + 1) * real_tile_width, w),
pad + i * real_tile_height,
min(pad + (i + 1) * real_tile_height, h),
]
# if the output bbox is close to the image boundary, we extend it to the image boundary
output_bbox = [
input_bbox[0] if input_bbox[0] > pad else 0,
input_bbox[1] if input_bbox[1] < w - pad else w,
input_bbox[2] if input_bbox[2] > pad else 0,
input_bbox[3] if input_bbox[3] < h - pad else h,
]
# scale to get the final output bbox
output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
tile_output_bboxes.append(output_bbox)
# indistinguishable expand the input bbox by pad pixels
tile_input_bboxes.append([
max(0, input_bbox[0] - pad),
min(w, input_bbox[1] + pad),
max(0, input_bbox[2] - pad),
min(h, input_bbox[3] + pad),
])
return tile_input_bboxes, tile_output_bboxes
@torch.no_grad()
def estimate_group_norm(self, z, task_queue, color_fix):
device = z.device
tile = z
last_id = len(task_queue) - 1
while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
last_id -= 1
if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
raise ValueError('No group norm found in the task queue')
# estimate until the last group norm
for i in range(last_id + 1):
task = task_queue[i]
if task[0] == 'pre_norm':
group_norm_func = GroupNormParam.from_tile(tile, task[1])
task_queue[i] = ('apply_norm', group_norm_func)
if i == last_id:
return True
tile = group_norm_func(tile)
elif task[0] == 'store_res':
task_id = i + 1
while task_id < last_id and task_queue[task_id][0] != 'add_res':
task_id += 1
if task_id >= last_id:
continue
task_queue[task_id][1] = task[1](tile)
elif task[0] == 'add_res':
tile += task[1].to(device)
task[1] = None
elif color_fix and task[0] == 'downsample':
for j in range(i, last_id + 1):
if task_queue[j][0] == 'store_res':
task_queue[j] = ('store_res_cpu', task_queue[j][1])
return True
else:
tile = task[1](tile)
try:
devices.test_for_nans(tile, "vae")
except:
print(f'Nan detected in fast mode estimation. Fast mode disabled.')
return False
raise IndexError('Should not reach here')
@perfcount
@torch.no_grad()
def vae_tile_forward(self, z):
"""
Decode a latent vector z into an image in a tiled manner.
@param z: latent vector
@return: image
"""
device = next(self.net.parameters()).device
net = self.net
tile_size = self.tile_size
is_decoder = self.is_decoder
z = z.detach() # detach the input to avoid backprop
N, height, width = z.shape[0], z.shape[2], z.shape[3]
net.last_z_shape = z.shape
# Split the input into tiles and build a task queue for each tile
print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
in_bboxes, out_bboxes = self.split_tiles(height, width)
# Prepare tiles by split the input latents
tiles = []
for input_bbox in in_bboxes:
tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
tiles.append(tile)
num_tiles = len(tiles)
num_completed = 0
# Build task queues
single_task_queue = build_task_queue(net, is_decoder)
#print(single_task_queue)
if self.fast_mode:
# Fast mode: downsample the input image to the tile size,
# then estimate the group norm parameters on the downsampled image
scale_factor = tile_size / max(height, width)
z = z.to(device)
downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
# use nearest-exact to keep statictics as close as possible
print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
# ======= Special thanks to @Kahsolt for distribution shift issue ======= #
# The downsampling will heavily distort its mean and std, so we need to recover it.
std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
del std_old, mean_old, std_new, mean_new
# occasionally the std_new is too small or too large, which exceeds the range of float16
# so we need to clamp it to max z's range.
downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
estimate_task_queue = clone_task_queue(single_task_queue)
if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
single_task_queue = estimate_task_queue
del downsampled_z
task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
# Dummy result
result = None
result_approx = None
#try:
# with devices.autocast():
# result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()
#except: pass
# Free memory of input latent tensor
del z
# Task queue execution
pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
# execute the task back and forth when switch tiles so that we always
# keep one tile on the GPU to reduce unnecessary data transfer
forward = True
interrupted = False
#state.interrupted = interrupted
while True:
#if state.interrupted: interrupted = True ; break
group_norm_param = GroupNormParam()
for i in range(num_tiles) if forward else reversed(range(num_tiles)):
#if state.interrupted: interrupted = True ; break
tile = tiles[i].to(device)
input_bbox = in_bboxes[i]
task_queue = task_queues[i]
interrupted = False
while len(task_queue) > 0:
#if state.interrupted: interrupted = True ; break
# DEBUG: current task
# print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
task = task_queue.pop(0)
if task[0] == 'pre_norm':
group_norm_param.add_tile(tile, task[1])
break
elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
task_id = 0
res = task[1](tile)
if not self.fast_mode or task[0] == 'store_res_cpu':
res = res.cpu()
while task_queue[task_id][0] != 'add_res':
task_id += 1
task_queue[task_id][1] = res
elif task[0] == 'add_res':
tile += task[1].to(device)
task[1] = None
else:
tile = task[1](tile)
pbar.update(1)
if interrupted: break
# check for NaNs in the tile.
# If there are NaNs, we abort the process to save user's time
#devices.test_for_nans(tile, "vae")
#print(tiles[i].shape, tile.shape, i, num_tiles)
if len(task_queue) == 0:
tiles[i] = None
num_completed += 1
if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
del tile
elif i == num_tiles - 1 and forward:
forward = False
tiles[i] = tile
elif i == 0 and not forward:
forward = True
tiles[i] = tile
else:
tiles[i] = tile.cpu()
del tile
if interrupted: break
if num_completed == num_tiles: break
# insert the group norm task to the head of each task queue
group_norm_func = group_norm_param.summary()
if group_norm_func is not None:
for i in range(num_tiles):
task_queue = task_queues[i]
task_queue.insert(0, ('apply_norm', group_norm_func))
# Done!
pbar.close()
return result if result is not None else result_approx.to(device)
\ No newline at end of file
import os
import re
import requests
import sys
import copy
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers import AutoTokenizer, CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel
from peft import LoraConfig, get_peft_model
p = "src/"
sys.path.append(p)
from model import make_1step_sched, my_lora_fwd
from basicsr.archs.arch_util import default_init_weights
def get_layer_number(module_name):
base_layers = {
'down_blocks': 0,
'mid_block': 4,
'up_blocks': 5
}
if module_name == 'conv_out':
return 9
base_layer = None
for key in base_layers:
if key in module_name:
base_layer = base_layers[key]
break
if base_layer is None:
return None
additional_layers = int(re.findall(r'\.(\d+)', module_name)[0]) #sum(int(num) for num in re.findall(r'\d+', module_name))
final_layer = base_layer + additional_layers
return final_layer
class S3Diff(torch.nn.Module):
def __init__(self, sd_path=None, pretrained_path=None, lora_rank_unet=32, lora_rank_vae=16, block_embedding_dim=64):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").cuda()
self.sched = make_1step_sched(sd_path)
vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(sd_path, subfolder="unet")
target_modules_vae = r"^encoder\..*(conv1|conv2|conv_in|conv_shortcut|conv|conv_out|to_k|to_q|to_v|to_out\.0)$"
target_modules_unet = [
"to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out",
"proj_in", "proj_out", "ff.net.2", "ff.net.0.proj"
]
num_embeddings = 64
self.W = nn.Parameter(torch.randn(num_embeddings), requires_grad=False)
self.vae_de_mlp = nn.Sequential(
nn.Linear(num_embeddings * 4, 256),
nn.ReLU(True),
)
self.unet_de_mlp = nn.Sequential(
nn.Linear(num_embeddings * 4, 256),
nn.ReLU(True),
)
self.vae_block_mlp = nn.Sequential(
nn.Linear(block_embedding_dim, 64),
nn.ReLU(True),
)
self.unet_block_mlp = nn.Sequential(
nn.Linear(block_embedding_dim, 64),
nn.ReLU(True),
)
self.vae_fuse_mlp = nn.Linear(256 + 64, lora_rank_vae ** 2)
self.unet_fuse_mlp = nn.Linear(256 + 64, lora_rank_unet ** 2)
default_init_weights([self.vae_de_mlp, self.unet_de_mlp, self.vae_block_mlp, self.unet_block_mlp, \
self.vae_fuse_mlp, self.unet_fuse_mlp], 1e-5)
# vae
self.vae_block_embeddings = nn.Embedding(6, block_embedding_dim)
self.unet_block_embeddings = nn.Embedding(10, block_embedding_dim)
if pretrained_path is not None:
sd = torch.load(pretrained_path, map_location="cpu")
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
_sd_vae = vae.state_dict()
for k in sd["state_dict_vae"]:
_sd_vae[k] = sd["state_dict_vae"][k]
vae.load_state_dict(_sd_vae)
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
unet.add_adapter(unet_lora_config)
_sd_unet = unet.state_dict()
for k in sd["state_dict_unet"]:
_sd_unet[k] = sd["state_dict_unet"][k]
unet.load_state_dict(_sd_unet)
_vae_de_mlp = self.vae_de_mlp.state_dict()
for k in sd["state_dict_vae_de_mlp"]:
_vae_de_mlp[k] = sd["state_dict_vae_de_mlp"][k]
self.vae_de_mlp.load_state_dict(_vae_de_mlp)
_unet_de_mlp = self.unet_de_mlp.state_dict()
for k in sd["state_dict_unet_de_mlp"]:
_unet_de_mlp[k] = sd["state_dict_unet_de_mlp"][k]
self.unet_de_mlp.load_state_dict(_unet_de_mlp)
_vae_block_mlp = self.vae_block_mlp.state_dict()
for k in sd["state_dict_vae_block_mlp"]:
_vae_block_mlp[k] = sd["state_dict_vae_block_mlp"][k]
self.vae_block_mlp.load_state_dict(_vae_block_mlp)
_unet_block_mlp = self.unet_block_mlp.state_dict()
for k in sd["state_dict_unet_block_mlp"]:
_unet_block_mlp[k] = sd["state_dict_unet_block_mlp"][k]
self.unet_block_mlp.load_state_dict(_unet_block_mlp)
_vae_fuse_mlp = self.vae_fuse_mlp.state_dict()
for k in sd["state_dict_vae_fuse_mlp"]:
_vae_fuse_mlp[k] = sd["state_dict_vae_fuse_mlp"][k]
self.vae_fuse_mlp.load_state_dict(_vae_fuse_mlp)
_unet_fuse_mlp = self.unet_fuse_mlp.state_dict()
for k in sd["state_dict_unet_fuse_mlp"]:
_unet_fuse_mlp[k] = sd["state_dict_unet_fuse_mlp"][k]
self.unet_fuse_mlp.load_state_dict(_unet_fuse_mlp)
self.W = nn.Parameter(sd["w"], requires_grad=False)
embeddings_state_dict = sd["state_embeddings"]
self.vae_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_vae_block'])
self.unet_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_unet_block'])
else:
print("Initializing model with random weights")
vae_lora_config = LoraConfig(r=lora_rank_vae, init_lora_weights="gaussian",
target_modules=target_modules_vae)
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
unet_lora_config = LoraConfig(r=lora_rank_unet, init_lora_weights="gaussian",
target_modules=target_modules_unet
)
unet.add_adapter(unet_lora_config)
self.lora_rank_unet = lora_rank_unet
self.lora_rank_vae = lora_rank_vae
self.target_modules_vae = target_modules_vae
self.target_modules_unet = target_modules_unet
self.vae_lora_layers = []
for name, module in vae.named_modules():
if 'base_layer' in name:
self.vae_lora_layers.append(name[:-len(".base_layer")])
for name, module in vae.named_modules():
if name in self.vae_lora_layers:
module.forward = my_lora_fwd.__get__(module, module.__class__)
self.unet_lora_layers = []
for name, module in unet.named_modules():
if 'base_layer' in name:
self.unet_lora_layers.append(name[:-len(".base_layer")])
for name, module in unet.named_modules():
if name in self.unet_lora_layers:
module.forward = my_lora_fwd.__get__(module, module.__class__)
self.unet_layer_dict = {name: get_layer_number(name) for name in self.unet_lora_layers}
unet.to("cuda")
vae.to("cuda")
self.unet, self.vae = unet, vae
self.timesteps = torch.tensor([999], device="cuda").long()
self.text_encoder.requires_grad_(False)
def set_eval(self):
self.unet.eval()
self.vae.eval()
self.vae_de_mlp.eval()
self.unet_de_mlp.eval()
self.vae_block_mlp.eval()
self.unet_block_mlp.eval()
self.vae_fuse_mlp.eval()
self.unet_fuse_mlp.eval()
self.vae_block_embeddings.requires_grad_(False)
self.unet_block_embeddings.requires_grad_(False)
self.unet.requires_grad_(False)
self.vae.requires_grad_(False)
def set_train(self):
self.unet.train()
self.vae.train()
self.vae_de_mlp.train()
self.unet_de_mlp.train()
self.vae_block_mlp.train()
self.unet_block_mlp.train()
self.vae_fuse_mlp.train()
self.unet_fuse_mlp.train()
self.vae_block_embeddings.requires_grad_(True)
self.unet_block_embeddings.requires_grad_(True)
for n, _p in self.unet.named_parameters():
if "lora" in n:
_p.requires_grad = True
self.unet.conv_in.requires_grad_(True)
for n, _p in self.vae.named_parameters():
if "lora" in n:
_p.requires_grad = True
def forward(self, c_t, deg_score, prompt):
if prompt is not None:
# encode the text prompt
caption_tokens = self.tokenizer(prompt, max_length=self.tokenizer.model_max_length,
padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
caption_enc = self.text_encoder(caption_tokens)[0]
else:
caption_enc = self.text_encoder(prompt_tokens)[0]
# degradation fourier embedding
deg_proj = deg_score[..., None] * self.W[None, None, :] * 2 * np.pi
deg_proj = torch.cat([torch.sin(deg_proj), torch.cos(deg_proj)], dim=-1)
deg_proj = torch.cat([deg_proj[:, 0], deg_proj[:, 1]], dim=-1)
# degradation mlp forward
vae_de_c_embed = self.vae_de_mlp(deg_proj)
unet_de_c_embed = self.unet_de_mlp(deg_proj)
# block embedding mlp forward
vae_block_c_embeds = self.vae_block_mlp(self.vae_block_embeddings.weight)
unet_block_c_embeds = self.unet_block_mlp(self.unet_block_embeddings.weight)
vae_embeds = self.vae_fuse_mlp(torch.cat([vae_de_c_embed.unsqueeze(1).repeat(1, vae_block_c_embeds.shape[0], 1), \
vae_block_c_embeds.unsqueeze(0).repeat(vae_de_c_embed.shape[0],1,1)], -1))
unet_embeds = self.unet_fuse_mlp(torch.cat([unet_de_c_embed.unsqueeze(1).repeat(1, unet_block_c_embeds.shape[0], 1), \
unet_block_c_embeds.unsqueeze(0).repeat(unet_de_c_embed.shape[0],1,1)], -1))
for layer_name, module in self.vae.named_modules():
if layer_name in self.vae_lora_layers:
split_name = layer_name.split(".")
if split_name[1] == 'down_blocks':
block_id = int(split_name[2])
vae_embed = vae_embeds[:, block_id]
elif split_name[1] == 'mid_block':
vae_embed = vae_embeds[:, -2]
else:
vae_embed = vae_embeds[:, -1]
module.de_mod = vae_embed.reshape(-1, self.lora_rank_vae, self.lora_rank_vae)
for layer_name, module in self.unet.named_modules():
if layer_name in self.unet_lora_layers:
split_name = layer_name.split(".")
if split_name[0] == 'down_blocks':
block_id = int(split_name[1])
unet_embed = unet_embeds[:, block_id]
elif split_name[0] == 'mid_block':
unet_embed = unet_embeds[:, 4]
elif split_name[0] == 'up_blocks':
block_id = int(split_name[1]) + 5
unet_embed = unet_embeds[:, block_id]
else:
unet_embed = unet_embeds[:, -1]
module.de_mod = unet_embed.reshape(-1, self.lora_rank_unet, self.lora_rank_unet)
encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc,).sample
x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample
output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1)
return output_image
def save_model(self, outf):
sd = {}
sd["unet_lora_target_modules"] = self.target_modules_unet
sd["vae_lora_target_modules"] = self.target_modules_vae
sd["rank_unet"] = self.lora_rank_unet
sd["rank_vae"] = self.lora_rank_vae
sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k}
sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip_conv" in k}
sd["state_dict_vae_de_mlp"] = {k: v for k, v in self.vae_de_mlp.state_dict().items()}
sd["state_dict_unet_de_mlp"] = {k: v for k, v in self.unet_de_mlp.state_dict().items()}
sd["state_dict_vae_block_mlp"] = {k: v for k, v in self.vae_block_mlp.state_dict().items()}
sd["state_dict_unet_block_mlp"] = {k: v for k, v in self.unet_block_mlp.state_dict().items()}
sd["state_dict_vae_fuse_mlp"] = {k: v for k, v in self.vae_fuse_mlp.state_dict().items()}
sd["state_dict_unet_fuse_mlp"] = {k: v for k, v in self.unet_fuse_mlp.state_dict().items()}
sd["w"] = self.W
sd["state_embeddings"] = {
"state_dict_vae_block": self.vae_block_embeddings.state_dict(),
"state_dict_unet_block": self.unet_block_embeddings.state_dict(),
}
torch.save(sd, outf)
import os
import re
import requests
import sys
import copy
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers import AutoTokenizer, CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel
from peft import LoraConfig, get_peft_model
p = "src/"
sys.path.append(p)
from model import make_1step_sched, my_lora_fwd
from basicsr.archs.arch_util import default_init_weights
def get_layer_number(module_name):
base_layers = {
'down_blocks': 0,
'mid_block': 4,
'up_blocks': 5
}
if module_name == 'conv_out':
return 9
base_layer = None
for key in base_layers:
if key in module_name:
base_layer = base_layers[key]
break
if base_layer is None:
return None
additional_layers = int(re.findall(r'\.(\d+)', module_name)[0]) #sum(int(num) for num in re.findall(r'\d+', module_name))
final_layer = base_layer + additional_layers
return final_layer
class S3Diff(torch.nn.Module):
def __init__(self, sd_path=None, pretrained_path=None, lora_rank_unet=8, lora_rank_vae=4, block_embedding_dim=64):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").cuda()
self.sched = make_1step_sched(sd_path)
self.guidance_scale = 1.07
vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(sd_path, subfolder="unet")
target_modules_vae = r"^encoder\..*(conv1|conv2|conv_in|conv_shortcut|conv|conv_out|to_k|to_q|to_v|to_out\.0)$"
target_modules_unet = [
"to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out",
"proj_in", "proj_out", "ff.net.2", "ff.net.0.proj"
]
num_embeddings = 64
self.W = nn.Parameter(torch.randn(num_embeddings), requires_grad=False)
self.vae_de_mlp = nn.Sequential(
nn.Linear(num_embeddings * 4, 256),
nn.ReLU(True),
)
self.unet_de_mlp = nn.Sequential(
nn.Linear(num_embeddings * 4, 256),
nn.ReLU(True),
)
self.vae_block_mlp = nn.Sequential(
nn.Linear(block_embedding_dim, 64),
nn.ReLU(True),
)
self.unet_block_mlp = nn.Sequential(
nn.Linear(block_embedding_dim, 64),
nn.ReLU(True),
)
self.vae_fuse_mlp = nn.Linear(256 + 64, lora_rank_vae ** 2)
self.unet_fuse_mlp = nn.Linear(256 + 64, lora_rank_unet ** 2)
default_init_weights([self.vae_de_mlp, self.unet_de_mlp, self.vae_block_mlp, self.unet_block_mlp, \
self.vae_fuse_mlp, self.unet_fuse_mlp], 1e-5)
# vae
self.vae_block_embeddings = nn.Embedding(6, block_embedding_dim)
self.unet_block_embeddings = nn.Embedding(10, block_embedding_dim)
if pretrained_path is not None:
sd = torch.load(pretrained_path, map_location="cpu")
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
_sd_vae = vae.state_dict()
for k in sd["state_dict_vae"]:
_sd_vae[k] = sd["state_dict_vae"][k]
vae.load_state_dict(_sd_vae)
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
unet.add_adapter(unet_lora_config)
_sd_unet = unet.state_dict()
for k in sd["state_dict_unet"]:
_sd_unet[k] = sd["state_dict_unet"][k]
unet.load_state_dict(_sd_unet)
_vae_de_mlp = self.vae_de_mlp.state_dict()
for k in sd["state_dict_vae_de_mlp"]:
_vae_de_mlp[k] = sd["state_dict_vae_de_mlp"][k]
self.vae_de_mlp.load_state_dict(_vae_de_mlp)
_unet_de_mlp = self.unet_de_mlp.state_dict()
for k in sd["state_dict_unet_de_mlp"]:
_unet_de_mlp[k] = sd["state_dict_unet_de_mlp"][k]
self.unet_de_mlp.load_state_dict(_unet_de_mlp)
_vae_block_mlp = self.vae_block_mlp.state_dict()
for k in sd["state_dict_vae_block_mlp"]:
_vae_block_mlp[k] = sd["state_dict_vae_block_mlp"][k]
self.vae_block_mlp.load_state_dict(_vae_block_mlp)
_unet_block_mlp = self.unet_block_mlp.state_dict()
for k in sd["state_dict_unet_block_mlp"]:
_unet_block_mlp[k] = sd["state_dict_unet_block_mlp"][k]
self.unet_block_mlp.load_state_dict(_unet_block_mlp)
_vae_fuse_mlp = self.vae_fuse_mlp.state_dict()
for k in sd["state_dict_vae_fuse_mlp"]:
_vae_fuse_mlp[k] = sd["state_dict_vae_fuse_mlp"][k]
self.vae_fuse_mlp.load_state_dict(_vae_fuse_mlp)
_unet_fuse_mlp = self.unet_fuse_mlp.state_dict()
for k in sd["state_dict_unet_fuse_mlp"]:
_unet_fuse_mlp[k] = sd["state_dict_unet_fuse_mlp"][k]
self.unet_fuse_mlp.load_state_dict(_unet_fuse_mlp)
self.W = nn.Parameter(sd["w"], requires_grad=False)
embeddings_state_dict = sd["state_embeddings"]
self.vae_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_vae_block'])
self.unet_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_unet_block'])
else:
print("Initializing model with random weights")
vae_lora_config = LoraConfig(r=lora_rank_vae, init_lora_weights="gaussian",
target_modules=target_modules_vae)
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
unet_lora_config = LoraConfig(r=lora_rank_unet, init_lora_weights="gaussian",
target_modules=target_modules_unet
)
unet.add_adapter(unet_lora_config)
self.lora_rank_unet = lora_rank_unet
self.lora_rank_vae = lora_rank_vae
self.target_modules_vae = target_modules_vae
self.target_modules_unet = target_modules_unet
self.vae_lora_layers = []
for name, module in vae.named_modules():
if 'base_layer' in name:
self.vae_lora_layers.append(name[:-len(".base_layer")])
for name, module in vae.named_modules():
if name in self.vae_lora_layers:
module.forward = my_lora_fwd.__get__(module, module.__class__)
self.unet_lora_layers = []
for name, module in unet.named_modules():
if 'base_layer' in name:
self.unet_lora_layers.append(name[:-len(".base_layer")])
for name, module in unet.named_modules():
if name in self.unet_lora_layers:
module.forward = my_lora_fwd.__get__(module, module.__class__)
self.unet_layer_dict = {name: get_layer_number(name) for name in self.unet_lora_layers}
unet.to("cuda")
vae.to("cuda")
self.unet, self.vae = unet, vae
self.timesteps = torch.tensor([999], device="cuda").long()
self.text_encoder.requires_grad_(False)
def set_eval(self):
self.unet.eval()
self.vae.eval()
self.vae_de_mlp.eval()
self.unet_de_mlp.eval()
self.vae_block_mlp.eval()
self.unet_block_mlp.eval()
self.vae_fuse_mlp.eval()
self.unet_fuse_mlp.eval()
self.vae_block_embeddings.requires_grad_(False)
self.unet_block_embeddings.requires_grad_(False)
self.unet.requires_grad_(False)
self.vae.requires_grad_(False)
def set_train(self):
self.unet.train()
self.vae.train()
self.vae_de_mlp.train()
self.unet_de_mlp.train()
self.vae_block_mlp.train()
self.unet_block_mlp.train()
self.vae_fuse_mlp.train()
self.unet_fuse_mlp.train()
self.vae_block_embeddings.requires_grad_(True)
self.unet_block_embeddings.requires_grad_(True)
for n, _p in self.unet.named_parameters():
if "lora" in n:
_p.requires_grad = True
self.unet.conv_in.requires_grad_(True)
for n, _p in self.vae.named_parameters():
if "lora" in n:
_p.requires_grad = True
def forward(self, c_t, deg_score, pos_prompt, neg_prompt):
if pos_prompt is not None:
# encode the text prompt
pos_caption_tokens = self.tokenizer(pos_prompt, max_length=self.tokenizer.model_max_length,
padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
pos_caption_enc = self.text_encoder(pos_caption_tokens)[0]
else:
pos_caption_enc = self.text_encoder(prompt_tokens)[0]
if neg_prompt is not None:
# encode the text prompt
neg_caption_tokens = self.tokenizer(neg_prompt, max_length=self.tokenizer.model_max_length,
padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
neg_caption_enc = self.text_encoder(neg_caption_tokens)[0]
else:
neg_caption_enc = self.text_encoder(neg_prompt_tokens)[0]
# degradation fourier embedding
deg_proj = deg_score[..., None] * self.W[None, None, :] * 2 * np.pi
deg_proj = torch.cat([torch.sin(deg_proj), torch.cos(deg_proj)], dim=-1)
deg_proj = torch.cat([deg_proj[:, 0], deg_proj[:, 1]], dim=-1)
# degradation mlp forward
vae_de_c_embed = self.vae_de_mlp(deg_proj)
unet_de_c_embed = self.unet_de_mlp(deg_proj)
# block embedding mlp forward
vae_block_c_embeds = self.vae_block_mlp(self.vae_block_embeddings.weight)
unet_block_c_embeds = self.unet_block_mlp(self.unet_block_embeddings.weight)
vae_embeds = self.vae_fuse_mlp(torch.cat([vae_de_c_embed.unsqueeze(1).repeat(1, vae_block_c_embeds.shape[0], 1), \
vae_block_c_embeds.unsqueeze(0).repeat(vae_de_c_embed.shape[0],1,1)], -1))
unet_embeds = self.unet_fuse_mlp(torch.cat([unet_de_c_embed.unsqueeze(1).repeat(1, unet_block_c_embeds.shape[0], 1), \
unet_block_c_embeds.unsqueeze(0).repeat(unet_de_c_embed.shape[0],1,1)], -1))
for layer_name, module in self.vae.named_modules():
if layer_name in self.vae_lora_layers:
split_name = layer_name.split(".")
if split_name[1] == 'down_blocks':
block_id = int(split_name[2])
vae_embed = vae_embeds[:, block_id]
elif split_name[1] == 'mid_block':
vae_embed = vae_embeds[:, -2]
else:
vae_embed = vae_embeds[:, -1]
module.de_mod = vae_embed.reshape(-1, self.lora_rank_vae, self.lora_rank_vae)
for layer_name, module in self.unet.named_modules():
if layer_name in self.unet_lora_layers:
split_name = layer_name.split(".")
if split_name[0] == 'down_blocks':
block_id = int(split_name[1])
unet_embed = unet_embeds[:, block_id]
elif split_name[0] == 'mid_block':
unet_embed = unet_embeds[:, 4]
elif split_name[0] == 'up_blocks':
block_id = int(split_name[1]) + 5
unet_embed = unet_embeds[:, block_id]
else:
unet_embed = unet_embeds[:, -1]
module.de_mod = unet_embed.reshape(-1, self.lora_rank_unet, self.lora_rank_unet)
encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
pos_model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=pos_caption_enc).sample
neg_model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=neg_caption_enc).sample
model_pred = neg_model_pred + self.guidance_scale * (pos_model_pred - neg_model_pred)
x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample
output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1)
return output_image
def save_model(self, outf):
sd = {}
sd["unet_lora_target_modules"] = self.target_modules_unet
sd["vae_lora_target_modules"] = self.target_modules_vae
sd["rank_unet"] = self.lora_rank_unet
sd["rank_vae"] = self.lora_rank_vae
sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k}
sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip_conv" in k}
sd["state_dict_vae_de_mlp"] = {k: v for k, v in self.vae_de_mlp.state_dict().items()}
sd["state_dict_unet_de_mlp"] = {k: v for k, v in self.unet_de_mlp.state_dict().items()}
sd["state_dict_vae_block_mlp"] = {k: v for k, v in self.vae_block_mlp.state_dict().items()}
sd["state_dict_unet_block_mlp"] = {k: v for k, v in self.unet_block_mlp.state_dict().items()}
sd["state_dict_vae_fuse_mlp"] = {k: v for k, v in self.vae_fuse_mlp.state_dict().items()}
sd["state_dict_unet_fuse_mlp"] = {k: v for k, v in self.unet_fuse_mlp.state_dict().items()}
sd["w"] = self.W
sd["state_embeddings"] = {
"state_dict_vae_block": self.vae_block_embeddings.state_dict(),
"state_dict_unet_block": self.unet_block_embeddings.state_dict(),
}
torch.save(sd, outf)
import os
import re
import requests
import sys
import copy
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers import AutoTokenizer, CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel
from peft import LoraConfig, get_peft_model
p = "src/"
sys.path.append(p)
from model import make_1step_sched, my_lora_fwd
from basicsr.archs.arch_util import default_init_weights
from my_utils.vaehook import VAEHook, perfcount
def get_layer_number(module_name):
base_layers = {
'down_blocks': 0,
'mid_block': 4,
'up_blocks': 5
}
if module_name == 'conv_out':
return 9
base_layer = None
for key in base_layers:
if key in module_name:
base_layer = base_layers[key]
break
if base_layer is None:
return None
additional_layers = int(re.findall(r'\.(\d+)', module_name)[0]) #sum(int(num) for num in re.findall(r'\d+', module_name))
final_layer = base_layer + additional_layers
return final_layer
class S3Diff(torch.nn.Module):
def __init__(self, sd_path=None, pretrained_path=None, lora_rank_unet=32, lora_rank_vae=16, block_embedding_dim=64, args=None):
super().__init__()
self.args = args
self.latent_tiled_size = args.latent_tiled_size
self.latent_tiled_overlap = args.latent_tiled_overlap
self.tokenizer = AutoTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").cuda()
self.sched = make_1step_sched(sd_path)
self.guidance_scale = 1.07
vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(sd_path, subfolder="unet")
target_modules_vae = r"^encoder\..*(conv1|conv2|conv_in|conv_shortcut|conv|conv_out|to_k|to_q|to_v|to_out\.0)$"
target_modules_unet = [
"to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out",
"proj_in", "proj_out", "ff.net.2", "ff.net.0.proj"
]
num_embeddings = 64
self.W = nn.Parameter(torch.randn(num_embeddings), requires_grad=False)
self.vae_de_mlp = nn.Sequential(
nn.Linear(num_embeddings * 4, 256),
nn.ReLU(True),
)
self.unet_de_mlp = nn.Sequential(
nn.Linear(num_embeddings * 4, 256),
nn.ReLU(True),
)
self.vae_block_mlp = nn.Sequential(
nn.Linear(block_embedding_dim, 64),
nn.ReLU(True),
)
self.unet_block_mlp = nn.Sequential(
nn.Linear(block_embedding_dim, 64),
nn.ReLU(True),
)
self.vae_fuse_mlp = nn.Linear(256 + 64, lora_rank_vae ** 2)
self.unet_fuse_mlp = nn.Linear(256 + 64, lora_rank_unet ** 2)
default_init_weights([self.vae_de_mlp, self.unet_de_mlp, self.vae_block_mlp, self.unet_block_mlp, \
self.vae_fuse_mlp, self.unet_fuse_mlp], 1e-5)
# vae
self.vae_block_embeddings = nn.Embedding(6, block_embedding_dim)
self.unet_block_embeddings = nn.Embedding(10, block_embedding_dim)
if pretrained_path is not None:
sd = torch.load(pretrained_path, map_location="cpu")
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
_sd_vae = vae.state_dict()
for k in sd["state_dict_vae"]:
_sd_vae[k] = sd["state_dict_vae"][k]
vae.load_state_dict(_sd_vae)
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
unet.add_adapter(unet_lora_config)
_sd_unet = unet.state_dict()
for k in sd["state_dict_unet"]:
_sd_unet[k] = sd["state_dict_unet"][k]
unet.load_state_dict(_sd_unet)
_vae_de_mlp = self.vae_de_mlp.state_dict()
for k in sd["state_dict_vae_de_mlp"]:
_vae_de_mlp[k] = sd["state_dict_vae_de_mlp"][k]
self.vae_de_mlp.load_state_dict(_vae_de_mlp)
_unet_de_mlp = self.unet_de_mlp.state_dict()
for k in sd["state_dict_unet_de_mlp"]:
_unet_de_mlp[k] = sd["state_dict_unet_de_mlp"][k]
self.unet_de_mlp.load_state_dict(_unet_de_mlp)
_vae_block_mlp = self.vae_block_mlp.state_dict()
for k in sd["state_dict_vae_block_mlp"]:
_vae_block_mlp[k] = sd["state_dict_vae_block_mlp"][k]
self.vae_block_mlp.load_state_dict(_vae_block_mlp)
_unet_block_mlp = self.unet_block_mlp.state_dict()
for k in sd["state_dict_unet_block_mlp"]:
_unet_block_mlp[k] = sd["state_dict_unet_block_mlp"][k]
self.unet_block_mlp.load_state_dict(_unet_block_mlp)
_vae_fuse_mlp = self.vae_fuse_mlp.state_dict()
for k in sd["state_dict_vae_fuse_mlp"]:
_vae_fuse_mlp[k] = sd["state_dict_vae_fuse_mlp"][k]
self.vae_fuse_mlp.load_state_dict(_vae_fuse_mlp)
_unet_fuse_mlp = self.unet_fuse_mlp.state_dict()
for k in sd["state_dict_unet_fuse_mlp"]:
_unet_fuse_mlp[k] = sd["state_dict_unet_fuse_mlp"][k]
self.unet_fuse_mlp.load_state_dict(_unet_fuse_mlp)
self.W = nn.Parameter(sd["w"], requires_grad=False)
embeddings_state_dict = sd["state_embeddings"]
self.vae_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_vae_block'])
self.unet_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_unet_block'])
else:
print("Initializing model with random weights")
vae_lora_config = LoraConfig(r=lora_rank_vae, init_lora_weights="gaussian",
target_modules=target_modules_vae)
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
unet_lora_config = LoraConfig(r=lora_rank_unet, init_lora_weights="gaussian",
target_modules=target_modules_unet
)
unet.add_adapter(unet_lora_config)
self.lora_rank_unet = lora_rank_unet
self.lora_rank_vae = lora_rank_vae
self.target_modules_vae = target_modules_vae
self.target_modules_unet = target_modules_unet
self.vae_lora_layers = []
for name, module in vae.named_modules():
if 'base_layer' in name:
self.vae_lora_layers.append(name[:-len(".base_layer")])
for name, module in vae.named_modules():
if name in self.vae_lora_layers:
module.forward = my_lora_fwd.__get__(module, module.__class__)
self.unet_lora_layers = []
for name, module in unet.named_modules():
if 'base_layer' in name:
self.unet_lora_layers.append(name[:-len(".base_layer")])
for name, module in unet.named_modules():
if name in self.unet_lora_layers:
module.forward = my_lora_fwd.__get__(module, module.__class__)
self.unet_layer_dict = {name: get_layer_number(name) for name in self.unet_lora_layers}
unet.to("cuda")
vae.to("cuda")
self.unet, self.vae = unet, vae
self.timesteps = torch.tensor([999], device="cuda").long()
self.text_encoder.requires_grad_(False)
# vae tile
self._init_tiled_vae(encoder_tile_size=args.vae_encoder_tiled_size, decoder_tile_size=args.vae_decoder_tiled_size)
def set_eval(self):
self.unet.eval()
self.vae.eval()
self.vae_de_mlp.eval()
self.unet_de_mlp.eval()
self.vae_block_mlp.eval()
self.unet_block_mlp.eval()
self.vae_fuse_mlp.eval()
self.unet_fuse_mlp.eval()
self.vae_block_embeddings.requires_grad_(False)
self.unet_block_embeddings.requires_grad_(False)
self.unet.requires_grad_(False)
self.vae.requires_grad_(False)
def set_train(self):
self.unet.train()
self.vae.train()
self.vae_de_mlp.train()
self.unet_de_mlp.train()
self.vae_block_mlp.train()
self.unet_block_mlp.train()
self.vae_fuse_mlp.train()
self.unet_fuse_mlp.train()
self.vae_block_embeddings.requires_grad_(True)
self.unet_block_embeddings.requires_grad_(True)
for n, _p in self.unet.named_parameters():
if "lora" in n:
_p.requires_grad = True
self.unet.conv_in.requires_grad_(True)
for n, _p in self.vae.named_parameters():
if "lora" in n:
_p.requires_grad = True
@perfcount
@torch.no_grad()
def forward(self, c_t, deg_score, pos_prompt, neg_prompt):
if pos_prompt is not None:
# encode the text prompt
pos_caption_tokens = self.tokenizer(pos_prompt, max_length=self.tokenizer.model_max_length,
padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
pos_caption_enc = self.text_encoder(pos_caption_tokens)[0]
else:
pos_caption_enc = self.text_encoder(prompt_tokens)[0]
if neg_prompt is not None:
# encode the text prompt
neg_caption_tokens = self.tokenizer(neg_prompt, max_length=self.tokenizer.model_max_length,
padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
neg_caption_enc = self.text_encoder(neg_caption_tokens)[0]
else:
neg_caption_enc = self.text_encoder(neg_prompt_tokens)[0]
# degradation fourier embedding
deg_proj = deg_score[..., None] * self.W[None, None, :] * 2 * np.pi
deg_proj = torch.cat([torch.sin(deg_proj), torch.cos(deg_proj)], dim=-1)
deg_proj = torch.cat([deg_proj[:, 0], deg_proj[:, 1]], dim=-1)
# degradation mlp forward
vae_de_c_embed = self.vae_de_mlp(deg_proj)
unet_de_c_embed = self.unet_de_mlp(deg_proj)
# block embedding mlp forward
vae_block_c_embeds = self.vae_block_mlp(self.vae_block_embeddings.weight)
unet_block_c_embeds = self.unet_block_mlp(self.unet_block_embeddings.weight)
vae_embeds = self.vae_fuse_mlp(torch.cat([vae_de_c_embed.unsqueeze(1).repeat(1, vae_block_c_embeds.shape[0], 1), \
vae_block_c_embeds.unsqueeze(0).repeat(vae_de_c_embed.shape[0],1,1)], -1))
unet_embeds = self.unet_fuse_mlp(torch.cat([unet_de_c_embed.unsqueeze(1).repeat(1, unet_block_c_embeds.shape[0], 1), \
unet_block_c_embeds.unsqueeze(0).repeat(unet_de_c_embed.shape[0],1,1)], -1))
for layer_name, module in self.vae.named_modules():
if layer_name in self.vae_lora_layers:
split_name = layer_name.split(".")
if split_name[1] == 'down_blocks':
block_id = int(split_name[2])
vae_embed = vae_embeds[:, block_id]
elif split_name[1] == 'mid_block':
vae_embed = vae_embeds[:, -2]
else:
vae_embed = vae_embeds[:, -1]
module.de_mod = vae_embed.reshape(-1, self.lora_rank_vae, self.lora_rank_vae)
for layer_name, module in self.unet.named_modules():
if layer_name in self.unet_lora_layers:
split_name = layer_name.split(".")
if split_name[0] == 'down_blocks':
block_id = int(split_name[1])
unet_embed = unet_embeds[:, block_id]
elif split_name[0] == 'mid_block':
unet_embed = unet_embeds[:, 4]
elif split_name[0] == 'up_blocks':
block_id = int(split_name[1]) + 5
unet_embed = unet_embeds[:, block_id]
else:
unet_embed = unet_embeds[:, -1]
module.de_mod = unet_embed.reshape(-1, self.lora_rank_unet, self.lora_rank_unet)
lq_latent = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
## add tile function
_, _, h, w = lq_latent.size()
tile_size, tile_overlap = (self.latent_tiled_size, self.latent_tiled_overlap)
if h * w <= tile_size * tile_size:
print(f"[Tiled Latent]: the input size is tiny and unnecessary to tile.")
pos_model_pred = self.unet(lq_latent, self.timesteps, encoder_hidden_states=pos_caption_enc).sample
neg_model_pred = self.unet(lq_latent, self.timesteps, encoder_hidden_states=neg_caption_enc).sample
model_pred = neg_model_pred + self.guidance_scale * (pos_model_pred - neg_model_pred)
else:
print(f"[Tiled Latent]: the input size is {c_t.shape[-2]}x{c_t.shape[-1]}, need to tiled")
# tile_weights = self._gaussian_weights(tile_size, tile_size, 1).to()
tile_size = min(tile_size, min(h, w))
tile_weights = self._gaussian_weights(tile_size, tile_size, 1).to(c_t.device)
grid_rows = 0
cur_x = 0
while cur_x < lq_latent.size(-1):
cur_x = max(grid_rows * tile_size-tile_overlap * grid_rows, 0)+tile_size
grid_rows += 1
grid_cols = 0
cur_y = 0
while cur_y < lq_latent.size(-2):
cur_y = max(grid_cols * tile_size-tile_overlap * grid_cols, 0)+tile_size
grid_cols += 1
input_list = []
noise_preds = []
for row in range(grid_rows):
noise_preds_row = []
for col in range(grid_cols):
if col < grid_cols-1 or row < grid_rows-1:
# extract tile from input image
ofs_x = max(row * tile_size-tile_overlap * row, 0)
ofs_y = max(col * tile_size-tile_overlap * col, 0)
# input tile area on total image
if row == grid_rows-1:
ofs_x = w - tile_size
if col == grid_cols-1:
ofs_y = h - tile_size
input_start_x = ofs_x
input_end_x = ofs_x + tile_size
input_start_y = ofs_y
input_end_y = ofs_y + tile_size
# input tile dimensions
input_tile = lq_latent[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
input_list.append(input_tile)
if len(input_list) == 1 or col == grid_cols-1:
input_list_t = torch.cat(input_list, dim=0)
# predict the noise residual
pos_model_pred = self.unet(input_list_t, self.timesteps, encoder_hidden_states=pos_caption_enc).sample
neg_model_pred = self.unet(input_list_t, self.timesteps, encoder_hidden_states=neg_caption_enc).sample
model_out = neg_model_pred + self.guidance_scale * (pos_model_pred - neg_model_pred)
input_list = []
noise_preds.append(model_out)
# Stitch noise predictions for all tiles
noise_pred = torch.zeros(lq_latent.shape, device=lq_latent.device)
contributors = torch.zeros(lq_latent.shape, device=lq_latent.device)
# Add each tile contribution to overall latents
for row in range(grid_rows):
for col in range(grid_cols):
if col < grid_cols-1 or row < grid_rows-1:
# extract tile from input image
ofs_x = max(row * tile_size-tile_overlap * row, 0)
ofs_y = max(col * tile_size-tile_overlap * col, 0)
# input tile area on total image
if row == grid_rows-1:
ofs_x = w - tile_size
if col == grid_cols-1:
ofs_y = h - tile_size
input_start_x = ofs_x
input_end_x = ofs_x + tile_size
input_start_y = ofs_y
input_end_y = ofs_y + tile_size
noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row*grid_cols + col] * tile_weights
contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights
# Average overlapping areas with more than 1 contributor
noise_pred /= contributors
model_pred = noise_pred
x_denoised = self.sched.step(model_pred, self.timesteps, lq_latent, return_dict=True).prev_sample
output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1)
return output_image
def save_model(self, outf):
sd = {}
sd["unet_lora_target_modules"] = self.target_modules_unet
sd["vae_lora_target_modules"] = self.target_modules_vae
sd["rank_unet"] = self.lora_rank_unet
sd["rank_vae"] = self.lora_rank_vae
sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k}
sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip_conv" in k}
sd["state_dict_vae_de_mlp"] = {k: v for k, v in self.vae_de_mlp.state_dict().items()}
sd["state_dict_unet_de_mlp"] = {k: v for k, v in self.unet_de_mlp.state_dict().items()}
sd["state_dict_vae_block_mlp"] = {k: v for k, v in self.vae_block_mlp.state_dict().items()}
sd["state_dict_unet_block_mlp"] = {k: v for k, v in self.unet_block_mlp.state_dict().items()}
sd["state_dict_vae_fuse_mlp"] = {k: v for k, v in self.vae_fuse_mlp.state_dict().items()}
sd["state_dict_unet_fuse_mlp"] = {k: v for k, v in self.unet_fuse_mlp.state_dict().items()}
sd["w"] = self.W
sd["state_embeddings"] = {
"state_dict_vae_block": self.vae_block_embeddings.state_dict(),
"state_dict_unet_block": self.unet_block_embeddings.state_dict(),
}
torch.save(sd, outf)
def _set_latent_tile(self,
latent_tiled_size = 96,
latent_tiled_overlap = 32):
self.latent_tiled_size = latent_tiled_size
self.latent_tiled_overlap = latent_tiled_overlap
def _init_tiled_vae(self,
encoder_tile_size = 256,
decoder_tile_size = 256,
fast_decoder = False,
fast_encoder = False,
color_fix = False,
vae_to_gpu = True):
# save original forward (only once)
if not hasattr(self.vae.encoder, 'original_forward'):
setattr(self.vae.encoder, 'original_forward', self.vae.encoder.forward)
if not hasattr(self.vae.decoder, 'original_forward'):
setattr(self.vae.decoder, 'original_forward', self.vae.decoder.forward)
encoder = self.vae.encoder
decoder = self.vae.decoder
self.vae.encoder.forward = VAEHook(
encoder, encoder_tile_size, is_decoder=False, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)
self.vae.decoder.forward = VAEHook(
decoder, decoder_tile_size, is_decoder=True, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)
def _gaussian_weights(self, tile_width, tile_height, nbatches):
"""Generates a gaussian mask of weights for tile contributions"""
from numpy import pi, exp, sqrt
import numpy as np
latent_width = tile_width
latent_height = tile_height
var = 0.01
midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)]
midpoint = latent_height / 2
y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)]
weights = np.outer(y_probs, x_probs)
return torch.tile(torch.tensor(weights), (nbatches, self.unet.config.in_channels, 1, 1))
import os
import re
import requests
import sys
import copy
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers import AutoTokenizer, CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel
from peft import LoraConfig, get_peft_model
from peft import PeftModel, load_adapter
import tempfile
p = "src/"
sys.path.append(p)
from model import make_1step_sched, my_lora_fwd
from basicsr.archs.arch_util import default_init_weights
from my_utils.vaehook import VAEHook, perfcount
def get_layer_number(module_name):
base_layers = {
'down_blocks': 0,
'mid_block': 4,
'up_blocks': 5
}
if module_name == 'conv_out':
return 9
base_layer = None
for key in base_layers:
if key in module_name:
base_layer = base_layers[key]
break
if base_layer is None:
return None
additional_layers = int(re.findall(r'\.(\d+)', module_name)[0]) #sum(int(num) for num in re.findall(r'\d+', module_name))
final_layer = base_layer + additional_layers
return final_layer
class S3Diff(torch.nn.Module):
def __init__(self, sd_path=None, pretrained_path=None, lora_rank_unet=32, lora_rank_vae=16, block_embedding_dim=64, args=None):
super().__init__()
self.args = args
self.latent_tiled_size = args.latent_tiled_size
self.latent_tiled_overlap = args.latent_tiled_overlap
self.tokenizer = AutoTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").cuda()
self.sched = make_1step_sched(sd_path)
self.guidance_scale = 1.07
vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(sd_path, subfolder="unet")
target_modules_vae = r"^encoder\..*(conv1|conv2|conv_in|conv_shortcut|conv|conv_out|to_k|to_q|to_v|to_out\.0)$"
target_modules_unet = [
"to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out",
"proj_in", "proj_out", "ff.net.2", "ff.net.0.proj"
]
num_embeddings = 64
self.W = nn.Parameter(torch.randn(num_embeddings), requires_grad=False)
self.vae_de_mlp = nn.Sequential(
nn.Linear(num_embeddings * 4, 256),
nn.ReLU(True),
)
self.unet_de_mlp = nn.Sequential(
nn.Linear(num_embeddings * 4, 256),
nn.ReLU(True),
)
self.vae_block_mlp = nn.Sequential(
nn.Linear(block_embedding_dim, 64),
nn.ReLU(True),
)
self.unet_block_mlp = nn.Sequential(
nn.Linear(block_embedding_dim, 64),
nn.ReLU(True),
)
self.vae_fuse_mlp = nn.Linear(256 + 64, lora_rank_vae ** 2)
self.unet_fuse_mlp = nn.Linear(256 + 64, lora_rank_unet ** 2)
default_init_weights([self.vae_de_mlp, self.unet_de_mlp, self.vae_block_mlp, self.unet_block_mlp, \
self.vae_fuse_mlp, self.unet_fuse_mlp], 1e-5)
# vae
self.vae_block_embeddings = nn.Embedding(6, block_embedding_dim)
self.unet_block_embeddings = nn.Embedding(10, block_embedding_dim)
if pretrained_path is not None:
sd = torch.load(pretrained_path, map_location="cpu")
#vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
#vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
# ========== 修复 VAE LoRA 权重加载 ==========
# 1.提取 VAE LoRA 配置和权重
vae_lora_rank = sd["rank_vae"]
vae_lora_target_modules = sd["vae_lora_target_modules"]
vae_lora_weights = sd["state_dict_vae"] # LoRA 适配器权重
# 2.创建临时目录(用于保存 LoRA 权重,peft 要求文件路径输入)
with tempfile.TemporaryDirectory() as temp_dir:
# 保存 LoRA 权重到临时文件
temp_vae_lora_path = os.path.join(temp_dir, "vae_lora.bin")
torch.save(vae_lora_weights, temp_vae_lora_path)
# 3. 定义 VAE LoRA 配置
vae_lora_config = LoraConfig(
r=vae_lora_rank,
lora_alpha=vae_lora_rank * 2, # 保持 alpha 为 rank 的 2 倍(默认最佳实践)
target_modules=vae_lora_target_modules,
lora_dropout=0.05,
bias="none",
task_type="VAETransformer",
adapter_name="vae_skip"
)
# 4. 用 PeftModel 包装 VAE + 加载 LoRA 适配器(关键修复)
# 3. 加载 VAE LoRA 权重(保持原逻辑不变)
_sd_vae = vae.state_dict()
for k in sd["state_dict_vae"]:
_sd_vae[k] = sd["state_dict_vae"][k]
vae.load_state_dict(_sd_vae)
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
unet.add_adapter(unet_lora_config)
_sd_unet = unet.state_dict()
for k in sd["state_dict_unet"]:
_sd_unet[k] = sd["state_dict_unet"][k]
unet.load_state_dict(_sd_unet)
_vae_de_mlp = self.vae_de_mlp.state_dict()
for k in sd["state_dict_vae_de_mlp"]:
_vae_de_mlp[k] = sd["state_dict_vae_de_mlp"][k]
self.vae_de_mlp.load_state_dict(_vae_de_mlp)
_unet_de_mlp = self.unet_de_mlp.state_dict()
for k in sd["state_dict_unet_de_mlp"]:
_unet_de_mlp[k] = sd["state_dict_unet_de_mlp"][k]
self.unet_de_mlp.load_state_dict(_unet_de_mlp)
_vae_block_mlp = self.vae_block_mlp.state_dict()
for k in sd["state_dict_vae_block_mlp"]:
_vae_block_mlp[k] = sd["state_dict_vae_block_mlp"][k]
self.vae_block_mlp.load_state_dict(_vae_block_mlp)
_unet_block_mlp = self.unet_block_mlp.state_dict()
for k in sd["state_dict_unet_block_mlp"]:
_unet_block_mlp[k] = sd["state_dict_unet_block_mlp"][k]
self.unet_block_mlp.load_state_dict(_unet_block_mlp)
_vae_fuse_mlp = self.vae_fuse_mlp.state_dict()
for k in sd["state_dict_vae_fuse_mlp"]:
_vae_fuse_mlp[k] = sd["state_dict_vae_fuse_mlp"][k]
self.vae_fuse_mlp.load_state_dict(_vae_fuse_mlp)
_unet_fuse_mlp = self.unet_fuse_mlp.state_dict()
for k in sd["state_dict_unet_fuse_mlp"]:
_unet_fuse_mlp[k] = sd["state_dict_unet_fuse_mlp"][k]
self.unet_fuse_mlp.load_state_dict(_unet_fuse_mlp)
self.W = nn.Parameter(sd["w"], requires_grad=False)
embeddings_state_dict = sd["state_embeddings"]
self.vae_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_vae_block'])
self.unet_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_unet_block'])
else:
print("Initializing model with random weights")
vae_lora_config = LoraConfig(r=lora_rank_vae, init_lora_weights="gaussian",
target_modules=target_modules_vae)
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
unet_lora_config = LoraConfig(r=lora_rank_unet, init_lora_weights="gaussian",
target_modules=target_modules_unet
)
unet.add_adapter(unet_lora_config)
self.lora_rank_unet = lora_rank_unet
self.lora_rank_vae = lora_rank_vae
self.target_modules_vae = target_modules_vae
self.target_modules_unet = target_modules_unet
self.vae_lora_layers = []
for name, module in vae.named_modules():
if 'base_layer' in name:
self.vae_lora_layers.append(name[:-len(".base_layer")])
for name, module in vae.named_modules():
if name in self.vae_lora_layers:
module.forward = my_lora_fwd.__get__(module, module.__class__)
self.unet_lora_layers = []
for name, module in unet.named_modules():
if 'base_layer' in name:
self.unet_lora_layers.append(name[:-len(".base_layer")])
for name, module in unet.named_modules():
if name in self.unet_lora_layers:
module.forward = my_lora_fwd.__get__(module, module.__class__)
self.unet_layer_dict = {name: get_layer_number(name) for name in self.unet_lora_layers}
unet.to("cuda")
vae.to("cuda")
self.unet, self.vae = unet, vae
self.timesteps = torch.tensor([999], device="cuda").long()
self.text_encoder.requires_grad_(False)
# vae tile
self._init_tiled_vae(encoder_tile_size=args.vae_encoder_tiled_size, decoder_tile_size=args.vae_decoder_tiled_size)
def set_eval(self):
self.unet.eval()
self.vae.eval()
self.vae_de_mlp.eval()
self.unet_de_mlp.eval()
self.vae_block_mlp.eval()
self.unet_block_mlp.eval()
self.vae_fuse_mlp.eval()
self.unet_fuse_mlp.eval()
self.vae_block_embeddings.requires_grad_(False)
self.unet_block_embeddings.requires_grad_(False)
self.unet.requires_grad_(False)
self.vae.requires_grad_(False)
def set_train(self):
self.unet.train()
self.vae.train()
self.vae_de_mlp.train()
self.unet_de_mlp.train()
self.vae_block_mlp.train()
self.unet_block_mlp.train()
self.vae_fuse_mlp.train()
self.unet_fuse_mlp.train()
self.vae_block_embeddings.requires_grad_(True)
self.unet_block_embeddings.requires_grad_(True)
for n, _p in self.unet.named_parameters():
if "lora" in n:
_p.requires_grad = True
self.unet.conv_in.requires_grad_(True)
for n, _p in self.vae.named_parameters():
if "lora" in n:
_p.requires_grad = True
@perfcount
@torch.no_grad()
def forward(self, c_t, deg_score, pos_prompt, neg_prompt):
if pos_prompt is not None:
# encode the text prompt
pos_caption_tokens = self.tokenizer(pos_prompt, max_length=self.tokenizer.model_max_length,
padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
pos_caption_enc = self.text_encoder(pos_caption_tokens)[0]
else:
pos_caption_enc = self.text_encoder(prompt_tokens)[0]
if neg_prompt is not None:
# encode the text prompt
neg_caption_tokens = self.tokenizer(neg_prompt, max_length=self.tokenizer.model_max_length,
padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
neg_caption_enc = self.text_encoder(neg_caption_tokens)[0]
else:
neg_caption_enc = self.text_encoder(neg_prompt_tokens)[0]
# degradation fourier embedding
deg_proj = deg_score[..., None] * self.W[None, None, :] * 2 * np.pi
deg_proj = torch.cat([torch.sin(deg_proj), torch.cos(deg_proj)], dim=-1)
deg_proj = torch.cat([deg_proj[:, 0], deg_proj[:, 1]], dim=-1)
# degradation mlp forward
vae_de_c_embed = self.vae_de_mlp(deg_proj)
unet_de_c_embed = self.unet_de_mlp(deg_proj)
# block embedding mlp forward
vae_block_c_embeds = self.vae_block_mlp(self.vae_block_embeddings.weight)
unet_block_c_embeds = self.unet_block_mlp(self.unet_block_embeddings.weight)
vae_embeds = self.vae_fuse_mlp(torch.cat([vae_de_c_embed.unsqueeze(1).repeat(1, vae_block_c_embeds.shape[0], 1), \
vae_block_c_embeds.unsqueeze(0).repeat(vae_de_c_embed.shape[0],1,1)], -1))
unet_embeds = self.unet_fuse_mlp(torch.cat([unet_de_c_embed.unsqueeze(1).repeat(1, unet_block_c_embeds.shape[0], 1), \
unet_block_c_embeds.unsqueeze(0).repeat(unet_de_c_embed.shape[0],1,1)], -1))
for layer_name, module in self.vae.named_modules():
if layer_name in self.vae_lora_layers:
split_name = layer_name.split(".")
if split_name[1] == 'down_blocks':
block_id = int(split_name[2])
vae_embed = vae_embeds[:, block_id]
elif split_name[1] == 'mid_block':
vae_embed = vae_embeds[:, -2]
else:
vae_embed = vae_embeds[:, -1]
module.de_mod = vae_embed.reshape(-1, self.lora_rank_vae, self.lora_rank_vae)
for layer_name, module in self.unet.named_modules():
if layer_name in self.unet_lora_layers:
split_name = layer_name.split(".")
if split_name[0] == 'down_blocks':
block_id = int(split_name[1])
unet_embed = unet_embeds[:, block_id]
elif split_name[0] == 'mid_block':
unet_embed = unet_embeds[:, 4]
elif split_name[0] == 'up_blocks':
block_id = int(split_name[1]) + 5
unet_embed = unet_embeds[:, block_id]
else:
unet_embed = unet_embeds[:, -1]
module.de_mod = unet_embed.reshape(-1, self.lora_rank_unet, self.lora_rank_unet)
lq_latent = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
## add tile function
_, _, h, w = lq_latent.size()
tile_size, tile_overlap = (self.latent_tiled_size, self.latent_tiled_overlap)
if h * w <= tile_size * tile_size:
print(f"[Tiled Latent]: the input size is tiny and unnecessary to tile.")
pos_model_pred = self.unet(lq_latent, self.timesteps, encoder_hidden_states=pos_caption_enc).sample
neg_model_pred = self.unet(lq_latent, self.timesteps, encoder_hidden_states=neg_caption_enc).sample
model_pred = neg_model_pred + self.guidance_scale * (pos_model_pred - neg_model_pred)
else:
print(f"[Tiled Latent]: the input size is {c_t.shape[-2]}x{c_t.shape[-1]}, need to tiled")
# tile_weights = self._gaussian_weights(tile_size, tile_size, 1).to()
tile_size = min(tile_size, min(h, w))
tile_weights = self._gaussian_weights(tile_size, tile_size, 1).to(c_t.device)
grid_rows = 0
cur_x = 0
while cur_x < lq_latent.size(-1):
cur_x = max(grid_rows * tile_size-tile_overlap * grid_rows, 0)+tile_size
grid_rows += 1
grid_cols = 0
cur_y = 0
while cur_y < lq_latent.size(-2):
cur_y = max(grid_cols * tile_size-tile_overlap * grid_cols, 0)+tile_size
grid_cols += 1
input_list = []
noise_preds = []
for row in range(grid_rows):
noise_preds_row = []
for col in range(grid_cols):
if col < grid_cols-1 or row < grid_rows-1:
# extract tile from input image
ofs_x = max(row * tile_size-tile_overlap * row, 0)
ofs_y = max(col * tile_size-tile_overlap * col, 0)
# input tile area on total image
if row == grid_rows-1:
ofs_x = w - tile_size
if col == grid_cols-1:
ofs_y = h - tile_size
input_start_x = ofs_x
input_end_x = ofs_x + tile_size
input_start_y = ofs_y
input_end_y = ofs_y + tile_size
# input tile dimensions
input_tile = lq_latent[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
input_list.append(input_tile)
if len(input_list) == 1 or col == grid_cols-1:
input_list_t = torch.cat(input_list, dim=0)
# predict the noise residual
pos_model_pred = self.unet(input_list_t, self.timesteps, encoder_hidden_states=pos_caption_enc).sample
neg_model_pred = self.unet(input_list_t, self.timesteps, encoder_hidden_states=neg_caption_enc).sample
model_out = neg_model_pred + self.guidance_scale * (pos_model_pred - neg_model_pred)
input_list = []
noise_preds.append(model_out)
# Stitch noise predictions for all tiles
noise_pred = torch.zeros(lq_latent.shape, device=lq_latent.device)
contributors = torch.zeros(lq_latent.shape, device=lq_latent.device)
# Add each tile contribution to overall latents
for row in range(grid_rows):
for col in range(grid_cols):
if col < grid_cols-1 or row < grid_rows-1:
# extract tile from input image
ofs_x = max(row * tile_size-tile_overlap * row, 0)
ofs_y = max(col * tile_size-tile_overlap * col, 0)
# input tile area on total image
if row == grid_rows-1:
ofs_x = w - tile_size
if col == grid_cols-1:
ofs_y = h - tile_size
input_start_x = ofs_x
input_end_x = ofs_x + tile_size
input_start_y = ofs_y
input_end_y = ofs_y + tile_size
noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row*grid_cols + col] * tile_weights
contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights
# Average overlapping areas with more than 1 contributor
noise_pred /= contributors
model_pred = noise_pred
x_denoised = self.sched.step(model_pred, self.timesteps, lq_latent, return_dict=True).prev_sample
output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1)
return output_image
def save_model(self, outf):
sd = {}
sd["unet_lora_target_modules"] = self.target_modules_unet
sd["vae_lora_target_modules"] = self.target_modules_vae
sd["rank_unet"] = self.lora_rank_unet
sd["rank_vae"] = self.lora_rank_vae
sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k}
sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip_conv" in k}
sd["state_dict_vae_de_mlp"] = {k: v for k, v in self.vae_de_mlp.state_dict().items()}
sd["state_dict_unet_de_mlp"] = {k: v for k, v in self.unet_de_mlp.state_dict().items()}
sd["state_dict_vae_block_mlp"] = {k: v for k, v in self.vae_block_mlp.state_dict().items()}
sd["state_dict_unet_block_mlp"] = {k: v for k, v in self.unet_block_mlp.state_dict().items()}
sd["state_dict_vae_fuse_mlp"] = {k: v for k, v in self.vae_fuse_mlp.state_dict().items()}
sd["state_dict_unet_fuse_mlp"] = {k: v for k, v in self.unet_fuse_mlp.state_dict().items()}
sd["w"] = self.W
sd["state_embeddings"] = {
"state_dict_vae_block": self.vae_block_embeddings.state_dict(),
"state_dict_unet_block": self.unet_block_embeddings.state_dict(),
}
torch.save(sd, outf)
def _set_latent_tile(self,
latent_tiled_size = 96,
latent_tiled_overlap = 32):
self.latent_tiled_size = latent_tiled_size
self.latent_tiled_overlap = latent_tiled_overlap
def _init_tiled_vae(self,
encoder_tile_size = 256,
decoder_tile_size = 256,
fast_decoder = False,
fast_encoder = False,
color_fix = False,
vae_to_gpu = True):
# save original forward (only once)
if not hasattr(self.vae.encoder, 'original_forward'):
setattr(self.vae.encoder, 'original_forward', self.vae.encoder.forward)
if not hasattr(self.vae.decoder, 'original_forward'):
setattr(self.vae.decoder, 'original_forward', self.vae.decoder.forward)
encoder = self.vae.encoder
decoder = self.vae.decoder
self.vae.encoder.forward = VAEHook(
encoder, encoder_tile_size, is_decoder=False, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)
self.vae.decoder.forward = VAEHook(
decoder, decoder_tile_size, is_decoder=True, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)
def _gaussian_weights(self, tile_width, tile_height, nbatches):
"""Generates a gaussian mask of weights for tile contributions"""
from numpy import pi, exp, sqrt
import numpy as np
latent_width = tile_width
latent_height = tile_height
var = 0.01
midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)]
midpoint = latent_height / 2
y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)]
weights = np.outer(y_probs, x_probs)
return torch.tile(torch.tensor(weights), (nbatches, self.unet.config.in_channels, 1, 1))
import os
os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO'
import gc
import lpips
import clip
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from omegaconf import OmegaConf
from accelerate import Accelerator
from accelerate.utils import set_seed
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
import diffusers
from diffusers.utils.import_utils import is_xformers_available
from diffusers.optimization import get_scheduler
from de_net import DEResNet
from s3diff import S3Diff
from my_utils.training_utils import parse_args_paired_training, PairedDataset, degradation_proc
def main(args):
# init and save configs
config = OmegaConf.load(args.base_config)
if args.sd_path is None:
from huggingface_hub import snapshot_download
sd_path = snapshot_download(repo_id="stabilityai/sd-turbo")
else:
sd_path = args.sd_path
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
if args.seed is not None:
set_seed(args.seed)
if accelerator.is_main_process:
os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True)
# initialize degradation estimation network
net_de = DEResNet(num_in_ch=3, num_degradation=2)
net_de.load_model(args.de_net_path)
net_de = net_de.cuda()
net_de.eval()
# initialize net_sr
net_sr = S3Diff(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae, sd_path=sd_path, pretrained_path=args.pretrained_path)
net_sr.set_train()
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
net_sr.unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available, please install it by running `pip install xformers`")
if args.gradient_checkpointing:
net_sr.unet.enable_gradient_checkpointing()
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.gan_disc_type == "vagan":
import vision_aided_loss
net_disc = vision_aided_loss.Discriminator(cv_type='dino', output_type='conv_multi_level', loss_type=args.gan_loss_type, device="cuda")
else:
raise NotImplementedError(f"Discriminator type {args.gan_disc_type} not implemented")
net_disc = net_disc.cuda()
net_disc.requires_grad_(True)
net_disc.cv_ensemble.requires_grad_(False)
net_disc.train()
net_lpips = lpips.LPIPS(net='vgg').cuda()
net_lpips.requires_grad_(False)
# make the optimizer
layers_to_opt = []
layers_to_opt = layers_to_opt + list(net_sr.vae_block_embeddings.parameters()) + list(net_sr.unet_block_embeddings.parameters())
layers_to_opt = layers_to_opt + list(net_sr.vae_de_mlp.parameters()) + list(net_sr.unet_de_mlp.parameters()) + \
list(net_sr.vae_block_mlp.parameters()) + list(net_sr.unet_block_mlp.parameters()) + \
list(net_sr.vae_fuse_mlp.parameters()) + list(net_sr.unet_fuse_mlp.parameters())
for n, _p in net_sr.unet.named_parameters():
if "lora" in n:
assert _p.requires_grad
layers_to_opt.append(_p)
layers_to_opt += list(net_sr.unet.conv_in.parameters())
for n, _p in net_sr.vae.named_parameters():
if "lora" in n:
assert _p.requires_grad
layers_to_opt.append(_p)
dataset_train = PairedDataset(config.train)
dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers)
dataset_val = PairedDataset(config.validation)
dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0)
optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,)
lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles, power=args.lr_power,)
optimizer_disc = torch.optim.AdamW(net_disc.parameters(), lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,)
lr_scheduler_disc = get_scheduler(args.lr_scheduler, optimizer=optimizer_disc,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles, power=args.lr_power)
# Prepare everything with our `accelerator`.
net_sr, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc = accelerator.prepare(
net_sr, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc
)
net_de, net_lpips = accelerator.prepare(net_de, net_lpips)
# # renorm with image net statistics
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move al networksr to device and cast to weight_dtype
net_sr.to(accelerator.device, dtype=weight_dtype)
net_de.to(accelerator.device, dtype=weight_dtype)
net_disc.to(accelerator.device, dtype=weight_dtype)
net_lpips.to(accelerator.device, dtype=weight_dtype)
progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps",
disable=not accelerator.is_local_main_process,)
for name, module in net_disc.named_modules():
if "attn" in name:
module.fused_attn = False
# start the training loop
global_step = 0
for epoch in range(0, args.num_training_epochs):
for step, batch in enumerate(dl_train):
l_acc = [net_sr, net_disc]
with accelerator.accumulate(*l_acc):
x_src, x_tgt, x_ori_size_src = degradation_proc(config, batch, accelerator.device)
B, C, H, W = x_src.shape
with torch.no_grad():
deg_score = net_de(x_ori_size_src.detach()).detach()
pos_tag_prompt = [args.pos_prompt for _ in range(B)]
neg_tag_prompt = [args.neg_prompt for _ in range(B)]
neg_probs = torch.rand(B).to(accelerator.device)
# build mixed prompt and target
mixed_tag_prompt = [_neg_tag if p_i < args.neg_prob else _pos_tag for _neg_tag, _pos_tag, p_i in zip(neg_tag_prompt, pos_tag_prompt, neg_probs)]
neg_probs = neg_probs.reshape(B, 1, 1, 1)
mixed_tgt = torch.where(neg_probs < args.neg_prob, x_src, x_tgt)
x_tgt_pred = net_sr(x_src.detach(), deg_score, mixed_tag_prompt)
loss_l2 = F.mse_loss(x_tgt_pred.float(), mixed_tgt.detach().float(), reduction="mean") * args.lambda_l2
loss_lpips = net_lpips(x_tgt_pred.float(), mixed_tgt.detach().float()).mean() * args.lambda_lpips
loss = loss_l2 + loss_lpips
accelerator.backward(loss, retain_graph=False)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
"""
Generator loss: fool the discriminator
"""
x_tgt_pred = net_sr(x_src.detach(), deg_score, pos_tag_prompt)
lossG = net_disc(x_tgt_pred, for_G=True).mean() * args.lambda_gan
accelerator.backward(lossG)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
"""
Discriminator loss: fake image vs real image
"""
# real image
lossD_real = net_disc(x_tgt.detach(), for_real=True).mean() * args.lambda_gan
accelerator.backward(lossD_real.mean())
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(net_disc.parameters(), args.max_grad_norm)
optimizer_disc.step()
lr_scheduler_disc.step()
optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none)
# fake image
lossD_fake = net_disc(x_tgt_pred.detach(), for_real=False).mean() * args.lambda_gan
accelerator.backward(lossD_fake.mean())
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(net_disc.parameters(), args.max_grad_norm)
optimizer_disc.step()
optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none)
lossD = lossD_real + lossD_fake
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
logs = {}
logs["lossG"] = lossG.detach().item()
logs["lossD"] = lossD.detach().item()
logs["loss_l2"] = loss_l2.detach().item()
logs["loss_lpips"] = loss_lpips.detach().item()
progress_bar.set_postfix(**logs)
# checkpoint the model
if global_step % args.checkpointing_steps == 1:
outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl")
accelerator.unwrap_model(net_sr).save_model(outf)
# compute validation set FID, L2, LPIPS, CLIP-SIM
if global_step % args.eval_freq == 1:
l_l2, l_lpips = [], []
val_count = 0
for step, batch_val in enumerate(dl_val):
if step >= args.num_samples_eval:
break
x_src, x_tgt, x_ori_size_src = degradation_proc(config, batch_val, accelerator.device)
B, C, H, W = x_src.shape
assert B == 1, "Use batch size 1 for eval."
with torch.no_grad():
# forward pass
with torch.no_grad():
deg_score = net_de(x_ori_size_src.detach())
pos_tag_prompt = [args.pos_prompt for _ in range(B)]
x_tgt_pred = accelerator.unwrap_model(net_sr)(x_src.detach(), deg_score, pos_tag_prompt)
# compute the reconstruction losses
loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.detach().float(), reduction="mean")
loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.detach().float()).mean()
l_l2.append(loss_l2.item())
l_lpips.append(loss_lpips.item())
if args.save_val and val_count < 5:
x_src = x_src.cpu().detach() * 0.5 + 0.5
x_tgt = x_tgt.cpu().detach() * 0.5 + 0.5
x_tgt_pred = x_tgt_pred.cpu().detach() * 0.5 + 0.5
combined = torch.cat([x_src, x_tgt_pred, x_tgt], dim=3)
output_pil = transforms.ToPILImage()(combined[0])
outf = os.path.join(args.output_dir, f"val_{step}.png")
output_pil.save(outf)
val_count += 1
logs["val/l2"] = np.mean(l_l2)
logs["val/lpips"] = np.mean(l_lpips)
gc.collect()
torch.cuda.empty_cache()
accelerator.log(logs, step=global_step)
if __name__ == "__main__":
args = parse_args_paired_training()
main(args)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import builtins
import datetime
import os
import time
from collections import defaultdict, deque
from pathlib import Path
import json
import subprocess
import torch
import torch.distributed as dist
from typing import List, Dict, Tuple, Optional
from torch import Tensor
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
builtin_print = builtins.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
force = force or (get_world_size() > 8)
if is_master or force:
now = datetime.datetime.now().time()
builtin_print('[{}] '.format(now), end='') # print with time stamp
builtin_print(*args, **kwargs)
builtins.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
args.dist_url = 'env://'
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
elif 'SLURM_PROCID' in os.environ:
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
addr = subprocess.getoutput(
'scontrol show hostname {} | head -n1'.format(node_list))
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29200')
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['LOCAL_SIZE'] = str(num_gpus)
args.dist_url = 'env://'
args.world_size = ntasks
args.rank = proc_id
args.gpu = proc_id % num_gpus
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
def clip_grad_norm_(
parameters, max_norm: float, norm_type: float = 2.0,
error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor:
r"""Clips gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float): max norm of the gradients
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
foreach (bool): use the faster foreach-based implementation.
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
fall back to the slow implementation for other device types.
Default: ``None``
Returns:
Total norm of the parameter gradients (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
grads = [p.grad for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(grads) == 0:
return torch.tensor(0.)
first_device = grads[0].device
grouped_grads: Dict[Tuple[torch.device, torch.dtype], List[List[Tensor]]] \
= {(first_device, grads[0].dtype): [[g.detach() for g in grads]]}
norms = [torch.norm(g) for g in grads]
total_norm = torch.norm(torch.stack(norms))
clip_coef = max_norm / (total_norm + 1e-6)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for ((device, _), [grads]) in grouped_grads.items():
if (foreach is None or foreach):
torch._foreach_mul_(grads, clip_coef_clamped.to(device)) # type: ignore[call-overload]
elif foreach:
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
else:
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in grads:
g.detach().mul_(clip_coef_clamped_device)
return total_norm
class NativeScalerWithGradNormCount:
state_dict_key = "amp_scaler"
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
self._scaler.scale(loss).backward(create_graph=create_graph)
if update_grad:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
norm = clip_grad_norm_(parameters, clip_grad)
else:
self._scaler.unscale_(optimizer)
norm = get_grad_norm_(parameters)
self._scaler.step(optimizer)
self._scaler.update()
else:
norm = None
return norm
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
return total_norm
def save_model(args, epoch, model, model_without_ddp, optimizer):
output_dir = Path(args.output_dir)
epoch_name = str(epoch)
# checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
checkpoint_paths = [output_dir / 'checkpoint.pth']
for checkpoint_path in checkpoint_paths:
to_save = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'args': args,
}
save_on_master(to_save, checkpoint_path)
def load_model(args, model_without_ddp, optimizer):
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1
print("With optim & sched!")
def auto_load_model(args, model, model_without_ddp, optimizer):
output_dir = Path(args.output_dir)
# torch.amp
if args.auto_resume and len(args.resume) == 0:
import glob
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
latest_ckpt = -1
for ckpt in all_checkpoints:
t = ckpt.split('-')[-1].split('.')[0]
if t.isdigit():
latest_ckpt = max(int(t), latest_ckpt)
if latest_ckpt >= 0:
args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
print("Auto resume checkpoint: %s" % args.resume)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1
print("With optim & sched!")
def all_reduce_mean(x):
world_size = get_world_size()
if world_size > 1:
x_reduce = torch.tensor(x).cuda()
dist.all_reduce(x_reduce)
x_reduce /= world_size
return x_reduce.item()
else:
return x
def create_ds_config(args):
args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
with open(args.deepspeed_config, mode="w") as writer:
ds_config = {
"train_batch_size": args.batch_size * args.accum_iter * get_world_size(),
"train_micro_batch_size_per_gpu": args.batch_size,
"steps_per_print": 1000,
"optimizer": {
"type": "Adam",
"adam_w_mode": True,
"params": {
"lr": args.lr,
"weight_decay": args.weight_decay,
"bias_correction": True,
"betas": [
args.opt_betas[0],
args.opt_betas[1]
],
"eps": args.opt_eps
}
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
# "bf16": {
# "enabled": True
# },
"amp": {
"enabled": False,
"opt_level": "O2"
},
"flops_profiler": {
"enabled": True,
"profile_step": -1,
"module_depth": -1,
"top_modules": 1,
"detailed": True,
},
}
if args.clip_grad is not None:
ds_config.update({'gradient_clipping': args.clip_grad})
if args.zero_stage == 1:
ds_config.update({"zero_optimization": {"stage": args.zero_stage, "reduce_bucket_size": 5e8}})
elif args.zero_stage > 1:
raise NotImplementedError()
writer.write(json.dumps(ds_config, indent=2))
def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
parameter_group_names = {}
parameter_group_vars = {}
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
group_name = "no_decay"
this_weight_decay = 0.
else:
group_name = "decay"
this_weight_decay = weight_decay
if get_num_layer is not None:
layer_id = get_num_layer(name)
group_name = "layer_%d_%s" % (layer_id, group_name)
else:
layer_id = None
if group_name not in parameter_group_names:
if get_layer_scale is not None:
scale = get_layer_scale(layer_id)
else:
scale = 1.
parameter_group_names[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(name)
print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
return list(parameter_group_vars.values())
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2021-11-24 16:54:19
import sys
import cv2
import math
import torch
import random
import numpy as np
from scipy import fft
from pathlib import Path
from einops import rearrange
from skimage import img_as_ubyte, img_as_float32
# --------------------------Metrics----------------------------
def ssim(img1, img2):
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def calculate_ssim(im1, im2, border=0, ycbcr=False):
'''
SSIM the same outputs as MATLAB's
im1, im2: h x w x , [0, 255], uint8
'''
if not im1.shape == im2.shape:
raise ValueError('Input images must have the same dimensions.')
if ycbcr:
im1 = rgb2ycbcr(im1, True)
im2 = rgb2ycbcr(im2, True)
h, w = im1.shape[:2]
im1 = im1[border:h-border, border:w-border]
im2 = im2[border:h-border, border:w-border]
if im1.ndim == 2:
return ssim(im1, im2)
elif im1.ndim == 3:
if im1.shape[2] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(im1[:,:,i], im2[:,:,i]))
return np.array(ssims).mean()
elif im1.shape[2] == 1:
return ssim(np.squeeze(im1), np.squeeze(im2))
else:
raise ValueError('Wrong input image dimensions.')
def calculate_psnr(im1, im2, border=0, ycbcr=False):
'''
PSNR metric.
im1, im2: h x w x , [0, 255], uint8
'''
if not im1.shape == im2.shape:
raise ValueError('Input images must have the same dimensions.')
if ycbcr:
im1 = rgb2ycbcr(im1, True)
im2 = rgb2ycbcr(im2, True)
h, w = im1.shape[:2]
im1 = im1[border:h-border, border:w-border]
im2 = im2[border:h-border, border:w-border]
im1 = im1.astype(np.float64)
im2 = im2.astype(np.float64)
mse = np.mean((im1 - im2)**2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
def batch_PSNR(img, imclean, border=0, ycbcr=False):
if ycbcr:
img = rgb2ycbcrTorch(img, True)
imclean = rgb2ycbcrTorch(imclean, True)
Img = img.data.cpu().numpy()
Iclean = imclean.data.cpu().numpy()
Img = img_as_ubyte(Img)
Iclean = img_as_ubyte(Iclean)
PSNR = 0
h, w = Iclean.shape[2:]
for i in range(Img.shape[0]):
PSNR += calculate_psnr(Iclean[i,:,].transpose((1,2,0)), Img[i,:,].transpose((1,2,0)), border)
return PSNR
def batch_SSIM(img, imclean, border=0, ycbcr=False):
if ycbcr:
img = rgb2ycbcrTorch(img, True)
imclean = rgb2ycbcrTorch(imclean, True)
Img = img.data.cpu().numpy()
Iclean = imclean.data.cpu().numpy()
Img = img_as_ubyte(Img)
Iclean = img_as_ubyte(Iclean)
SSIM = 0
for i in range(Img.shape[0]):
SSIM += calculate_ssim(Iclean[i,:,].transpose((1,2,0)), Img[i,:,].transpose((1,2,0)), border)
return SSIM
def normalize_np(im, mean=0.5, std=0.5, reverse=False):
'''
Input:
im: h x w x c, numpy array
Normalize: (im - mean) / std
Reverse: im * std + mean
'''
if not isinstance(mean, (list, tuple)):
mean = [mean, ] * im.shape[2]
mean = np.array(mean).reshape([1, 1, im.shape[2]])
if not isinstance(std, (list, tuple)):
std = [std, ] * im.shape[2]
std = np.array(std).reshape([1, 1, im.shape[2]])
if not reverse:
out = (im.astype(np.float32) - mean) / std
else:
out = im.astype(np.float32) * std + mean
return out
def normalize_th(im, mean=0.5, std=0.5, reverse=False):
'''
Input:
im: b x c x h x w, torch tensor
Normalize: (im - mean) / std
Reverse: im * std + mean
'''
if not isinstance(mean, (list, tuple)):
mean = [mean, ] * im.shape[1]
mean = torch.tensor(mean, device=im.device).view([1, im.shape[1], 1, 1])
if not isinstance(std, (list, tuple)):
std = [std, ] * im.shape[1]
std = torch.tensor(std, device=im.device).view([1, im.shape[1], 1, 1])
if not reverse:
out = (im - mean) / std
else:
out = im * std + mean
return out
# ------------------------Image format--------------------------
def rgb2ycbcr(im, only_y=True):
'''
same as matlab rgb2ycbcr
Input:
im: uint8 [0,255] or float [0,1]
only_y: only return Y channel
'''
# transform to float64 data type, range [0, 255]
if im.dtype == np.uint8:
im_temp = im.astype(np.float64)
else:
im_temp = (im * 255).astype(np.float64)
# convert
if only_y:
rlt = np.dot(im_temp, np.array([65.481, 128.553, 24.966])/ 255.0) + 16.0
else:
rlt = np.matmul(im_temp, np.array([[65.481, -37.797, 112.0 ],
[128.553, -74.203, -93.786],
[24.966, 112.0, -18.214]])/255.0) + [16, 128, 128]
if im.dtype == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(im.dtype)
def rgb2ycbcrTorch(im, only_y=True):
'''
same as matlab rgb2ycbcr
Input:
im: float [0,1], N x 3 x H x W
only_y: only return Y channel
'''
# transform to range [0,255.0]
im_temp = im.permute([0,2,3,1]) * 255.0 # N x H x W x C --> N x H x W x C
# convert
if only_y:
rlt = torch.matmul(im_temp, torch.tensor([65.481, 128.553, 24.966],
device=im.device, dtype=im.dtype).view([3,1])/ 255.0) + 16.0
else:
rlt = torch.matmul(im_temp, torch.tensor([[65.481, -37.797, 112.0 ],
[128.553, -74.203, -93.786],
[24.966, 112.0, -18.214]],
device=im.device, dtype=im.dtype)/255.0) + \
torch.tensor([16, 128, 128]).view([-1, 1, 1, 3])
rlt /= 255.0
rlt.clamp_(0.0, 1.0)
return rlt.permute([0, 3, 1, 2])
def bgr2rgb(im): return cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
def rgb2bgr(im): return cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.
After clamping to [min, max], values will be normalized to [0, 1].
Args:
tensor (Tensor or list[Tensor]): Accept shapes:
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
2) 3D Tensor of shape (3/1 x H x W);
3) 2D Tensor of shape (H x W).
Tensor channel should be in RGB order.
rgb2bgr (bool): Whether to change rgb to bgr.
out_type (numpy type): output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple[int]): min and max values for clamp.
Returns:
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
shape (H x W). The channel order is BGR.
"""
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
flag_tensor = torch.is_tensor(tensor)
if flag_tensor:
tensor = [tensor]
result = []
for _tensor in tensor:
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
n_dim = _tensor.dim()
if n_dim == 4:
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
img_np = img_np.transpose(1, 2, 0)
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 3:
img_np = _tensor.numpy()
img_np = img_np.transpose(1, 2, 0)
if img_np.shape[2] == 1: # gray image
img_np = np.squeeze(img_np, axis=2)
else:
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 2:
img_np = _tensor.numpy()
else:
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
if out_type == np.uint8:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np = (img_np * 255.0).round()
img_np = img_np.astype(out_type)
result.append(img_np)
if len(result) == 1 and flag_tensor:
result = result[0]
return result
def img2tensor(imgs, out_type=torch.float32):
"""Convert image numpy arrays into torch tensor.
Args:
imgs (Array or list[array]): Accept shapes:
3) list of numpy arrays
1) 3D numpy array of shape (H x W x 3/1);
2) 2D Tensor of shape (H x W).
Tensor channel should be in RGB order.
Returns:
(array or list): 4D ndarray of shape (1 x C x H x W)
"""
def _img2tensor(img):
if img.ndim == 2:
tensor = torch.from_numpy(img[None, None,]).type(out_type)
elif img.ndim == 3:
tensor = torch.from_numpy(rearrange(img, 'h w c -> c h w')).type(out_type).unsqueeze(0)
else:
raise TypeError(f'2D or 3D numpy array expected, got{img.ndim}D array')
return tensor
if not (isinstance(imgs, np.ndarray) or (isinstance(imgs, list) and all(isinstance(t, np.ndarray) for t in imgs))):
raise TypeError(f'Numpy array or list of numpy array expected, got {type(imgs)}')
flag_numpy = isinstance(imgs, np.ndarray)
if flag_numpy:
imgs = [imgs,]
result = []
for _img in imgs:
result.append(_img2tensor(_img))
if len(result) == 1 and flag_numpy:
result = result[0]
return result
# ------------------------Image resize-----------------------------
def imresize_np(img, scale, antialiasing=True):
# Now the scale should be the same for H and W
# input: img: Numpy, HWC or HW [0,1]
# output: HWC or HW [0,1] w/o round
img = torch.from_numpy(img)
need_squeeze = True if img.dim() == 2 else False
if need_squeeze:
img.unsqueeze_(2)
in_H, in_W, in_C = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
kernel_width = 4
kernel = 'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
sym_patch = img[:sym_len_Hs, :, :]
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(0, inv_idx)
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
sym_patch = img[-sym_len_He:, :, :]
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(0, inv_idx)
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(out_H, in_W, in_C)
kernel_width = weights_H.size(1)
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
sym_patch = out_1[:, :sym_len_Ws, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
sym_patch = out_1[:, -sym_len_We:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(out_H, out_W, in_C)
kernel_width = weights_W.size(1)
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
if need_squeeze:
out_2.squeeze_()
return out_2.numpy()
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
if (scale < 1) and (antialiasing):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width = kernel_width / scale
# Output-space coordinates
x = torch.linspace(1, out_length, out_length)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5+scale in output
# space maps to 1.5 in input space.
u = x / scale + 0.5 * (1 - 1 / scale)
# What is the left-most pixel that can be involved in the computation?
left = torch.floor(u - kernel_width / 2)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
P = math.ceil(kernel_width) + 2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
1, P).expand(out_length, P)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
# apply cubic kernel
if (scale < 1) and (antialiasing):
weights = scale * cubic(distance_to_center * scale)
else:
weights = cubic(distance_to_center)
# Normalize the weights matrix so that each row sums to 1.
weights_sum = torch.sum(weights, 1).view(out_length, 1)
weights = weights / weights_sum.expand(out_length, P)
# If a column in weights is all zero, get rid of it. only consider the first and last column.
weights_zero_tmp = torch.sum((weights == 0), 0)
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
indices = indices.narrow(1, 1, P - 2)
weights = weights.narrow(1, 1, P - 2)
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
indices = indices.narrow(1, 0, P - 2)
weights = weights.narrow(1, 0, P - 2)
weights = weights.contiguous()
indices = indices.contiguous()
sym_len_s = -indices.min() + 1
sym_len_e = indices.max() - in_length
indices = indices + sym_len_s - 1
return weights, indices, int(sym_len_s), int(sym_len_e)
# matlab 'imresize' function, now only support 'bicubic'
def cubic(x):
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
# ------------------------Image I/O-----------------------------
def imread(path, chn='rgb', dtype='float32'):
'''
Read image.
chn: 'rgb', 'bgr' or 'gray'
out:
im: h x w x c, numpy tensor
'''
im = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) # BGR, uint8
try:
if chn.lower() == 'rgb':
if im.ndim == 3:
im = bgr2rgb(im)
else:
im = np.stack((im, im, im), axis=2)
elif chn.lower() == 'gray':
assert im.ndim == 2
except:
print(str(path))
if dtype == 'float32':
im = im.astype(np.float32) / 255.
elif dtype == 'float64':
im = im.astype(np.float64) / 255.
elif dtype == 'uint8':
pass
else:
sys.exit('Please input corrected dtype: float32, float64 or uint8!')
return im
def imwrite(im_in, path, chn='rgb', dtype_in='float32', qf=None):
'''
Save image.
Input:
im: h x w x c, numpy tensor
path: the saving path
chn: the channel order of the im,
'''
im = im_in.copy()
if isinstance(path, str):
path = Path(path)
if dtype_in != 'uint8':
im = img_as_ubyte(im)
if chn.lower() == 'rgb' and im.ndim == 3:
im = rgb2bgr(im)
if qf is not None and path.suffix.lower() in ['.jpg', '.jpeg']:
flag = cv2.imwrite(str(path), im, [int(cv2.IMWRITE_JPEG_QUALITY), int(qf)])
else:
flag = cv2.imwrite(str(path), im)
return flag
def jpeg_compress(im, qf, chn_in='rgb'):
'''
Input:
im: h x w x 3 array
qf: compress factor, (0, 100]
chn_in: 'rgb' or 'bgr'
Return:
Compressed Image with channel order: chn_in
'''
# transform to BGR channle and uint8 data type
im_bgr = rgb2bgr(im) if chn_in.lower() == 'rgb' else im
if im.dtype != np.dtype('uint8'): im_bgr = img_as_ubyte(im_bgr)
# JPEG compress
flag, encimg = cv2.imencode('.jpg', im_bgr, [int(cv2.IMWRITE_JPEG_QUALITY), qf])
assert flag
im_jpg_bgr = cv2.imdecode(encimg, 1) # uint8, BGR
# transform back to original channel and the original data type
im_out = bgr2rgb(im_jpg_bgr) if chn_in.lower() == 'rgb' else im_jpg_bgr
if im.dtype != np.dtype('uint8'): im_out = img_as_float32(im_out).astype(im.dtype)
return im_out
# ------------------------Augmentation-----------------------------
def data_aug_np(image, mode):
'''
Performs data augmentation of the input image
Input:
image: a cv2 (OpenCV) image
mode: int. Choice of transformation to apply to the image
0 - no transformation
1 - flip up and down
2 - rotate counterwise 90 degree
3 - rotate 90 degree and flip up and down
4 - rotate 180 degree
5 - rotate 180 degree and flip
6 - rotate 270 degree
7 - rotate 270 degree and flip
'''
if mode == 0:
# original
out = image
elif mode == 1:
# flip up and down
out = np.flipud(image)
elif mode == 2:
# rotate counterwise 90 degree
out = np.rot90(image)
elif mode == 3:
# rotate 90 degree and flip up and down
out = np.rot90(image)
out = np.flipud(out)
elif mode == 4:
# rotate 180 degree
out = np.rot90(image, k=2)
elif mode == 5:
# rotate 180 degree and flip
out = np.rot90(image, k=2)
out = np.flipud(out)
elif mode == 6:
# rotate 270 degree
out = np.rot90(image, k=3)
elif mode == 7:
# rotate 270 degree and flip
out = np.rot90(image, k=3)
out = np.flipud(out)
else:
raise Exception('Invalid choice of image transformation')
return out.copy()
def inverse_data_aug_np(image, mode):
'''
Performs inverse data augmentation of the input image
'''
if mode == 0:
# original
out = image
elif mode == 1:
out = np.flipud(image)
elif mode == 2:
out = np.rot90(image, axes=(1,0))
elif mode == 3:
out = np.flipud(image)
out = np.rot90(out, axes=(1,0))
elif mode == 4:
out = np.rot90(image, k=2, axes=(1,0))
elif mode == 5:
out = np.flipud(image)
out = np.rot90(out, k=2, axes=(1,0))
elif mode == 6:
out = np.rot90(image, k=3, axes=(1,0))
elif mode == 7:
# rotate 270 degree and flip
out = np.flipud(image)
out = np.rot90(out, k=3, axes=(1,0))
else:
raise Exception('Invalid choice of image transformation')
return out
class SpatialAug:
def __init__(self):
pass
def __call__(self, im, flag=None):
if flag is None:
flag = random.randint(0, 7)
out = data_aug_np(im, flag)
return out
# ----------------------Visualization----------------------------
def imshow(x, title=None, cbar=False):
import matplotlib.pyplot as plt
plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
if title:
plt.title(title)
if cbar:
plt.colorbar()
plt.show()
# -----------------------Covolution------------------------------
def imgrad(im, pading_mode='mirror'):
'''
Calculate image gradient.
Input:
im: h x w x c numpy array
'''
from scipy.ndimage import correlate # lazy import
wx = np.array([[0, 0, 0],
[-1, 1, 0],
[0, 0, 0]], dtype=np.float32)
wy = np.array([[0, -1, 0],
[0, 1, 0],
[0, 0, 0]], dtype=np.float32)
if im.ndim == 3:
gradx = np.stack(
[correlate(im[:,:,c], wx, mode=pading_mode) for c in range(im.shape[2])],
axis=2
)
grady = np.stack(
[correlate(im[:,:,c], wy, mode=pading_mode) for c in range(im.shape[2])],
axis=2
)
grad = np.concatenate((gradx, grady), axis=2)
else:
gradx = correlate(im, wx, mode=pading_mode)
grady = correlate(im, wy, mode=pading_mode)
grad = np.stack((gradx, grady), axis=2)
return {'gradx': gradx, 'grady': grady, 'grad':grad}
def imgrad_fft(im):
'''
Calculate image gradient.
Input:
im: h x w x c numpy array
'''
wx = np.rot90(np.array([[0, 0, 0],
[-1, 1, 0],
[0, 0, 0]], dtype=np.float32), k=2)
gradx = convfft(im, wx)
wy = np.rot90(np.array([[0, -1, 0],
[0, 1, 0],
[0, 0, 0]], dtype=np.float32), k=2)
grady = convfft(im, wy)
grad = np.concatenate((gradx, grady), axis=2)
return {'gradx': gradx, 'grady': grady, 'grad':grad}
def convfft(im, weight):
'''
Convolution with FFT
Input:
im: h1 x w1 x c numpy array
weight: h2 x w2 numpy array
Output:
out: h1 x w1 x c numpy array
'''
axes = (0,1)
otf = psf2otf(weight, im.shape[:2])
if im.ndim == 3:
otf = np.tile(otf[:, :, None], (1,1,im.shape[2]))
out = fft.ifft2(fft.fft2(im, axes=axes) * otf, axes=axes).real
return out
def psf2otf(psf, shape):
"""
MATLAB psf2otf function.
Borrowed from https://github.com/aboucaud/pypher/blob/master/pypher/pypher.py.
Input:
psf : h x w numpy array
shape : list or tuple, output shape of the OTF array
Output:
otf : OTF array with the desirable shape
"""
if np.all(psf == 0):
return np.zeros_like(psf)
inshape = psf.shape
# Pad the PSF to outsize
psf = zero_pad(psf, shape, position='corner')
# Circularly shift OTF so that the 'center' of the PSF is [0,0] element of the array
for axis, axis_size in enumerate(inshape):
psf = np.roll(psf, -int(axis_size / 2), axis=axis)
# Compute the OTF
otf = fft.fft2(psf)
# Estimate the rough number of operations involved in the FFT
# and discard the PSF imaginary part if within roundoff error
# roundoff error = machine epsilon = sys.float_info.epsilon
# or np.finfo().eps
n_ops = np.sum(psf.size * np.log2(psf.shape))
otf = np.real_if_close(otf, tol=n_ops)
return otf
# ----------------------Patch Cropping----------------------------
def random_crop(im, pch_size):
'''
Randomly crop a patch from the give image.
'''
h, w = im.shape[:2]
if h == pch_size and w == pch_size:
im_pch = im
else:
assert h >= pch_size or w >= pch_size
ind_h = random.randint(0, h-pch_size)
ind_w = random.randint(0, w-pch_size)
im_pch = im[ind_h:ind_h+pch_size, ind_w:ind_w+pch_size,]
return im_pch
class RandomCrop:
def __init__(self, pch_size):
self.pch_size = pch_size
def __call__(self, im):
return random_crop(im, self.pch_size)
class ImageSpliterNp:
def __init__(self, im, pch_size, stride, sf=1):
'''
Input:
im: h x w x c, numpy array, [0, 1], low-resolution image in SR
pch_size, stride: patch setting
sf: scale factor in image super-resolution
'''
assert stride <= pch_size
self.stride = stride
self.pch_size = pch_size
self.sf = sf
if im.ndim == 2:
im = im[:, :, None]
height, width, chn = im.shape
self.height_starts_list = self.extract_starts(height)
self.width_starts_list = self.extract_starts(width)
self.length = self.__len__()
self.num_pchs = 0
self.im_ori = im
self.im_res = np.zeros([height*sf, width*sf, chn], dtype=im.dtype)
self.pixel_count = np.zeros([height*sf, width*sf, chn], dtype=im.dtype)
def extract_starts(self, length):
starts = list(range(0, length, self.stride))
if starts[-1] + self.pch_size > length:
starts[-1] = length - self.pch_size
return starts
def __len__(self):
return len(self.height_starts_list) * len(self.width_starts_list)
def __iter__(self):
return self
def __next__(self):
if self.num_pchs < self.length:
w_start_idx = self.num_pchs // len(self.height_starts_list)
w_start = self.width_starts_list[w_start_idx] * self.sf
w_end = w_start + self.pch_size * self.sf
h_start_idx = self.num_pchs % len(self.height_starts_list)
h_start = self.height_starts_list[h_start_idx] * self.sf
h_end = h_start + self.pch_size * self.sf
pch = self.im_ori[h_start:h_end, w_start:w_end,]
self.w_start, self.w_end = w_start, w_end
self.h_start, self.h_end = h_start, h_end
self.num_pchs += 1
else:
raise StopIteration(0)
return pch, (h_start, h_end, w_start, w_end)
def update(self, pch_res, index_infos):
'''
Input:
pch_res: pch_size x pch_size x 3, [0,1]
index_infos: (h_start, h_end, w_start, w_end)
'''
if index_infos is None:
w_start, w_end = self.w_start, self.w_end
h_start, h_end = self.h_start, self.h_end
else:
h_start, h_end, w_start, w_end = index_infos
self.im_res[h_start:h_end, w_start:w_end] += pch_res
self.pixel_count[h_start:h_end, w_start:w_end] += 1
def gather(self):
assert np.all(self.pixel_count != 0)
return self.im_res / self.pixel_count
class ImageSpliterTh:
def __init__(self, im, pch_size, stride, sf=1, extra_bs=1):
'''
Input:
im: n x c x h x w, torch tensor, float, low-resolution image in SR
pch_size, stride: patch setting
sf: scale factor in image super-resolution
pch_bs: aggregate pchs to processing, only used when inputing single image
'''
assert stride <= pch_size
self.stride = stride
self.pch_size = pch_size
self.sf = sf
self.extra_bs = extra_bs
bs, chn, height, width= im.shape
self.true_bs = bs
self.height_starts_list = self.extract_starts(height)
self.width_starts_list = self.extract_starts(width)
self.starts_list = []
for ii in self.height_starts_list:
for jj in self.width_starts_list:
self.starts_list.append([ii, jj])
self.length = self.__len__()
self.count_pchs = 0
self.im_ori = im
self.im_res = torch.zeros([bs, chn, height*sf, width*sf], dtype=im.dtype, device=im.device)
self.pixel_count = torch.zeros([bs, chn, height*sf, width*sf], dtype=im.dtype, device=im.device)
def extract_starts(self, length):
if length <= self.pch_size:
starts = [0,]
else:
starts = list(range(0, length, self.stride))
for ii in range(len(starts)):
if starts[ii] + self.pch_size > length:
starts[ii] = length - self.pch_size
starts = sorted(set(starts), key=starts.index)
return starts
def __len__(self):
return len(self.height_starts_list) * len(self.width_starts_list)
def __iter__(self):
return self
def __next__(self):
if self.count_pchs < self.length:
index_infos = []
current_starts_list = self.starts_list[self.count_pchs:self.count_pchs+self.extra_bs]
for ii, (h_start, w_start) in enumerate(current_starts_list):
w_end = w_start + self.pch_size
h_end = h_start + self.pch_size
current_pch = self.im_ori[:, :, h_start:h_end, w_start:w_end]
if ii == 0:
pch = current_pch
else:
pch = torch.cat([pch, current_pch], dim=0)
h_start *= self.sf
h_end *= self.sf
w_start *= self.sf
w_end *= self.sf
index_infos.append([h_start, h_end, w_start, w_end])
self.count_pchs += len(current_starts_list)
else:
raise StopIteration()
return pch, index_infos
def update(self, pch_res, index_infos):
'''
Input:
pch_res: (n*extra_bs) x c x pch_size x pch_size, float
index_infos: [(h_start, h_end, w_start, w_end),]
'''
assert pch_res.shape[0] % self.true_bs == 0
pch_list = torch.split(pch_res, self.true_bs, dim=0)
assert len(pch_list) == len(index_infos)
for ii, (h_start, h_end, w_start, w_end) in enumerate(index_infos):
current_pch = pch_list[ii]
self.im_res[:, :, h_start:h_end, w_start:w_end] += current_pch
self.pixel_count[:, :, h_start:h_end, w_start:w_end] += 1
def gather(self):
assert torch.all(self.pixel_count != 0)
return self.im_res.div(self.pixel_count)
# ----------------------Patch Cropping----------------------------
class Clamper:
def __init__(self, min_max=(-1, 1)):
self.min_bound, self.max_bound = min_max[0], min_max[1]
def __call__(self, im):
if isinstance(im, np.ndarray):
return np.clip(im, a_min=self.min_bound, a_max=self.max_bound)
elif isinstance(im, torch.Tensor):
return torch.clamp(im, min=self.min_bound, max=self.max_bound)
else:
raise TypeError(f'ndarray or Tensor expected, got {type(im)}')
if __name__ == '__main__':
im = np.random.randn(64, 64, 3).astype(np.float32)
grad1 = imgrad(im)['grad']
grad2 = imgrad_fft(im)['grad']
error = np.abs(grad1 -grad2).max()
mean_error = np.abs(grad1 -grad2).mean()
print('The largest error is {:.2e}'.format(error))
print('The mean error is {:.2e}'.format(mean_error))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment