Commit f7544730 authored by 0x3f3f3f3fun's avatar 0x3f3f3f3fun
Browse files

first commit

parents
# based on https://github.com/isl-org/MiDaS
import cv2
import torch
import torch.nn as nn
from torchvision.transforms import Compose
from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
from ldm.modules.midas.midas.midas_net import MidasNet
from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
ISL_PATHS = {
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
"midas_v21": "",
"midas_v21_small": "",
}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def load_midas_transform(model_type):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load transform only
if model_type == "dpt_large": # DPT-Large
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_hybrid": # DPT-Hybrid
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "midas_v21":
net_w, net_h = 384, 384
resize_mode = "upper_bound"
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
elif model_type == "midas_v21_small":
net_w, net_h = 256, 256
resize_mode = "upper_bound"
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
else:
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
]
)
return transform
def load_model(model_type):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load network
model_path = ISL_PATHS[model_type]
if model_type == "dpt_large": # DPT-Large
model = DPTDepthModel(
path=model_path,
backbone="vitl16_384",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_hybrid": # DPT-Hybrid
model = DPTDepthModel(
path=model_path,
backbone="vitb_rn50_384",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "midas_v21":
model = MidasNet(model_path, non_negative=True)
net_w, net_h = 384, 384
resize_mode = "upper_bound"
normalization = NormalizeImage(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
elif model_type == "midas_v21_small":
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
non_negative=True, blocks={'expand': True})
net_w, net_h = 256, 256
resize_mode = "upper_bound"
normalization = NormalizeImage(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
else:
print(f"model_type '{model_type}' not implemented, use: --model_type large")
assert False
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
]
)
return model.eval(), transform
class MiDaSInference(nn.Module):
MODEL_TYPES_TORCH_HUB = [
"DPT_Large",
"DPT_Hybrid",
"MiDaS_small"
]
MODEL_TYPES_ISL = [
"dpt_large",
"dpt_hybrid",
"midas_v21",
"midas_v21_small",
]
def __init__(self, model_type):
super().__init__()
assert (model_type in self.MODEL_TYPES_ISL)
model, _ = load_model(model_type)
self.model = model
self.model.train = disabled_train
def forward(self, x):
# x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
# NOTE: we expect that the correct transform has been called during dataloading.
with torch.no_grad():
prediction = self.model(x)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=x.shape[2:],
mode="bicubic",
align_corners=False,
)
assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
return prediction
import torch
class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device('cpu'))
if "optimizer" in parameters:
parameters = parameters["model"]
self.load_state_dict(parameters)
import torch
import torch.nn as nn
from .vit import (
_make_pretrained_vitb_rn50_384,
_make_pretrained_vitl16_384,
_make_pretrained_vitb16_384,
forward_vit,
)
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
if backbone == "vitl16_384":
pretrained = _make_pretrained_vitl16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[256, 512, 1024, 1024], features, groups=groups, expand=expand
) # ViT-L/16 - 85.0% Top1 (backbone)
elif backbone == "vitb_rn50_384":
pretrained = _make_pretrained_vitb_rn50_384(
use_pretrained,
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)
scratch = _make_scratch(
[256, 512, 768, 768], features, groups=groups, expand=expand
) # ViT-H/16 - 85.0% Top1 (backbone)
elif backbone == "vitb16_384":
pretrained = _make_pretrained_vitb16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[96, 192, 384, 768], features, groups=groups, expand=expand
) # ViT-B/16 - 84.6% Top1 (backbone)
elif backbone == "resnext101_wsl":
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
elif backbone == "efficientnet_lite3":
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
else:
print(f"Backbone '{backbone}' not implemented")
assert False
return pretrained, scratch
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand==True:
out_shape1 = out_shape
out_shape2 = out_shape*2
out_shape3 = out_shape*4
out_shape4 = out_shape*8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
efficientnet = torch.hub.load(
"rwightman/gen-efficientnet-pytorch",
"tf_efficientnet_lite3",
pretrained=use_pretrained,
exportable=exportable
)
return _make_efficientnet_backbone(efficientnet)
def _make_efficientnet_backbone(effnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
)
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
return pretrained
def _make_resnet_backbone(resnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
)
pretrained.layer2 = resnet.layer2
pretrained.layer3 = resnet.layer3
pretrained.layer4 = resnet.layer4
return pretrained
def _make_pretrained_resnext101_wsl(use_pretrained):
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
return _make_resnet_backbone(resnet)
class Interpolate(nn.Module):
"""Interpolation module.
"""
def __init__(self, scale_factor, mode, align_corners=False):
"""Init.
Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: interpolated data
"""
x = self.interp(
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
)
return x
class ResidualConvUnit(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
output += self.resConfUnit1(xs[1])
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=True
)
return output
class ResidualConvUnit_custom(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups=1
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
if self.bn==True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn==True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn==True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
# return out + x
class FeatureFusionBlock_custom(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_custom, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups=1
self.expand = expand
out_features = features
if self.expand==True:
out_features = features//2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
# output += res
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
)
output = self.out_conv(output)
return output
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_model import BaseModel
from .blocks import (
FeatureFusionBlock,
FeatureFusionBlock_custom,
Interpolate,
_make_encoder,
forward_vit,
)
def _make_fusion_block(features, use_bn):
return FeatureFusionBlock_custom(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
)
class DPT(BaseModel):
def __init__(
self,
head,
features=256,
backbone="vitb_rn50_384",
readout="project",
channels_last=False,
use_bn=False,
):
super(DPT, self).__init__()
self.channels_last = channels_last
hooks = {
"vitb_rn50_384": [0, 1, 8, 11],
"vitb16_384": [2, 5, 8, 11],
"vitl16_384": [5, 11, 17, 23],
}
# Instantiate backbone and reassemble blocks
self.pretrained, self.scratch = _make_encoder(
backbone,
features,
False, # Set to true of you want to train from scratch, uses ImageNet weights
groups=1,
expand=False,
exportable=False,
hooks=hooks[backbone],
use_readout=readout,
)
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
self.scratch.output_conv = head
def forward(self, x):
if self.channels_last == True:
x.contiguous(memory_format=torch.channels_last)
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return out
class DPTDepthModel(DPT):
def __init__(self, path=None, non_negative=True, **kwargs):
features = kwargs["features"] if "features" in kwargs else 256
head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
nn.Identity(),
)
super().__init__(head, **kwargs)
if path is not None:
self.load(path)
def forward(self, x):
return super().forward(x).squeeze(dim=1)
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn
from .base_model import BaseModel
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
class MidasNet(BaseModel):
"""Network for monocular depth estimation.
"""
def __init__(self, path=None, features=256, non_negative=True):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print("Loading weights: ", path)
super(MidasNet, self).__init__()
use_pretrained = False if path is None else True
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
self.scratch.refinenet4 = FeatureFusionBlock(features)
self.scratch.refinenet3 = FeatureFusionBlock(features)
self.scratch.refinenet2 = FeatureFusionBlock(features)
self.scratch.refinenet1 = FeatureFusionBlock(features)
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear"),
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
)
if path:
self.load(path)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
layer_1 = self.pretrained.layer1(x)
layer_2 = self.pretrained.layer2(layer_1)
layer_3 = self.pretrained.layer3(layer_2)
layer_4 = self.pretrained.layer4(layer_3)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return torch.squeeze(out, dim=1)
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn
from .base_model import BaseModel
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
class MidasNet_small(BaseModel):
"""Network for monocular depth estimation.
"""
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
blocks={'expand': True}):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print("Loading weights: ", path)
super(MidasNet_small, self).__init__()
use_pretrained = False if path else True
self.channels_last = channels_last
self.blocks = blocks
self.backbone = backbone
self.groups = 1
features1=features
features2=features
features3=features
features4=features
self.expand = False
if "expand" in self.blocks and self.blocks['expand'] == True:
self.expand = True
features1=features
features2=features*2
features3=features*4
features4=features*8
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
self.scratch.activation = nn.ReLU(False)
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
Interpolate(scale_factor=2, mode="bilinear"),
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
self.scratch.activation,
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
nn.Identity(),
)
if path:
self.load(path)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
if self.channels_last==True:
print("self.channels_last = ", self.channels_last)
x.contiguous(memory_format=torch.channels_last)
layer_1 = self.pretrained.layer1(x)
layer_2 = self.pretrained.layer2(layer_1)
layer_3 = self.pretrained.layer3(layer_2)
layer_4 = self.pretrained.layer4(layer_3)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return torch.squeeze(out, dim=1)
def fuse_model(m):
prev_previous_type = nn.Identity()
prev_previous_name = ''
previous_type = nn.Identity()
previous_name = ''
for name, module in m.named_modules():
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
# print("FUSED ", prev_previous_name, previous_name, name)
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
# print("FUSED ", prev_previous_name, previous_name)
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
# print("FUSED ", previous_name, name)
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
prev_previous_type = previous_type
prev_previous_name = previous_name
previous_type = type(module)
previous_name = name
\ No newline at end of file
import numpy as np
import cv2
import math
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
Args:
sample (dict): sample
size (tuple): image size
Returns:
tuple: new size
"""
shape = list(sample["disparity"].shape)
if shape[0] >= size[0] and shape[1] >= size[1]:
return sample
scale = [0, 0]
scale[0] = size[0] / shape[0]
scale[1] = size[1] / shape[1]
scale = max(scale)
shape[0] = math.ceil(scale * shape[0])
shape[1] = math.ceil(scale * shape[1])
# resize
sample["image"] = cv2.resize(
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
)
sample["disparity"] = cv2.resize(
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
tuple(shape[::-1]),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return tuple(shape)
class Resize(object):
"""Resize sample to given size (width, height).
"""
def __init__(
self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_AREA,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == "lower_bound":
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "upper_bound":
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "minimal":
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(
f"resize_method {self.__resize_method} not implemented"
)
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(
scale_height * height, min_val=self.__height
)
new_width = self.constrain_to_multiple_of(
scale_width * width, min_val=self.__width
)
elif self.__resize_method == "upper_bound":
new_height = self.constrain_to_multiple_of(
scale_height * height, max_val=self.__height
)
new_width = self.constrain_to_multiple_of(
scale_width * width, max_val=self.__width
)
elif self.__resize_method == "minimal":
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(
sample["image"].shape[1], sample["image"].shape[0]
)
# resize sample
sample["image"] = cv2.resize(
sample["image"],
(width, height),
interpolation=self.__image_interpolation_method,
)
if self.__resize_target:
if "disparity" in sample:
sample["disparity"] = cv2.resize(
sample["disparity"],
(width, height),
interpolation=cv2.INTER_NEAREST,
)
if "depth" in sample:
sample["depth"] = cv2.resize(
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
(width, height),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return sample
class NormalizeImage(object):
"""Normlize image by given mean and std.
"""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample["image"] = (sample["image"] - self.__mean) / self.__std
return sample
class PrepareForNet(object):
"""Prepare sample for usage as network input.
"""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample["image"], (2, 0, 1))
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
if "mask" in sample:
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])
if "disparity" in sample:
disparity = sample["disparity"].astype(np.float32)
sample["disparity"] = np.ascontiguousarray(disparity)
if "depth" in sample:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)
return sample
import torch
import torch.nn as nn
import timm
import types
import math
import torch.nn.functional as F
class Slice(nn.Module):
def __init__(self, start_index=1):
super(Slice, self).__init__()
self.start_index = start_index
def forward(self, x):
return x[:, self.start_index :]
class AddReadout(nn.Module):
def __init__(self, start_index=1):
super(AddReadout, self).__init__()
self.start_index = start_index
def forward(self, x):
if self.start_index == 2:
readout = (x[:, 0] + x[:, 1]) / 2
else:
readout = x[:, 0]
return x[:, self.start_index :] + readout.unsqueeze(1)
class ProjectReadout(nn.Module):
def __init__(self, in_features, start_index=1):
super(ProjectReadout, self).__init__()
self.start_index = start_index
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
def forward(self, x):
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
features = torch.cat((x[:, self.start_index :], readout), -1)
return self.project(features)
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
x = x.transpose(self.dim0, self.dim1)
return x
def forward_vit(pretrained, x):
b, c, h, w = x.shape
glob = pretrained.model.forward_flex(x)
layer_1 = pretrained.activations["1"]
layer_2 = pretrained.activations["2"]
layer_3 = pretrained.activations["3"]
layer_4 = pretrained.activations["4"]
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
unflatten = nn.Sequential(
nn.Unflatten(
2,
torch.Size(
[
h // pretrained.model.patch_size[1],
w // pretrained.model.patch_size[0],
]
),
)
)
if layer_1.ndim == 3:
layer_1 = unflatten(layer_1)
if layer_2.ndim == 3:
layer_2 = unflatten(layer_2)
if layer_3.ndim == 3:
layer_3 = unflatten(layer_3)
if layer_4.ndim == 3:
layer_4 = unflatten(layer_4)
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
return layer_1, layer_2, layer_3, layer_4
def _resize_pos_embed(self, posemb, gs_h, gs_w):
posemb_tok, posemb_grid = (
posemb[:, : self.start_index],
posemb[0, self.start_index :],
)
gs_old = int(math.sqrt(len(posemb_grid)))
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def forward_flex(self, x):
b, c, h, w = x.shape
pos_embed = self._resize_pos_embed(
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
)
B = x.shape[0]
if hasattr(self.patch_embed, "backbone"):
x = self.patch_embed.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
if getattr(self, "dist_token", None) is not None:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
else:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
activations = {}
def get_activation(name):
def hook(model, input, output):
activations[name] = output
return hook
def get_readout_oper(vit_features, features, use_readout, start_index=1):
if use_readout == "ignore":
readout_oper = [Slice(start_index)] * len(features)
elif use_readout == "add":
readout_oper = [AddReadout(start_index)] * len(features)
elif use_readout == "project":
readout_oper = [
ProjectReadout(vit_features, start_index) for out_feat in features
]
else:
assert (
False
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
return readout_oper
def _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
size=[384, 384],
hooks=[2, 5, 8, 11],
vit_features=768,
use_readout="ignore",
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
# 32, 48, 136, 384
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
hooks = [5, 11, 17, 23] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[256, 512, 1024, 1024],
hooks=hooks,
vit_features=1024,
use_readout=use_readout,
)
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
)
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
)
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model(
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
hooks=hooks,
use_readout=use_readout,
start_index=2,
)
def _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=[0, 1, 8, 11],
vit_features=768,
use_vit_only=False,
use_readout="ignore",
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
if use_vit_only == True:
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
else:
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
get_activation("1")
)
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
get_activation("2")
)
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
if use_vit_only == True:
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
else:
pretrained.act_postprocess1 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess2 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitb_rn50_384(
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
):
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
hooks = [0, 1, 8, 11] if hooks == None else hooks
return _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)
"""Utils for monoDepth."""
import sys
import re
import numpy as np
import cv2
import torch
def read_pfm(path):
"""Read pfm file.
Args:
path (str): path to file
Returns:
tuple: (data, scale)
"""
with open(path, "rb") as file:
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header.decode("ascii") == "PF":
color = True
elif header.decode("ascii") == "Pf":
color = False
else:
raise Exception("Not a PFM file: " + path)
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
if dim_match:
width, height = list(map(int, dim_match.groups()))
else:
raise Exception("Malformed PFM header.")
scale = float(file.readline().decode("ascii").rstrip())
if scale < 0:
# little-endian
endian = "<"
scale = -scale
else:
# big-endian
endian = ">"
data = np.fromfile(file, endian + "f")
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data, scale
def write_pfm(path, image, scale=1):
"""Write pfm file.
Args:
path (str): pathto file
image (array): data
scale (int, optional): Scale. Defaults to 1.
"""
with open(path, "wb") as file:
color = None
if image.dtype.name != "float32":
raise Exception("Image dtype must be float32.")
image = np.flipud(image)
if len(image.shape) == 3 and image.shape[2] == 3: # color image
color = True
elif (
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
): # greyscale
color = False
else:
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
file.write("PF\n" if color else "Pf\n".encode())
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
endian = image.dtype.byteorder
if endian == "<" or endian == "=" and sys.byteorder == "little":
scale = -scale
file.write("%f\n".encode() % scale)
image.tofile(file)
def read_image(path):
"""Read image and output RGB image (0-1).
Args:
path (str): path to file
Returns:
array: RGB image (0-1)
"""
img = cv2.imread(path)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
return img
def resize_image(img):
"""Resize image and make it fit for network.
Args:
img (array): image
Returns:
tensor: data ready for network
"""
height_orig = img.shape[0]
width_orig = img.shape[1]
if width_orig > height_orig:
scale = width_orig / 384
else:
scale = height_orig / 384
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
img_resized = (
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
)
img_resized = img_resized.unsqueeze(0)
return img_resized
def resize_depth(depth, width, height):
"""Resize depth map and bring to CPU (numpy).
Args:
depth (tensor): depth
width (int): image width
height (int): image height
Returns:
array: processed depth
"""
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
depth_resized = cv2.resize(
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
)
return depth_resized
def write_depth(path, depth, bits=1):
"""Write depth map to pfm and png file.
Args:
path (str): filepath without extension
depth (array): depth
"""
write_pfm(path + ".pfm", depth.astype(np.float32))
depth_min = depth.min()
depth_max = depth.max()
max_val = (2**(8*bits))-1
if depth_max - depth_min > np.finfo("float").eps:
out = max_val * (depth - depth_min) / (depth_max - depth_min)
else:
out = np.zeros(depth.shape, dtype=depth.type)
if bits == 1:
cv2.imwrite(path + ".png", out.astype("uint8"))
elif bits == 2:
cv2.imwrite(path + ".png", out.astype("uint16"))
return
import importlib
import torch
from torch import optim
import numpy as np
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
# font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
font = ImageFont.load_default()
nc = int(40 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x,torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
return total_params
def instantiate_from_config(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
class AdamWwithEMAandWings(optim.Optimizer):
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
ema_power=1., param_names=()):
"""AdamW that saves EMA versions of the parameters."""
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= ema_decay <= 1.0:
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
ema_power=ema_power, param_names=param_names)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
ema_params_with_grad = []
state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
beta1, beta2 = group['betas']
ema_decay = group['ema_decay']
ema_power = group['ema_power']
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of parameter values
state['param_exp_avg'] = p.detach().float().clone()
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
ema_params_with_grad.append(state['param_exp_avg'])
if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'])
optim._functional.adamw(params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
maximize=False)
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
return loss
\ No newline at end of file
from typing import Dict, Any
import os
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.types import STEP_OUTPUT
import torch
import torchvision
from PIL import Image
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.distributed import rank_zero_only
from .mixins import ImageLoggerMixin
__all__ = [
"ModelCheckpoint",
"ImageLogger"
]
class ImageLogger(Callback):
"""
Log images during training or validating.
TODO: Support validating.
"""
def __init__(
self,
log_every_n_steps: int=2000,
max_images_each_step: int=4,
log_images_kwargs: Dict[str, Any]=None
) -> "ImageLogger":
super().__init__()
self.log_every_n_steps = log_every_n_steps
self.max_images_each_step = max_images_each_step
self.log_images_kwargs = log_images_kwargs or dict()
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
assert isinstance(pl_module, ImageLoggerMixin)
@rank_zero_only
def on_train_batch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT,
batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if pl_module.global_step % self.log_every_n_steps == 0:
is_train = pl_module.training
if is_train:
pl_module.freeze()
with torch.no_grad():
# returned images should be: nchw, rgb, [0, 1]
images: Dict[str, torch.Tensor] = pl_module.log_images(batch, **self.log_images_kwargs)
# save images
save_dir = os.path.join(pl_module.logger.save_dir, "image_log", "train")
os.makedirs(save_dir, exist_ok=True)
for image_key in images:
image = images[image_key].detach().cpu()
N = min(self.max_images_each_step, len(image))
grid = torchvision.utils.make_grid(image[:N], nrow=4)
# chw -> hwc (hw if gray)
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1).numpy()
grid = (grid * 255).clip(0, 255).astype(np.uint8)
filename = "{}_step-{:06}_e-{:06}_b-{:06}.png".format(
image_key, pl_module.global_step, pl_module.current_epoch, batch_idx
)
path = os.path.join(save_dir, filename)
Image.fromarray(grid).save(path)
if is_train:
pl_module.unfreeze()
from typing import Mapping, Any
import copy
from collections import OrderedDict
import einops
import torch
import torch as th
import torch.nn as nn
from ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module,
timestep_embedding,
)
from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock, UNetModel
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.util import log_txt_as_img, exists, instantiate_from_config
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from utils.common import frozen_module
from .spaced_sampler import SpacedSampler
class ControlledUnetModel(UNetModel):
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
hs = []
with torch.no_grad():
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
if control is not None:
h += control.pop()
for i, module in enumerate(self.output_blocks):
if only_mid_control or control is None:
h = torch.cat([h, hs.pop()], dim=1)
else:
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
return self.out(h)
class ControlNet(nn.Module):
def __init__(
self,
image_size,
in_channels,
model_channels,
hint_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.dims = dims
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels + hint_channels, model_channels, 3, padding=1)
)
]
)
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch))
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self.middle_block_out = self.make_zero_conv(ch)
self._feature_size += ch
def make_zero_conv(self, channels):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, timesteps, context, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
x = torch.cat((x, hint), dim=1)
outs = []
h = x.type(self.dtype)
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
h = module(h, emb, context)
outs.append(zero_conv(h, emb, context))
h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context))
return outs
class ControlLDM(LatentDiffusion):
def __init__(
self,
control_stage_config: Mapping[str, Any],
control_key: str,
sd_locked: bool,
only_mid_control: bool,
learning_rate: float,
preprocess_config,
*args,
**kwargs
) -> "ControlLDM":
super().__init__(*args, **kwargs)
# instantiate control module
self.control_model: ControlNet = instantiate_from_config(control_stage_config)
self.control_key = control_key
self.sd_locked = sd_locked
self.only_mid_control = only_mid_control
self.learning_rate = learning_rate
self.control_scales = [1.0] * 13
# instantiate preprocess module (SwinIR)
self.preprocess_model = instantiate_from_config(preprocess_config)
frozen_module(self.preprocess_model)
# instantiate condition encoder, since our condition encoder has the same
# structure with AE encoder, we just make a copy of AE encoder. please
# note that AE encoder's parameters has not been initialized here.
self.cond_encoder = nn.Sequential(OrderedDict([
("encoder", copy.deepcopy(self.first_stage_model.encoder)), # cond_encoder.encoder
("quant_conv", copy.deepcopy(self.first_stage_model.quant_conv)) # cond_encoder.quant_conv
]))
frozen_module(self.cond_encoder)
def apply_condition_encoder(self, control):
c_latent_meanvar = self.cond_encoder(control * 2 - 1)
c_latent = DiagonalGaussianDistribution(c_latent_meanvar).mode() # only use mode
c_latent = c_latent * self.scale_factor
return c_latent
@torch.no_grad()
def get_input(self, batch, k, bs=None, *args, **kwargs):
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
control = batch[self.control_key]
if bs is not None:
control = control[:bs]
control = control.to(self.device)
control = einops.rearrange(control, 'b h w c -> b c h w')
control = control.to(memory_format=torch.contiguous_format).float()
lq = control
# apply preprocess model
control = self.preprocess_model(control)
# apply condition encoder
c_latent = self.apply_condition_encoder(control)
return x, dict(c_crossattn=[c], c_latent=[c_latent], lq=[lq], c_concat=[control])
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
assert isinstance(cond, dict)
diffusion_model = self.model.diffusion_model
cond_txt = torch.cat(cond['c_crossattn'], 1)
if cond['c_latent'] is None:
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
else:
control = self.control_model(
x=x_noisy, hint=torch.cat(cond['c_latent'], 1),
timesteps=t, context=cond_txt
)
control = [c * scale for c, scale in zip(control, self.control_scales)]
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
return eps
@torch.no_grad()
def get_unconditional_conditioning(self, N):
return self.get_learned_conditioning([""] * N)
@torch.no_grad()
def log_images(self, batch, sample_steps=50):
log = dict()
z, c = self.get_input(batch, self.first_stage_key)
c_lq = c["lq"][0]
c_latent = c["c_latent"][0]
c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
log["hq"] = (self.decode_first_stage(z) + 1) / 2
log["control"] = c_cat
log["decoded_control"] = (self.decode_first_stage(c_latent) + 1) / 2
log["lq"] = c_lq
log["text"] = (log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16) + 1) / 2
samples = self.sample_log(
# TODO: remove c_concat from cond
cond={"c_concat": [c_cat], "c_crossattn": [c], "c_latent": [c_latent]},
steps=sample_steps
)
x_samples = self.decode_first_stage(samples)
log["samples"] = (x_samples + 1) / 2
return log
@torch.no_grad()
def sample_log(self, cond, steps):
sampler = SpacedSampler(self)
b, c, h, w = cond["c_concat"][0].shape
shape = (b, self.channels, h // 8, w // 8)
samples = sampler.sample(
steps, shape, cond, unconditional_guidance_scale=1.0,
unconditional_conditioning=None
)
return samples
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.control_model.parameters())
if not self.sd_locked:
params += list(self.model.diffusion_model.output_blocks.parameters())
params += list(self.model.diffusion_model.out.parameters())
opt = torch.optim.AdamW(params, lr=lr)
return opt
def validation_step(self, batch, batch_idx):
# TODO:
pass
from typing import overload
import torch
from torch.nn import functional as F
class Guidance:
def __init__(self, scale, type, t_start, t_stop, space, repeat, loss_type):
self.scale = scale
self.type = type
self.t_start = t_start
self.t_stop = t_stop
self.target = None
self.space = space
self.repeat = repeat
self.loss_type = loss_type
def load_target(self, target):
self.target = target
def __call__(self, target_x0, pred_x0, t):
if self.t_stop < t and t < self.t_start:
# print("sampling with classifier guidance")
# avoid propagating gradient out of this scope
pred_x0 = pred_x0.detach().clone()
target_x0 = target_x0.detach().clone()
return self.scale * self._forward(target_x0, pred_x0)
else:
return None
@overload
def _forward(self, target_x0, pred_x0): ...
class MSEGuidance(Guidance):
def __init__(self, scale, type, t_start, t_stop, space, repeat, loss_type) -> None:
super().__init__(
scale, type, t_start, t_stop, space, repeat, loss_type
)
@torch.enable_grad()
def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor):
# inputs: [-1, 1], nchw, rgb
pred_x0.requires_grad_(True)
if self.loss_type == "mse":
loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum()
elif self.loss_type == "downsample_mse":
# FIXME: scale_factor should be 1/4, not 4
lr_pred_x0 = F.interpolate(pred_x0, scale_factor=4, mode="bicubic")
lr_target_x0 = F.interpolate(target_x0, scale_factor=4, mode="bicubic")
loss = (lr_pred_x0 - lr_target_x0).pow(2).mean((1, 2, 3)).sum()
else:
raise ValueError(self.loss_type)
print(f"loss = {loss.item()}")
return -torch.autograd.grad(loss, pred_x0)[0]
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
model_t = self.model.apply_model(x, t, c)
model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec
\ No newline at end of file
from typing import overload, Any, Dict
import torch
class ImageLoggerMixin:
@overload
def log_images(self, batch: Any, **kwargs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
...
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from ldm.modules.diffusionmodules.util import make_beta_schedule
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim") :])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}"
)
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/gaussian_diffusion.py
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
class SpacedSampler:
def __init__(self, model, schedule="linear", var_type: str="fixed_small"):
self.model = model
self.original_num_steps = model.num_timesteps
self.schedule = schedule
self.var_type = var_type
def make_schedule(self, num_steps):
# NOTE: this schedule, which generates betas linearly in log space, is a little different
# from guided diffusion.
original_betas = make_beta_schedule(self.schedule, self.original_num_steps, linear_start=self.model.linear_start,
linear_end=self.model.linear_end)
original_alphas = 1.0 - original_betas
original_alphas_cumprod = np.cumprod(original_alphas, axis=0)
# calcualte betas for spaced sampling
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
used_timesteps = space_timesteps(self.original_num_steps, str(num_steps))
# print(f"timesteps used in spaced sampler: \n\t{used_timesteps}")
betas = []
last_alpha_cumprod = 1.0
for i, alpha_cumprod in enumerate(original_alphas_cumprod):
if i in used_timesteps:
# marginal distribution is the same as q(x_{S_t}|x_0)
betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
assert len(betas) == num_steps
betas = np.array(betas, dtype=np.float64)
self.betas = betas
self.timesteps = np.array(sorted(list(used_timesteps)), dtype=np.int32) # e.g. [0, 10, 20, ...]
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
assert self.alphas_cumprod_prev.shape == (num_steps,)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
# log calculation clipped because the posterior variance is 0 at the
# beginning of the diffusion chain.
self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1], self.posterior_variance[1:])
)
self.posterior_mean_coef1 = (
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev)
* np.sqrt(alphas)
/ (1.0 - self.alphas_cumprod)
)
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
assert noise.shape == x_start.shape
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
@torch.no_grad()
def sample(
self,
steps,
shape,
conditioning=None,
x_T=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
cond_fn=None # for classifier guidance
):
self.make_schedule(num_steps=steps)
samples = self.sapced_sampling(
conditioning, shape, x_T=x_T,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
cond_fn=cond_fn
)
return samples
@torch.no_grad()
def sapced_sampling(
self, cond, shape, x_T,
unconditional_guidance_scale, unconditional_conditioning,
cond_fn
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
print("start to sample from a given noise")
img = x_T
time_range = np.flip(self.timesteps) # [1000, 950, 900, ...]
total_steps = len(self.timesteps)
print(f"Running Spaced Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Spaced Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1 # t in guided diffusion
ts = torch.full((b,), step, device=device, dtype=torch.long)
img = self.p_sample_spaced(img, cond, ts, index=index, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
cond_fn=cond_fn)
return img
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def predict_noise(self, x, t, c, unconditional_guidance_scale, unconditional_conditioning):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
model_t = self.model.apply_model(x, t, c)
model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
return e_t
def apply_cond_fn(self, x, c, t, index, cond_fn, unconditional_guidance_scale,
unconditional_conditioning):
device = x.device
t_now = int(t[0].item()) + 1
# ----------------- predict noise and x0 ----------------- #
e_t = self.predict_noise(
x, t, c, unconditional_guidance_scale, unconditional_conditioning
)
pred_x0: torch.Tensor = self._predict_xstart_from_eps(x_t=x, t=index, eps=e_t)
model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_x0, x_t=x, t=index
)
# apply classifier guidance for multiple times
for _ in range(cond_fn.repeat):
# ----------------- compute gradient for x0 in latent space ----------------- #
target, pred = None, None
if cond_fn.space == "latent":
target = self.model.get_first_stage_encoding(
self.model.encode_first_stage(cond_fn.target.to(device))
)
pred = pred_x0
elif cond_fn.space == "rgb":
# We need to backward gradient to x0 in latent space, so it's required
# to trace the computation graph while decoding the latent.
with torch.enable_grad():
pred_x0.requires_grad_(True)
target = cond_fn.target.to(device)
pred = self.model.decode_first_stage_with_grad(pred_x0)
else:
raise NotImplementedError(cond_fn.space)
delta_pred = cond_fn(target, pred, t_now)
# ----------------- apply classifier guidance ----------------- #
if delta_pred is not None:
if cond_fn.space == "rgb":
# compute gradient for pred_x0
pred.backward(delta_pred)
delta_pred_x0 = pred_x0.grad
# update prex_x0
pred_x0 += delta_pred_x0
# our classifier guidance is equivalent to multiply delta_pred_x0
# by a constant and then add it to model_mean, We set the constant
# to 0.5
model_mean += 0.5 * delta_pred_x0
pred_x0.grad.zero_()
else:
delta_pred_x0 = delta_pred
pred_x0 += delta_pred_x0
model_mean += 0.5 * delta_pred_x0
else:
# means stop guidance
break
return model_mean.detach().clone(), pred_x0.detach().clone()
@torch.no_grad()
def p_sample_spaced(
self, x: torch.Tensor, c, t, index,
unconditional_guidance_scale,
unconditional_conditioning, cond_fn
):
index = torch.full_like(t, fill_value=index)
model_variance = {
"fixed_large": np.append(self.posterior_variance[1], self.betas[1:]),
"fixed_small": self.posterior_variance
}[self.var_type]
model_variance = _extract_into_tensor(model_variance, index, x.shape)
if cond_fn is not None:
model_mean, pred_x0 = self.apply_cond_fn(
x, c, t, index, cond_fn,
unconditional_guidance_scale, unconditional_conditioning
)
else:
e_t = self.predict_noise(
x, t, c,
unconditional_guidance_scale, unconditional_conditioning
)
pred_x0 = self._predict_xstart_from_eps(x_t=x, t=index, eps=e_t)
model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_x0, x_t=x, t=index
)
noise = torch.randn_like(x)
nonzero_mask = (
(index != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
# TODO: use log variance ?
x_prev = model_mean + nonzero_mask * torch.sqrt(model_variance) * noise
return x_prev
# -----------------------------------------------------------------------------------
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
# Originally Written by Ze Liu, Modified by Jingyun Liang.
# -----------------------------------------------------------------------------------
import math
from typing import Any, Dict, Set
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import pytorch_lightning as pl
from pytorch_lightning.utilities.types import STEP_OUTPUT
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange
from utils.metrics import calculate_psnr_pt, LPIPS
from .mixins import ImageLoggerMixin
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
attn_mask = self.calculate_mask(self.input_resolution)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
if self.input_resolution == x_size:
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, x_size):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, x_size)
else:
x = blk(x, x_size)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
img_size: Input image size.
patch_size: Patch size.
resi_connection: The convolutional block before residual connection.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
img_size=224, patch_size=4, resi_connection='1conv'):
super(RSTB, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = BasicLayer(dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint)
if resi_connection == '1conv':
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim, 3, 1, 1))
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
norm_layer=None)
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
norm_layer=None)
def forward(self, x, x_size):
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
def flops(self):
flops = 0
flops += self.residual_group.flops()
H, W = self.input_resolution
flops += H * W * self.dim * self.dim * 9
flops += self.patch_embed.flops()
flops += self.patch_unembed.flops()
return flops
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
flops = 0
H, W = self.img_size
if self.norm is not None:
flops += H * W * self.embed_dim
return flops
class PatchUnEmbed(nn.Module):
r""" Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
B, HW, C = x.shape
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
return x
def flops(self):
flops = 0
return flops
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
Used in lightweight SR to save parameters.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
super(UpsampleOneStep, self).__init__(*m)
def flops(self):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
class SwinIR(pl.LightningModule, ImageLoggerMixin):
r""" SwinIR
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
sf: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 255.
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
def __init__(
self,
img_size=64,
patch_size=1,
in_chans=3,
embed_dim=96,
depths=[6, 6, 6, 6],
num_heads=[6, 6, 6, 6],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
use_checkpoint=False,
sf=4,
img_range=1.,
upsampler='',
resi_connection='1conv',
unshuffle=False,
unshuffle_scale=None,
hq_key: str="jpg",
lq_key: str="hint",
learning_rate: float=None,
weight_decay: float=None
) -> "SwinIR":
super(SwinIR, self).__init__()
num_in_ch = in_chans * (unshuffle_scale**2) if unshuffle else in_chans
num_out_ch = in_chans
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
self.upscale = sf
self.upsampler = upsampler
self.window_size = window_size
self.unshuffle_scale = unshuffle_scale
self.unshuffle = unshuffle
#####################################################################################################
################################### 1, shallow feature extraction ###################################
if unshuffle:
assert unshuffle_scale is not None
self.conv_first = nn.Sequential(
nn.PixelUnshuffle(sf),
nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1),
)
else:
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
#####################################################################################################
################################### 2, deep feature extraction ######################################
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None
)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build Residual Swin Transformer blocks (RSTB)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RSTB(
dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv_after_body = nn.Sequential(
nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)
)
#####################################################################################################
################################ 3, high quality image reconstruction ################################
if self.upsampler == 'pixelshuffle':
# for classical SR
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True)
)
self.upsample = Upsample(sf, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(
sf, embed_dim, num_out_ch,
(patches_resolution[0], patches_resolution[1])
)
elif self.upsampler == 'nearest+conv':
# for real-world SR (less artifacts)
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True)
)
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
if self.upscale == 4:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
elif self.upscale == 8:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
# for image denoising and JPEG compression artifact reduction
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
self.apply(self._init_weights)
self.hq_key = hq_key
self.lq_key = lq_key
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.lpips_metric = LPIPS(net="alex")
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
# TODO: What's this ?
@torch.jit.ignore
def no_weight_decay(self) -> Set[str]:
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self) -> Set[str]:
return {'relative_position_bias_table'}
def check_image_size(self, x: torch.Tensor) -> torch.Tensor:
_, _, h, w = x.size()
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
return x
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
H, W = x.shape[2:]
x = self.check_image_size(x)
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
if self.upsampler == 'pixelshuffle':
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.upsample(x)
elif self.upsampler == 'nearest+conv':
# for real-world SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
if self.upscale == 4:
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
elif self.upscale == 8:
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.lrelu(self.conv_up3(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
x = x / self.img_range + self.mean
return x[:, :, :H*self.upscale, :W*self.upscale]
def flops(self) -> int:
flops = 0
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
return flops
def get_loss(self, pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
"""
Compute loss between model predictions and labels.
Args:
pred (torch.Tensor): Batch model predictions.
label (torch.Tensor): Batch labels.
Returns:
loss (torch.Tensor): The loss tensor.
"""
return F.mse_loss(input=pred, target=label, reduction="sum")
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTPUT:
"""
Args:
batch (Dict[str, torch.Tensor]): A dict contains LQ and HQ (NHWC, RGB,
LQ range in [0, 1] and HQ range in [-1, 1]).
batch_idx (int): Index of this batch.
Returns:
outputs (torch.Tensor): The loss tensor.
"""
hq, lq = batch[self.hq_key], batch[self.lq_key]
hq = rearrange(((hq + 1) / 2).clamp_(0, 1), "n h w c -> n c h w")
lq = rearrange(lq, "n h w c -> n c h w")
pred = self(lq)
loss = self.get_loss(pred, hq)
self.log("train_loss", loss, on_step=True)
return loss
def on_validation_start(self) -> None:
self.lpips_metric.to(self.device)
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None:
hq, lq = batch[self.hq_key], batch[self.lq_key]
lq = rearrange(lq, "n h w c -> n c h w")
pred = self(lq)
hq = rearrange(((hq + 1) / 2).clamp_(0, 1), "n h w c -> n c h w")
# requiremtns for lpips model inputs:
# https://github.com/richzhang/PerceptualSimilarity
lpips = self.lpips_metric(pred, hq, normalize=True).mean()
self.log("val_lpips", lpips)
pnsr = calculate_psnr_pt(pred, hq, crop_border=0).mean()
self.log("val_pnsr", pnsr)
loss = self.get_loss(pred, hq)
self.log("val_loss", loss)
def configure_optimizers(self) -> optim.AdamW:
"""
Configure optimizer for this model.
Returns:
optimizer (optim.AdamW): The optimizer for this model.
"""
optimizer = optim.AdamW(
[p for p in self.parameters() if p.requires_grad], lr=self.learning_rate,
weight_decay=self.weight_decay
)
return optimizer
@torch.no_grad()
def log_images(self, batch: Any) -> Dict[str, torch.Tensor]:
hq, lq = batch[self.hq_key], batch[self.lq_key]
hq = rearrange(((hq + 1) / 2).clamp_(0, 1), "n h w c -> n c h w")
lq = rearrange(lq, "n h w c -> n c h w")
pred = self(lq)
return dict(lq=lq, pred=pred, hq=hq)
import sys
sys.path.append(".")
import os
from argparse import ArgumentParser, Namespace
import pytorch_lightning as pl
from omegaconf import OmegaConf
import torch
from PIL import Image
import numpy as np
from tqdm import tqdm
from utils.image import auto_resize, pad
from utils.common import load_state_dict, instantiate_from_config
from utils.file import list_image_files, get_file_name_parts
def parse_args() -> Namespace:
parser = ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--input", type=str, required=True)
parser.add_argument("--sr_scale", type=float, default=1)
parser.add_argument("--image_size", type=int, default=512)
parser.add_argument("--show_lq", action="store_true")
parser.add_argument("--resize_back", action="store_true")
parser.add_argument("--output", type=str, required=True)
parser.add_argument("--skip_if_exist", action="store_true")
parser.add_argument("--seed", type=int, default=231)
return parser.parse_args()
@torch.no_grad()
def main():
args = parse_args()
pl.seed_everything(args.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
model: pl.LightningModule = instantiate_from_config(OmegaConf.load(args.config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
model.freeze()
model.to(device)
assert os.path.isdir(args.input)
pbar = tqdm(list_image_files(args.input, follow_links=True))
for file_path in pbar:
pbar.set_description(file_path)
save_path = os.path.join(args.output, os.path.relpath(file_path, args.input))
parent_path, stem, _ = get_file_name_parts(save_path)
save_path = os.path.join(parent_path, f"{stem}.png")
if os.path.exists(save_path):
if args.skip_if_exist:
print(f"skip {save_path}")
continue
else:
raise RuntimeError(f"{save_path} already exist")
os.makedirs(parent_path, exist_ok=True)
# load low-quality image and resize
lq = Image.open(file_path).convert("RGB")
if args.sr_scale != 1:
lq = lq.resize(
tuple(int(x * args.sr_scale) for x in lq.size), Image.BICUBIC
)
lq_resized = auto_resize(lq, args.image_size)
# padding
x = pad(np.array(lq_resized), scale=64)
x = torch.tensor(x, dtype=torch.float32, device=device) / 255.0
x = x.permute(2, 0, 1).unsqueeze(0).contiguous()
try:
pred = model(x).detach().squeeze(0).permute(1, 2, 0) * 255
pred = pred.clamp(0, 255).to(torch.uint8).cpu().numpy()
except RuntimeError as e:
print(f"inference failed, error: {e}")
continue
# remove padding
pred = pred[:lq_resized.height, :lq_resized.width, :]
if args.show_lq:
if args.resize_back:
lq = np.array(lq)
if lq_resized.size != lq.size:
pred = np.array(Image.fromarray(pred).resize(lq.size, Image.LANCZOS))
else:
lq = np.array(lq_resized)
final_image = Image.fromarray(np.concatenate([lq, pred], axis=1))
else:
if args.resize_back and lq_resized.size != lq.size:
final_image = Image.fromarray(pred).resize(lq.size, Image.LANCZOS)
else:
final_image = Image.fromarray(pred)
final_image.save(save_path)
if __name__ == "__main__":
main()
import sys
sys.path.append(".")
import os
from argparse import ArgumentParser
from utils.file import list_image_files
parser = ArgumentParser()
parser.add_argument("--img_folder", type=str, required=True)
parser.add_argument("--val_size", type=int, required=True)
parser.add_argument("--save_folder", type=str, required=True)
parser.add_argument("--follow_links", action="store_true")
args = parser.parse_args()
files = list_image_files(
args.img_folder, exts=(".jpg", ".png", ".jpeg"), follow_links=args.follow_links,
log_progress=True, log_every_n_files=10000
)
print(f"find {len(files)} images in {args.img_folder}")
assert args.val_size < len(files)
val_files = files[:args.val_size]
train_files = files[args.val_size:]
os.makedirs(args.save_folder, exist_ok=True)
with open(os.path.join(args.save_folder, "train.list"), "w") as fp:
for file_path in train_files:
fp.write(f"{file_path}\n")
with open(os.path.join(args.save_folder, "val.list"), "w") as fp:
for file_path in val_files:
fp.write(f"{file_path}\n")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment