Commit f7544730 authored by 0x3f3f3f3fun's avatar 0x3f3f3f3fun
Browse files

first commit

parents
dataset:
target: dataset.codeformer.CodeformerDataset
params:
# Path to the file list.
file_list:
out_size: 512
use_crop: true
use_hflip: False
blur_kernel_size: 41
kernel_list: ['iso', 'aniso']
kernel_prob: [0.5, 0.5]
blur_sigma: [0.1, 10]
downsample_range: [0.8, 8]
noise_range: [0, 20]
jpeg_range: [60, 100]
data_loader:
batch_size: 16
shuffle: true
num_workers: 16
drop_last: true
batch_transform:
target: dataset.batch_transform.IdentityBatchTransform
dataset:
target: dataset.codeformer.CodeformerDataset
params:
# Path to the file list.
file_list:
out_size: 512
use_crop: true
use_hflip: False
blur_kernel_size: 41
kernel_list: ['iso', 'aniso']
kernel_prob: [0.5, 0.5]
blur_sigma: [0.1, 10]
downsample_range: [0.8, 8]
noise_range: [0, 20]
jpeg_range: [60, 100]
data_loader:
batch_size: 16
shuffle: false
num_workers: 16
drop_last: true
batch_transform:
target: dataset.batch_transform.IdentityBatchTransform
dataset:
target: dataset.codeformer.CodeformerDataset
params:
# Path to the file list.
file_list:
out_size: 512
use_crop: true
use_hflip: False
blur_kernel_size: 41
kernel_list: ['iso', 'aniso']
kernel_prob: [0.5, 0.5]
blur_sigma: [0.1, 12]
downsample_range: [1, 12]
noise_range: [0, 15]
jpeg_range: [30, 100]
data_loader:
batch_size: 16
shuffle: true
num_workers: 16
drop_last: true
batch_transform:
target: dataset.batch_transform.IdentityBatchTransform
dataset:
target: dataset.codeformer.CodeformerDataset
params:
# Path to the file list.
file_list:
out_size: 512
use_crop: true
use_hflip: False
blur_kernel_size: 41
kernel_list: ['iso', 'aniso']
kernel_prob: [0.5, 0.5]
blur_sigma: [0.1, 12]
downsample_range: [1, 12]
noise_range: [0, 15]
jpeg_range: [30, 100]
data_loader:
batch_size: 16
shuffle: false
num_workers: 16
drop_last: true
batch_transform:
target: dataset.batch_transform.IdentityBatchTransform
dataset:
target: dataset.realesrgan.RealESRGANDataset
params:
# Path to the file list.
file_list:
out_size: 512
crop_type: center
use_hflip: false
use_rot: false
blur_kernel_size: 21
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob: 0.1
blur_sigma: [0.2, 3]
betag_range: [0.5, 4]
betap_range: [1, 2]
blur_kernel_size2: 21
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob2: 0.1
blur_sigma2: [0.2, 1.5]
betag_range2: [0.5, 4]
betap_range2: [1, 2]
final_sinc_prob: 0.8
data_loader:
batch_size: 32
shuffle: true
num_workers: 16
prefetch_factor: 2
drop_last: true
batch_transform:
target: dataset.batch_transform.RealESRGANBatchTransform
params:
use_sharpener: false
resize_hq: false
# Queue size of training pool, this should be multiples of batch_size.
queue_size: 256
# the first degradation process
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
resize_range: [0.15, 1.5]
gaussian_noise_prob: 0.5
noise_range: [1, 30]
poisson_scale_range: [0.05, 3]
gray_noise_prob: 0.4
jpeg_range: [30, 95]
# the second degradation process
stage2_scale: 4
second_blur_prob: 0.8
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
resize_range2: [0.3, 1.2]
gaussian_noise_prob2: 0.5
noise_range2: [1, 25]
poisson_scale_range2: [0.05, 2.5]
gray_noise_prob2: 0.4
jpeg_range2: [30, 95]
dataset:
target: dataset.realesrgan.RealESRGANDataset
params:
# Path to the file list.
file_list:
out_size: 512
crop_type: center
use_hflip: false
use_rot: false
blur_kernel_size: 21
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob: 0.1
blur_sigma: [0.2, 3]
betag_range: [0.5, 4]
betap_range: [1, 2]
blur_kernel_size2: 21
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob2: 0.1
blur_sigma2: [0.2, 1.5]
betag_range2: [0.5, 4]
betap_range2: [1, 2]
final_sinc_prob: 0.8
data_loader:
batch_size: 32
shuffle: false
num_workers: 16
prefetch_factor: 2
drop_last: true
batch_transform:
target: dataset.batch_transform.RealESRGANBatchTransform
params:
use_sharpener: false
resize_hq: false
# Queue size of training pool, this should be multiples of batch_size.
queue_size: 256
# the first degradation process
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
resize_range: [0.15, 1.5]
gaussian_noise_prob: 0.5
noise_range: [1, 30]
poisson_scale_range: [0.05, 3]
gray_noise_prob: 0.4
jpeg_range: [30, 95]
# the second degradation process
stage2_scale: 4
second_blur_prob: 0.8
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
resize_range2: [0.3, 1.2]
gaussian_noise_prob2: 0.5
noise_range2: [1, 25]
poisson_scale_range2: [0.05, 2.5]
gray_noise_prob2: 0.4
jpeg_range2: [30, 95]
target: model.cldm.ControlLDM
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
control_key: "hint"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
sd_locked: True
only_mid_control: False
# Learning rate.
learning_rate: 1e-4
control_stage_config:
target: model.cldm.ControlNet
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 4
hint_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
unet_config:
target: model.cldm.ControlledUnetModel
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
preprocess_config:
target: model.swinir.SwinIR
params:
img_size: 64
patch_size: 1
in_chans: 3
embed_dim: 180
depths: [6, 6, 6, 6, 6, 6, 6, 6]
num_heads: [6, 6, 6, 6, 6, 6, 6, 6]
window_size: 8
mlp_ratio: 2
sf: 8
img_range: 1.0
upsampler: "nearest+conv"
resi_connection: "1conv"
unshuffle: True
unshuffle_scale: 8
target: model.swinir.SwinIR
params:
img_size: 64
patch_size: 1
in_chans: 3
embed_dim: 180
depths: [6, 6, 6, 6, 6, 6, 6, 6]
num_heads: [6, 6, 6, 6, 6, 6, 6, 6]
window_size: 8
mlp_ratio: 2
sf: 8
img_range: 1.0
upsampler: "nearest+conv"
resi_connection: "1conv"
unshuffle: True
unshuffle_scale: 8
hq_key: jpg
lq_key: hint
# Learning rate.
learning_rate: 1e-4
weight_decay: 0
data:
target: dataset.data_module.BIRDataModule
params:
# Path to training set configuration file.
train_config:
# Path to validation set configuration file.
val_config:
model:
# You can set learning rate in the following configuration file.
config: configs/model/cldm.yaml
# Path to the checkpoints or weights you want to resume. At the begining,
# this should be set to the initial weights created by scripts/make_stage2_init_weight.py.
resume:
lightning:
seed: 231
trainer:
accelerator: ddp
precision: 32
# Indices of GPUs used for training.
gpus: [0, 1, 2, 3, 4, 5]
# Path to save logs and checkpoints.
default_root_dir:
# Max number of training steps (batches).
max_steps: 25001
# Validation frequency in terms of training steps.
val_check_interval: 500
log_every_n_steps: 50
# Accumulate gradients from multiple batches so as to increase batch size.
accumulate_grad_batches: 1
callbacks:
- target: model.callbacks.ImageLogger
params:
# Log frequency of image logger.
log_every_n_steps: 1000
max_images_each_step: 4
log_images_kwargs: ~
- target: model.callbacks.ModelCheckpoint
params:
# Frequency of saving checkpoints.
every_n_train_steps: 5000
save_top_k: -1
filename: "{step}"
data:
target: dataset.data_module.BIRDataModule
params:
# Path to training set configuration file.
train_config:
# Path to validation set configuration file.
val_config:
model:
# You can set learning rate in the following configuration file.
config: configs/model/swinir.yaml
# Path to the checkpoints or weights you want to resume.
resume: ~
lightning:
seed: 231
trainer:
accelerator: ddp
precision: 32
# Indices of GPUs used for training.
gpus: [0, 1, 2]
# Path to save logs and checkpoints.
default_root_dir:
# Max number of training steps (batches).
max_steps: 150001
# Validation frequency in terms of training steps.
val_check_interval: 500
# Log frequency of tensorboard logger.
log_every_n_steps: 50
# Accumulate gradients from multiple batches so as to increase batch size.
accumulate_grad_batches: 1
callbacks:
- target: model.callbacks.ImageLogger
params:
# Log frequency of image logger.
log_every_n_steps: 1000
max_images_each_step: 4
log_images_kwargs: ~
- target: model.callbacks.ModelCheckpoint
params:
# Frequency of saving checkpoints.
every_n_train_steps: 10000
save_top_k: -1
filename: "{step}"
from typing import Any, overload, Dict, Union, List, Sequence
import random
import torch
from torch.nn import functional as F
import numpy as np
from utils.image import USMSharp, DiffJPEG, filter2D
from utils.degradation import (
random_add_gaussian_noise_pt, random_add_poisson_noise_pt
)
class BatchTransform:
@overload
def __call__(self, batch: Any) -> Any:
...
class IdentityBatchTransform(BatchTransform):
def __call__(self, batch: Any) -> Any:
return batch
class RealESRGANBatchTransform(BatchTransform):
"""
It's too slow to process a batch of images under RealESRGAN degradation
model on CPU (by dataloader), which may cost 0.2 ~ 1 second per image.
So we execute the degradation process on GPU after loading a batch of images
and kernels from dataloader.
"""
def __init__(
self,
use_sharpener: bool,
resize_hq: bool,
queue_size: int,
resize_prob: Sequence[float],
resize_range: Sequence[float],
gray_noise_prob: float,
gaussian_noise_prob: float,
noise_range: Sequence[float],
poisson_scale_range: Sequence[float],
jpeg_range: Sequence[int],
second_blur_prob: float,
stage2_scale: Union[float, Sequence[Union[float, int]]],
resize_prob2: Sequence[float],
resize_range2: Sequence[float],
gray_noise_prob2: float,
gaussian_noise_prob2: float,
noise_range2: Sequence[float],
poisson_scale_range2: Sequence[float],
jpeg_range2: Sequence[int]
) -> "RealESRGANBatchTransform":
super().__init__()
# resize settings for the first degradation process
self.resize_prob = resize_prob
self.resize_range = resize_range
# noise settings for the first degradation process
self.gray_noise_prob = gray_noise_prob
self.gaussian_noise_prob = gaussian_noise_prob
self.noise_range = noise_range
self.poisson_scale_range = poisson_scale_range
self.jpeg_range = jpeg_range
self.second_blur_prob = second_blur_prob
self.stage2_scale = stage2_scale
assert (
isinstance(stage2_scale, (float, int)) or (
isinstance(stage2_scale, Sequence) and len(stage2_scale) == 2 and
all(isinstance(x, (float, int)) for x in stage2_scale)
)
), f"stage2_scale can not be {type(stage2_scale)}"
# resize settings for the second degradation process
self.resize_prob2 = resize_prob2
self.resize_range2 = resize_range2
# noise settings for the second degradation process
self.gray_noise_prob2 = gray_noise_prob2
self.gaussian_noise_prob2 = gaussian_noise_prob2
self.noise_range2 = noise_range2
self.poisson_scale_range2 = poisson_scale_range2
self.jpeg_range2 = jpeg_range2
self.use_sharpener = use_sharpener
if self.use_sharpener:
self.usm_sharpener = USMSharp()
else:
self.usm_sharpener = None
self.resize_hq = resize_hq
self.queue_size = queue_size
self.jpeger = DiffJPEG(differentiable=False)
@torch.no_grad()
def _dequeue_and_enqueue(self):
"""It is the training pair pool for increasing the diversity in a batch.
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
to increase the degradation diversity in a batch.
"""
# initialize
b, c, h, w = self.lq.size()
if not hasattr(self, "queue_lr"):
# TODO: Being multiple of batch_size seems not necessary for queue_size
assert self.queue_size % b == 0, f"queue size {self.queue_size} should be divisible by batch size {b}"
self.queue_lr = torch.zeros(self.queue_size, c, h, w).to(self.lq)
_, c, h, w = self.gt.size()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).to(self.lq)
self.queue_ptr = 0
if self.queue_ptr == self.queue_size: # the pool is full
# do dequeue and enqueue
# shuffle
idx = torch.randperm(self.queue_size)
self.queue_lr = self.queue_lr[idx]
self.queue_gt = self.queue_gt[idx]
# get first b samples
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
# update the queue
self.queue_lr[0:b, :, :, :] = self.lq.clone()
self.queue_gt[0:b, :, :, :] = self.gt.clone()
self.lq = lq_dequeue
self.gt = gt_dequeue
else:
# only do enqueue
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
self.queue_ptr = self.queue_ptr + b
@torch.no_grad()
def __call__(self, batch: Dict[str, Union[torch.Tensor, str]]) -> Dict[str, Union[torch.Tensor, List[str]]]:
# training data synthesis
hq = batch["hq"]
if self.use_sharpener:
self.use_sharpener.to(hq)
hq = self.usm_sharpener(hq)
self.jpeger.to(hq)
kernel1 = batch["kernel1"]
kernel2 = batch["kernel2"]
sinc_kernel = batch["sinc_kernel"]
ori_h, ori_w = hq.size()[2:4]
# ----------------------- The first degradation process ----------------------- #
# blur
out = filter2D(hq, kernel1)
# random resize
updown_type = random.choices(["up", "down", "keep"], self.resize_prob)[0]
if updown_type == "up":
scale = np.random.uniform(1, self.resize_range[1])
elif updown_type == "down":
scale = np.random.uniform(self.resize_range[0], 1)
else:
scale = 1
mode = random.choice(["area", "bilinear", "bicubic"])
out = F.interpolate(out, scale_factor=scale, mode=mode)
# add noise
if np.random.uniform() < self.gaussian_noise_prob:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.noise_range, clip=True,
rounds=False, gray_prob=self.gray_noise_prob
)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.poisson_scale_range,
gray_prob=self.gray_noise_prob,
clip=True,
rounds=False
)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.jpeg_range)
# clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# ----------------------- The second degradation process ----------------------- #
# blur
if np.random.uniform() < self.second_blur_prob:
out = filter2D(out, kernel2)
# select scale of second degradation stage
if isinstance(self.stage2_scale, Sequence):
min_scale, max_scale = self.stage2_scale
stage2_scale = np.random.uniform(min_scale, max_scale)
else:
stage2_scale = self.stage2_scale
stage2_h, stage2_w = int(ori_h / stage2_scale), int(ori_w / stage2_scale)
# print(f"stage2 scale = {stage2_scale}")
# random resize
updown_type = random.choices(["up", "down", "keep"], self.resize_prob2)[0]
if updown_type == "up":
scale = np.random.uniform(1, self.resize_range2[1])
elif updown_type == "down":
scale = np.random.uniform(self.resize_range2[0], 1)
else:
scale = 1
mode = random.choice(["area", "bilinear", "bicubic"])
out = F.interpolate(
out, size=(int(stage2_h * scale), int(stage2_w * scale)), mode=mode
)
# add noise
if np.random.uniform() < self.gaussian_noise_prob2:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.noise_range2, clip=True,
rounds=False, gray_prob=self.gray_noise_prob2
)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.poisson_scale_range2,
gray_prob=self.gray_noise_prob2,
clip=True,
rounds=False
)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if np.random.uniform() < 0.5:
# resize back + the final sinc filter
mode = random.choice(["area", "bilinear", "bicubic"])
out = F.interpolate(out, size=(stage2_h, stage2_w), mode=mode)
out = filter2D(out, sinc_kernel)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.jpeg_range2)
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.jpeg_range2)
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(["area", "bilinear", "bicubic"])
out = F.interpolate(out, size=(stage2_h, stage2_w), mode=mode)
out = filter2D(out, sinc_kernel)
# resize back to gt_size since We are doing restoration task
if stage2_scale != 1:
out = F.interpolate(out, size=(ori_h, ori_w), mode="bicubic")
# clamp and round
lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
if self.resize_hq and stage2_scale != 1:
# resize hq
hq = F.interpolate(hq, size=(stage2_h, stage2_w), mode="bicubic", antialias=True)
hq = F.interpolate(hq, size=(ori_h, ori_w), mode="bicubic", antialias=True)
self.gt = hq
self.lq = lq
self._dequeue_and_enqueue()
# [0, 1], float32, rgb, nhwc
lq = self.lq.float().permute(0, 2, 3, 1).contiguous()
# [-1, 1], float32, rgb, nhwc
hq = (self.gt * 2 - 1).float().permute(0, 2, 3, 1).contiguous()
return dict(jpg=hq, hint=lq, txt=batch["txt"])
from typing import Sequence, Dict, Union
import math
import time
import numpy as np
import cv2
from PIL import Image
import torch.utils.data as data
from utils.file import load_file_list
from utils.image import center_crop_arr, augment, random_crop_arr
from utils.degradation import (
random_mixed_kernels, random_add_gaussian_noise, random_add_jpg_compression
)
class CodeformerDataset(data.Dataset):
def __init__(
self,
file_list: str,
out_size: int,
crop_type: str,
use_hflip: bool,
blur_kernel_size: int,
kernel_list: Sequence[str],
kernel_prob: Sequence[float],
blur_sigma: Sequence[float],
downsample_range: Sequence[float],
noise_range: Sequence[float],
jpeg_range: Sequence[int]
) -> "CodeformerDataset":
super(CodeformerDataset, self).__init__()
self.file_list = file_list
self.paths = load_file_list(file_list)
self.out_size = out_size
self.crop_type = crop_type
assert self.crop_type in ["none", "center", "random"]
self.use_hflip = use_hflip
# degradation configurations
self.blur_kernel_size = blur_kernel_size
self.kernel_list = kernel_list
self.kernel_prob = kernel_prob
self.blur_sigma = blur_sigma
self.downsample_range = downsample_range
self.noise_range = noise_range
self.jpeg_range = jpeg_range
def __getitem__(self, index: int) -> Dict[str, Union[np.ndarray, str]]:
# load gt image
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
gt_path = self.paths[index]
success = False
for _ in range(3):
try:
pil_img = Image.open(gt_path).convert("RGB")
success = True
break
except:
time.sleep(1)
assert success, f"failed to load image {gt_path}"
if self.crop_type == "center":
pil_img_gt = center_crop_arr(pil_img, self.out_size)
elif self.crop_type == "random":
pil_img_gt = random_crop_arr(pil_img, self.out_size)
else:
pil_img_gt = np.array(pil_img)
assert pil_img_gt.shape[:2] == (self.out_size, self.out_size)
img_gt = (pil_img_gt[..., ::-1] / 255.0).astype(np.float32)
# random horizontal flip
img_gt = augment(img_gt, hflip=self.use_hflip, rotation=False, return_status=False)
h, w, _ = img_gt.shape
# ------------------------ generate lq image ------------------------ #
# blur
kernel = random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
self.blur_kernel_size,
self.blur_sigma,
self.blur_sigma,
[-math.pi, math.pi],
noise_range=None
)
img_lq = cv2.filter2D(img_gt, -1, kernel)
# downsample
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
# noise
if self.noise_range is not None:
img_lq = random_add_gaussian_noise(img_lq, self.noise_range)
# jpeg compression
if self.jpeg_range is not None:
img_lq = random_add_jpg_compression(img_lq, self.jpeg_range)
# resize to original size
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
# BGR to RGB, [-1, 1]
target = (img_gt[..., ::-1] * 2 - 1).astype(np.float32)
# BGR to RGB, [0, 1]
source = img_lq[..., ::-1].astype(np.float32)
return dict(jpg=target, txt="", hint=source)
def __len__(self) -> int:
return len(self.paths)
from typing import Any, Tuple, Mapping
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from omegaconf import OmegaConf
from utils.common import instantiate_from_config
from dataset.batch_transform import BatchTransform, IdentityBatchTransform
class BIRDataModule(pl.LightningDataModule):
def __init__(
self,
train_config: str,
val_config: str=None
) -> "BIRDataModule":
super().__init__()
self.train_config = OmegaConf.load(train_config)
self.val_config = OmegaConf.load(val_config) if val_config else None
def load_dataset(self, config: Mapping[str, Any]) -> Tuple[Dataset, BatchTransform]:
dataset = instantiate_from_config(config["dataset"])
batch_transform = (
instantiate_from_config(config["batch_transform"])
if config.get("batch_transform") else IdentityBatchTransform()
)
return dataset, batch_transform
def setup(self, stage: str) -> None:
if stage == "fit":
self.train_dataset, self.train_batch_transform = self.load_dataset(self.train_config)
if self.val_config:
self.val_dataset, self.val_batch_transform = self.load_dataset(self.val_config)
else:
self.val_dataset, self.val_batch_transform = None, None
else:
raise NotImplementedError(stage)
def train_dataloader(self) -> TRAIN_DATALOADERS:
return DataLoader(
dataset=self.train_dataset, **self.train_config["data_loader"]
)
def val_dataloader(self) -> EVAL_DATALOADERS:
if self.val_dataset is None:
return None
return DataLoader(
dataset=self.val_dataset, **self.val_config["data_loader"]
)
def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
self.trainer: pl.Trainer
if self.trainer.training:
return self.train_batch_transform(batch)
elif self.trainer.validating or self.trainer.sanity_checking:
return self.val_batch_transform(batch)
else:
raise RuntimeError(
"Trainer state: \n"
f"training: {self.trainer.training}\n"
f"validating: {self.trainer.validating}\n"
f"testing: {self.trainer.testing}\n"
f"predicting: {self.trainer.predicting}\n"
f"sanity_checking: {self.trainer.sanity_checking}"
)
from typing import Dict, Sequence
import math
import random
import time
import numpy as np
import torch
from torch.utils import data
from PIL import Image
from utils.degradation import circular_lowpass_kernel, random_mixed_kernels
from utils.image import augment, random_crop_arr, center_crop_arr
from utils.file import load_file_list
class RealESRGANDataset(data.Dataset):
"""
# TODO: add comment
"""
def __init__(
self,
file_list: str,
out_size: int,
crop_type: str,
use_hflip: bool,
use_rot: bool,
# blur kernel settings of the first degradation stage
blur_kernel_size: int,
kernel_list: Sequence[str],
kernel_prob: Sequence[float],
blur_sigma: Sequence[float],
betag_range: Sequence[float],
betap_range: Sequence[float],
sinc_prob: float,
# blur kernel settings of the second degradation stage
blur_kernel_size2: int,
kernel_list2: Sequence[str],
kernel_prob2: Sequence[float],
blur_sigma2: Sequence[float],
betag_range2: Sequence[float],
betap_range2: Sequence[float],
sinc_prob2: float,
final_sinc_prob: float
) -> "RealESRGANDataset":
super(RealESRGANDataset, self).__init__()
self.paths = load_file_list(file_list)
self.out_size = out_size
self.crop_type = crop_type
assert self.crop_type in ["center", "random", "none"], f"invalid crop type: {self.crop_type}"
self.blur_kernel_size = blur_kernel_size
self.kernel_list = kernel_list
# a list for each kernel probability
self.kernel_prob = kernel_prob
self.blur_sigma = blur_sigma
# betag used in generalized Gaussian blur kernels
self.betag_range = betag_range
# betap used in plateau blur kernels
self.betap_range = betap_range
# the probability for sinc filters
self.sinc_prob = sinc_prob
self.blur_kernel_size2 = blur_kernel_size2
self.kernel_list2 = kernel_list2
self.kernel_prob2 = kernel_prob2
self.blur_sigma2 = blur_sigma2
self.betag_range2 = betag_range2
self.betap_range2 = betap_range2
self.sinc_prob2 = sinc_prob2
# a final sinc filter
self.final_sinc_prob = final_sinc_prob
self.use_hflip = use_hflip
self.use_rot = use_rot
# kernel size ranges from 7 to 21
self.kernel_range = [2 * v + 1 for v in range(3, 11)]
# TODO: kernel range is now hard-coded, should be in the configure file
# convolving with pulse tensor brings no blurry effect
self.pulse_tensor = torch.zeros(21, 21).float()
self.pulse_tensor[10, 10] = 1
@torch.no_grad()
def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
# -------------------------------- Load hq images -------------------------------- #
hq_path = self.paths[index]
success = False
for _ in range(3):
try:
pil_img = Image.open(hq_path).convert("RGB")
success = True
break
except:
time.sleep(1)
assert success, f"failed to load image {hq_path}"
if self.crop_type == "random":
pil_img = random_crop_arr(pil_img, self.out_size)
elif self.crop_type == "center":
pil_img = center_crop_arr(pil_img, self.out_size)
# self.crop_type is "none"
else:
pil_img = np.array(pil_img)
assert pil_img.shape[:2] == (self.out_size, self.out_size)
# hwc, rgb to bgr, [0, 255] to [0, 1], float32
img_hq = (pil_img[..., ::-1] / 255.0).astype(np.float32)
# -------------------- Do augmentation for training: flip, rotation -------------------- #
img_hq = augment(img_hq, self.use_hflip, self.use_rot)
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.sinc_prob:
# this sinc filter setting is for kernels ranging from [7, 21]
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel = random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
self.betag_range,
self.betap_range,
noise_range=None
)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.sinc_prob2:
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel2 = random_mixed_kernels(
self.kernel_list2,
self.kernel_prob2,
kernel_size,
self.blur_sigma2,
self.blur_sigma2, [-math.pi, math.pi],
self.betag_range2,
self.betap_range2,
noise_range=None
)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------------------- the final sinc kernel ------------------------------------- #
if np.random.uniform() < self.final_sinc_prob:
kernel_size = random.choice(self.kernel_range)
omega_c = np.random.uniform(np.pi / 3, np.pi)
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
sinc_kernel = torch.FloatTensor(sinc_kernel)
else:
sinc_kernel = self.pulse_tensor
# [0, 1], BGR to RGB, HWC to CHW
img_hq = torch.from_numpy(
img_hq[..., ::-1].transpose(2, 0, 1).copy()
).float()
kernel = torch.FloatTensor(kernel)
kernel2 = torch.FloatTensor(kernel2)
return {
"hq": img_hq, "kernel1": kernel, "kernel2": kernel2,
"sinc_kernel": sinc_kernel, "txt": ""
}
def __len__(self) -> int:
return len(self.paths)
from typing import List
import math
from argparse import ArgumentParser
import numpy as np
import torch
import einops
import pytorch_lightning as pl
import gradio as gr
from PIL import Image
from omegaconf import OmegaConf
from model.spaced_sampler import SpacedSampler
from model.cldm import ControlLDM
from utils.image import (
wavelet_reconstruction, auto_resize, pad
)
from utils.common import instantiate_from_config, load_state_dict
parser = ArgumentParser()
parser.add_argument("--config", required=True, type=str)
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--reload_swinir", action="store_true")
parser.add_argument("--swinir_ckpt", type=str, default="")
args = parser.parse_args()
# load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
# reload preprocess model if specified
if args.reload_swinir:
print(f"reload swinir model from {args.swinir_ckpt}")
load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True)
model.freeze()
model.to(device)
# load sampler
sampler = SpacedSampler(model, var_type="fixed_small")
@torch.no_grad()
def process(
control_img: Image.Image,
num_samples: int,
sr_scale: int,
image_size: int,
disable_preprocess_model: bool,
strength: float,
positive_prompt: str,
negative_prompt: str,
cond_scale: float,
steps: int,
use_color_fix: bool,
keep_original_size: bool,
seed: int
) -> List[np.ndarray]:
print(
f"control image shape={control_img.size}\n"
f"num_samples={num_samples}, sr_scale={sr_scale}, image_size={image_size}\n"
f"disable_preprocess_model={disable_preprocess_model}, strength={strength}\n"
f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n"
f"prompt scale={cond_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
f"seed={seed}"
)
pl.seed_everything(seed)
# prepare condition
if sr_scale != 1:
control_img = control_img.resize(
tuple(math.ceil(x * sr_scale) for x in control_img.size),
Image.BICUBIC
)
input_size = control_img.size
control_img = auto_resize(control_img, image_size)
h, w = control_img.height, control_img.width
control_img = pad(np.array(control_img), scale=64) # HWC, RGB, [0, 255]
control_imgs = [control_img] * num_samples
control = torch.tensor(np.stack(control_imgs) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
if not disable_preprocess_model:
control = model.preprocess_model(control)
height, width = control.size(-2), control.size(-1)
cond = {
"c_latent": [model.apply_condition_encoder(control)],
"c_crossattn": [model.get_learned_conditioning([positive_prompt] * num_samples)]
}
uncond = {
"c_latent": [model.apply_condition_encoder(control)],
"c_crossattn": [model.get_learned_conditioning([negative_prompt] * num_samples)]
}
model.control_scales = [strength] * 13
shape = (num_samples, 4, height // 8, width // 8)
print(f"latent shape = {shape}")
x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
samples = sampler.sample(
steps, shape, cond,
unconditional_guidance_scale=cond_scale,
unconditional_conditioning=uncond,
cond_fn=None, x_T=x_T
)
x_samples = model.decode_first_stage(samples)
x_samples = ((x_samples + 1) / 2).clamp(0, 1)
# apply color correction
if use_color_fix:
x_samples = wavelet_reconstruction(x_samples, control)
x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
preds = []
for img in x_samples:
if keep_original_size:
# remove padding and resize to input size
img = Image.fromarray(img[:h, :w, :]).resize(input_size, Image.LANCZOS)
preds.append(np.array(img))
else:
# remove padding
preds.append(img[:h, :w, :])
return preds
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## DiffBIR")
with gr.Row():
with gr.Column():
input_image = gr.Image(source="upload", type="pil")
run_button = gr.Button(label="Run")
with gr.Accordion("Options", open=True):
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
sr_scale = gr.Number(label="SR Scale", value=1)
image_size = gr.Slider(label="Image size", minimum=256, maximum=768, value=512, step=64)
positive_prompt = gr.Textbox(label="Positive Prompt", value="")
# It's worth noting that if your positive prompt is short while the negative prompt
# is long, the positive prompt will lose its effectiveness.
# Example (control strength = 0):
# positive prompt: cat
# negative prompt: longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality
# I take some experiments and find that sd_v2.1 will suffer from this problem while sd_v1.5 won't.
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
)
cond_scale = gr.Slider(label="Prompt Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
disable_preprocess_model = gr.Checkbox(label="Disable Preprocess Model", value=False)
use_color_fix = gr.Checkbox(label="Use Color Correction", value=True)
keep_original_size = gr.Checkbox(label="Keep Original Size", value=True)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231)
with gr.Column():
result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery").style(grid=2, height="auto")
inputs = [
input_image,
num_samples,
sr_scale,
image_size,
disable_preprocess_model,
strength,
positive_prompt,
negative_prompt,
cond_scale,
steps,
use_color_fix,
keep_original_size,
seed
]
run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])
block.launch(server_name='0.0.0.0')
from typing import List, Tuple
import os
import math
from argparse import ArgumentParser, Namespace
import numpy as np
import torch
import einops
import pytorch_lightning as pl
from PIL import Image
from omegaconf import OmegaConf
from model.spaced_sampler import SpacedSampler
from model.ddim_sampler import DDIMSampler
from model.cldm import ControlLDM
from utils.image import (
wavelet_reconstruction, adaptive_instance_normalization, auto_resize, pad
)
from utils.common import instantiate_from_config, load_state_dict
from utils.file import list_image_files, get_file_name_parts
@torch.no_grad()
def process(
model: ControlLDM,
control_imgs: List[np.ndarray],
sampler: str,
steps: int,
strength: float,
color_fix_type: str,
disable_preprocess_model: bool
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""
Apply DiffBIR model on a list of low-quality images.
Args:
model (ControlLDM): Model.
control_imgs (List[np.ndarray]): A list of low-quality images (HWC, RGB, range in [0, 255])
sampler (str): Sampler name.
steps (int): Sampling steps.
strength (float): Control strength. Set to 1.0 during traning.
color_fix_type (str): Type of color correction for samples.
disable_preprocess_model (bool): If specified, preprocess model (SwinIR) will not be used.
Returns:
preds (List[np.ndarray]): Restoration results (HWC, RGB, range in [0, 255]).
stage1_preds (List[np.ndarray]): Outputs of preprocess model (HWC, RGB, range in [0, 255]).
If `disable_preprocess_model` is specified, then preprocess model's outputs is the same
as low-quality inputs.
"""
n_samples = len(control_imgs)
if sampler == "ddpm":
sampler = SpacedSampler(model, var_type="fixed_small")
else:
sampler = DDIMSampler(model)
control = torch.tensor(np.stack(control_imgs) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
# TODO: model.preprocess_model = lambda x: x
if not disable_preprocess_model and hasattr(model, "preprocess_model"):
control = model.preprocess_model(control)
elif disable_preprocess_model and not hasattr(model, "preprocess_model"):
raise ValueError(f"model doesn't have a preprocess model.")
height, width = control.size(-2), control.size(-1)
cond = {
"c_latent": [model.apply_condition_encoder(control)],
"c_crossattn": [model.get_learned_conditioning([""] * n_samples)]
}
model.control_scales = [strength] * 13
shape = (n_samples, 4, height // 8, width // 8)
x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
if isinstance(sampler, SpacedSampler):
samples = sampler.sample(
steps, shape, cond,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
cond_fn=None, x_T=x_T
)
else:
sampler: DDIMSampler
samples, _ = sampler.sample(
S=steps, batch_size=shape[0], shape=shape[1:],
conditioning=cond, unconditional_conditioning=None,
x_T=x_T, eta=0
)
x_samples = model.decode_first_stage(samples)
x_samples = ((x_samples + 1) / 2).clamp(0, 1)
# apply color correction (borrowed from StableSR)
if color_fix_type == "adain":
x_samples = adaptive_instance_normalization(x_samples, control)
elif color_fix_type == "wavelet":
x_samples = wavelet_reconstruction(x_samples, control)
else:
assert color_fix_type == "none", f"unexpected color fix type: {color_fix_type}"
x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
control = (einops.rearrange(control, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
preds = [x_samples[i] for i in range(n_samples)]
stage1_preds = [control[i] for i in range(n_samples)]
return preds, stage1_preds
def parse_args() -> Namespace:
parser = ArgumentParser()
parser.add_argument("--ckpt", required=True, type=str)
parser.add_argument("--config", required=True, type=str)
parser.add_argument("--reload_swinir", action="store_true")
parser.add_argument("--swinir_ckpt", type=str, default="")
parser.add_argument("--input", type=str, required=True)
parser.add_argument("--sampler", type=str, default="ddpm", choices=["ddpm", "ddim"])
parser.add_argument("--steps", required=True, type=int)
parser.add_argument("--sr_scale", type=float, default=1)
parser.add_argument("--image_size", type=int, default=512)
parser.add_argument("--repeat_times", type=int, default=1)
parser.add_argument("--disable_preprocess_model", action="store_true")
parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"])
parser.add_argument("--resize_back", action="store_true")
parser.add_argument("--output", type=str, required=True)
parser.add_argument("--show_lq", action="store_true")
parser.add_argument("--skip_if_exist", action="store_true")
parser.add_argument("--seed", type=int, default=231)
return parser.parse_args()
def main() -> None:
args = parse_args()
pl.seed_everything(args.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
# reload preprocess model if specified
if args.reload_swinir:
if not hasattr(model, "preprocess_model"):
raise ValueError(f"model don't have a preprocess model.")
print(f"reload swinir model from {args.swinir_ckpt}")
load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True)
model.freeze()
model.to(device)
assert os.path.isdir(args.input)
print(f"sampling {args.steps} steps using ddpm sampler")
for file_path in list_image_files(args.input, follow_links=True):
lq = Image.open(file_path).convert("RGB")
if args.sr_scale != 1:
lq = lq.resize(
tuple(math.ceil(x * args.sr_scale) for x in lq.size),
Image.BICUBIC
)
lq_resized = auto_resize(lq, args.image_size)
x = pad(np.array(lq_resized), scale=64)
for i in range(args.repeat_times):
save_path = os.path.join(args.output, os.path.relpath(file_path, args.input))
parent_path, stem, _ = get_file_name_parts(save_path)
save_path = os.path.join(parent_path, f"{stem}_{i}.png")
if os.path.exists(save_path):
if args.skip_if_exist:
print(f"skip {save_path}")
continue
else:
raise RuntimeError(f"{save_path} already exist")
os.makedirs(parent_path, exist_ok=True)
try:
preds, stage1_preds = process(
model, [x], steps=args.steps, sampler=args.sampler,
strength=1,
color_fix_type=args.color_fix_type,
disable_preprocess_model=args.disable_preprocess_model
)
except RuntimeError as e:
# Avoid cuda_out_of_memory error.
print(f"{file_path}, error: {e}")
continue
pred, stage1_pred = preds[0], stage1_preds[0]
# remove padding
pred = pred[:lq_resized.height, :lq_resized.width, :]
stage1_pred = stage1_pred[:lq_resized.height, :lq_resized.width, :]
if args.show_lq:
if args.resize_back:
if lq_resized.size != lq.size:
pred = np.array(Image.fromarray(pred).resize(lq.size, Image.LANCZOS))
stage1_pred = np.array(Image.fromarray(stage1_pred).resize(lq.size, Image.LANCZOS))
lq = np.array(lq)
else:
lq = np.array(lq_resized)
images = [lq, pred] if args.disable_preprocess_model else [lq, stage1_pred, pred]
Image.fromarray(np.concatenate(images, axis=1)).save(save_path)
else:
if args.resize_back and lq_resized.size != lq.size:
Image.fromarray(pred).resize(lq.size, Image.LANCZOS).save(save_path)
else:
Image.fromarray(pred).save(save_path)
print(f"save to {save_path}")
if __name__ == "__main__":
main()
pip install pytorch_lightning==1.4.2
pip install einops
conda install transformers
conda install chardet
pip install open-clip-torch
pip install omegaconf
pip install torchmetrics==0.6.0
pip install triton
pip install opencv-python-headless
conda install scipy
conda install matplotlib
pip install lpips
pip install gradio
import torch
from ldm.modules.midas.api import load_midas_transform
class AddMiDaS(object):
def __init__(self, model_type):
super().__init__()
self.transform = load_midas_transform(model_type)
def pt2np(self, x):
x = ((x + 1.0) * .5).detach().cpu().numpy()
return x
def np2pt(self, x):
x = torch.from_numpy(x) * 2 - 1.
return x
def __call__(self, sample):
# sample['jpg'] is tensor hwc in [-1, 1] at this point
x = self.pt2np(sample['jpg'])
x = self.transform({"image": x})["image"]
sample['midas_in'] = x
return sample
\ No newline at end of file
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config
from ldm.modules.ema import LitEma
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False
):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, postfix=""):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
if self.learn_logvar:
print(f"{self.__class__.__name__}: Learning logvar")
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list,
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
log["reconstructions_ema"] = xrec_ema
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment