Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a7e8159d
Commit
a7e8159d
authored
Nov 08, 2022
by
Maruyama_Aya
Browse files
add ColoDiffusion codes: /ldm/module/, /ldm/data/, /scripts/test/
parent
441d584e
Changes
30
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
3271 additions
and
0 deletions
+3271
-0
examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py
.../images/diffusion/ldm/modules/image_degradation/bsrgan.py
+730
-0
examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py
...s/diffusion/ldm/modules/image_degradation/bsrgan_light.py
+650
-0
examples/images/diffusion/ldm/modules/image_degradation/utils/test.png
...es/diffusion/ldm/modules/image_degradation/utils/test.png
+0
-0
examples/images/diffusion/ldm/modules/image_degradation/utils_image.py
...es/diffusion/ldm/modules/image_degradation/utils_image.py
+916
-0
examples/images/diffusion/ldm/modules/losses/__init__.py
examples/images/diffusion/ldm/modules/losses/__init__.py
+1
-0
examples/images/diffusion/ldm/modules/losses/contperceptual.py
...les/images/diffusion/ldm/modules/losses/contperceptual.py
+111
-0
examples/images/diffusion/ldm/modules/losses/vqperceptual.py
examples/images/diffusion/ldm/modules/losses/vqperceptual.py
+167
-0
examples/images/diffusion/ldm/modules/x_transformer.py
examples/images/diffusion/ldm/modules/x_transformer.py
+641
-0
examples/images/diffusion/scripts/tests/test_checkpoint.py
examples/images/diffusion/scripts/tests/test_checkpoint.py
+37
-0
examples/images/diffusion/scripts/tests/test_watermark.py
examples/images/diffusion/scripts/tests/test_watermark.py
+18
-0
No files found.
examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py
0 → 100644
View file @
a7e8159d
# -*- coding: utf-8 -*-
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
import
numpy
as
np
import
cv2
import
torch
from
functools
import
partial
import
random
from
scipy
import
ndimage
import
scipy
import
scipy.stats
as
ss
from
scipy.interpolate
import
interp2d
from
scipy.linalg
import
orth
import
albumentations
import
ldm.modules.image_degradation.utils_image
as
util
def
modcrop_np
(
img
,
sf
):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w
,
h
=
img
.
shape
[:
2
]
im
=
np
.
copy
(
img
)
return
im
[:
w
-
w
%
sf
,
:
h
-
h
%
sf
,
...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def
analytic_kernel
(
k
):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size
=
k
.
shape
[
0
]
# Calculate the big kernels size
big_k
=
np
.
zeros
((
3
*
k_size
-
2
,
3
*
k_size
-
2
))
# Loop over the small kernel to fill the big one
for
r
in
range
(
k_size
):
for
c
in
range
(
k_size
):
big_k
[
2
*
r
:
2
*
r
+
k_size
,
2
*
c
:
2
*
c
+
k_size
]
+=
k
[
r
,
c
]
*
k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop
=
k_size
//
2
cropped_big_k
=
big_k
[
crop
:
-
crop
,
crop
:
-
crop
]
# Normalize to 1
return
cropped_big_k
/
cropped_big_k
.
sum
()
def
anisotropic_Gaussian
(
ksize
=
15
,
theta
=
np
.
pi
,
l1
=
6
,
l2
=
6
):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v
=
np
.
dot
(
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]]),
np
.
array
([
1.
,
0.
]))
V
=
np
.
array
([[
v
[
0
],
v
[
1
]],
[
v
[
1
],
-
v
[
0
]]])
D
=
np
.
array
([[
l1
,
0
],
[
0
,
l2
]])
Sigma
=
np
.
dot
(
np
.
dot
(
V
,
D
),
np
.
linalg
.
inv
(
V
))
k
=
gm_blur_kernel
(
mean
=
[
0
,
0
],
cov
=
Sigma
,
size
=
ksize
)
return
k
def
gm_blur_kernel
(
mean
,
cov
,
size
=
15
):
center
=
size
/
2.0
+
0.5
k
=
np
.
zeros
([
size
,
size
])
for
y
in
range
(
size
):
for
x
in
range
(
size
):
cy
=
y
-
center
+
1
cx
=
x
-
center
+
1
k
[
y
,
x
]
=
ss
.
multivariate_normal
.
pdf
([
cx
,
cy
],
mean
=
mean
,
cov
=
cov
)
k
=
k
/
np
.
sum
(
k
)
return
k
def
shift_pixel
(
x
,
sf
,
upper_left
=
True
):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h
,
w
=
x
.
shape
[:
2
]
shift
=
(
sf
-
1
)
*
0.5
xv
,
yv
=
np
.
arange
(
0
,
w
,
1.0
),
np
.
arange
(
0
,
h
,
1.0
)
if
upper_left
:
x1
=
xv
+
shift
y1
=
yv
+
shift
else
:
x1
=
xv
-
shift
y1
=
yv
-
shift
x1
=
np
.
clip
(
x1
,
0
,
w
-
1
)
y1
=
np
.
clip
(
y1
,
0
,
h
-
1
)
if
x
.
ndim
==
2
:
x
=
interp2d
(
xv
,
yv
,
x
)(
x1
,
y1
)
if
x
.
ndim
==
3
:
for
i
in
range
(
x
.
shape
[
-
1
]):
x
[:,
:,
i
]
=
interp2d
(
xv
,
yv
,
x
[:,
:,
i
])(
x1
,
y1
)
return
x
def
blur
(
x
,
k
):
'''
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
n
,
c
=
x
.
shape
[:
2
]
p1
,
p2
=
(
k
.
shape
[
-
2
]
-
1
)
//
2
,
(
k
.
shape
[
-
1
]
-
1
)
//
2
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
=
(
p1
,
p2
,
p1
,
p2
),
mode
=
'replicate'
)
k
=
k
.
repeat
(
1
,
c
,
1
,
1
)
k
=
k
.
view
(
-
1
,
1
,
k
.
shape
[
2
],
k
.
shape
[
3
])
x
=
x
.
view
(
1
,
-
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
x
=
torch
.
nn
.
functional
.
conv2d
(
x
,
k
,
bias
=
None
,
stride
=
1
,
padding
=
0
,
groups
=
n
*
c
)
x
=
x
.
view
(
n
,
c
,
x
.
shape
[
2
],
x
.
shape
[
3
])
return
x
def
gen_kernel
(
k_size
=
np
.
array
([
15
,
15
]),
scale_factor
=
np
.
array
([
4
,
4
]),
min_var
=
0.6
,
max_var
=
10.
,
noise_level
=
0
):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
lambda_2
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
theta
=
np
.
random
.
rand
()
*
np
.
pi
# random theta
noise
=
-
noise_level
+
np
.
random
.
rand
(
*
k_size
)
*
noise_level
*
2
# Set COV matrix using Lambdas and Theta
LAMBDA
=
np
.
diag
([
lambda_1
,
lambda_2
])
Q
=
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]])
SIGMA
=
Q
@
LAMBDA
@
Q
.
T
INV_SIGMA
=
np
.
linalg
.
inv
(
SIGMA
)[
None
,
None
,
:,
:]
# Set expectation position (shifting kernel for aligned image)
MU
=
k_size
//
2
-
0.5
*
(
scale_factor
-
1
)
# - 0.5 * (scale_factor - k_size % 2)
MU
=
MU
[
None
,
None
,
:,
None
]
# Create meshgrid for Gaussian
[
X
,
Y
]
=
np
.
meshgrid
(
range
(
k_size
[
0
]),
range
(
k_size
[
1
]))
Z
=
np
.
stack
([
X
,
Y
],
2
)[:,
:,
:,
None
]
# Calcualte Gaussian for every pixel of the kernel
ZZ
=
Z
-
MU
ZZ_t
=
ZZ
.
transpose
(
0
,
1
,
3
,
2
)
raw_kernel
=
np
.
exp
(
-
0.5
*
np
.
squeeze
(
ZZ_t
@
INV_SIGMA
@
ZZ
))
*
(
1
+
noise
)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel
=
raw_kernel
/
np
.
sum
(
raw_kernel
)
return
kernel
def
fspecial_gaussian
(
hsize
,
sigma
):
hsize
=
[
hsize
,
hsize
]
siz
=
[(
hsize
[
0
]
-
1.0
)
/
2.0
,
(
hsize
[
1
]
-
1.0
)
/
2.0
]
std
=
sigma
[
x
,
y
]
=
np
.
meshgrid
(
np
.
arange
(
-
siz
[
1
],
siz
[
1
]
+
1
),
np
.
arange
(
-
siz
[
0
],
siz
[
0
]
+
1
))
arg
=
-
(
x
*
x
+
y
*
y
)
/
(
2
*
std
*
std
)
h
=
np
.
exp
(
arg
)
h
[
h
<
scipy
.
finfo
(
float
).
eps
*
h
.
max
()]
=
0
sumh
=
h
.
sum
()
if
sumh
!=
0
:
h
=
h
/
sumh
return
h
def
fspecial_laplacian
(
alpha
):
alpha
=
max
([
0
,
min
([
alpha
,
1
])])
h1
=
alpha
/
(
alpha
+
1
)
h2
=
(
1
-
alpha
)
/
(
alpha
+
1
)
h
=
[[
h1
,
h2
,
h1
],
[
h2
,
-
4
/
(
alpha
+
1
),
h2
],
[
h1
,
h2
,
h1
]]
h
=
np
.
array
(
h
)
return
h
def
fspecial
(
filter_type
,
*
args
,
**
kwargs
):
'''
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
if
filter_type
==
'gaussian'
:
return
fspecial_gaussian
(
*
args
,
**
kwargs
)
if
filter_type
==
'laplacian'
:
return
fspecial_laplacian
(
*
args
,
**
kwargs
)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def
bicubic_degradation
(
x
,
sf
=
3
):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x
=
util
.
imresize_np
(
x
,
scale
=
1
/
sf
)
return
x
def
srmd_degradation
(
x
,
k
,
sf
=
3
):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# 'nearest' | 'mirror'
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
return
x
def
dpsr_degradation
(
x
,
k
,
sf
=
3
):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
return
x
def
classical_degradation
(
x
,
k
,
sf
=
3
):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st
=
0
return
x
[
st
::
sf
,
st
::
sf
,
...]
def
add_sharpening
(
img
,
weight
=
0.5
,
radius
=
50
,
threshold
=
10
):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if
radius
%
2
==
0
:
radius
+=
1
blur
=
cv2
.
GaussianBlur
(
img
,
(
radius
,
radius
),
0
)
residual
=
img
-
blur
mask
=
np
.
abs
(
residual
)
*
255
>
threshold
mask
=
mask
.
astype
(
'float32'
)
soft_mask
=
cv2
.
GaussianBlur
(
mask
,
(
radius
,
radius
),
0
)
K
=
img
+
weight
*
residual
K
=
np
.
clip
(
K
,
0
,
1
)
return
soft_mask
*
K
+
(
1
-
soft_mask
)
*
img
def
add_blur
(
img
,
sf
=
4
):
wd2
=
4.0
+
sf
wd
=
2.0
+
0.2
*
sf
if
random
.
random
()
<
0.5
:
l1
=
wd2
*
random
.
random
()
l2
=
wd2
*
random
.
random
()
k
=
anisotropic_Gaussian
(
ksize
=
2
*
random
.
randint
(
2
,
11
)
+
3
,
theta
=
random
.
random
()
*
np
.
pi
,
l1
=
l1
,
l2
=
l2
)
else
:
k
=
fspecial
(
'gaussian'
,
2
*
random
.
randint
(
2
,
11
)
+
3
,
wd
*
random
.
random
())
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'mirror'
)
return
img
def
add_resize
(
img
,
sf
=
4
):
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.8
:
# up
sf1
=
random
.
uniform
(
1
,
2
)
elif
rnum
<
0.7
:
# down
sf1
=
random
.
uniform
(
0.5
/
sf
,
1
)
else
:
sf1
=
1.0
img
=
cv2
.
resize
(
img
,
(
int
(
sf1
*
img
.
shape
[
1
]),
int
(
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.6
:
# add color Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
# add grayscale Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
# add noise
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
=
img
+
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_speckle_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
rnum
=
random
.
random
()
if
rnum
>
0.6
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
+=
img
*
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_Poisson_noise
(
img
):
img
=
np
.
clip
((
img
*
255.0
).
round
(),
0
,
255
)
/
255.
vals
=
10
**
(
2
*
random
.
random
()
+
2.0
)
# [2, 4]
if
random
.
random
()
<
0.5
:
img
=
np
.
random
.
poisson
(
img
*
vals
).
astype
(
np
.
float32
)
/
vals
else
:
img_gray
=
np
.
dot
(
img
[...,
:
3
],
[
0.299
,
0.587
,
0.114
])
img_gray
=
np
.
clip
((
img_gray
*
255.0
).
round
(),
0
,
255
)
/
255.
noise_gray
=
np
.
random
.
poisson
(
img_gray
*
vals
).
astype
(
np
.
float32
)
/
vals
-
img_gray
img
+=
noise_gray
[:,
:,
np
.
newaxis
]
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_JPEG_noise
(
img
):
quality_factor
=
random
.
randint
(
30
,
95
)
img
=
cv2
.
cvtColor
(
util
.
single2uint
(
img
),
cv2
.
COLOR_RGB2BGR
)
result
,
encimg
=
cv2
.
imencode
(
'.jpg'
,
img
,
[
int
(
cv2
.
IMWRITE_JPEG_QUALITY
),
quality_factor
])
img
=
cv2
.
imdecode
(
encimg
,
1
)
img
=
cv2
.
cvtColor
(
util
.
uint2single
(
img
),
cv2
.
COLOR_BGR2RGB
)
return
img
def
random_crop
(
lq
,
hq
,
sf
=
4
,
lq_patchsize
=
64
):
h
,
w
=
lq
.
shape
[:
2
]
rnd_h
=
random
.
randint
(
0
,
h
-
lq_patchsize
)
rnd_w
=
random
.
randint
(
0
,
w
-
lq_patchsize
)
lq
=
lq
[
rnd_h
:
rnd_h
+
lq_patchsize
,
rnd_w
:
rnd_w
+
lq_patchsize
,
:]
rnd_h_H
,
rnd_w_H
=
int
(
rnd_h
*
sf
),
int
(
rnd_w
*
sf
)
hq
=
hq
[
rnd_h_H
:
rnd_h_H
+
lq_patchsize
*
sf
,
rnd_w_H
:
rnd_w_H
+
lq_patchsize
*
sf
,
:]
return
lq
,
hq
def
degradation_bsrgan
(
img
,
sf
=
4
,
lq_patchsize
=
72
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
img
.
shape
[:
2
]
img
=
img
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
img
.
shape
[:
2
]
if
h
<
lq_patchsize
*
sf
or
w
<
lq_patchsize
*
sf
:
raise
ValueError
(
f
'img size (
{
h1
}
X
{
w1
}
) is too small!'
)
hq
=
img
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
2
*
img
.
shape
[
1
]),
int
(
1
/
2
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
img
=
util
.
imresize_np
(
img
,
1
/
2
,
True
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
1
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
2
:
a
,
b
=
img
.
shape
[
1
],
img
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.75
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf1
*
img
.
shape
[
1
]),
int
(
1
/
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
img
=
img
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
img
=
add_JPEG_noise
(
img
)
elif
i
==
6
:
# add processed camera sensor noise
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
# add final JPEG compression noise
img
=
add_JPEG_noise
(
img
)
# random crop
img
,
hq
=
random_crop
(
img
,
hq
,
sf_ori
,
lq_patchsize
)
return
img
,
hq
# todo no isp_model?
def
degradation_bsrgan_variant
(
image
,
sf
=
4
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image
=
util
.
uint2single
(
image
)
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
image
.
shape
[:
2
]
image
=
image
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
image
.
shape
[:
2
]
hq
=
image
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
2
*
image
.
shape
[
1
]),
int
(
1
/
2
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
image
=
util
.
imresize_np
(
image
,
1
/
2
,
True
)
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
image
=
add_blur
(
image
,
sf
=
sf
)
elif
i
==
1
:
image
=
add_blur
(
image
,
sf
=
sf
)
elif
i
==
2
:
a
,
b
=
image
.
shape
[
1
],
image
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.75
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf1
*
image
.
shape
[
1
]),
int
(
1
/
sf1
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
image
=
ndimage
.
filters
.
convolve
(
image
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
image
=
image
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
image
=
add_Gaussian_noise
(
image
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
image
=
add_JPEG_noise
(
image
)
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image
=
add_JPEG_noise
(
image
)
image
=
util
.
single2uint
(
image
)
example
=
{
"image"
:
image
}
return
example
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
def
degradation_bsrgan_plus
(
img
,
sf
=
4
,
shuffle_prob
=
0.5
,
use_sharp
=
True
,
lq_patchsize
=
64
,
isp_model
=
None
):
"""
This is an extended degradation model by combining
the degradation models of BSRGAN and Real-ESRGAN
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
use_shuffle: the degradation shuffle
use_sharp: sharpening the img
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
h1
,
w1
=
img
.
shape
[:
2
]
img
=
img
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
img
.
shape
[:
2
]
if
h
<
lq_patchsize
*
sf
or
w
<
lq_patchsize
*
sf
:
raise
ValueError
(
f
'img size (
{
h1
}
X
{
w1
}
) is too small!'
)
if
use_sharp
:
img
=
add_sharpening
(
img
)
hq
=
img
.
copy
()
if
random
.
random
()
<
shuffle_prob
:
shuffle_order
=
random
.
sample
(
range
(
13
),
13
)
else
:
shuffle_order
=
list
(
range
(
13
))
# local shuffle for noise, JPEG is always the last one
shuffle_order
[
2
:
6
]
=
random
.
sample
(
shuffle_order
[
2
:
6
],
len
(
range
(
2
,
6
)))
shuffle_order
[
9
:
13
]
=
random
.
sample
(
shuffle_order
[
9
:
13
],
len
(
range
(
9
,
13
)))
poisson_prob
,
speckle_prob
,
isp_prob
=
0.1
,
0.1
,
0.1
for
i
in
shuffle_order
:
if
i
==
0
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
1
:
img
=
add_resize
(
img
,
sf
=
sf
)
elif
i
==
2
:
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
3
:
if
random
.
random
()
<
poisson_prob
:
img
=
add_Poisson_noise
(
img
)
elif
i
==
4
:
if
random
.
random
()
<
speckle_prob
:
img
=
add_speckle_noise
(
img
)
elif
i
==
5
:
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
elif
i
==
6
:
img
=
add_JPEG_noise
(
img
)
elif
i
==
7
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
8
:
img
=
add_resize
(
img
,
sf
=
sf
)
elif
i
==
9
:
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
10
:
if
random
.
random
()
<
poisson_prob
:
img
=
add_Poisson_noise
(
img
)
elif
i
==
11
:
if
random
.
random
()
<
speckle_prob
:
img
=
add_speckle_noise
(
img
)
elif
i
==
12
:
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
else
:
print
(
'check the shuffle!'
)
# resize to desired size
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf
*
hq
.
shape
[
1
]),
int
(
1
/
sf
*
hq
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
# add final JPEG compression noise
img
=
add_JPEG_noise
(
img
)
# random crop
img
,
hq
=
random_crop
(
img
,
hq
,
sf
,
lq_patchsize
)
return
img
,
hq
if
__name__
==
'__main__'
:
print
(
"hey"
)
img
=
util
.
imread_uint
(
'utils/test.png'
,
3
)
print
(
img
)
img
=
util
.
uint2single
(
img
)
print
(
img
)
img
=
img
[:
448
,
:
448
]
h
=
img
.
shape
[
0
]
//
4
print
(
"resizing to"
,
h
)
sf
=
4
deg_fn
=
partial
(
degradation_bsrgan_variant
,
sf
=
sf
)
for
i
in
range
(
20
):
print
(
i
)
img_lq
=
deg_fn
(
img
)
print
(
img_lq
)
img_lq_bicubic
=
albumentations
.
SmallestMaxSize
(
max_size
=
h
,
interpolation
=
cv2
.
INTER_CUBIC
)(
image
=
img
)[
"image"
]
print
(
img_lq
.
shape
)
print
(
"bicubic"
,
img_lq_bicubic
.
shape
)
print
(
img_hq
.
shape
)
lq_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
lq_bicubic_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq_bicubic
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
img_concat
=
np
.
concatenate
([
lq_bicubic_nearest
,
lq_nearest
,
util
.
single2uint
(
img_hq
)],
axis
=
1
)
util
.
imsave
(
img_concat
,
str
(
i
)
+
'.png'
)
examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py
0 → 100644
View file @
a7e8159d
# -*- coding: utf-8 -*-
import
numpy
as
np
import
cv2
import
torch
from
functools
import
partial
import
random
from
scipy
import
ndimage
import
scipy
import
scipy.stats
as
ss
from
scipy.interpolate
import
interp2d
from
scipy.linalg
import
orth
import
albumentations
import
ldm.modules.image_degradation.utils_image
as
util
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
def
modcrop_np
(
img
,
sf
):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w
,
h
=
img
.
shape
[:
2
]
im
=
np
.
copy
(
img
)
return
im
[:
w
-
w
%
sf
,
:
h
-
h
%
sf
,
...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def
analytic_kernel
(
k
):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size
=
k
.
shape
[
0
]
# Calculate the big kernels size
big_k
=
np
.
zeros
((
3
*
k_size
-
2
,
3
*
k_size
-
2
))
# Loop over the small kernel to fill the big one
for
r
in
range
(
k_size
):
for
c
in
range
(
k_size
):
big_k
[
2
*
r
:
2
*
r
+
k_size
,
2
*
c
:
2
*
c
+
k_size
]
+=
k
[
r
,
c
]
*
k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop
=
k_size
//
2
cropped_big_k
=
big_k
[
crop
:
-
crop
,
crop
:
-
crop
]
# Normalize to 1
return
cropped_big_k
/
cropped_big_k
.
sum
()
def
anisotropic_Gaussian
(
ksize
=
15
,
theta
=
np
.
pi
,
l1
=
6
,
l2
=
6
):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v
=
np
.
dot
(
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]]),
np
.
array
([
1.
,
0.
]))
V
=
np
.
array
([[
v
[
0
],
v
[
1
]],
[
v
[
1
],
-
v
[
0
]]])
D
=
np
.
array
([[
l1
,
0
],
[
0
,
l2
]])
Sigma
=
np
.
dot
(
np
.
dot
(
V
,
D
),
np
.
linalg
.
inv
(
V
))
k
=
gm_blur_kernel
(
mean
=
[
0
,
0
],
cov
=
Sigma
,
size
=
ksize
)
return
k
def
gm_blur_kernel
(
mean
,
cov
,
size
=
15
):
center
=
size
/
2.0
+
0.5
k
=
np
.
zeros
([
size
,
size
])
for
y
in
range
(
size
):
for
x
in
range
(
size
):
cy
=
y
-
center
+
1
cx
=
x
-
center
+
1
k
[
y
,
x
]
=
ss
.
multivariate_normal
.
pdf
([
cx
,
cy
],
mean
=
mean
,
cov
=
cov
)
k
=
k
/
np
.
sum
(
k
)
return
k
def
shift_pixel
(
x
,
sf
,
upper_left
=
True
):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h
,
w
=
x
.
shape
[:
2
]
shift
=
(
sf
-
1
)
*
0.5
xv
,
yv
=
np
.
arange
(
0
,
w
,
1.0
),
np
.
arange
(
0
,
h
,
1.0
)
if
upper_left
:
x1
=
xv
+
shift
y1
=
yv
+
shift
else
:
x1
=
xv
-
shift
y1
=
yv
-
shift
x1
=
np
.
clip
(
x1
,
0
,
w
-
1
)
y1
=
np
.
clip
(
y1
,
0
,
h
-
1
)
if
x
.
ndim
==
2
:
x
=
interp2d
(
xv
,
yv
,
x
)(
x1
,
y1
)
if
x
.
ndim
==
3
:
for
i
in
range
(
x
.
shape
[
-
1
]):
x
[:,
:,
i
]
=
interp2d
(
xv
,
yv
,
x
[:,
:,
i
])(
x1
,
y1
)
return
x
def
blur
(
x
,
k
):
'''
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
n
,
c
=
x
.
shape
[:
2
]
p1
,
p2
=
(
k
.
shape
[
-
2
]
-
1
)
//
2
,
(
k
.
shape
[
-
1
]
-
1
)
//
2
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
=
(
p1
,
p2
,
p1
,
p2
),
mode
=
'replicate'
)
k
=
k
.
repeat
(
1
,
c
,
1
,
1
)
k
=
k
.
view
(
-
1
,
1
,
k
.
shape
[
2
],
k
.
shape
[
3
])
x
=
x
.
view
(
1
,
-
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
x
=
torch
.
nn
.
functional
.
conv2d
(
x
,
k
,
bias
=
None
,
stride
=
1
,
padding
=
0
,
groups
=
n
*
c
)
x
=
x
.
view
(
n
,
c
,
x
.
shape
[
2
],
x
.
shape
[
3
])
return
x
def
gen_kernel
(
k_size
=
np
.
array
([
15
,
15
]),
scale_factor
=
np
.
array
([
4
,
4
]),
min_var
=
0.6
,
max_var
=
10.
,
noise_level
=
0
):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
lambda_2
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
theta
=
np
.
random
.
rand
()
*
np
.
pi
# random theta
noise
=
-
noise_level
+
np
.
random
.
rand
(
*
k_size
)
*
noise_level
*
2
# Set COV matrix using Lambdas and Theta
LAMBDA
=
np
.
diag
([
lambda_1
,
lambda_2
])
Q
=
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]])
SIGMA
=
Q
@
LAMBDA
@
Q
.
T
INV_SIGMA
=
np
.
linalg
.
inv
(
SIGMA
)[
None
,
None
,
:,
:]
# Set expectation position (shifting kernel for aligned image)
MU
=
k_size
//
2
-
0.5
*
(
scale_factor
-
1
)
# - 0.5 * (scale_factor - k_size % 2)
MU
=
MU
[
None
,
None
,
:,
None
]
# Create meshgrid for Gaussian
[
X
,
Y
]
=
np
.
meshgrid
(
range
(
k_size
[
0
]),
range
(
k_size
[
1
]))
Z
=
np
.
stack
([
X
,
Y
],
2
)[:,
:,
:,
None
]
# Calcualte Gaussian for every pixel of the kernel
ZZ
=
Z
-
MU
ZZ_t
=
ZZ
.
transpose
(
0
,
1
,
3
,
2
)
raw_kernel
=
np
.
exp
(
-
0.5
*
np
.
squeeze
(
ZZ_t
@
INV_SIGMA
@
ZZ
))
*
(
1
+
noise
)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel
=
raw_kernel
/
np
.
sum
(
raw_kernel
)
return
kernel
def
fspecial_gaussian
(
hsize
,
sigma
):
hsize
=
[
hsize
,
hsize
]
siz
=
[(
hsize
[
0
]
-
1.0
)
/
2.0
,
(
hsize
[
1
]
-
1.0
)
/
2.0
]
std
=
sigma
[
x
,
y
]
=
np
.
meshgrid
(
np
.
arange
(
-
siz
[
1
],
siz
[
1
]
+
1
),
np
.
arange
(
-
siz
[
0
],
siz
[
0
]
+
1
))
arg
=
-
(
x
*
x
+
y
*
y
)
/
(
2
*
std
*
std
)
h
=
np
.
exp
(
arg
)
h
[
h
<
scipy
.
finfo
(
float
).
eps
*
h
.
max
()]
=
0
sumh
=
h
.
sum
()
if
sumh
!=
0
:
h
=
h
/
sumh
return
h
def
fspecial_laplacian
(
alpha
):
alpha
=
max
([
0
,
min
([
alpha
,
1
])])
h1
=
alpha
/
(
alpha
+
1
)
h2
=
(
1
-
alpha
)
/
(
alpha
+
1
)
h
=
[[
h1
,
h2
,
h1
],
[
h2
,
-
4
/
(
alpha
+
1
),
h2
],
[
h1
,
h2
,
h1
]]
h
=
np
.
array
(
h
)
return
h
def
fspecial
(
filter_type
,
*
args
,
**
kwargs
):
'''
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
if
filter_type
==
'gaussian'
:
return
fspecial_gaussian
(
*
args
,
**
kwargs
)
if
filter_type
==
'laplacian'
:
return
fspecial_laplacian
(
*
args
,
**
kwargs
)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def
bicubic_degradation
(
x
,
sf
=
3
):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x
=
util
.
imresize_np
(
x
,
scale
=
1
/
sf
)
return
x
def
srmd_degradation
(
x
,
k
,
sf
=
3
):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# 'nearest' | 'mirror'
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
return
x
def
dpsr_degradation
(
x
,
k
,
sf
=
3
):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
return
x
def
classical_degradation
(
x
,
k
,
sf
=
3
):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st
=
0
return
x
[
st
::
sf
,
st
::
sf
,
...]
def
add_sharpening
(
img
,
weight
=
0.5
,
radius
=
50
,
threshold
=
10
):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if
radius
%
2
==
0
:
radius
+=
1
blur
=
cv2
.
GaussianBlur
(
img
,
(
radius
,
radius
),
0
)
residual
=
img
-
blur
mask
=
np
.
abs
(
residual
)
*
255
>
threshold
mask
=
mask
.
astype
(
'float32'
)
soft_mask
=
cv2
.
GaussianBlur
(
mask
,
(
radius
,
radius
),
0
)
K
=
img
+
weight
*
residual
K
=
np
.
clip
(
K
,
0
,
1
)
return
soft_mask
*
K
+
(
1
-
soft_mask
)
*
img
def
add_blur
(
img
,
sf
=
4
):
wd2
=
4.0
+
sf
wd
=
2.0
+
0.2
*
sf
wd2
=
wd2
/
4
wd
=
wd
/
4
if
random
.
random
()
<
0.5
:
l1
=
wd2
*
random
.
random
()
l2
=
wd2
*
random
.
random
()
k
=
anisotropic_Gaussian
(
ksize
=
random
.
randint
(
2
,
11
)
+
3
,
theta
=
random
.
random
()
*
np
.
pi
,
l1
=
l1
,
l2
=
l2
)
else
:
k
=
fspecial
(
'gaussian'
,
random
.
randint
(
2
,
4
)
+
3
,
wd
*
random
.
random
())
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'mirror'
)
return
img
def
add_resize
(
img
,
sf
=
4
):
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.8
:
# up
sf1
=
random
.
uniform
(
1
,
2
)
elif
rnum
<
0.7
:
# down
sf1
=
random
.
uniform
(
0.5
/
sf
,
1
)
else
:
sf1
=
1.0
img
=
cv2
.
resize
(
img
,
(
int
(
sf1
*
img
.
shape
[
1
]),
int
(
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.6
:
# add color Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
# add grayscale Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
# add noise
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
=
img
+
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_speckle_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
rnum
=
random
.
random
()
if
rnum
>
0.6
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
+=
img
*
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_Poisson_noise
(
img
):
img
=
np
.
clip
((
img
*
255.0
).
round
(),
0
,
255
)
/
255.
vals
=
10
**
(
2
*
random
.
random
()
+
2.0
)
# [2, 4]
if
random
.
random
()
<
0.5
:
img
=
np
.
random
.
poisson
(
img
*
vals
).
astype
(
np
.
float32
)
/
vals
else
:
img_gray
=
np
.
dot
(
img
[...,
:
3
],
[
0.299
,
0.587
,
0.114
])
img_gray
=
np
.
clip
((
img_gray
*
255.0
).
round
(),
0
,
255
)
/
255.
noise_gray
=
np
.
random
.
poisson
(
img_gray
*
vals
).
astype
(
np
.
float32
)
/
vals
-
img_gray
img
+=
noise_gray
[:,
:,
np
.
newaxis
]
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_JPEG_noise
(
img
):
quality_factor
=
random
.
randint
(
80
,
95
)
img
=
cv2
.
cvtColor
(
util
.
single2uint
(
img
),
cv2
.
COLOR_RGB2BGR
)
result
,
encimg
=
cv2
.
imencode
(
'.jpg'
,
img
,
[
int
(
cv2
.
IMWRITE_JPEG_QUALITY
),
quality_factor
])
img
=
cv2
.
imdecode
(
encimg
,
1
)
img
=
cv2
.
cvtColor
(
util
.
uint2single
(
img
),
cv2
.
COLOR_BGR2RGB
)
return
img
def
random_crop
(
lq
,
hq
,
sf
=
4
,
lq_patchsize
=
64
):
h
,
w
=
lq
.
shape
[:
2
]
rnd_h
=
random
.
randint
(
0
,
h
-
lq_patchsize
)
rnd_w
=
random
.
randint
(
0
,
w
-
lq_patchsize
)
lq
=
lq
[
rnd_h
:
rnd_h
+
lq_patchsize
,
rnd_w
:
rnd_w
+
lq_patchsize
,
:]
rnd_h_H
,
rnd_w_H
=
int
(
rnd_h
*
sf
),
int
(
rnd_w
*
sf
)
hq
=
hq
[
rnd_h_H
:
rnd_h_H
+
lq_patchsize
*
sf
,
rnd_w_H
:
rnd_w_H
+
lq_patchsize
*
sf
,
:]
return
lq
,
hq
def
degradation_bsrgan
(
img
,
sf
=
4
,
lq_patchsize
=
72
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
img
.
shape
[:
2
]
img
=
img
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
img
.
shape
[:
2
]
if
h
<
lq_patchsize
*
sf
or
w
<
lq_patchsize
*
sf
:
raise
ValueError
(
f
'img size (
{
h1
}
X
{
w1
}
) is too small!'
)
hq
=
img
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
2
*
img
.
shape
[
1
]),
int
(
1
/
2
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
img
=
util
.
imresize_np
(
img
,
1
/
2
,
True
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
1
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
2
:
a
,
b
=
img
.
shape
[
1
],
img
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.75
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf1
*
img
.
shape
[
1
]),
int
(
1
/
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
img
=
img
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
8
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
img
=
add_JPEG_noise
(
img
)
elif
i
==
6
:
# add processed camera sensor noise
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
# add final JPEG compression noise
img
=
add_JPEG_noise
(
img
)
# random crop
img
,
hq
=
random_crop
(
img
,
hq
,
sf_ori
,
lq_patchsize
)
return
img
,
hq
# todo no isp_model?
def
degradation_bsrgan_variant
(
image
,
sf
=
4
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image
=
util
.
uint2single
(
image
)
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
image
.
shape
[:
2
]
image
=
image
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
image
.
shape
[:
2
]
hq
=
image
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
2
*
image
.
shape
[
1
]),
int
(
1
/
2
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
image
=
util
.
imresize_np
(
image
,
1
/
2
,
True
)
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
image
=
add_blur
(
image
,
sf
=
sf
)
# elif i == 1:
# image = add_blur(image, sf=sf)
if
i
==
0
:
pass
elif
i
==
2
:
a
,
b
=
image
.
shape
[
1
],
image
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.8
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf1
*
image
.
shape
[
1
]),
int
(
1
/
sf1
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
image
=
ndimage
.
filters
.
convolve
(
image
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
image
=
image
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
image
=
add_Gaussian_noise
(
image
,
noise_level1
=
1
,
noise_level2
=
2
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
image
=
add_JPEG_noise
(
image
)
#
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image
=
add_JPEG_noise
(
image
)
image
=
util
.
single2uint
(
image
)
example
=
{
"image"
:
image
}
return
example
if
__name__
==
'__main__'
:
print
(
"hey"
)
img
=
util
.
imread_uint
(
'utils/test.png'
,
3
)
img
=
img
[:
448
,
:
448
]
h
=
img
.
shape
[
0
]
//
4
print
(
"resizing to"
,
h
)
sf
=
4
deg_fn
=
partial
(
degradation_bsrgan_variant
,
sf
=
sf
)
for
i
in
range
(
20
):
print
(
i
)
img_hq
=
img
img_lq
=
deg_fn
(
img
)[
"image"
]
img_hq
,
img_lq
=
util
.
uint2single
(
img_hq
),
util
.
uint2single
(
img_lq
)
print
(
img_lq
)
img_lq_bicubic
=
albumentations
.
SmallestMaxSize
(
max_size
=
h
,
interpolation
=
cv2
.
INTER_CUBIC
)(
image
=
img_hq
)[
"image"
]
print
(
img_lq
.
shape
)
print
(
"bicubic"
,
img_lq_bicubic
.
shape
)
print
(
img_hq
.
shape
)
lq_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
lq_bicubic_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq_bicubic
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
img_concat
=
np
.
concatenate
([
lq_bicubic_nearest
,
lq_nearest
,
util
.
single2uint
(
img_hq
)],
axis
=
1
)
util
.
imsave
(
img_concat
,
str
(
i
)
+
'.png'
)
examples/images/diffusion/ldm/modules/image_degradation/utils/test.png
0 → 100644
View file @
a7e8159d
431 KB
examples/images/diffusion/ldm/modules/image_degradation/utils_image.py
0 → 100644
View file @
a7e8159d
import
os
import
math
import
random
import
numpy
as
np
import
torch
import
cv2
from
torchvision.utils
import
make_grid
from
datetime
import
datetime
#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
os
.
environ
[
"KMP_DUPLICATE_LIB_OK"
]
=
"TRUE"
'''
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# --------------------------------------------
# https://github.com/twhui/SRGAN-pyTorch
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
IMG_EXTENSIONS
=
[
'.jpg'
,
'.JPG'
,
'.jpeg'
,
'.JPEG'
,
'.png'
,
'.PNG'
,
'.ppm'
,
'.PPM'
,
'.bmp'
,
'.BMP'
,
'.tif'
]
def
is_image_file
(
filename
):
return
any
(
filename
.
endswith
(
extension
)
for
extension
in
IMG_EXTENSIONS
)
def
get_timestamp
():
return
datetime
.
now
().
strftime
(
'%y%m%d-%H%M%S'
)
def
imshow
(
x
,
title
=
None
,
cbar
=
False
,
figsize
=
None
):
plt
.
figure
(
figsize
=
figsize
)
plt
.
imshow
(
np
.
squeeze
(
x
),
interpolation
=
'nearest'
,
cmap
=
'gray'
)
if
title
:
plt
.
title
(
title
)
if
cbar
:
plt
.
colorbar
()
plt
.
show
()
def
surf
(
Z
,
cmap
=
'rainbow'
,
figsize
=
None
):
plt
.
figure
(
figsize
=
figsize
)
ax3
=
plt
.
axes
(
projection
=
'3d'
)
w
,
h
=
Z
.
shape
[:
2
]
xx
=
np
.
arange
(
0
,
w
,
1
)
yy
=
np
.
arange
(
0
,
h
,
1
)
X
,
Y
=
np
.
meshgrid
(
xx
,
yy
)
ax3
.
plot_surface
(
X
,
Y
,
Z
,
cmap
=
cmap
)
#ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
plt
.
show
()
'''
# --------------------------------------------
# get image pathes
# --------------------------------------------
'''
def
get_image_paths
(
dataroot
):
paths
=
None
# return None if dataroot is None
if
dataroot
is
not
None
:
paths
=
sorted
(
_get_paths_from_images
(
dataroot
))
return
paths
def
_get_paths_from_images
(
path
):
assert
os
.
path
.
isdir
(
path
),
'{:s} is not a valid directory'
.
format
(
path
)
images
=
[]
for
dirpath
,
_
,
fnames
in
sorted
(
os
.
walk
(
path
)):
for
fname
in
sorted
(
fnames
):
if
is_image_file
(
fname
):
img_path
=
os
.
path
.
join
(
dirpath
,
fname
)
images
.
append
(
img_path
)
assert
images
,
'{:s} has no valid image file'
.
format
(
path
)
return
images
'''
# --------------------------------------------
# split large images into small images
# --------------------------------------------
'''
def
patches_from_image
(
img
,
p_size
=
512
,
p_overlap
=
64
,
p_max
=
800
):
w
,
h
=
img
.
shape
[:
2
]
patches
=
[]
if
w
>
p_max
and
h
>
p_max
:
w1
=
list
(
np
.
arange
(
0
,
w
-
p_size
,
p_size
-
p_overlap
,
dtype
=
np
.
int
))
h1
=
list
(
np
.
arange
(
0
,
h
-
p_size
,
p_size
-
p_overlap
,
dtype
=
np
.
int
))
w1
.
append
(
w
-
p_size
)
h1
.
append
(
h
-
p_size
)
# print(w1)
# print(h1)
for
i
in
w1
:
for
j
in
h1
:
patches
.
append
(
img
[
i
:
i
+
p_size
,
j
:
j
+
p_size
,:])
else
:
patches
.
append
(
img
)
return
patches
def
imssave
(
imgs
,
img_path
):
"""
imgs: list, N images of size WxHxC
"""
img_name
,
ext
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))
for
i
,
img
in
enumerate
(
imgs
):
if
img
.
ndim
==
3
:
img
=
img
[:,
:,
[
2
,
1
,
0
]]
new_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
img_path
),
img_name
+
str
(
'_s{:04d}'
.
format
(
i
))
+
'.png'
)
cv2
.
imwrite
(
new_path
,
img
)
def
split_imageset
(
original_dataroot
,
taget_dataroot
,
n_channels
=
3
,
p_size
=
800
,
p_overlap
=
96
,
p_max
=
1000
):
"""
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
will be splitted.
Args:
original_dataroot:
taget_dataroot:
p_size: size of small images
p_overlap: patch size in training is a good choice
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
"""
paths
=
get_image_paths
(
original_dataroot
)
for
img_path
in
paths
:
# img_name, ext = os.path.splitext(os.path.basename(img_path))
img
=
imread_uint
(
img_path
,
n_channels
=
n_channels
)
patches
=
patches_from_image
(
img
,
p_size
,
p_overlap
,
p_max
)
imssave
(
patches
,
os
.
path
.
join
(
taget_dataroot
,
os
.
path
.
basename
(
img_path
)))
#if original_dataroot == taget_dataroot:
#del img_path
'''
# --------------------------------------------
# makedir
# --------------------------------------------
'''
def
mkdir
(
path
):
if
not
os
.
path
.
exists
(
path
):
os
.
makedirs
(
path
)
def
mkdirs
(
paths
):
if
isinstance
(
paths
,
str
):
mkdir
(
paths
)
else
:
for
path
in
paths
:
mkdir
(
path
)
def
mkdir_and_rename
(
path
):
if
os
.
path
.
exists
(
path
):
new_name
=
path
+
'_archived_'
+
get_timestamp
()
print
(
'Path already exists. Rename it to [{:s}]'
.
format
(
new_name
))
os
.
rename
(
path
,
new_name
)
os
.
makedirs
(
path
)
'''
# --------------------------------------------
# read image from path
# opencv is fast, but read BGR numpy image
# --------------------------------------------
'''
# --------------------------------------------
# get uint8 image of size HxWxn_channles (RGB)
# --------------------------------------------
def
imread_uint
(
path
,
n_channels
=
3
):
# input: path
# output: HxWx3(RGB or GGG), or HxWx1 (G)
if
n_channels
==
1
:
img
=
cv2
.
imread
(
path
,
0
)
# cv2.IMREAD_GRAYSCALE
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
# HxWx1
elif
n_channels
==
3
:
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
# BGR or G
if
img
.
ndim
==
2
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2RGB
)
# GGG
else
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
# RGB
return
img
# --------------------------------------------
# matlab's imwrite
# --------------------------------------------
def
imsave
(
img
,
img_path
):
img
=
np
.
squeeze
(
img
)
if
img
.
ndim
==
3
:
img
=
img
[:,
:,
[
2
,
1
,
0
]]
cv2
.
imwrite
(
img_path
,
img
)
def
imwrite
(
img
,
img_path
):
img
=
np
.
squeeze
(
img
)
if
img
.
ndim
==
3
:
img
=
img
[:,
:,
[
2
,
1
,
0
]]
cv2
.
imwrite
(
img_path
,
img
)
# --------------------------------------------
# get single image of size HxWxn_channles (BGR)
# --------------------------------------------
def
read_img
(
path
):
# read image by cv2
# return: Numpy float32, HWC, BGR, [0,1]
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
# cv2.IMREAD_GRAYSCALE
img
=
img
.
astype
(
np
.
float32
)
/
255.
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
# some images have 4 channels
if
img
.
shape
[
2
]
>
3
:
img
=
img
[:,
:,
:
3
]
return
img
'''
# --------------------------------------------
# image format conversion
# --------------------------------------------
# numpy(single) <---> numpy(unit)
# numpy(single) <---> tensor
# numpy(unit) <---> tensor
# --------------------------------------------
'''
# --------------------------------------------
# numpy(single) [0, 1] <---> numpy(unit)
# --------------------------------------------
def
uint2single
(
img
):
return
np
.
float32
(
img
/
255.
)
def
single2uint
(
img
):
return
np
.
uint8
((
img
.
clip
(
0
,
1
)
*
255.
).
round
())
def
uint162single
(
img
):
return
np
.
float32
(
img
/
65535.
)
def
single2uint16
(
img
):
return
np
.
uint16
((
img
.
clip
(
0
,
1
)
*
65535.
).
round
())
# --------------------------------------------
# numpy(unit) (HxWxC or HxW) <---> tensor
# --------------------------------------------
# convert uint to 4-dimensional torch tensor
def
uint2tensor4
(
img
):
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
().
div
(
255.
).
unsqueeze
(
0
)
# convert uint to 3-dimensional torch tensor
def
uint2tensor3
(
img
):
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
().
div
(
255.
)
# convert 2/3/4-dimensional torch tensor to uint
def
tensor2uint
(
img
):
img
=
img
.
data
.
squeeze
().
float
().
clamp_
(
0
,
1
).
cpu
().
numpy
()
if
img
.
ndim
==
3
:
img
=
np
.
transpose
(
img
,
(
1
,
2
,
0
))
return
np
.
uint8
((
img
*
255.0
).
round
())
# --------------------------------------------
# numpy(single) (HxWxC) <---> tensor
# --------------------------------------------
# convert single (HxWxC) to 3-dimensional torch tensor
def
single2tensor3
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
()
# convert single (HxWxC) to 4-dimensional torch tensor
def
single2tensor4
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
().
unsqueeze
(
0
)
# convert torch tensor to single
def
tensor2single
(
img
):
img
=
img
.
data
.
squeeze
().
float
().
cpu
().
numpy
()
if
img
.
ndim
==
3
:
img
=
np
.
transpose
(
img
,
(
1
,
2
,
0
))
return
img
# convert torch tensor to single
def
tensor2single3
(
img
):
img
=
img
.
data
.
squeeze
().
float
().
cpu
().
numpy
()
if
img
.
ndim
==
3
:
img
=
np
.
transpose
(
img
,
(
1
,
2
,
0
))
elif
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
return
img
def
single2tensor5
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
,
3
).
float
().
unsqueeze
(
0
)
def
single32tensor5
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
float
().
unsqueeze
(
0
).
unsqueeze
(
0
)
def
single42tensor4
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
,
3
).
float
()
# from skimage.io import imread, imsave
def
tensor2img
(
tensor
,
out_type
=
np
.
uint8
,
min_max
=
(
0
,
1
)):
'''
Converts a torch Tensor into an image Numpy array of BGR channel order
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
'''
tensor
=
tensor
.
squeeze
().
float
().
cpu
().
clamp_
(
*
min_max
)
# squeeze first, then clamp
tensor
=
(
tensor
-
min_max
[
0
])
/
(
min_max
[
1
]
-
min_max
[
0
])
# to range [0,1]
n_dim
=
tensor
.
dim
()
if
n_dim
==
4
:
n_img
=
len
(
tensor
)
img_np
=
make_grid
(
tensor
,
nrow
=
int
(
math
.
sqrt
(
n_img
)),
normalize
=
False
).
numpy
()
img_np
=
np
.
transpose
(
img_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
# HWC, BGR
elif
n_dim
==
3
:
img_np
=
tensor
.
numpy
()
img_np
=
np
.
transpose
(
img_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
# HWC, BGR
elif
n_dim
==
2
:
img_np
=
tensor
.
numpy
()
else
:
raise
TypeError
(
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'
.
format
(
n_dim
))
if
out_type
==
np
.
uint8
:
img_np
=
(
img_np
*
255.0
).
round
()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return
img_np
.
astype
(
out_type
)
'''
# --------------------------------------------
# Augmentation, flipe and/or rotate
# --------------------------------------------
# The following two are enough.
# (1) augmet_img: numpy image of WxHxC or WxH
# (2) augment_img_tensor4: tensor image 1xCxWxH
# --------------------------------------------
'''
def
augment_img
(
img
,
mode
=
0
):
'''Kai Zhang (github: https://github.com/cszn)
'''
if
mode
==
0
:
return
img
elif
mode
==
1
:
return
np
.
flipud
(
np
.
rot90
(
img
))
elif
mode
==
2
:
return
np
.
flipud
(
img
)
elif
mode
==
3
:
return
np
.
rot90
(
img
,
k
=
3
)
elif
mode
==
4
:
return
np
.
flipud
(
np
.
rot90
(
img
,
k
=
2
))
elif
mode
==
5
:
return
np
.
rot90
(
img
)
elif
mode
==
6
:
return
np
.
rot90
(
img
,
k
=
2
)
elif
mode
==
7
:
return
np
.
flipud
(
np
.
rot90
(
img
,
k
=
3
))
def
augment_img_tensor4
(
img
,
mode
=
0
):
'''Kai Zhang (github: https://github.com/cszn)
'''
if
mode
==
0
:
return
img
elif
mode
==
1
:
return
img
.
rot90
(
1
,
[
2
,
3
]).
flip
([
2
])
elif
mode
==
2
:
return
img
.
flip
([
2
])
elif
mode
==
3
:
return
img
.
rot90
(
3
,
[
2
,
3
])
elif
mode
==
4
:
return
img
.
rot90
(
2
,
[
2
,
3
]).
flip
([
2
])
elif
mode
==
5
:
return
img
.
rot90
(
1
,
[
2
,
3
])
elif
mode
==
6
:
return
img
.
rot90
(
2
,
[
2
,
3
])
elif
mode
==
7
:
return
img
.
rot90
(
3
,
[
2
,
3
]).
flip
([
2
])
def
augment_img_tensor
(
img
,
mode
=
0
):
'''Kai Zhang (github: https://github.com/cszn)
'''
img_size
=
img
.
size
()
img_np
=
img
.
data
.
cpu
().
numpy
()
if
len
(
img_size
)
==
3
:
img_np
=
np
.
transpose
(
img_np
,
(
1
,
2
,
0
))
elif
len
(
img_size
)
==
4
:
img_np
=
np
.
transpose
(
img_np
,
(
2
,
3
,
1
,
0
))
img_np
=
augment_img
(
img_np
,
mode
=
mode
)
img_tensor
=
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img_np
))
if
len
(
img_size
)
==
3
:
img_tensor
=
img_tensor
.
permute
(
2
,
0
,
1
)
elif
len
(
img_size
)
==
4
:
img_tensor
=
img_tensor
.
permute
(
3
,
2
,
0
,
1
)
return
img_tensor
.
type_as
(
img
)
def
augment_img_np3
(
img
,
mode
=
0
):
if
mode
==
0
:
return
img
elif
mode
==
1
:
return
img
.
transpose
(
1
,
0
,
2
)
elif
mode
==
2
:
return
img
[::
-
1
,
:,
:]
elif
mode
==
3
:
img
=
img
[::
-
1
,
:,
:]
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
elif
mode
==
4
:
return
img
[:,
::
-
1
,
:]
elif
mode
==
5
:
img
=
img
[:,
::
-
1
,
:]
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
elif
mode
==
6
:
img
=
img
[:,
::
-
1
,
:]
img
=
img
[::
-
1
,
:,
:]
return
img
elif
mode
==
7
:
img
=
img
[:,
::
-
1
,
:]
img
=
img
[::
-
1
,
:,
:]
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
def
augment_imgs
(
img_list
,
hflip
=
True
,
rot
=
True
):
# horizontal flip OR rotate
hflip
=
hflip
and
random
.
random
()
<
0.5
vflip
=
rot
and
random
.
random
()
<
0.5
rot90
=
rot
and
random
.
random
()
<
0.5
def
_augment
(
img
):
if
hflip
:
img
=
img
[:,
::
-
1
,
:]
if
vflip
:
img
=
img
[::
-
1
,
:,
:]
if
rot90
:
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
return
[
_augment
(
img
)
for
img
in
img_list
]
'''
# --------------------------------------------
# modcrop and shave
# --------------------------------------------
'''
def
modcrop
(
img_in
,
scale
):
# img_in: Numpy, HWC or HW
img
=
np
.
copy
(
img_in
)
if
img
.
ndim
==
2
:
H
,
W
=
img
.
shape
H_r
,
W_r
=
H
%
scale
,
W
%
scale
img
=
img
[:
H
-
H_r
,
:
W
-
W_r
]
elif
img
.
ndim
==
3
:
H
,
W
,
C
=
img
.
shape
H_r
,
W_r
=
H
%
scale
,
W
%
scale
img
=
img
[:
H
-
H_r
,
:
W
-
W_r
,
:]
else
:
raise
ValueError
(
'Wrong img ndim: [{:d}].'
.
format
(
img
.
ndim
))
return
img
def
shave
(
img_in
,
border
=
0
):
# img_in: Numpy, HWC or HW
img
=
np
.
copy
(
img_in
)
h
,
w
=
img
.
shape
[:
2
]
img
=
img
[
border
:
h
-
border
,
border
:
w
-
border
]
return
img
'''
# --------------------------------------------
# image processing process on numpy image
# channel_convert(in_c, tar_type, img_list):
# rgb2ycbcr(img, only_y=True):
# bgr2ycbcr(img, only_y=True):
# ycbcr2rgb(img):
# --------------------------------------------
'''
def
rgb2ycbcr
(
img
,
only_y
=
True
):
'''same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type
=
img
.
dtype
img
.
astype
(
np
.
float32
)
if
in_img_type
!=
np
.
uint8
:
img
*=
255.
# convert
if
only_y
:
rlt
=
np
.
dot
(
img
,
[
65.481
,
128.553
,
24.966
])
/
255.0
+
16.0
else
:
rlt
=
np
.
matmul
(
img
,
[[
65.481
,
-
37.797
,
112.0
],
[
128.553
,
-
74.203
,
-
93.786
],
[
24.966
,
112.0
,
-
18.214
]])
/
255.0
+
[
16
,
128
,
128
]
if
in_img_type
==
np
.
uint8
:
rlt
=
rlt
.
round
()
else
:
rlt
/=
255.
return
rlt
.
astype
(
in_img_type
)
def
ycbcr2rgb
(
img
):
'''same as matlab ycbcr2rgb
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type
=
img
.
dtype
img
.
astype
(
np
.
float32
)
if
in_img_type
!=
np
.
uint8
:
img
*=
255.
# convert
rlt
=
np
.
matmul
(
img
,
[[
0.00456621
,
0.00456621
,
0.00456621
],
[
0
,
-
0.00153632
,
0.00791071
],
[
0.00625893
,
-
0.00318811
,
0
]])
*
255.0
+
[
-
222.921
,
135.576
,
-
276.836
]
if
in_img_type
==
np
.
uint8
:
rlt
=
rlt
.
round
()
else
:
rlt
/=
255.
return
rlt
.
astype
(
in_img_type
)
def
bgr2ycbcr
(
img
,
only_y
=
True
):
'''bgr version of rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type
=
img
.
dtype
img
.
astype
(
np
.
float32
)
if
in_img_type
!=
np
.
uint8
:
img
*=
255.
# convert
if
only_y
:
rlt
=
np
.
dot
(
img
,
[
24.966
,
128.553
,
65.481
])
/
255.0
+
16.0
else
:
rlt
=
np
.
matmul
(
img
,
[[
24.966
,
112.0
,
-
18.214
],
[
128.553
,
-
74.203
,
-
93.786
],
[
65.481
,
-
37.797
,
112.0
]])
/
255.0
+
[
16
,
128
,
128
]
if
in_img_type
==
np
.
uint8
:
rlt
=
rlt
.
round
()
else
:
rlt
/=
255.
return
rlt
.
astype
(
in_img_type
)
def
channel_convert
(
in_c
,
tar_type
,
img_list
):
# conversion among BGR, gray and y
if
in_c
==
3
and
tar_type
==
'gray'
:
# BGR to gray
gray_list
=
[
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
for
img
in
img_list
]
return
[
np
.
expand_dims
(
img
,
axis
=
2
)
for
img
in
gray_list
]
elif
in_c
==
3
and
tar_type
==
'y'
:
# BGR to y
y_list
=
[
bgr2ycbcr
(
img
,
only_y
=
True
)
for
img
in
img_list
]
return
[
np
.
expand_dims
(
img
,
axis
=
2
)
for
img
in
y_list
]
elif
in_c
==
1
and
tar_type
==
'RGB'
:
# gray/y to BGR
return
[
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
for
img
in
img_list
]
else
:
return
img_list
'''
# --------------------------------------------
# metric, PSNR and SSIM
# --------------------------------------------
'''
# --------------------------------------------
# PSNR
# --------------------------------------------
def
calculate_psnr
(
img1
,
img2
,
border
=
0
):
# img1 and img2 have range [0, 255]
#img1 = img1.squeeze()
#img2 = img2.squeeze()
if
not
img1
.
shape
==
img2
.
shape
:
raise
ValueError
(
'Input images must have the same dimensions.'
)
h
,
w
=
img1
.
shape
[:
2
]
img1
=
img1
[
border
:
h
-
border
,
border
:
w
-
border
]
img2
=
img2
[
border
:
h
-
border
,
border
:
w
-
border
]
img1
=
img1
.
astype
(
np
.
float64
)
img2
=
img2
.
astype
(
np
.
float64
)
mse
=
np
.
mean
((
img1
-
img2
)
**
2
)
if
mse
==
0
:
return
float
(
'inf'
)
return
20
*
math
.
log10
(
255.0
/
math
.
sqrt
(
mse
))
# --------------------------------------------
# SSIM
# --------------------------------------------
def
calculate_ssim
(
img1
,
img2
,
border
=
0
):
'''calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
#img1 = img1.squeeze()
#img2 = img2.squeeze()
if
not
img1
.
shape
==
img2
.
shape
:
raise
ValueError
(
'Input images must have the same dimensions.'
)
h
,
w
=
img1
.
shape
[:
2
]
img1
=
img1
[
border
:
h
-
border
,
border
:
w
-
border
]
img2
=
img2
[
border
:
h
-
border
,
border
:
w
-
border
]
if
img1
.
ndim
==
2
:
return
ssim
(
img1
,
img2
)
elif
img1
.
ndim
==
3
:
if
img1
.
shape
[
2
]
==
3
:
ssims
=
[]
for
i
in
range
(
3
):
ssims
.
append
(
ssim
(
img1
[:,:,
i
],
img2
[:,:,
i
]))
return
np
.
array
(
ssims
).
mean
()
elif
img1
.
shape
[
2
]
==
1
:
return
ssim
(
np
.
squeeze
(
img1
),
np
.
squeeze
(
img2
))
else
:
raise
ValueError
(
'Wrong input image dimensions.'
)
def
ssim
(
img1
,
img2
):
C1
=
(
0.01
*
255
)
**
2
C2
=
(
0.03
*
255
)
**
2
img1
=
img1
.
astype
(
np
.
float64
)
img2
=
img2
.
astype
(
np
.
float64
)
kernel
=
cv2
.
getGaussianKernel
(
11
,
1.5
)
window
=
np
.
outer
(
kernel
,
kernel
.
transpose
())
mu1
=
cv2
.
filter2D
(
img1
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
# valid
mu2
=
cv2
.
filter2D
(
img2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
mu1_sq
=
mu1
**
2
mu2_sq
=
mu2
**
2
mu1_mu2
=
mu1
*
mu2
sigma1_sq
=
cv2
.
filter2D
(
img1
**
2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu1_sq
sigma2_sq
=
cv2
.
filter2D
(
img2
**
2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu2_sq
sigma12
=
cv2
.
filter2D
(
img1
*
img2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu1_mu2
ssim_map
=
((
2
*
mu1_mu2
+
C1
)
*
(
2
*
sigma12
+
C2
))
/
((
mu1_sq
+
mu2_sq
+
C1
)
*
(
sigma1_sq
+
sigma2_sq
+
C2
))
return
ssim_map
.
mean
()
'''
# --------------------------------------------
# matlab's bicubic imresize (numpy and torch) [0, 1]
# --------------------------------------------
'''
# matlab 'imresize' function, now only support 'bicubic'
def
cubic
(
x
):
absx
=
torch
.
abs
(
x
)
absx2
=
absx
**
2
absx3
=
absx
**
3
return
(
1.5
*
absx3
-
2.5
*
absx2
+
1
)
*
((
absx
<=
1
).
type_as
(
absx
))
+
\
(
-
0.5
*
absx3
+
2.5
*
absx2
-
4
*
absx
+
2
)
*
(((
absx
>
1
)
*
(
absx
<=
2
)).
type_as
(
absx
))
def
calculate_weights_indices
(
in_length
,
out_length
,
scale
,
kernel
,
kernel_width
,
antialiasing
):
if
(
scale
<
1
)
and
(
antialiasing
):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width
=
kernel_width
/
scale
# Output-space coordinates
x
=
torch
.
linspace
(
1
,
out_length
,
out_length
)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5+scale in output
# space maps to 1.5 in input space.
u
=
x
/
scale
+
0.5
*
(
1
-
1
/
scale
)
# What is the left-most pixel that can be involved in the computation?
left
=
torch
.
floor
(
u
-
kernel_width
/
2
)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
P
=
math
.
ceil
(
kernel_width
)
+
2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices
=
left
.
view
(
out_length
,
1
).
expand
(
out_length
,
P
)
+
torch
.
linspace
(
0
,
P
-
1
,
P
).
view
(
1
,
P
).
expand
(
out_length
,
P
)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center
=
u
.
view
(
out_length
,
1
).
expand
(
out_length
,
P
)
-
indices
# apply cubic kernel
if
(
scale
<
1
)
and
(
antialiasing
):
weights
=
scale
*
cubic
(
distance_to_center
*
scale
)
else
:
weights
=
cubic
(
distance_to_center
)
# Normalize the weights matrix so that each row sums to 1.
weights_sum
=
torch
.
sum
(
weights
,
1
).
view
(
out_length
,
1
)
weights
=
weights
/
weights_sum
.
expand
(
out_length
,
P
)
# If a column in weights is all zero, get rid of it. only consider the first and last column.
weights_zero_tmp
=
torch
.
sum
((
weights
==
0
),
0
)
if
not
math
.
isclose
(
weights_zero_tmp
[
0
],
0
,
rel_tol
=
1e-6
):
indices
=
indices
.
narrow
(
1
,
1
,
P
-
2
)
weights
=
weights
.
narrow
(
1
,
1
,
P
-
2
)
if
not
math
.
isclose
(
weights_zero_tmp
[
-
1
],
0
,
rel_tol
=
1e-6
):
indices
=
indices
.
narrow
(
1
,
0
,
P
-
2
)
weights
=
weights
.
narrow
(
1
,
0
,
P
-
2
)
weights
=
weights
.
contiguous
()
indices
=
indices
.
contiguous
()
sym_len_s
=
-
indices
.
min
()
+
1
sym_len_e
=
indices
.
max
()
-
in_length
indices
=
indices
+
sym_len_s
-
1
return
weights
,
indices
,
int
(
sym_len_s
),
int
(
sym_len_e
)
# --------------------------------------------
# imresize for tensor image [0, 1]
# --------------------------------------------
def
imresize
(
img
,
scale
,
antialiasing
=
True
):
# Now the scale should be the same for H and W
# input: img: pytorch tensor, CHW or HW [0,1]
# output: CHW or HW [0,1] w/o round
need_squeeze
=
True
if
img
.
dim
()
==
2
else
False
if
need_squeeze
:
img
.
unsqueeze_
(
0
)
in_C
,
in_H
,
in_W
=
img
.
size
()
out_C
,
out_H
,
out_W
=
in_C
,
math
.
ceil
(
in_H
*
scale
),
math
.
ceil
(
in_W
*
scale
)
kernel_width
=
4
kernel
=
'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H
,
indices_H
,
sym_len_Hs
,
sym_len_He
=
calculate_weights_indices
(
in_H
,
out_H
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
weights_W
,
indices_W
,
sym_len_Ws
,
sym_len_We
=
calculate_weights_indices
(
in_W
,
out_W
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
# process H dimension
# symmetric copying
img_aug
=
torch
.
FloatTensor
(
in_C
,
in_H
+
sym_len_Hs
+
sym_len_He
,
in_W
)
img_aug
.
narrow
(
1
,
sym_len_Hs
,
in_H
).
copy_
(
img
)
sym_patch
=
img
[:,
:
sym_len_Hs
,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
img_aug
.
narrow
(
1
,
0
,
sym_len_Hs
).
copy_
(
sym_patch_inv
)
sym_patch
=
img
[:,
-
sym_len_He
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
img_aug
.
narrow
(
1
,
sym_len_Hs
+
in_H
,
sym_len_He
).
copy_
(
sym_patch_inv
)
out_1
=
torch
.
FloatTensor
(
in_C
,
out_H
,
in_W
)
kernel_width
=
weights_H
.
size
(
1
)
for
i
in
range
(
out_H
):
idx
=
int
(
indices_H
[
i
][
0
])
for
j
in
range
(
out_C
):
out_1
[
j
,
i
,
:]
=
img_aug
[
j
,
idx
:
idx
+
kernel_width
,
:].
transpose
(
0
,
1
).
mv
(
weights_H
[
i
])
# process W dimension
# symmetric copying
out_1_aug
=
torch
.
FloatTensor
(
in_C
,
out_H
,
in_W
+
sym_len_Ws
+
sym_len_We
)
out_1_aug
.
narrow
(
2
,
sym_len_Ws
,
in_W
).
copy_
(
out_1
)
sym_patch
=
out_1
[:,
:,
:
sym_len_Ws
]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
2
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
2
,
inv_idx
)
out_1_aug
.
narrow
(
2
,
0
,
sym_len_Ws
).
copy_
(
sym_patch_inv
)
sym_patch
=
out_1
[:,
:,
-
sym_len_We
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
2
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
2
,
inv_idx
)
out_1_aug
.
narrow
(
2
,
sym_len_Ws
+
in_W
,
sym_len_We
).
copy_
(
sym_patch_inv
)
out_2
=
torch
.
FloatTensor
(
in_C
,
out_H
,
out_W
)
kernel_width
=
weights_W
.
size
(
1
)
for
i
in
range
(
out_W
):
idx
=
int
(
indices_W
[
i
][
0
])
for
j
in
range
(
out_C
):
out_2
[
j
,
:,
i
]
=
out_1_aug
[
j
,
:,
idx
:
idx
+
kernel_width
].
mv
(
weights_W
[
i
])
if
need_squeeze
:
out_2
.
squeeze_
()
return
out_2
# --------------------------------------------
# imresize for numpy image [0, 1]
# --------------------------------------------
def
imresize_np
(
img
,
scale
,
antialiasing
=
True
):
# Now the scale should be the same for H and W
# input: img: Numpy, HWC or HW [0,1]
# output: HWC or HW [0,1] w/o round
img
=
torch
.
from_numpy
(
img
)
need_squeeze
=
True
if
img
.
dim
()
==
2
else
False
if
need_squeeze
:
img
.
unsqueeze_
(
2
)
in_H
,
in_W
,
in_C
=
img
.
size
()
out_C
,
out_H
,
out_W
=
in_C
,
math
.
ceil
(
in_H
*
scale
),
math
.
ceil
(
in_W
*
scale
)
kernel_width
=
4
kernel
=
'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H
,
indices_H
,
sym_len_Hs
,
sym_len_He
=
calculate_weights_indices
(
in_H
,
out_H
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
weights_W
,
indices_W
,
sym_len_Ws
,
sym_len_We
=
calculate_weights_indices
(
in_W
,
out_W
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
# process H dimension
# symmetric copying
img_aug
=
torch
.
FloatTensor
(
in_H
+
sym_len_Hs
+
sym_len_He
,
in_W
,
in_C
)
img_aug
.
narrow
(
0
,
sym_len_Hs
,
in_H
).
copy_
(
img
)
sym_patch
=
img
[:
sym_len_Hs
,
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
0
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
0
,
inv_idx
)
img_aug
.
narrow
(
0
,
0
,
sym_len_Hs
).
copy_
(
sym_patch_inv
)
sym_patch
=
img
[
-
sym_len_He
:,
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
0
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
0
,
inv_idx
)
img_aug
.
narrow
(
0
,
sym_len_Hs
+
in_H
,
sym_len_He
).
copy_
(
sym_patch_inv
)
out_1
=
torch
.
FloatTensor
(
out_H
,
in_W
,
in_C
)
kernel_width
=
weights_H
.
size
(
1
)
for
i
in
range
(
out_H
):
idx
=
int
(
indices_H
[
i
][
0
])
for
j
in
range
(
out_C
):
out_1
[
i
,
:,
j
]
=
img_aug
[
idx
:
idx
+
kernel_width
,
:,
j
].
transpose
(
0
,
1
).
mv
(
weights_H
[
i
])
# process W dimension
# symmetric copying
out_1_aug
=
torch
.
FloatTensor
(
out_H
,
in_W
+
sym_len_Ws
+
sym_len_We
,
in_C
)
out_1_aug
.
narrow
(
1
,
sym_len_Ws
,
in_W
).
copy_
(
out_1
)
sym_patch
=
out_1
[:,
:
sym_len_Ws
,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
out_1_aug
.
narrow
(
1
,
0
,
sym_len_Ws
).
copy_
(
sym_patch_inv
)
sym_patch
=
out_1
[:,
-
sym_len_We
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
out_1_aug
.
narrow
(
1
,
sym_len_Ws
+
in_W
,
sym_len_We
).
copy_
(
sym_patch_inv
)
out_2
=
torch
.
FloatTensor
(
out_H
,
out_W
,
in_C
)
kernel_width
=
weights_W
.
size
(
1
)
for
i
in
range
(
out_W
):
idx
=
int
(
indices_W
[
i
][
0
])
for
j
in
range
(
out_C
):
out_2
[:,
i
,
j
]
=
out_1_aug
[:,
idx
:
idx
+
kernel_width
,
j
].
mv
(
weights_W
[
i
])
if
need_squeeze
:
out_2
.
squeeze_
()
return
out_2
.
numpy
()
if
__name__
==
'__main__'
:
print
(
'---'
)
# img = imread_uint('test.bmp', 3)
# img = uint2single(img)
# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
examples/images/diffusion/ldm/modules/losses/__init__.py
0 → 100644
View file @
a7e8159d
from
ldm.modules.losses.contperceptual
import
LPIPSWithDiscriminator
\ No newline at end of file
examples/images/diffusion/ldm/modules/losses/contperceptual.py
0 → 100644
View file @
a7e8159d
import
torch
import
torch.nn
as
nn
from
taming.modules.losses.vqperceptual
import
*
# TODO: taming dependency yes/no?
class
LPIPSWithDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
disc_start
,
logvar_init
=
0.0
,
kl_weight
=
1.0
,
pixelloss_weight
=
1.0
,
disc_num_layers
=
3
,
disc_in_channels
=
3
,
disc_factor
=
1.0
,
disc_weight
=
1.0
,
perceptual_weight
=
1.0
,
use_actnorm
=
False
,
disc_conditional
=
False
,
disc_loss
=
"hinge"
):
super
().
__init__
()
assert
disc_loss
in
[
"hinge"
,
"vanilla"
]
self
.
kl_weight
=
kl_weight
self
.
pixel_weight
=
pixelloss_weight
self
.
perceptual_loss
=
LPIPS
().
eval
()
self
.
perceptual_weight
=
perceptual_weight
# output log variance
self
.
logvar
=
nn
.
Parameter
(
torch
.
ones
(
size
=
())
*
logvar_init
)
self
.
discriminator
=
NLayerDiscriminator
(
input_nc
=
disc_in_channels
,
n_layers
=
disc_num_layers
,
use_actnorm
=
use_actnorm
).
apply
(
weights_init
)
self
.
discriminator_iter_start
=
disc_start
self
.
disc_loss
=
hinge_d_loss
if
disc_loss
==
"hinge"
else
vanilla_d_loss
self
.
disc_factor
=
disc_factor
self
.
discriminator_weight
=
disc_weight
self
.
disc_conditional
=
disc_conditional
def
calculate_adaptive_weight
(
self
,
nll_loss
,
g_loss
,
last_layer
=
None
):
if
last_layer
is
not
None
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
else
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
d_weight
=
torch
.
norm
(
nll_grads
)
/
(
torch
.
norm
(
g_grads
)
+
1e-4
)
d_weight
=
torch
.
clamp
(
d_weight
,
0.0
,
1e4
).
detach
()
d_weight
=
d_weight
*
self
.
discriminator_weight
return
d_weight
def
forward
(
self
,
inputs
,
reconstructions
,
posteriors
,
optimizer_idx
,
global_step
,
last_layer
=
None
,
cond
=
None
,
split
=
"train"
,
weights
=
None
):
rec_loss
=
torch
.
abs
(
inputs
.
contiguous
()
-
reconstructions
.
contiguous
())
if
self
.
perceptual_weight
>
0
:
p_loss
=
self
.
perceptual_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
())
rec_loss
=
rec_loss
+
self
.
perceptual_weight
*
p_loss
nll_loss
=
rec_loss
/
torch
.
exp
(
self
.
logvar
)
+
self
.
logvar
weighted_nll_loss
=
nll_loss
if
weights
is
not
None
:
weighted_nll_loss
=
weights
*
nll_loss
weighted_nll_loss
=
torch
.
sum
(
weighted_nll_loss
)
/
weighted_nll_loss
.
shape
[
0
]
nll_loss
=
torch
.
sum
(
nll_loss
)
/
nll_loss
.
shape
[
0
]
kl_loss
=
posteriors
.
kl
()
kl_loss
=
torch
.
sum
(
kl_loss
)
/
kl_loss
.
shape
[
0
]
# now the GAN part
if
optimizer_idx
==
0
:
# generator update
if
cond
is
None
:
assert
not
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
())
else
:
assert
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
(),
cond
),
dim
=
1
))
g_loss
=
-
torch
.
mean
(
logits_fake
)
if
self
.
disc_factor
>
0.0
:
try
:
d_weight
=
self
.
calculate_adaptive_weight
(
nll_loss
,
g_loss
,
last_layer
=
last_layer
)
except
RuntimeError
:
assert
not
self
.
training
d_weight
=
torch
.
tensor
(
0.0
)
else
:
d_weight
=
torch
.
tensor
(
0.0
)
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
loss
=
weighted_nll_loss
+
self
.
kl_weight
*
kl_loss
+
d_weight
*
disc_factor
*
g_loss
log
=
{
"{}/total_loss"
.
format
(
split
):
loss
.
clone
().
detach
().
mean
(),
"{}/logvar"
.
format
(
split
):
self
.
logvar
.
detach
(),
"{}/kl_loss"
.
format
(
split
):
kl_loss
.
detach
().
mean
(),
"{}/nll_loss"
.
format
(
split
):
nll_loss
.
detach
().
mean
(),
"{}/rec_loss"
.
format
(
split
):
rec_loss
.
detach
().
mean
(),
"{}/d_weight"
.
format
(
split
):
d_weight
.
detach
(),
"{}/disc_factor"
.
format
(
split
):
torch
.
tensor
(
disc_factor
),
"{}/g_loss"
.
format
(
split
):
g_loss
.
detach
().
mean
(),
}
return
loss
,
log
if
optimizer_idx
==
1
:
# second pass for discriminator update
if
cond
is
None
:
logits_real
=
self
.
discriminator
(
inputs
.
contiguous
().
detach
())
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
().
detach
())
else
:
logits_real
=
self
.
discriminator
(
torch
.
cat
((
inputs
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
d_loss
=
disc_factor
*
self
.
disc_loss
(
logits_real
,
logits_fake
)
log
=
{
"{}/disc_loss"
.
format
(
split
):
d_loss
.
clone
().
detach
().
mean
(),
"{}/logits_real"
.
format
(
split
):
logits_real
.
detach
().
mean
(),
"{}/logits_fake"
.
format
(
split
):
logits_fake
.
detach
().
mean
()
}
return
d_loss
,
log
examples/images/diffusion/ldm/modules/losses/vqperceptual.py
0 → 100644
View file @
a7e8159d
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
einops
import
repeat
from
taming.modules.discriminator.model
import
NLayerDiscriminator
,
weights_init
from
taming.modules.losses.lpips
import
LPIPS
from
taming.modules.losses.vqperceptual
import
hinge_d_loss
,
vanilla_d_loss
def
hinge_d_loss_with_exemplar_weights
(
logits_real
,
logits_fake
,
weights
):
assert
weights
.
shape
[
0
]
==
logits_real
.
shape
[
0
]
==
logits_fake
.
shape
[
0
]
loss_real
=
torch
.
mean
(
F
.
relu
(
1.
-
logits_real
),
dim
=
[
1
,
2
,
3
])
loss_fake
=
torch
.
mean
(
F
.
relu
(
1.
+
logits_fake
),
dim
=
[
1
,
2
,
3
])
loss_real
=
(
weights
*
loss_real
).
sum
()
/
weights
.
sum
()
loss_fake
=
(
weights
*
loss_fake
).
sum
()
/
weights
.
sum
()
d_loss
=
0.5
*
(
loss_real
+
loss_fake
)
return
d_loss
def
adopt_weight
(
weight
,
global_step
,
threshold
=
0
,
value
=
0.
):
if
global_step
<
threshold
:
weight
=
value
return
weight
def
measure_perplexity
(
predicted_indices
,
n_embed
):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings
=
F
.
one_hot
(
predicted_indices
,
n_embed
).
float
().
reshape
(
-
1
,
n_embed
)
avg_probs
=
encodings
.
mean
(
0
)
perplexity
=
(
-
(
avg_probs
*
torch
.
log
(
avg_probs
+
1e-10
)).
sum
()).
exp
()
cluster_use
=
torch
.
sum
(
avg_probs
>
0
)
return
perplexity
,
cluster_use
def
l1
(
x
,
y
):
return
torch
.
abs
(
x
-
y
)
def
l2
(
x
,
y
):
return
torch
.
pow
((
x
-
y
),
2
)
class
VQLPIPSWithDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
disc_start
,
codebook_weight
=
1.0
,
pixelloss_weight
=
1.0
,
disc_num_layers
=
3
,
disc_in_channels
=
3
,
disc_factor
=
1.0
,
disc_weight
=
1.0
,
perceptual_weight
=
1.0
,
use_actnorm
=
False
,
disc_conditional
=
False
,
disc_ndf
=
64
,
disc_loss
=
"hinge"
,
n_classes
=
None
,
perceptual_loss
=
"lpips"
,
pixel_loss
=
"l1"
):
super
().
__init__
()
assert
disc_loss
in
[
"hinge"
,
"vanilla"
]
assert
perceptual_loss
in
[
"lpips"
,
"clips"
,
"dists"
]
assert
pixel_loss
in
[
"l1"
,
"l2"
]
self
.
codebook_weight
=
codebook_weight
self
.
pixel_weight
=
pixelloss_weight
if
perceptual_loss
==
"lpips"
:
print
(
f
"
{
self
.
__class__
.
__name__
}
: Running with LPIPS."
)
self
.
perceptual_loss
=
LPIPS
().
eval
()
else
:
raise
ValueError
(
f
"Unknown perceptual loss: >>
{
perceptual_loss
}
<<"
)
self
.
perceptual_weight
=
perceptual_weight
if
pixel_loss
==
"l1"
:
self
.
pixel_loss
=
l1
else
:
self
.
pixel_loss
=
l2
self
.
discriminator
=
NLayerDiscriminator
(
input_nc
=
disc_in_channels
,
n_layers
=
disc_num_layers
,
use_actnorm
=
use_actnorm
,
ndf
=
disc_ndf
).
apply
(
weights_init
)
self
.
discriminator_iter_start
=
disc_start
if
disc_loss
==
"hinge"
:
self
.
disc_loss
=
hinge_d_loss
elif
disc_loss
==
"vanilla"
:
self
.
disc_loss
=
vanilla_d_loss
else
:
raise
ValueError
(
f
"Unknown GAN loss '
{
disc_loss
}
'."
)
print
(
f
"VQLPIPSWithDiscriminator running with
{
disc_loss
}
loss."
)
self
.
disc_factor
=
disc_factor
self
.
discriminator_weight
=
disc_weight
self
.
disc_conditional
=
disc_conditional
self
.
n_classes
=
n_classes
def
calculate_adaptive_weight
(
self
,
nll_loss
,
g_loss
,
last_layer
=
None
):
if
last_layer
is
not
None
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
else
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
d_weight
=
torch
.
norm
(
nll_grads
)
/
(
torch
.
norm
(
g_grads
)
+
1e-4
)
d_weight
=
torch
.
clamp
(
d_weight
,
0.0
,
1e4
).
detach
()
d_weight
=
d_weight
*
self
.
discriminator_weight
return
d_weight
def
forward
(
self
,
codebook_loss
,
inputs
,
reconstructions
,
optimizer_idx
,
global_step
,
last_layer
=
None
,
cond
=
None
,
split
=
"train"
,
predicted_indices
=
None
):
if
not
exists
(
codebook_loss
):
codebook_loss
=
torch
.
tensor
([
0.
]).
to
(
inputs
.
device
)
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss
=
self
.
pixel_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
())
if
self
.
perceptual_weight
>
0
:
p_loss
=
self
.
perceptual_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
())
rec_loss
=
rec_loss
+
self
.
perceptual_weight
*
p_loss
else
:
p_loss
=
torch
.
tensor
([
0.0
])
nll_loss
=
rec_loss
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss
=
torch
.
mean
(
nll_loss
)
# now the GAN part
if
optimizer_idx
==
0
:
# generator update
if
cond
is
None
:
assert
not
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
())
else
:
assert
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
(),
cond
),
dim
=
1
))
g_loss
=
-
torch
.
mean
(
logits_fake
)
try
:
d_weight
=
self
.
calculate_adaptive_weight
(
nll_loss
,
g_loss
,
last_layer
=
last_layer
)
except
RuntimeError
:
assert
not
self
.
training
d_weight
=
torch
.
tensor
(
0.0
)
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
loss
=
nll_loss
+
d_weight
*
disc_factor
*
g_loss
+
self
.
codebook_weight
*
codebook_loss
.
mean
()
log
=
{
"{}/total_loss"
.
format
(
split
):
loss
.
clone
().
detach
().
mean
(),
"{}/quant_loss"
.
format
(
split
):
codebook_loss
.
detach
().
mean
(),
"{}/nll_loss"
.
format
(
split
):
nll_loss
.
detach
().
mean
(),
"{}/rec_loss"
.
format
(
split
):
rec_loss
.
detach
().
mean
(),
"{}/p_loss"
.
format
(
split
):
p_loss
.
detach
().
mean
(),
"{}/d_weight"
.
format
(
split
):
d_weight
.
detach
(),
"{}/disc_factor"
.
format
(
split
):
torch
.
tensor
(
disc_factor
),
"{}/g_loss"
.
format
(
split
):
g_loss
.
detach
().
mean
(),
}
if
predicted_indices
is
not
None
:
assert
self
.
n_classes
is
not
None
with
torch
.
no_grad
():
perplexity
,
cluster_usage
=
measure_perplexity
(
predicted_indices
,
self
.
n_classes
)
log
[
f
"
{
split
}
/perplexity"
]
=
perplexity
log
[
f
"
{
split
}
/cluster_usage"
]
=
cluster_usage
return
loss
,
log
if
optimizer_idx
==
1
:
# second pass for discriminator update
if
cond
is
None
:
logits_real
=
self
.
discriminator
(
inputs
.
contiguous
().
detach
())
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
().
detach
())
else
:
logits_real
=
self
.
discriminator
(
torch
.
cat
((
inputs
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
d_loss
=
disc_factor
*
self
.
disc_loss
(
logits_real
,
logits_fake
)
log
=
{
"{}/disc_loss"
.
format
(
split
):
d_loss
.
clone
().
detach
().
mean
(),
"{}/logits_real"
.
format
(
split
):
logits_real
.
detach
().
mean
(),
"{}/logits_fake"
.
format
(
split
):
logits_fake
.
detach
().
mean
()
}
return
d_loss
,
log
examples/images/diffusion/ldm/modules/x_transformer.py
0 → 100644
View file @
a7e8159d
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
import
torch
from
torch
import
nn
,
einsum
import
torch.nn.functional
as
F
from
functools
import
partial
from
inspect
import
isfunction
from
collections
import
namedtuple
from
einops
import
rearrange
,
repeat
,
reduce
# constants
DEFAULT_DIM_HEAD
=
64
Intermediates
=
namedtuple
(
'Intermediates'
,
[
'pre_softmax_attn'
,
'post_softmax_attn'
])
LayerIntermediates
=
namedtuple
(
'Intermediates'
,
[
'hiddens'
,
'attn_intermediates'
])
class
AbsolutePositionalEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_seq_len
):
super
().
__init__
()
self
.
emb
=
nn
.
Embedding
(
max_seq_len
,
dim
)
self
.
init_
()
def
init_
(
self
):
nn
.
init
.
normal_
(
self
.
emb
.
weight
,
std
=
0.02
)
def
forward
(
self
,
x
):
n
=
torch
.
arange
(
x
.
shape
[
1
],
device
=
x
.
device
)
return
self
.
emb
(
n
)[
None
,
:,
:]
class
FixedPositionalEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
inv_freq
=
1.
/
(
10000
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
def
forward
(
self
,
x
,
seq_dim
=
1
,
offset
=
0
):
t
=
torch
.
arange
(
x
.
shape
[
seq_dim
],
device
=
x
.
device
).
type_as
(
self
.
inv_freq
)
+
offset
sinusoid_inp
=
torch
.
einsum
(
'i , j -> i j'
,
t
,
self
.
inv_freq
)
emb
=
torch
.
cat
((
sinusoid_inp
.
sin
(),
sinusoid_inp
.
cos
()),
dim
=-
1
)
return
emb
[
None
,
:,
:]
# helpers
def
exists
(
val
):
return
val
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
def
always
(
val
):
def
inner
(
*
args
,
**
kwargs
):
return
val
return
inner
def
not_equals
(
val
):
def
inner
(
x
):
return
x
!=
val
return
inner
def
equals
(
val
):
def
inner
(
x
):
return
x
==
val
return
inner
def
max_neg_value
(
tensor
):
return
-
torch
.
finfo
(
tensor
.
dtype
).
max
# keyword argument helpers
def
pick_and_pop
(
keys
,
d
):
values
=
list
(
map
(
lambda
key
:
d
.
pop
(
key
),
keys
))
return
dict
(
zip
(
keys
,
values
))
def
group_dict_by_key
(
cond
,
d
):
return_val
=
[
dict
(),
dict
()]
for
key
in
d
.
keys
():
match
=
bool
(
cond
(
key
))
ind
=
int
(
not
match
)
return_val
[
ind
][
key
]
=
d
[
key
]
return
(
*
return_val
,)
def
string_begins_with
(
prefix
,
str
):
return
str
.
startswith
(
prefix
)
def
group_by_key_prefix
(
prefix
,
d
):
return
group_dict_by_key
(
partial
(
string_begins_with
,
prefix
),
d
)
def
groupby_prefix_and_trim
(
prefix
,
d
):
kwargs_with_prefix
,
kwargs
=
group_dict_by_key
(
partial
(
string_begins_with
,
prefix
),
d
)
kwargs_without_prefix
=
dict
(
map
(
lambda
x
:
(
x
[
0
][
len
(
prefix
):],
x
[
1
]),
tuple
(
kwargs_with_prefix
.
items
())))
return
kwargs_without_prefix
,
kwargs
# classes
class
Scale
(
nn
.
Module
):
def
__init__
(
self
,
value
,
fn
):
super
().
__init__
()
self
.
value
=
value
self
.
fn
=
fn
def
forward
(
self
,
x
,
**
kwargs
):
x
,
*
rest
=
self
.
fn
(
x
,
**
kwargs
)
return
(
x
*
self
.
value
,
*
rest
)
class
Rezero
(
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
().
__init__
()
self
.
fn
=
fn
self
.
g
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
x
,
**
kwargs
):
x
,
*
rest
=
self
.
fn
(
x
,
**
kwargs
)
return
(
x
*
self
.
g
,
*
rest
)
class
ScaleNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
eps
=
1e-5
):
super
().
__init__
()
self
.
scale
=
dim
**
-
0.5
self
.
eps
=
eps
self
.
g
=
nn
.
Parameter
(
torch
.
ones
(
1
))
def
forward
(
self
,
x
):
norm
=
torch
.
norm
(
x
,
dim
=-
1
,
keepdim
=
True
)
*
self
.
scale
return
x
/
norm
.
clamp
(
min
=
self
.
eps
)
*
self
.
g
class
RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
eps
=
1e-8
):
super
().
__init__
()
self
.
scale
=
dim
**
-
0.5
self
.
eps
=
eps
self
.
g
=
nn
.
Parameter
(
torch
.
ones
(
dim
))
def
forward
(
self
,
x
):
norm
=
torch
.
norm
(
x
,
dim
=-
1
,
keepdim
=
True
)
*
self
.
scale
return
x
/
norm
.
clamp
(
min
=
self
.
eps
)
*
self
.
g
class
Residual
(
nn
.
Module
):
def
forward
(
self
,
x
,
residual
):
return
x
+
residual
class
GRUGating
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
gru
=
nn
.
GRUCell
(
dim
,
dim
)
def
forward
(
self
,
x
,
residual
):
gated_output
=
self
.
gru
(
rearrange
(
x
,
'b n d -> (b n) d'
),
rearrange
(
residual
,
'b n d -> (b n) d'
)
)
return
gated_output
.
reshape_as
(
x
)
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
):
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
()
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
)
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
# attention.
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_head
=
DEFAULT_DIM_HEAD
,
heads
=
8
,
causal
=
False
,
mask
=
None
,
talking_heads
=
False
,
sparse_topk
=
None
,
use_entmax15
=
False
,
num_mem_kv
=
0
,
dropout
=
0.
,
on_attn
=
False
):
super
().
__init__
()
if
use_entmax15
:
raise
NotImplementedError
(
"Check out entmax activation instead of softmax activation!"
)
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
causal
=
causal
self
.
mask
=
mask
inner_dim
=
dim_head
*
heads
self
.
to_q
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
# talking heads
self
.
talking_heads
=
talking_heads
if
talking_heads
:
self
.
pre_softmax_proj
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
heads
))
self
.
post_softmax_proj
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
heads
))
# explicit topk sparse attention
self
.
sparse_topk
=
sparse_topk
# entmax
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
self
.
attn_fn
=
F
.
softmax
# add memory key / values
self
.
num_mem_kv
=
num_mem_kv
if
num_mem_kv
>
0
:
self
.
mem_k
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
num_mem_kv
,
dim_head
))
self
.
mem_v
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
num_mem_kv
,
dim_head
))
# attention on attention
self
.
attn_on_attn
=
on_attn
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
dim
*
2
),
nn
.
GLU
())
if
on_attn
else
nn
.
Linear
(
inner_dim
,
dim
)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
context_mask
=
None
,
rel_pos
=
None
,
sinusoidal_emb
=
None
,
prev_attn
=
None
,
mem
=
None
):
b
,
n
,
_
,
h
,
talking_heads
,
device
=
*
x
.
shape
,
self
.
heads
,
self
.
talking_heads
,
x
.
device
kv_input
=
default
(
context
,
x
)
q_input
=
x
k_input
=
kv_input
v_input
=
kv_input
if
exists
(
mem
):
k_input
=
torch
.
cat
((
mem
,
k_input
),
dim
=-
2
)
v_input
=
torch
.
cat
((
mem
,
v_input
),
dim
=-
2
)
if
exists
(
sinusoidal_emb
):
# in shortformer, the query would start at a position offset depending on the past cached memory
offset
=
k_input
.
shape
[
-
2
]
-
q_input
.
shape
[
-
2
]
q_input
=
q_input
+
sinusoidal_emb
(
q_input
,
offset
=
offset
)
k_input
=
k_input
+
sinusoidal_emb
(
k_input
)
q
=
self
.
to_q
(
q_input
)
k
=
self
.
to_k
(
k_input
)
v
=
self
.
to_v
(
v_input
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b n (h d) -> b h n d'
,
h
=
h
),
(
q
,
k
,
v
))
input_mask
=
None
if
any
(
map
(
exists
,
(
mask
,
context_mask
))):
q_mask
=
default
(
mask
,
lambda
:
torch
.
ones
((
b
,
n
),
device
=
device
).
bool
())
k_mask
=
q_mask
if
not
exists
(
context
)
else
context_mask
k_mask
=
default
(
k_mask
,
lambda
:
torch
.
ones
((
b
,
k
.
shape
[
-
2
]),
device
=
device
).
bool
())
q_mask
=
rearrange
(
q_mask
,
'b i -> b () i ()'
)
k_mask
=
rearrange
(
k_mask
,
'b j -> b () () j'
)
input_mask
=
q_mask
*
k_mask
if
self
.
num_mem_kv
>
0
:
mem_k
,
mem_v
=
map
(
lambda
t
:
repeat
(
t
,
'h n d -> b h n d'
,
b
=
b
),
(
self
.
mem_k
,
self
.
mem_v
))
k
=
torch
.
cat
((
mem_k
,
k
),
dim
=-
2
)
v
=
torch
.
cat
((
mem_v
,
v
),
dim
=-
2
)
if
exists
(
input_mask
):
input_mask
=
F
.
pad
(
input_mask
,
(
self
.
num_mem_kv
,
0
),
value
=
True
)
dots
=
einsum
(
'b h i d, b h j d -> b h i j'
,
q
,
k
)
*
self
.
scale
mask_value
=
max_neg_value
(
dots
)
if
exists
(
prev_attn
):
dots
=
dots
+
prev_attn
pre_softmax_attn
=
dots
if
talking_heads
:
dots
=
einsum
(
'b h i j, h k -> b k i j'
,
dots
,
self
.
pre_softmax_proj
).
contiguous
()
if
exists
(
rel_pos
):
dots
=
rel_pos
(
dots
)
if
exists
(
input_mask
):
dots
.
masked_fill_
(
~
input_mask
,
mask_value
)
del
input_mask
if
self
.
causal
:
i
,
j
=
dots
.
shape
[
-
2
:]
r
=
torch
.
arange
(
i
,
device
=
device
)
mask
=
rearrange
(
r
,
'i -> () () i ()'
)
<
rearrange
(
r
,
'j -> () () () j'
)
mask
=
F
.
pad
(
mask
,
(
j
-
i
,
0
),
value
=
False
)
dots
.
masked_fill_
(
mask
,
mask_value
)
del
mask
if
exists
(
self
.
sparse_topk
)
and
self
.
sparse_topk
<
dots
.
shape
[
-
1
]:
top
,
_
=
dots
.
topk
(
self
.
sparse_topk
,
dim
=-
1
)
vk
=
top
[...,
-
1
].
unsqueeze
(
-
1
).
expand_as
(
dots
)
mask
=
dots
<
vk
dots
.
masked_fill_
(
mask
,
mask_value
)
del
mask
attn
=
self
.
attn_fn
(
dots
,
dim
=-
1
)
post_softmax_attn
=
attn
attn
=
self
.
dropout
(
attn
)
if
talking_heads
:
attn
=
einsum
(
'b h i j, h k -> b k i j'
,
attn
,
self
.
post_softmax_proj
).
contiguous
()
out
=
einsum
(
'b h i j, b h j d -> b h i d'
,
attn
,
v
)
out
=
rearrange
(
out
,
'b h n d -> b n (h d)'
)
intermediates
=
Intermediates
(
pre_softmax_attn
=
pre_softmax_attn
,
post_softmax_attn
=
post_softmax_attn
)
return
self
.
to_out
(
out
),
intermediates
class
AttentionLayers
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
depth
,
heads
=
8
,
causal
=
False
,
cross_attend
=
False
,
only_cross
=
False
,
use_scalenorm
=
False
,
use_rmsnorm
=
False
,
use_rezero
=
False
,
rel_pos_num_buckets
=
32
,
rel_pos_max_distance
=
128
,
position_infused_attn
=
False
,
custom_layers
=
None
,
sandwich_coef
=
None
,
par_ratio
=
None
,
residual_attn
=
False
,
cross_residual_attn
=
False
,
macaron
=
False
,
pre_norm
=
True
,
gate_residual
=
False
,
**
kwargs
):
super
().
__init__
()
ff_kwargs
,
kwargs
=
groupby_prefix_and_trim
(
'ff_'
,
kwargs
)
attn_kwargs
,
_
=
groupby_prefix_and_trim
(
'attn_'
,
kwargs
)
dim_head
=
attn_kwargs
.
get
(
'dim_head'
,
DEFAULT_DIM_HEAD
)
self
.
dim
=
dim
self
.
depth
=
depth
self
.
layers
=
nn
.
ModuleList
([])
self
.
has_pos_emb
=
position_infused_attn
self
.
pia_pos_emb
=
FixedPositionalEmbedding
(
dim
)
if
position_infused_attn
else
None
self
.
rotary_pos_emb
=
always
(
None
)
assert
rel_pos_num_buckets
<=
rel_pos_max_distance
,
'number of relative position buckets must be less than the relative position max distance'
self
.
rel_pos
=
None
self
.
pre_norm
=
pre_norm
self
.
residual_attn
=
residual_attn
self
.
cross_residual_attn
=
cross_residual_attn
norm_class
=
ScaleNorm
if
use_scalenorm
else
nn
.
LayerNorm
norm_class
=
RMSNorm
if
use_rmsnorm
else
norm_class
norm_fn
=
partial
(
norm_class
,
dim
)
norm_fn
=
nn
.
Identity
if
use_rezero
else
norm_fn
branch_fn
=
Rezero
if
use_rezero
else
None
if
cross_attend
and
not
only_cross
:
default_block
=
(
'a'
,
'c'
,
'f'
)
elif
cross_attend
and
only_cross
:
default_block
=
(
'c'
,
'f'
)
else
:
default_block
=
(
'a'
,
'f'
)
if
macaron
:
default_block
=
(
'f'
,)
+
default_block
if
exists
(
custom_layers
):
layer_types
=
custom_layers
elif
exists
(
par_ratio
):
par_depth
=
depth
*
len
(
default_block
)
assert
1
<
par_ratio
<=
par_depth
,
'par ratio out of range'
default_block
=
tuple
(
filter
(
not_equals
(
'f'
),
default_block
))
par_attn
=
par_depth
//
par_ratio
depth_cut
=
par_depth
*
2
//
3
# 2 / 3 attention layer cutoff suggested by PAR paper
par_width
=
(
depth_cut
+
depth_cut
//
par_attn
)
//
par_attn
assert
len
(
default_block
)
<=
par_width
,
'default block is too large for par_ratio'
par_block
=
default_block
+
(
'f'
,)
*
(
par_width
-
len
(
default_block
))
par_head
=
par_block
*
par_attn
layer_types
=
par_head
+
(
'f'
,)
*
(
par_depth
-
len
(
par_head
))
elif
exists
(
sandwich_coef
):
assert
sandwich_coef
>
0
and
sandwich_coef
<=
depth
,
'sandwich coefficient should be less than the depth'
layer_types
=
(
'a'
,)
*
sandwich_coef
+
default_block
*
(
depth
-
sandwich_coef
)
+
(
'f'
,)
*
sandwich_coef
else
:
layer_types
=
default_block
*
depth
self
.
layer_types
=
layer_types
self
.
num_attn_layers
=
len
(
list
(
filter
(
equals
(
'a'
),
layer_types
)))
for
layer_type
in
self
.
layer_types
:
if
layer_type
==
'a'
:
layer
=
Attention
(
dim
,
heads
=
heads
,
causal
=
causal
,
**
attn_kwargs
)
elif
layer_type
==
'c'
:
layer
=
Attention
(
dim
,
heads
=
heads
,
**
attn_kwargs
)
elif
layer_type
==
'f'
:
layer
=
FeedForward
(
dim
,
**
ff_kwargs
)
layer
=
layer
if
not
macaron
else
Scale
(
0.5
,
layer
)
else
:
raise
Exception
(
f
'invalid layer type
{
layer_type
}
'
)
if
isinstance
(
layer
,
Attention
)
and
exists
(
branch_fn
):
layer
=
branch_fn
(
layer
)
if
gate_residual
:
residual_fn
=
GRUGating
(
dim
)
else
:
residual_fn
=
Residual
()
self
.
layers
.
append
(
nn
.
ModuleList
([
norm_fn
(),
layer
,
residual_fn
]))
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
context_mask
=
None
,
mems
=
None
,
return_hiddens
=
False
):
hiddens
=
[]
intermediates
=
[]
prev_attn
=
None
prev_cross_attn
=
None
mems
=
mems
.
copy
()
if
exists
(
mems
)
else
[
None
]
*
self
.
num_attn_layers
for
ind
,
(
layer_type
,
(
norm
,
block
,
residual_fn
))
in
enumerate
(
zip
(
self
.
layer_types
,
self
.
layers
)):
is_last
=
ind
==
(
len
(
self
.
layers
)
-
1
)
if
layer_type
==
'a'
:
hiddens
.
append
(
x
)
layer_mem
=
mems
.
pop
(
0
)
residual
=
x
if
self
.
pre_norm
:
x
=
norm
(
x
)
if
layer_type
==
'a'
:
out
,
inter
=
block
(
x
,
mask
=
mask
,
sinusoidal_emb
=
self
.
pia_pos_emb
,
rel_pos
=
self
.
rel_pos
,
prev_attn
=
prev_attn
,
mem
=
layer_mem
)
elif
layer_type
==
'c'
:
out
,
inter
=
block
(
x
,
context
=
context
,
mask
=
mask
,
context_mask
=
context_mask
,
prev_attn
=
prev_cross_attn
)
elif
layer_type
==
'f'
:
out
=
block
(
x
)
x
=
residual_fn
(
out
,
residual
)
if
layer_type
in
(
'a'
,
'c'
):
intermediates
.
append
(
inter
)
if
layer_type
==
'a'
and
self
.
residual_attn
:
prev_attn
=
inter
.
pre_softmax_attn
elif
layer_type
==
'c'
and
self
.
cross_residual_attn
:
prev_cross_attn
=
inter
.
pre_softmax_attn
if
not
self
.
pre_norm
and
not
is_last
:
x
=
norm
(
x
)
if
return_hiddens
:
intermediates
=
LayerIntermediates
(
hiddens
=
hiddens
,
attn_intermediates
=
intermediates
)
return
x
,
intermediates
return
x
class
Encoder
(
AttentionLayers
):
def
__init__
(
self
,
**
kwargs
):
assert
'causal'
not
in
kwargs
,
'cannot set causality on encoder'
super
().
__init__
(
causal
=
False
,
**
kwargs
)
class
TransformerWrapper
(
nn
.
Module
):
def
__init__
(
self
,
*
,
num_tokens
,
max_seq_len
,
attn_layers
,
emb_dim
=
None
,
max_mem_len
=
0.
,
emb_dropout
=
0.
,
num_memory_tokens
=
None
,
tie_embedding
=
False
,
use_pos_emb
=
True
):
super
().
__init__
()
assert
isinstance
(
attn_layers
,
AttentionLayers
),
'attention layers must be one of Encoder or Decoder'
dim
=
attn_layers
.
dim
emb_dim
=
default
(
emb_dim
,
dim
)
self
.
max_seq_len
=
max_seq_len
self
.
max_mem_len
=
max_mem_len
self
.
num_tokens
=
num_tokens
self
.
token_emb
=
nn
.
Embedding
(
num_tokens
,
emb_dim
)
self
.
pos_emb
=
AbsolutePositionalEmbedding
(
emb_dim
,
max_seq_len
)
if
(
use_pos_emb
and
not
attn_layers
.
has_pos_emb
)
else
always
(
0
)
self
.
emb_dropout
=
nn
.
Dropout
(
emb_dropout
)
self
.
project_emb
=
nn
.
Linear
(
emb_dim
,
dim
)
if
emb_dim
!=
dim
else
nn
.
Identity
()
self
.
attn_layers
=
attn_layers
self
.
norm
=
nn
.
LayerNorm
(
dim
)
self
.
init_
()
self
.
to_logits
=
nn
.
Linear
(
dim
,
num_tokens
)
if
not
tie_embedding
else
lambda
t
:
t
@
self
.
token_emb
.
weight
.
t
()
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens
=
default
(
num_memory_tokens
,
0
)
self
.
num_memory_tokens
=
num_memory_tokens
if
num_memory_tokens
>
0
:
self
.
memory_tokens
=
nn
.
Parameter
(
torch
.
randn
(
num_memory_tokens
,
dim
))
# let funnel encoder know number of memory tokens, if specified
if
hasattr
(
attn_layers
,
'num_memory_tokens'
):
attn_layers
.
num_memory_tokens
=
num_memory_tokens
def
init_
(
self
):
nn
.
init
.
normal_
(
self
.
token_emb
.
weight
,
std
=
0.02
)
def
forward
(
self
,
x
,
return_embeddings
=
False
,
mask
=
None
,
return_mems
=
False
,
return_attn
=
False
,
mems
=
None
,
**
kwargs
):
b
,
n
,
device
,
num_mem
=
*
x
.
shape
,
x
.
device
,
self
.
num_memory_tokens
x
=
self
.
token_emb
(
x
)
x
+=
self
.
pos_emb
(
x
)
x
=
self
.
emb_dropout
(
x
)
x
=
self
.
project_emb
(
x
)
if
num_mem
>
0
:
mem
=
repeat
(
self
.
memory_tokens
,
'n d -> b n d'
,
b
=
b
)
x
=
torch
.
cat
((
mem
,
x
),
dim
=
1
)
# auto-handle masking after appending memory tokens
if
exists
(
mask
):
mask
=
F
.
pad
(
mask
,
(
num_mem
,
0
),
value
=
True
)
x
,
intermediates
=
self
.
attn_layers
(
x
,
mask
=
mask
,
mems
=
mems
,
return_hiddens
=
True
,
**
kwargs
)
x
=
self
.
norm
(
x
)
mem
,
x
=
x
[:,
:
num_mem
],
x
[:,
num_mem
:]
out
=
self
.
to_logits
(
x
)
if
not
return_embeddings
else
x
if
return_mems
:
hiddens
=
intermediates
.
hiddens
new_mems
=
list
(
map
(
lambda
pair
:
torch
.
cat
(
pair
,
dim
=-
2
),
zip
(
mems
,
hiddens
)))
if
exists
(
mems
)
else
hiddens
new_mems
=
list
(
map
(
lambda
t
:
t
[...,
-
self
.
max_mem_len
:,
:].
detach
(),
new_mems
))
return
out
,
new_mems
if
return_attn
:
attn_maps
=
list
(
map
(
lambda
t
:
t
.
post_softmax_attn
,
intermediates
.
attn_intermediates
))
return
out
,
attn_maps
return
out
examples/images/diffusion/scripts/tests/test_checkpoint.py
0 → 100644
View file @
a7e8159d
import
os
import
sys
from
copy
import
deepcopy
import
yaml
from
datetime
import
datetime
from
diffusers
import
StableDiffusionPipeline
import
torch
from
ldm.util
import
instantiate_from_config
from
main
import
get_parser
if
__name__
==
"__main__"
:
with
torch
.
no_grad
():
yaml_path
=
"../../train_colossalai.yaml"
with
open
(
yaml_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
config
=
f
.
read
()
base_config
=
yaml
.
load
(
config
,
Loader
=
yaml
.
FullLoader
)
unet_config
=
base_config
[
'model'
][
'params'
][
'unet_config'
]
diffusion_model
=
instantiate_from_config
(
unet_config
).
to
(
"cuda:0"
)
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"/data/scratch/diffuser/stable-diffusion-v1-4"
).
to
(
"cuda:0"
)
dif_model_2
=
pipe
.
unet
random_input_
=
torch
.
rand
((
4
,
4
,
32
,
32
)).
to
(
"cuda:0"
)
random_input_2
=
torch
.
clone
(
random_input_
).
to
(
"cuda:0"
)
time_stamp
=
torch
.
randint
(
20
,
(
4
,)).
to
(
"cuda:0"
)
time_stamp2
=
torch
.
clone
(
time_stamp
).
to
(
"cuda:0"
)
context_
=
torch
.
rand
((
4
,
77
,
768
)).
to
(
"cuda:0"
)
context_2
=
torch
.
clone
(
context_
).
to
(
"cuda:0"
)
out_1
=
diffusion_model
(
random_input_
,
time_stamp
,
context_
)
out_2
=
dif_model_2
(
random_input_2
,
time_stamp2
,
context_2
)
print
(
out_1
.
shape
)
print
(
out_2
[
'sample'
].
shape
)
\ No newline at end of file
examples/images/diffusion/scripts/tests/test_watermark.py
0 → 100644
View file @
a7e8159d
import
cv2
import
fire
from
imwatermark
import
WatermarkDecoder
def
testit
(
img_path
):
bgr
=
cv2
.
imread
(
img_path
)
decoder
=
WatermarkDecoder
(
'bytes'
,
136
)
watermark
=
decoder
.
decode
(
bgr
,
'dwtDct'
)
try
:
dec
=
watermark
.
decode
(
'utf-8'
)
except
:
dec
=
"null"
print
(
dec
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
testit
)
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment