Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dongchy920
instruct_pix2pix
Commits
9cfc6603
Commit
9cfc6603
authored
Nov 26, 2024
by
dongchy920
Browse files
instruct first commit
parents
Pipeline
#1969
canceled with stages
Changes
200
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4746 additions
and
0 deletions
+4746
-0
stable_diffusion/ldm/modules/ema.py
stable_diffusion/ldm/modules/ema.py
+76
-0
stable_diffusion/ldm/modules/encoders/__init__.py
stable_diffusion/ldm/modules/encoders/__init__.py
+0
-0
stable_diffusion/ldm/modules/encoders/modules.py
stable_diffusion/ldm/modules/encoders/modules.py
+235
-0
stable_diffusion/ldm/modules/image_degradation/__init__.py
stable_diffusion/ldm/modules/image_degradation/__init__.py
+2
-0
stable_diffusion/ldm/modules/image_degradation/bsrgan.py
stable_diffusion/ldm/modules/image_degradation/bsrgan.py
+730
-0
stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py
...e_diffusion/ldm/modules/image_degradation/bsrgan_light.py
+650
-0
stable_diffusion/ldm/modules/image_degradation/utils/test.png
...le_diffusion/ldm/modules/image_degradation/utils/test.png
+0
-0
stable_diffusion/ldm/modules/image_degradation/utils_image.py
...le_diffusion/ldm/modules/image_degradation/utils_image.py
+917
-0
stable_diffusion/ldm/modules/losses/__init__.py
stable_diffusion/ldm/modules/losses/__init__.py
+2
-0
stable_diffusion/ldm/modules/losses/contperceptual.py
stable_diffusion/ldm/modules/losses/contperceptual.py
+111
-0
stable_diffusion/ldm/modules/losses/vqperceptual.py
stable_diffusion/ldm/modules/losses/vqperceptual.py
+167
-0
stable_diffusion/ldm/modules/x_transformer.py
stable_diffusion/ldm/modules/x_transformer.py
+641
-0
stable_diffusion/ldm/util.py
stable_diffusion/ldm/util.py
+203
-0
stable_diffusion/main.py
stable_diffusion/main.py
+744
-0
stable_diffusion/models/first_stage_models/kl-f16/config.yaml
...le_diffusion/models/first_stage_models/kl-f16/config.yaml
+44
-0
stable_diffusion/models/first_stage_models/kl-f32/config.yaml
...le_diffusion/models/first_stage_models/kl-f32/config.yaml
+46
-0
stable_diffusion/models/first_stage_models/kl-f4/config.yaml
stable_diffusion/models/first_stage_models/kl-f4/config.yaml
+41
-0
stable_diffusion/models/first_stage_models/kl-f8/config.yaml
stable_diffusion/models/first_stage_models/kl-f8/config.yaml
+42
-0
stable_diffusion/models/first_stage_models/vq-f16/config.yaml
...le_diffusion/models/first_stage_models/vq-f16/config.yaml
+49
-0
stable_diffusion/models/first_stage_models/vq-f4-noattn/config.yaml
...fusion/models/first_stage_models/vq-f4-noattn/config.yaml
+46
-0
No files found.
stable_diffusion/ldm/modules/ema.py
0 → 100644
View file @
9cfc6603
import
torch
from
torch
import
nn
class
LitEma
(
nn
.
Module
):
def
__init__
(
self
,
model
,
decay
=
0.9999
,
use_num_upates
=
True
):
super
().
__init__
()
if
decay
<
0.0
or
decay
>
1.0
:
raise
ValueError
(
'Decay must be between 0 and 1'
)
self
.
m_name2s_name
=
{}
self
.
register_buffer
(
'decay'
,
torch
.
tensor
(
decay
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
'num_updates'
,
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
use_num_upates
else
torch
.
tensor
(
-
1
,
dtype
=
torch
.
int
))
for
name
,
p
in
model
.
named_parameters
():
if
p
.
requires_grad
:
#remove as '.'-character is not allowed in buffers
s_name
=
name
.
replace
(
'.'
,
''
)
self
.
m_name2s_name
.
update
({
name
:
s_name
})
self
.
register_buffer
(
s_name
,
p
.
clone
().
detach
().
data
)
self
.
collected_params
=
[]
def
forward
(
self
,
model
):
decay
=
self
.
decay
if
self
.
num_updates
>=
0
:
self
.
num_updates
+=
1
decay
=
min
(
self
.
decay
,(
1
+
self
.
num_updates
)
/
(
10
+
self
.
num_updates
))
one_minus_decay
=
1.0
-
decay
with
torch
.
no_grad
():
m_param
=
dict
(
model
.
named_parameters
())
shadow_params
=
dict
(
self
.
named_buffers
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
sname
=
self
.
m_name2s_name
[
key
]
shadow_params
[
sname
]
=
shadow_params
[
sname
].
type_as
(
m_param
[
key
])
shadow_params
[
sname
].
sub_
(
one_minus_decay
*
(
shadow_params
[
sname
]
-
m_param
[
key
]))
else
:
assert
not
key
in
self
.
m_name2s_name
def
copy_to
(
self
,
model
):
m_param
=
dict
(
model
.
named_parameters
())
shadow_params
=
dict
(
self
.
named_buffers
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
m_param
[
key
].
data
.
copy_
(
shadow_params
[
self
.
m_name2s_name
[
key
]].
data
)
else
:
assert
not
key
in
self
.
m_name2s_name
def
store
(
self
,
parameters
):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self
.
collected_params
=
[
param
.
clone
()
for
param
in
parameters
]
def
restore
(
self
,
parameters
):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for
c_param
,
param
in
zip
(
self
.
collected_params
,
parameters
):
param
.
data
.
copy_
(
c_param
.
data
)
stable_diffusion/ldm/modules/encoders/__init__.py
0 → 100644
View file @
9cfc6603
stable_diffusion/ldm/modules/encoders/modules.py
0 → 100644
View file @
9cfc6603
import
torch
import
torch.nn
as
nn
from
functools
import
partial
import
clip
from
einops
import
rearrange
,
repeat
from
transformers
import
CLIPTokenizer
,
CLIPTextModel
import
kornia
from
ldm.modules.x_transformer
import
Encoder
,
TransformerWrapper
# TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
class
AbstractEncoder
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
encode
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
class
ClassEmbedder
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
n_classes
=
1000
,
key
=
'class'
):
super
().
__init__
()
self
.
key
=
key
self
.
embedding
=
nn
.
Embedding
(
n_classes
,
embed_dim
)
def
forward
(
self
,
batch
,
key
=
None
):
if
key
is
None
:
key
=
self
.
key
# this is for use in crossattn
c
=
batch
[
key
][:,
None
]
c
=
self
.
embedding
(
c
)
return
c
class
TransformerEmbedder
(
AbstractEncoder
):
"""Some transformer encoder layers"""
def
__init__
(
self
,
n_embed
,
n_layer
,
vocab_size
,
max_seq_len
=
77
,
device
=
"cuda"
):
super
().
__init__
()
self
.
device
=
device
self
.
transformer
=
TransformerWrapper
(
num_tokens
=
vocab_size
,
max_seq_len
=
max_seq_len
,
attn_layers
=
Encoder
(
dim
=
n_embed
,
depth
=
n_layer
))
def
forward
(
self
,
tokens
):
tokens
=
tokens
.
to
(
self
.
device
)
# meh
z
=
self
.
transformer
(
tokens
,
return_embeddings
=
True
)
return
z
def
encode
(
self
,
x
):
return
self
(
x
)
class
BERTTokenizer
(
AbstractEncoder
):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def
__init__
(
self
,
device
=
"cuda"
,
vq_interface
=
True
,
max_length
=
77
):
super
().
__init__
()
from
transformers
import
BertTokenizerFast
# TODO: add to reuquirements
self
.
tokenizer
=
BertTokenizerFast
.
from_pretrained
(
"bert-base-uncased"
)
self
.
device
=
device
self
.
vq_interface
=
vq_interface
self
.
max_length
=
max_length
def
forward
(
self
,
text
):
batch_encoding
=
self
.
tokenizer
(
text
,
truncation
=
True
,
max_length
=
self
.
max_length
,
return_length
=
True
,
return_overflowing_tokens
=
False
,
padding
=
"max_length"
,
return_tensors
=
"pt"
)
tokens
=
batch_encoding
[
"input_ids"
].
to
(
self
.
device
)
return
tokens
@
torch
.
no_grad
()
def
encode
(
self
,
text
):
tokens
=
self
(
text
)
if
not
self
.
vq_interface
:
return
tokens
return
None
,
None
,
[
None
,
None
,
tokens
]
def
decode
(
self
,
text
):
return
text
class
BERTEmbedder
(
AbstractEncoder
):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def
__init__
(
self
,
n_embed
,
n_layer
,
vocab_size
=
30522
,
max_seq_len
=
77
,
device
=
"cuda"
,
use_tokenizer
=
True
,
embedding_dropout
=
0.0
):
super
().
__init__
()
self
.
use_tknz_fn
=
use_tokenizer
if
self
.
use_tknz_fn
:
self
.
tknz_fn
=
BERTTokenizer
(
vq_interface
=
False
,
max_length
=
max_seq_len
)
self
.
device
=
device
self
.
transformer
=
TransformerWrapper
(
num_tokens
=
vocab_size
,
max_seq_len
=
max_seq_len
,
attn_layers
=
Encoder
(
dim
=
n_embed
,
depth
=
n_layer
),
emb_dropout
=
embedding_dropout
)
def
forward
(
self
,
text
):
if
self
.
use_tknz_fn
:
tokens
=
self
.
tknz_fn
(
text
)
#.to(self.device)
else
:
tokens
=
text
z
=
self
.
transformer
(
tokens
,
return_embeddings
=
True
)
return
z
def
encode
(
self
,
text
):
# output of length 77
return
self
(
text
)
class
SpatialRescaler
(
nn
.
Module
):
def
__init__
(
self
,
n_stages
=
1
,
method
=
'bilinear'
,
multiplier
=
0.5
,
in_channels
=
3
,
out_channels
=
None
,
bias
=
False
):
super
().
__init__
()
self
.
n_stages
=
n_stages
assert
self
.
n_stages
>=
0
assert
method
in
[
'nearest'
,
'linear'
,
'bilinear'
,
'trilinear'
,
'bicubic'
,
'area'
]
self
.
multiplier
=
multiplier
self
.
interpolator
=
partial
(
torch
.
nn
.
functional
.
interpolate
,
mode
=
method
)
self
.
remap_output
=
out_channels
is
not
None
if
self
.
remap_output
:
print
(
f
'Spatial Rescaler mapping from
{
in_channels
}
to
{
out_channels
}
channels after resizing.'
)
self
.
channel_mapper
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
1
,
bias
=
bias
)
def
forward
(
self
,
x
):
for
stage
in
range
(
self
.
n_stages
):
x
=
self
.
interpolator
(
x
,
scale_factor
=
self
.
multiplier
)
if
self
.
remap_output
:
x
=
self
.
channel_mapper
(
x
)
return
x
def
encode
(
self
,
x
):
return
self
(
x
)
class
FrozenCLIPEmbedder
(
AbstractEncoder
):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def
__init__
(
self
,
version
=
"openai/clip-vit-large-patch14"
,
device
=
"cuda"
,
max_length
=
77
):
super
().
__init__
()
self
.
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
version
)
self
.
transformer
=
CLIPTextModel
.
from_pretrained
(
version
)
self
.
device
=
device
self
.
max_length
=
max_length
self
.
freeze
()
def
freeze
(
self
):
self
.
transformer
=
self
.
transformer
.
eval
()
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
text
):
batch_encoding
=
self
.
tokenizer
(
text
,
truncation
=
True
,
max_length
=
self
.
max_length
,
return_length
=
True
,
return_overflowing_tokens
=
False
,
padding
=
"max_length"
,
return_tensors
=
"pt"
)
tokens
=
batch_encoding
[
"input_ids"
].
to
(
self
.
device
)
outputs
=
self
.
transformer
(
input_ids
=
tokens
)
z
=
outputs
.
last_hidden_state
return
z
def
encode
(
self
,
text
):
return
self
(
text
)
class
FrozenCLIPTextEmbedder
(
nn
.
Module
):
"""
Uses the CLIP transformer encoder for text.
"""
def
__init__
(
self
,
version
=
'ViT-L/14'
,
device
=
"cuda"
,
max_length
=
77
,
n_repeat
=
1
,
normalize
=
True
):
super
().
__init__
()
self
.
model
,
_
=
clip
.
load
(
version
,
jit
=
False
,
device
=
"cpu"
)
self
.
device
=
device
self
.
max_length
=
max_length
self
.
n_repeat
=
n_repeat
self
.
normalize
=
normalize
def
freeze
(
self
):
self
.
model
=
self
.
model
.
eval
()
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
text
):
tokens
=
clip
.
tokenize
(
text
).
to
(
self
.
device
)
z
=
self
.
model
.
encode_text
(
tokens
)
if
self
.
normalize
:
z
=
z
/
torch
.
linalg
.
norm
(
z
,
dim
=
1
,
keepdim
=
True
)
return
z
def
encode
(
self
,
text
):
z
=
self
(
text
)
if
z
.
ndim
==
2
:
z
=
z
[:,
None
,
:]
z
=
repeat
(
z
,
'b 1 d -> b k d'
,
k
=
self
.
n_repeat
)
return
z
class
FrozenClipImageEmbedder
(
nn
.
Module
):
"""
Uses the CLIP image encoder.
"""
def
__init__
(
self
,
model
,
jit
=
False
,
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
,
antialias
=
False
,
):
super
().
__init__
()
self
.
model
,
_
=
clip
.
load
(
name
=
model
,
device
=
device
,
jit
=
jit
)
self
.
antialias
=
antialias
self
.
register_buffer
(
'mean'
,
torch
.
Tensor
([
0.48145466
,
0.4578275
,
0.40821073
]),
persistent
=
False
)
self
.
register_buffer
(
'std'
,
torch
.
Tensor
([
0.26862954
,
0.26130258
,
0.27577711
]),
persistent
=
False
)
def
preprocess
(
self
,
x
):
# normalize to [0,1]
x
=
kornia
.
geometry
.
resize
(
x
,
(
224
,
224
),
interpolation
=
'bicubic'
,
align_corners
=
True
,
antialias
=
self
.
antialias
)
x
=
(
x
+
1.
)
/
2.
# renormalize according to clip
x
=
kornia
.
enhance
.
normalize
(
x
,
self
.
mean
,
self
.
std
)
return
x
def
forward
(
self
,
x
):
# x is assumed to be in range [-1,1]
return
self
.
model
.
encode_image
(
self
.
preprocess
(
x
))
if
__name__
==
"__main__"
:
from
ldm.util
import
count_params
model
=
FrozenCLIPEmbedder
()
count_params
(
model
,
verbose
=
True
)
\ No newline at end of file
stable_diffusion/ldm/modules/image_degradation/__init__.py
0 → 100644
View file @
9cfc6603
from
ldm.modules.image_degradation.bsrgan
import
degradation_bsrgan_variant
as
degradation_fn_bsr
from
ldm.modules.image_degradation.bsrgan_light
import
degradation_bsrgan_variant
as
degradation_fn_bsr_light
stable_diffusion/ldm/modules/image_degradation/bsrgan.py
0 → 100644
View file @
9cfc6603
# -*- coding: utf-8 -*-
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
import
numpy
as
np
import
cv2
import
torch
from
functools
import
partial
import
random
from
scipy
import
ndimage
import
scipy
import
scipy.stats
as
ss
from
scipy.interpolate
import
interp2d
from
scipy.linalg
import
orth
import
albumentations
import
ldm.modules.image_degradation.utils_image
as
util
def
modcrop_np
(
img
,
sf
):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w
,
h
=
img
.
shape
[:
2
]
im
=
np
.
copy
(
img
)
return
im
[:
w
-
w
%
sf
,
:
h
-
h
%
sf
,
...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def
analytic_kernel
(
k
):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size
=
k
.
shape
[
0
]
# Calculate the big kernels size
big_k
=
np
.
zeros
((
3
*
k_size
-
2
,
3
*
k_size
-
2
))
# Loop over the small kernel to fill the big one
for
r
in
range
(
k_size
):
for
c
in
range
(
k_size
):
big_k
[
2
*
r
:
2
*
r
+
k_size
,
2
*
c
:
2
*
c
+
k_size
]
+=
k
[
r
,
c
]
*
k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop
=
k_size
//
2
cropped_big_k
=
big_k
[
crop
:
-
crop
,
crop
:
-
crop
]
# Normalize to 1
return
cropped_big_k
/
cropped_big_k
.
sum
()
def
anisotropic_Gaussian
(
ksize
=
15
,
theta
=
np
.
pi
,
l1
=
6
,
l2
=
6
):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v
=
np
.
dot
(
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]]),
np
.
array
([
1.
,
0.
]))
V
=
np
.
array
([[
v
[
0
],
v
[
1
]],
[
v
[
1
],
-
v
[
0
]]])
D
=
np
.
array
([[
l1
,
0
],
[
0
,
l2
]])
Sigma
=
np
.
dot
(
np
.
dot
(
V
,
D
),
np
.
linalg
.
inv
(
V
))
k
=
gm_blur_kernel
(
mean
=
[
0
,
0
],
cov
=
Sigma
,
size
=
ksize
)
return
k
def
gm_blur_kernel
(
mean
,
cov
,
size
=
15
):
center
=
size
/
2.0
+
0.5
k
=
np
.
zeros
([
size
,
size
])
for
y
in
range
(
size
):
for
x
in
range
(
size
):
cy
=
y
-
center
+
1
cx
=
x
-
center
+
1
k
[
y
,
x
]
=
ss
.
multivariate_normal
.
pdf
([
cx
,
cy
],
mean
=
mean
,
cov
=
cov
)
k
=
k
/
np
.
sum
(
k
)
return
k
def
shift_pixel
(
x
,
sf
,
upper_left
=
True
):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h
,
w
=
x
.
shape
[:
2
]
shift
=
(
sf
-
1
)
*
0.5
xv
,
yv
=
np
.
arange
(
0
,
w
,
1.0
),
np
.
arange
(
0
,
h
,
1.0
)
if
upper_left
:
x1
=
xv
+
shift
y1
=
yv
+
shift
else
:
x1
=
xv
-
shift
y1
=
yv
-
shift
x1
=
np
.
clip
(
x1
,
0
,
w
-
1
)
y1
=
np
.
clip
(
y1
,
0
,
h
-
1
)
if
x
.
ndim
==
2
:
x
=
interp2d
(
xv
,
yv
,
x
)(
x1
,
y1
)
if
x
.
ndim
==
3
:
for
i
in
range
(
x
.
shape
[
-
1
]):
x
[:,
:,
i
]
=
interp2d
(
xv
,
yv
,
x
[:,
:,
i
])(
x1
,
y1
)
return
x
def
blur
(
x
,
k
):
'''
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
n
,
c
=
x
.
shape
[:
2
]
p1
,
p2
=
(
k
.
shape
[
-
2
]
-
1
)
//
2
,
(
k
.
shape
[
-
1
]
-
1
)
//
2
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
=
(
p1
,
p2
,
p1
,
p2
),
mode
=
'replicate'
)
k
=
k
.
repeat
(
1
,
c
,
1
,
1
)
k
=
k
.
view
(
-
1
,
1
,
k
.
shape
[
2
],
k
.
shape
[
3
])
x
=
x
.
view
(
1
,
-
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
x
=
torch
.
nn
.
functional
.
conv2d
(
x
,
k
,
bias
=
None
,
stride
=
1
,
padding
=
0
,
groups
=
n
*
c
)
x
=
x
.
view
(
n
,
c
,
x
.
shape
[
2
],
x
.
shape
[
3
])
return
x
def
gen_kernel
(
k_size
=
np
.
array
([
15
,
15
]),
scale_factor
=
np
.
array
([
4
,
4
]),
min_var
=
0.6
,
max_var
=
10.
,
noise_level
=
0
):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
lambda_2
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
theta
=
np
.
random
.
rand
()
*
np
.
pi
# random theta
noise
=
-
noise_level
+
np
.
random
.
rand
(
*
k_size
)
*
noise_level
*
2
# Set COV matrix using Lambdas and Theta
LAMBDA
=
np
.
diag
([
lambda_1
,
lambda_2
])
Q
=
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]])
SIGMA
=
Q
@
LAMBDA
@
Q
.
T
INV_SIGMA
=
np
.
linalg
.
inv
(
SIGMA
)[
None
,
None
,
:,
:]
# Set expectation position (shifting kernel for aligned image)
MU
=
k_size
//
2
-
0.5
*
(
scale_factor
-
1
)
# - 0.5 * (scale_factor - k_size % 2)
MU
=
MU
[
None
,
None
,
:,
None
]
# Create meshgrid for Gaussian
[
X
,
Y
]
=
np
.
meshgrid
(
range
(
k_size
[
0
]),
range
(
k_size
[
1
]))
Z
=
np
.
stack
([
X
,
Y
],
2
)[:,
:,
:,
None
]
# Calcualte Gaussian for every pixel of the kernel
ZZ
=
Z
-
MU
ZZ_t
=
ZZ
.
transpose
(
0
,
1
,
3
,
2
)
raw_kernel
=
np
.
exp
(
-
0.5
*
np
.
squeeze
(
ZZ_t
@
INV_SIGMA
@
ZZ
))
*
(
1
+
noise
)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel
=
raw_kernel
/
np
.
sum
(
raw_kernel
)
return
kernel
def
fspecial_gaussian
(
hsize
,
sigma
):
hsize
=
[
hsize
,
hsize
]
siz
=
[(
hsize
[
0
]
-
1.0
)
/
2.0
,
(
hsize
[
1
]
-
1.0
)
/
2.0
]
std
=
sigma
[
x
,
y
]
=
np
.
meshgrid
(
np
.
arange
(
-
siz
[
1
],
siz
[
1
]
+
1
),
np
.
arange
(
-
siz
[
0
],
siz
[
0
]
+
1
))
arg
=
-
(
x
*
x
+
y
*
y
)
/
(
2
*
std
*
std
)
h
=
np
.
exp
(
arg
)
h
[
h
<
scipy
.
finfo
(
float
).
eps
*
h
.
max
()]
=
0
sumh
=
h
.
sum
()
if
sumh
!=
0
:
h
=
h
/
sumh
return
h
def
fspecial_laplacian
(
alpha
):
alpha
=
max
([
0
,
min
([
alpha
,
1
])])
h1
=
alpha
/
(
alpha
+
1
)
h2
=
(
1
-
alpha
)
/
(
alpha
+
1
)
h
=
[[
h1
,
h2
,
h1
],
[
h2
,
-
4
/
(
alpha
+
1
),
h2
],
[
h1
,
h2
,
h1
]]
h
=
np
.
array
(
h
)
return
h
def
fspecial
(
filter_type
,
*
args
,
**
kwargs
):
'''
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
if
filter_type
==
'gaussian'
:
return
fspecial_gaussian
(
*
args
,
**
kwargs
)
if
filter_type
==
'laplacian'
:
return
fspecial_laplacian
(
*
args
,
**
kwargs
)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def
bicubic_degradation
(
x
,
sf
=
3
):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x
=
util
.
imresize_np
(
x
,
scale
=
1
/
sf
)
return
x
def
srmd_degradation
(
x
,
k
,
sf
=
3
):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# 'nearest' | 'mirror'
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
return
x
def
dpsr_degradation
(
x
,
k
,
sf
=
3
):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
return
x
def
classical_degradation
(
x
,
k
,
sf
=
3
):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st
=
0
return
x
[
st
::
sf
,
st
::
sf
,
...]
def
add_sharpening
(
img
,
weight
=
0.5
,
radius
=
50
,
threshold
=
10
):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if
radius
%
2
==
0
:
radius
+=
1
blur
=
cv2
.
GaussianBlur
(
img
,
(
radius
,
radius
),
0
)
residual
=
img
-
blur
mask
=
np
.
abs
(
residual
)
*
255
>
threshold
mask
=
mask
.
astype
(
'float32'
)
soft_mask
=
cv2
.
GaussianBlur
(
mask
,
(
radius
,
radius
),
0
)
K
=
img
+
weight
*
residual
K
=
np
.
clip
(
K
,
0
,
1
)
return
soft_mask
*
K
+
(
1
-
soft_mask
)
*
img
def
add_blur
(
img
,
sf
=
4
):
wd2
=
4.0
+
sf
wd
=
2.0
+
0.2
*
sf
if
random
.
random
()
<
0.5
:
l1
=
wd2
*
random
.
random
()
l2
=
wd2
*
random
.
random
()
k
=
anisotropic_Gaussian
(
ksize
=
2
*
random
.
randint
(
2
,
11
)
+
3
,
theta
=
random
.
random
()
*
np
.
pi
,
l1
=
l1
,
l2
=
l2
)
else
:
k
=
fspecial
(
'gaussian'
,
2
*
random
.
randint
(
2
,
11
)
+
3
,
wd
*
random
.
random
())
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'mirror'
)
return
img
def
add_resize
(
img
,
sf
=
4
):
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.8
:
# up
sf1
=
random
.
uniform
(
1
,
2
)
elif
rnum
<
0.7
:
# down
sf1
=
random
.
uniform
(
0.5
/
sf
,
1
)
else
:
sf1
=
1.0
img
=
cv2
.
resize
(
img
,
(
int
(
sf1
*
img
.
shape
[
1
]),
int
(
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.6
:
# add color Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
# add grayscale Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
# add noise
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
=
img
+
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_speckle_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
rnum
=
random
.
random
()
if
rnum
>
0.6
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
+=
img
*
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_Poisson_noise
(
img
):
img
=
np
.
clip
((
img
*
255.0
).
round
(),
0
,
255
)
/
255.
vals
=
10
**
(
2
*
random
.
random
()
+
2.0
)
# [2, 4]
if
random
.
random
()
<
0.5
:
img
=
np
.
random
.
poisson
(
img
*
vals
).
astype
(
np
.
float32
)
/
vals
else
:
img_gray
=
np
.
dot
(
img
[...,
:
3
],
[
0.299
,
0.587
,
0.114
])
img_gray
=
np
.
clip
((
img_gray
*
255.0
).
round
(),
0
,
255
)
/
255.
noise_gray
=
np
.
random
.
poisson
(
img_gray
*
vals
).
astype
(
np
.
float32
)
/
vals
-
img_gray
img
+=
noise_gray
[:,
:,
np
.
newaxis
]
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_JPEG_noise
(
img
):
quality_factor
=
random
.
randint
(
30
,
95
)
img
=
cv2
.
cvtColor
(
util
.
single2uint
(
img
),
cv2
.
COLOR_RGB2BGR
)
result
,
encimg
=
cv2
.
imencode
(
'.jpg'
,
img
,
[
int
(
cv2
.
IMWRITE_JPEG_QUALITY
),
quality_factor
])
img
=
cv2
.
imdecode
(
encimg
,
1
)
img
=
cv2
.
cvtColor
(
util
.
uint2single
(
img
),
cv2
.
COLOR_BGR2RGB
)
return
img
def
random_crop
(
lq
,
hq
,
sf
=
4
,
lq_patchsize
=
64
):
h
,
w
=
lq
.
shape
[:
2
]
rnd_h
=
random
.
randint
(
0
,
h
-
lq_patchsize
)
rnd_w
=
random
.
randint
(
0
,
w
-
lq_patchsize
)
lq
=
lq
[
rnd_h
:
rnd_h
+
lq_patchsize
,
rnd_w
:
rnd_w
+
lq_patchsize
,
:]
rnd_h_H
,
rnd_w_H
=
int
(
rnd_h
*
sf
),
int
(
rnd_w
*
sf
)
hq
=
hq
[
rnd_h_H
:
rnd_h_H
+
lq_patchsize
*
sf
,
rnd_w_H
:
rnd_w_H
+
lq_patchsize
*
sf
,
:]
return
lq
,
hq
def
degradation_bsrgan
(
img
,
sf
=
4
,
lq_patchsize
=
72
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
img
.
shape
[:
2
]
img
=
img
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
img
.
shape
[:
2
]
if
h
<
lq_patchsize
*
sf
or
w
<
lq_patchsize
*
sf
:
raise
ValueError
(
f
'img size (
{
h1
}
X
{
w1
}
) is too small!'
)
hq
=
img
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
2
*
img
.
shape
[
1
]),
int
(
1
/
2
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
img
=
util
.
imresize_np
(
img
,
1
/
2
,
True
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
1
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
2
:
a
,
b
=
img
.
shape
[
1
],
img
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.75
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf1
*
img
.
shape
[
1
]),
int
(
1
/
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
img
=
img
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
img
=
add_JPEG_noise
(
img
)
elif
i
==
6
:
# add processed camera sensor noise
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
# add final JPEG compression noise
img
=
add_JPEG_noise
(
img
)
# random crop
img
,
hq
=
random_crop
(
img
,
hq
,
sf_ori
,
lq_patchsize
)
return
img
,
hq
# todo no isp_model?
def
degradation_bsrgan_variant
(
image
,
sf
=
4
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image
=
util
.
uint2single
(
image
)
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
image
.
shape
[:
2
]
image
=
image
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
image
.
shape
[:
2
]
hq
=
image
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
2
*
image
.
shape
[
1
]),
int
(
1
/
2
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
image
=
util
.
imresize_np
(
image
,
1
/
2
,
True
)
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
image
=
add_blur
(
image
,
sf
=
sf
)
elif
i
==
1
:
image
=
add_blur
(
image
,
sf
=
sf
)
elif
i
==
2
:
a
,
b
=
image
.
shape
[
1
],
image
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.75
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf1
*
image
.
shape
[
1
]),
int
(
1
/
sf1
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
image
=
ndimage
.
filters
.
convolve
(
image
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
image
=
image
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
image
=
add_Gaussian_noise
(
image
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
image
=
add_JPEG_noise
(
image
)
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image
=
add_JPEG_noise
(
image
)
image
=
util
.
single2uint
(
image
)
example
=
{
"image"
:
image
}
return
example
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
def
degradation_bsrgan_plus
(
img
,
sf
=
4
,
shuffle_prob
=
0.5
,
use_sharp
=
True
,
lq_patchsize
=
64
,
isp_model
=
None
):
"""
This is an extended degradation model by combining
the degradation models of BSRGAN and Real-ESRGAN
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
use_shuffle: the degradation shuffle
use_sharp: sharpening the img
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
h1
,
w1
=
img
.
shape
[:
2
]
img
=
img
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
img
.
shape
[:
2
]
if
h
<
lq_patchsize
*
sf
or
w
<
lq_patchsize
*
sf
:
raise
ValueError
(
f
'img size (
{
h1
}
X
{
w1
}
) is too small!'
)
if
use_sharp
:
img
=
add_sharpening
(
img
)
hq
=
img
.
copy
()
if
random
.
random
()
<
shuffle_prob
:
shuffle_order
=
random
.
sample
(
range
(
13
),
13
)
else
:
shuffle_order
=
list
(
range
(
13
))
# local shuffle for noise, JPEG is always the last one
shuffle_order
[
2
:
6
]
=
random
.
sample
(
shuffle_order
[
2
:
6
],
len
(
range
(
2
,
6
)))
shuffle_order
[
9
:
13
]
=
random
.
sample
(
shuffle_order
[
9
:
13
],
len
(
range
(
9
,
13
)))
poisson_prob
,
speckle_prob
,
isp_prob
=
0.1
,
0.1
,
0.1
for
i
in
shuffle_order
:
if
i
==
0
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
1
:
img
=
add_resize
(
img
,
sf
=
sf
)
elif
i
==
2
:
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
3
:
if
random
.
random
()
<
poisson_prob
:
img
=
add_Poisson_noise
(
img
)
elif
i
==
4
:
if
random
.
random
()
<
speckle_prob
:
img
=
add_speckle_noise
(
img
)
elif
i
==
5
:
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
elif
i
==
6
:
img
=
add_JPEG_noise
(
img
)
elif
i
==
7
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
8
:
img
=
add_resize
(
img
,
sf
=
sf
)
elif
i
==
9
:
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
10
:
if
random
.
random
()
<
poisson_prob
:
img
=
add_Poisson_noise
(
img
)
elif
i
==
11
:
if
random
.
random
()
<
speckle_prob
:
img
=
add_speckle_noise
(
img
)
elif
i
==
12
:
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
else
:
print
(
'check the shuffle!'
)
# resize to desired size
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf
*
hq
.
shape
[
1
]),
int
(
1
/
sf
*
hq
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
# add final JPEG compression noise
img
=
add_JPEG_noise
(
img
)
# random crop
img
,
hq
=
random_crop
(
img
,
hq
,
sf
,
lq_patchsize
)
return
img
,
hq
if
__name__
==
'__main__'
:
print
(
"hey"
)
img
=
util
.
imread_uint
(
'utils/test.png'
,
3
)
print
(
img
)
img
=
util
.
uint2single
(
img
)
print
(
img
)
img
=
img
[:
448
,
:
448
]
h
=
img
.
shape
[
0
]
//
4
print
(
"resizing to"
,
h
)
sf
=
4
deg_fn
=
partial
(
degradation_bsrgan_variant
,
sf
=
sf
)
for
i
in
range
(
20
):
print
(
i
)
img_lq
=
deg_fn
(
img
)
print
(
img_lq
)
img_lq_bicubic
=
albumentations
.
SmallestMaxSize
(
max_size
=
h
,
interpolation
=
cv2
.
INTER_CUBIC
)(
image
=
img
)[
"image"
]
print
(
img_lq
.
shape
)
print
(
"bicubic"
,
img_lq_bicubic
.
shape
)
print
(
img_hq
.
shape
)
lq_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
lq_bicubic_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq_bicubic
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
img_concat
=
np
.
concatenate
([
lq_bicubic_nearest
,
lq_nearest
,
util
.
single2uint
(
img_hq
)],
axis
=
1
)
util
.
imsave
(
img_concat
,
str
(
i
)
+
'.png'
)
stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py
0 → 100644
View file @
9cfc6603
# -*- coding: utf-8 -*-
import
numpy
as
np
import
cv2
import
torch
from
functools
import
partial
import
random
from
scipy
import
ndimage
import
scipy
import
scipy.stats
as
ss
from
scipy.interpolate
import
interp2d
from
scipy.linalg
import
orth
import
albumentations
import
ldm.modules.image_degradation.utils_image
as
util
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
def
modcrop_np
(
img
,
sf
):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w
,
h
=
img
.
shape
[:
2
]
im
=
np
.
copy
(
img
)
return
im
[:
w
-
w
%
sf
,
:
h
-
h
%
sf
,
...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def
analytic_kernel
(
k
):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size
=
k
.
shape
[
0
]
# Calculate the big kernels size
big_k
=
np
.
zeros
((
3
*
k_size
-
2
,
3
*
k_size
-
2
))
# Loop over the small kernel to fill the big one
for
r
in
range
(
k_size
):
for
c
in
range
(
k_size
):
big_k
[
2
*
r
:
2
*
r
+
k_size
,
2
*
c
:
2
*
c
+
k_size
]
+=
k
[
r
,
c
]
*
k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop
=
k_size
//
2
cropped_big_k
=
big_k
[
crop
:
-
crop
,
crop
:
-
crop
]
# Normalize to 1
return
cropped_big_k
/
cropped_big_k
.
sum
()
def
anisotropic_Gaussian
(
ksize
=
15
,
theta
=
np
.
pi
,
l1
=
6
,
l2
=
6
):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v
=
np
.
dot
(
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]]),
np
.
array
([
1.
,
0.
]))
V
=
np
.
array
([[
v
[
0
],
v
[
1
]],
[
v
[
1
],
-
v
[
0
]]])
D
=
np
.
array
([[
l1
,
0
],
[
0
,
l2
]])
Sigma
=
np
.
dot
(
np
.
dot
(
V
,
D
),
np
.
linalg
.
inv
(
V
))
k
=
gm_blur_kernel
(
mean
=
[
0
,
0
],
cov
=
Sigma
,
size
=
ksize
)
return
k
def
gm_blur_kernel
(
mean
,
cov
,
size
=
15
):
center
=
size
/
2.0
+
0.5
k
=
np
.
zeros
([
size
,
size
])
for
y
in
range
(
size
):
for
x
in
range
(
size
):
cy
=
y
-
center
+
1
cx
=
x
-
center
+
1
k
[
y
,
x
]
=
ss
.
multivariate_normal
.
pdf
([
cx
,
cy
],
mean
=
mean
,
cov
=
cov
)
k
=
k
/
np
.
sum
(
k
)
return
k
def
shift_pixel
(
x
,
sf
,
upper_left
=
True
):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h
,
w
=
x
.
shape
[:
2
]
shift
=
(
sf
-
1
)
*
0.5
xv
,
yv
=
np
.
arange
(
0
,
w
,
1.0
),
np
.
arange
(
0
,
h
,
1.0
)
if
upper_left
:
x1
=
xv
+
shift
y1
=
yv
+
shift
else
:
x1
=
xv
-
shift
y1
=
yv
-
shift
x1
=
np
.
clip
(
x1
,
0
,
w
-
1
)
y1
=
np
.
clip
(
y1
,
0
,
h
-
1
)
if
x
.
ndim
==
2
:
x
=
interp2d
(
xv
,
yv
,
x
)(
x1
,
y1
)
if
x
.
ndim
==
3
:
for
i
in
range
(
x
.
shape
[
-
1
]):
x
[:,
:,
i
]
=
interp2d
(
xv
,
yv
,
x
[:,
:,
i
])(
x1
,
y1
)
return
x
def
blur
(
x
,
k
):
'''
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
n
,
c
=
x
.
shape
[:
2
]
p1
,
p2
=
(
k
.
shape
[
-
2
]
-
1
)
//
2
,
(
k
.
shape
[
-
1
]
-
1
)
//
2
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
=
(
p1
,
p2
,
p1
,
p2
),
mode
=
'replicate'
)
k
=
k
.
repeat
(
1
,
c
,
1
,
1
)
k
=
k
.
view
(
-
1
,
1
,
k
.
shape
[
2
],
k
.
shape
[
3
])
x
=
x
.
view
(
1
,
-
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
x
=
torch
.
nn
.
functional
.
conv2d
(
x
,
k
,
bias
=
None
,
stride
=
1
,
padding
=
0
,
groups
=
n
*
c
)
x
=
x
.
view
(
n
,
c
,
x
.
shape
[
2
],
x
.
shape
[
3
])
return
x
def
gen_kernel
(
k_size
=
np
.
array
([
15
,
15
]),
scale_factor
=
np
.
array
([
4
,
4
]),
min_var
=
0.6
,
max_var
=
10.
,
noise_level
=
0
):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
lambda_2
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
theta
=
np
.
random
.
rand
()
*
np
.
pi
# random theta
noise
=
-
noise_level
+
np
.
random
.
rand
(
*
k_size
)
*
noise_level
*
2
# Set COV matrix using Lambdas and Theta
LAMBDA
=
np
.
diag
([
lambda_1
,
lambda_2
])
Q
=
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]])
SIGMA
=
Q
@
LAMBDA
@
Q
.
T
INV_SIGMA
=
np
.
linalg
.
inv
(
SIGMA
)[
None
,
None
,
:,
:]
# Set expectation position (shifting kernel for aligned image)
MU
=
k_size
//
2
-
0.5
*
(
scale_factor
-
1
)
# - 0.5 * (scale_factor - k_size % 2)
MU
=
MU
[
None
,
None
,
:,
None
]
# Create meshgrid for Gaussian
[
X
,
Y
]
=
np
.
meshgrid
(
range
(
k_size
[
0
]),
range
(
k_size
[
1
]))
Z
=
np
.
stack
([
X
,
Y
],
2
)[:,
:,
:,
None
]
# Calcualte Gaussian for every pixel of the kernel
ZZ
=
Z
-
MU
ZZ_t
=
ZZ
.
transpose
(
0
,
1
,
3
,
2
)
raw_kernel
=
np
.
exp
(
-
0.5
*
np
.
squeeze
(
ZZ_t
@
INV_SIGMA
@
ZZ
))
*
(
1
+
noise
)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel
=
raw_kernel
/
np
.
sum
(
raw_kernel
)
return
kernel
def
fspecial_gaussian
(
hsize
,
sigma
):
hsize
=
[
hsize
,
hsize
]
siz
=
[(
hsize
[
0
]
-
1.0
)
/
2.0
,
(
hsize
[
1
]
-
1.0
)
/
2.0
]
std
=
sigma
[
x
,
y
]
=
np
.
meshgrid
(
np
.
arange
(
-
siz
[
1
],
siz
[
1
]
+
1
),
np
.
arange
(
-
siz
[
0
],
siz
[
0
]
+
1
))
arg
=
-
(
x
*
x
+
y
*
y
)
/
(
2
*
std
*
std
)
h
=
np
.
exp
(
arg
)
h
[
h
<
scipy
.
finfo
(
float
).
eps
*
h
.
max
()]
=
0
sumh
=
h
.
sum
()
if
sumh
!=
0
:
h
=
h
/
sumh
return
h
def
fspecial_laplacian
(
alpha
):
alpha
=
max
([
0
,
min
([
alpha
,
1
])])
h1
=
alpha
/
(
alpha
+
1
)
h2
=
(
1
-
alpha
)
/
(
alpha
+
1
)
h
=
[[
h1
,
h2
,
h1
],
[
h2
,
-
4
/
(
alpha
+
1
),
h2
],
[
h1
,
h2
,
h1
]]
h
=
np
.
array
(
h
)
return
h
def
fspecial
(
filter_type
,
*
args
,
**
kwargs
):
'''
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
if
filter_type
==
'gaussian'
:
return
fspecial_gaussian
(
*
args
,
**
kwargs
)
if
filter_type
==
'laplacian'
:
return
fspecial_laplacian
(
*
args
,
**
kwargs
)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def
bicubic_degradation
(
x
,
sf
=
3
):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x
=
util
.
imresize_np
(
x
,
scale
=
1
/
sf
)
return
x
def
srmd_degradation
(
x
,
k
,
sf
=
3
):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# 'nearest' | 'mirror'
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
return
x
def
dpsr_degradation
(
x
,
k
,
sf
=
3
):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
return
x
def
classical_degradation
(
x
,
k
,
sf
=
3
):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st
=
0
return
x
[
st
::
sf
,
st
::
sf
,
...]
def
add_sharpening
(
img
,
weight
=
0.5
,
radius
=
50
,
threshold
=
10
):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if
radius
%
2
==
0
:
radius
+=
1
blur
=
cv2
.
GaussianBlur
(
img
,
(
radius
,
radius
),
0
)
residual
=
img
-
blur
mask
=
np
.
abs
(
residual
)
*
255
>
threshold
mask
=
mask
.
astype
(
'float32'
)
soft_mask
=
cv2
.
GaussianBlur
(
mask
,
(
radius
,
radius
),
0
)
K
=
img
+
weight
*
residual
K
=
np
.
clip
(
K
,
0
,
1
)
return
soft_mask
*
K
+
(
1
-
soft_mask
)
*
img
def
add_blur
(
img
,
sf
=
4
):
wd2
=
4.0
+
sf
wd
=
2.0
+
0.2
*
sf
wd2
=
wd2
/
4
wd
=
wd
/
4
if
random
.
random
()
<
0.5
:
l1
=
wd2
*
random
.
random
()
l2
=
wd2
*
random
.
random
()
k
=
anisotropic_Gaussian
(
ksize
=
random
.
randint
(
2
,
11
)
+
3
,
theta
=
random
.
random
()
*
np
.
pi
,
l1
=
l1
,
l2
=
l2
)
else
:
k
=
fspecial
(
'gaussian'
,
random
.
randint
(
2
,
4
)
+
3
,
wd
*
random
.
random
())
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'mirror'
)
return
img
def
add_resize
(
img
,
sf
=
4
):
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.8
:
# up
sf1
=
random
.
uniform
(
1
,
2
)
elif
rnum
<
0.7
:
# down
sf1
=
random
.
uniform
(
0.5
/
sf
,
1
)
else
:
sf1
=
1.0
img
=
cv2
.
resize
(
img
,
(
int
(
sf1
*
img
.
shape
[
1
]),
int
(
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.6
:
# add color Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
# add grayscale Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
# add noise
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
=
img
+
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_speckle_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
rnum
=
random
.
random
()
if
rnum
>
0.6
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
+=
img
*
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_Poisson_noise
(
img
):
img
=
np
.
clip
((
img
*
255.0
).
round
(),
0
,
255
)
/
255.
vals
=
10
**
(
2
*
random
.
random
()
+
2.0
)
# [2, 4]
if
random
.
random
()
<
0.5
:
img
=
np
.
random
.
poisson
(
img
*
vals
).
astype
(
np
.
float32
)
/
vals
else
:
img_gray
=
np
.
dot
(
img
[...,
:
3
],
[
0.299
,
0.587
,
0.114
])
img_gray
=
np
.
clip
((
img_gray
*
255.0
).
round
(),
0
,
255
)
/
255.
noise_gray
=
np
.
random
.
poisson
(
img_gray
*
vals
).
astype
(
np
.
float32
)
/
vals
-
img_gray
img
+=
noise_gray
[:,
:,
np
.
newaxis
]
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_JPEG_noise
(
img
):
quality_factor
=
random
.
randint
(
80
,
95
)
img
=
cv2
.
cvtColor
(
util
.
single2uint
(
img
),
cv2
.
COLOR_RGB2BGR
)
result
,
encimg
=
cv2
.
imencode
(
'.jpg'
,
img
,
[
int
(
cv2
.
IMWRITE_JPEG_QUALITY
),
quality_factor
])
img
=
cv2
.
imdecode
(
encimg
,
1
)
img
=
cv2
.
cvtColor
(
util
.
uint2single
(
img
),
cv2
.
COLOR_BGR2RGB
)
return
img
def
random_crop
(
lq
,
hq
,
sf
=
4
,
lq_patchsize
=
64
):
h
,
w
=
lq
.
shape
[:
2
]
rnd_h
=
random
.
randint
(
0
,
h
-
lq_patchsize
)
rnd_w
=
random
.
randint
(
0
,
w
-
lq_patchsize
)
lq
=
lq
[
rnd_h
:
rnd_h
+
lq_patchsize
,
rnd_w
:
rnd_w
+
lq_patchsize
,
:]
rnd_h_H
,
rnd_w_H
=
int
(
rnd_h
*
sf
),
int
(
rnd_w
*
sf
)
hq
=
hq
[
rnd_h_H
:
rnd_h_H
+
lq_patchsize
*
sf
,
rnd_w_H
:
rnd_w_H
+
lq_patchsize
*
sf
,
:]
return
lq
,
hq
def
degradation_bsrgan
(
img
,
sf
=
4
,
lq_patchsize
=
72
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
img
.
shape
[:
2
]
img
=
img
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
img
.
shape
[:
2
]
if
h
<
lq_patchsize
*
sf
or
w
<
lq_patchsize
*
sf
:
raise
ValueError
(
f
'img size (
{
h1
}
X
{
w1
}
) is too small!'
)
hq
=
img
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
2
*
img
.
shape
[
1
]),
int
(
1
/
2
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
img
=
util
.
imresize_np
(
img
,
1
/
2
,
True
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
1
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
2
:
a
,
b
=
img
.
shape
[
1
],
img
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.75
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf1
*
img
.
shape
[
1
]),
int
(
1
/
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
img
=
img
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
8
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
img
=
add_JPEG_noise
(
img
)
elif
i
==
6
:
# add processed camera sensor noise
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
# add final JPEG compression noise
img
=
add_JPEG_noise
(
img
)
# random crop
img
,
hq
=
random_crop
(
img
,
hq
,
sf_ori
,
lq_patchsize
)
return
img
,
hq
# todo no isp_model?
def
degradation_bsrgan_variant
(
image
,
sf
=
4
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image
=
util
.
uint2single
(
image
)
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
image
.
shape
[:
2
]
image
=
image
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
image
.
shape
[:
2
]
hq
=
image
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
2
*
image
.
shape
[
1
]),
int
(
1
/
2
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
image
=
util
.
imresize_np
(
image
,
1
/
2
,
True
)
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
image
=
add_blur
(
image
,
sf
=
sf
)
# elif i == 1:
# image = add_blur(image, sf=sf)
if
i
==
0
:
pass
elif
i
==
2
:
a
,
b
=
image
.
shape
[
1
],
image
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.8
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf1
*
image
.
shape
[
1
]),
int
(
1
/
sf1
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
image
=
ndimage
.
filters
.
convolve
(
image
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
image
=
image
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
image
=
add_Gaussian_noise
(
image
,
noise_level1
=
1
,
noise_level2
=
2
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
image
=
add_JPEG_noise
(
image
)
#
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image
=
add_JPEG_noise
(
image
)
image
=
util
.
single2uint
(
image
)
example
=
{
"image"
:
image
}
return
example
if
__name__
==
'__main__'
:
print
(
"hey"
)
img
=
util
.
imread_uint
(
'utils/test.png'
,
3
)
img
=
img
[:
448
,
:
448
]
h
=
img
.
shape
[
0
]
//
4
print
(
"resizing to"
,
h
)
sf
=
4
deg_fn
=
partial
(
degradation_bsrgan_variant
,
sf
=
sf
)
for
i
in
range
(
20
):
print
(
i
)
img_hq
=
img
img_lq
=
deg_fn
(
img
)[
"image"
]
img_hq
,
img_lq
=
util
.
uint2single
(
img_hq
),
util
.
uint2single
(
img_lq
)
print
(
img_lq
)
img_lq_bicubic
=
albumentations
.
SmallestMaxSize
(
max_size
=
h
,
interpolation
=
cv2
.
INTER_CUBIC
)(
image
=
img_hq
)[
"image"
]
print
(
img_lq
.
shape
)
print
(
"bicubic"
,
img_lq_bicubic
.
shape
)
print
(
img_hq
.
shape
)
lq_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
lq_bicubic_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq_bicubic
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
img_concat
=
np
.
concatenate
([
lq_bicubic_nearest
,
lq_nearest
,
util
.
single2uint
(
img_hq
)],
axis
=
1
)
util
.
imsave
(
img_concat
,
str
(
i
)
+
'.png'
)
stable_diffusion/ldm/modules/image_degradation/utils/test.png
0 → 100644
View file @
9cfc6603
431 KB
stable_diffusion/ldm/modules/image_degradation/utils_image.py
0 → 100644
View file @
9cfc6603
import
os
import
math
import
random
import
numpy
as
np
import
torch
import
cv2
from
torchvision.utils
import
make_grid
from
datetime
import
datetime
#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
os
.
environ
[
"KMP_DUPLICATE_LIB_OK"
]
=
"TRUE"
'''
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# --------------------------------------------
# https://github.com/twhui/SRGAN-pyTorch
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
IMG_EXTENSIONS
=
[
'.jpg'
,
'.JPG'
,
'.jpeg'
,
'.JPEG'
,
'.png'
,
'.PNG'
,
'.ppm'
,
'.PPM'
,
'.bmp'
,
'.BMP'
,
'.tif'
]
def
is_image_file
(
filename
):
return
any
(
filename
.
endswith
(
extension
)
for
extension
in
IMG_EXTENSIONS
)
def
get_timestamp
():
return
datetime
.
now
().
strftime
(
'%y%m%d-%H%M%S'
)
def
imshow
(
x
,
title
=
None
,
cbar
=
False
,
figsize
=
None
):
plt
.
figure
(
figsize
=
figsize
)
plt
.
imshow
(
np
.
squeeze
(
x
),
interpolation
=
'nearest'
,
cmap
=
'gray'
)
if
title
:
plt
.
title
(
title
)
if
cbar
:
plt
.
colorbar
()
plt
.
show
()
def
surf
(
Z
,
cmap
=
'rainbow'
,
figsize
=
None
):
plt
.
figure
(
figsize
=
figsize
)
ax3
=
plt
.
axes
(
projection
=
'3d'
)
w
,
h
=
Z
.
shape
[:
2
]
xx
=
np
.
arange
(
0
,
w
,
1
)
yy
=
np
.
arange
(
0
,
h
,
1
)
X
,
Y
=
np
.
meshgrid
(
xx
,
yy
)
ax3
.
plot_surface
(
X
,
Y
,
Z
,
cmap
=
cmap
)
#ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
plt
.
show
()
'''
# --------------------------------------------
# get image pathes
# --------------------------------------------
'''
def
get_image_paths
(
dataroot
):
paths
=
None
# return None if dataroot is None
if
dataroot
is
not
None
:
paths
=
sorted
(
_get_paths_from_images
(
dataroot
))
return
paths
def
_get_paths_from_images
(
path
):
assert
os
.
path
.
isdir
(
path
),
'{:s} is not a valid directory'
.
format
(
path
)
images
=
[]
for
dirpath
,
_
,
fnames
in
sorted
(
os
.
walk
(
path
)):
for
fname
in
sorted
(
fnames
):
if
is_image_file
(
fname
):
img_path
=
os
.
path
.
join
(
dirpath
,
fname
)
images
.
append
(
img_path
)
assert
images
,
'{:s} has no valid image file'
.
format
(
path
)
return
images
'''
# --------------------------------------------
# split large images into small images
# --------------------------------------------
'''
def
patches_from_image
(
img
,
p_size
=
512
,
p_overlap
=
64
,
p_max
=
800
):
w
,
h
=
img
.
shape
[:
2
]
patches
=
[]
if
w
>
p_max
and
h
>
p_max
:
w1
=
list
(
np
.
arange
(
0
,
w
-
p_size
,
p_size
-
p_overlap
,
dtype
=
np
.
int
))
h1
=
list
(
np
.
arange
(
0
,
h
-
p_size
,
p_size
-
p_overlap
,
dtype
=
np
.
int
))
w1
.
append
(
w
-
p_size
)
h1
.
append
(
h
-
p_size
)
# print(w1)
# print(h1)
for
i
in
w1
:
for
j
in
h1
:
patches
.
append
(
img
[
i
:
i
+
p_size
,
j
:
j
+
p_size
,:])
else
:
patches
.
append
(
img
)
return
patches
def
imssave
(
imgs
,
img_path
):
"""
imgs: list, N images of size WxHxC
"""
img_name
,
ext
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))
for
i
,
img
in
enumerate
(
imgs
):
if
img
.
ndim
==
3
:
img
=
img
[:,
:,
[
2
,
1
,
0
]]
new_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
img_path
),
img_name
+
str
(
'_s{:04d}'
.
format
(
i
))
+
'.png'
)
cv2
.
imwrite
(
new_path
,
img
)
def
split_imageset
(
original_dataroot
,
taget_dataroot
,
n_channels
=
3
,
p_size
=
800
,
p_overlap
=
96
,
p_max
=
1000
):
"""
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
will be splitted.
Args:
original_dataroot:
taget_dataroot:
p_size: size of small images
p_overlap: patch size in training is a good choice
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
"""
paths
=
get_image_paths
(
original_dataroot
)
for
img_path
in
paths
:
# img_name, ext = os.path.splitext(os.path.basename(img_path))
img
=
imread_uint
(
img_path
,
n_channels
=
n_channels
)
patches
=
patches_from_image
(
img
,
p_size
,
p_overlap
,
p_max
)
imssave
(
patches
,
os
.
path
.
join
(
taget_dataroot
,
os
.
path
.
basename
(
img_path
)))
#if original_dataroot == taget_dataroot:
#del img_path
'''
# --------------------------------------------
# makedir
# --------------------------------------------
'''
def
mkdir
(
path
):
if
not
os
.
path
.
exists
(
path
):
os
.
makedirs
(
path
)
def
mkdirs
(
paths
):
if
isinstance
(
paths
,
str
):
mkdir
(
paths
)
else
:
for
path
in
paths
:
mkdir
(
path
)
def
mkdir_and_rename
(
path
):
if
os
.
path
.
exists
(
path
):
new_name
=
path
+
'_archived_'
+
get_timestamp
()
print
(
'Path already exists. Rename it to [{:s}]'
.
format
(
new_name
))
os
.
rename
(
path
,
new_name
)
os
.
makedirs
(
path
)
'''
# --------------------------------------------
# read image from path
# opencv is fast, but read BGR numpy image
# --------------------------------------------
'''
# --------------------------------------------
# get uint8 image of size HxWxn_channles (RGB)
# --------------------------------------------
def
imread_uint
(
path
,
n_channels
=
3
):
# input: path
# output: HxWx3(RGB or GGG), or HxWx1 (G)
if
n_channels
==
1
:
img
=
cv2
.
imread
(
path
,
0
)
# cv2.IMREAD_GRAYSCALE
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
# HxWx1
elif
n_channels
==
3
:
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
# BGR or G
if
img
.
ndim
==
2
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2RGB
)
# GGG
else
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
# RGB
return
img
# --------------------------------------------
# matlab's imwrite
# --------------------------------------------
def
imsave
(
img
,
img_path
):
img
=
np
.
squeeze
(
img
)
if
img
.
ndim
==
3
:
img
=
img
[:,
:,
[
2
,
1
,
0
]]
cv2
.
imwrite
(
img_path
,
img
)
def
imwrite
(
img
,
img_path
):
img
=
np
.
squeeze
(
img
)
if
img
.
ndim
==
3
:
img
=
img
[:,
:,
[
2
,
1
,
0
]]
cv2
.
imwrite
(
img_path
,
img
)
# --------------------------------------------
# get single image of size HxWxn_channles (BGR)
# --------------------------------------------
def
read_img
(
path
):
# read image by cv2
# return: Numpy float32, HWC, BGR, [0,1]
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
# cv2.IMREAD_GRAYSCALE
img
=
img
.
astype
(
np
.
float32
)
/
255.
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
# some images have 4 channels
if
img
.
shape
[
2
]
>
3
:
img
=
img
[:,
:,
:
3
]
return
img
'''
# --------------------------------------------
# image format conversion
# --------------------------------------------
# numpy(single) <---> numpy(unit)
# numpy(single) <---> tensor
# numpy(unit) <---> tensor
# --------------------------------------------
'''
# --------------------------------------------
# numpy(single) [0, 1] <---> numpy(unit)
# --------------------------------------------
def
uint2single
(
img
):
return
np
.
float32
(
img
/
255.
)
def
single2uint
(
img
):
return
np
.
uint8
((
img
.
clip
(
0
,
1
)
*
255.
).
round
())
def
uint162single
(
img
):
return
np
.
float32
(
img
/
65535.
)
def
single2uint16
(
img
):
return
np
.
uint16
((
img
.
clip
(
0
,
1
)
*
65535.
).
round
())
# --------------------------------------------
# numpy(unit) (HxWxC or HxW) <---> tensor
# --------------------------------------------
# convert uint to 4-dimensional torch tensor
def
uint2tensor4
(
img
):
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
().
div
(
255.
).
unsqueeze
(
0
)
# convert uint to 3-dimensional torch tensor
def
uint2tensor3
(
img
):
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
().
div
(
255.
)
# convert 2/3/4-dimensional torch tensor to uint
def
tensor2uint
(
img
):
img
=
img
.
data
.
squeeze
().
float
().
clamp_
(
0
,
1
).
cpu
().
numpy
()
if
img
.
ndim
==
3
:
img
=
np
.
transpose
(
img
,
(
1
,
2
,
0
))
return
np
.
uint8
((
img
*
255.0
).
round
())
# --------------------------------------------
# numpy(single) (HxWxC) <---> tensor
# --------------------------------------------
# convert single (HxWxC) to 3-dimensional torch tensor
def
single2tensor3
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
()
# convert single (HxWxC) to 4-dimensional torch tensor
def
single2tensor4
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
().
unsqueeze
(
0
)
# convert torch tensor to single
def
tensor2single
(
img
):
img
=
img
.
data
.
squeeze
().
float
().
cpu
().
numpy
()
if
img
.
ndim
==
3
:
img
=
np
.
transpose
(
img
,
(
1
,
2
,
0
))
return
img
# convert torch tensor to single
def
tensor2single3
(
img
):
img
=
img
.
data
.
squeeze
().
float
().
cpu
().
numpy
()
if
img
.
ndim
==
3
:
img
=
np
.
transpose
(
img
,
(
1
,
2
,
0
))
elif
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
return
img
def
single2tensor5
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
,
3
).
float
().
unsqueeze
(
0
)
def
single32tensor5
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
float
().
unsqueeze
(
0
).
unsqueeze
(
0
)
def
single42tensor4
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
,
3
).
float
()
# from skimage.io import imread, imsave
def
tensor2img
(
tensor
,
out_type
=
np
.
uint8
,
min_max
=
(
0
,
1
)):
'''
Converts a torch Tensor into an image Numpy array of BGR channel order
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
'''
tensor
=
tensor
.
squeeze
().
float
().
cpu
().
clamp_
(
*
min_max
)
# squeeze first, then clamp
tensor
=
(
tensor
-
min_max
[
0
])
/
(
min_max
[
1
]
-
min_max
[
0
])
# to range [0,1]
n_dim
=
tensor
.
dim
()
if
n_dim
==
4
:
n_img
=
len
(
tensor
)
img_np
=
make_grid
(
tensor
,
nrow
=
int
(
math
.
sqrt
(
n_img
)),
normalize
=
False
).
numpy
()
img_np
=
np
.
transpose
(
img_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
# HWC, BGR
elif
n_dim
==
3
:
img_np
=
tensor
.
numpy
()
img_np
=
np
.
transpose
(
img_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
# HWC, BGR
elif
n_dim
==
2
:
img_np
=
tensor
.
numpy
()
else
:
raise
TypeError
(
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'
.
format
(
n_dim
))
if
out_type
==
np
.
uint8
:
img_np
=
(
img_np
*
255.0
).
round
()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return
img_np
.
astype
(
out_type
)
'''
# --------------------------------------------
# Augmentation, flipe and/or rotate
# --------------------------------------------
# The following two are enough.
# (1) augmet_img: numpy image of WxHxC or WxH
# (2) augment_img_tensor4: tensor image 1xCxWxH
# --------------------------------------------
'''
def
augment_img
(
img
,
mode
=
0
):
'''Kai Zhang (github: https://github.com/cszn)
'''
if
mode
==
0
:
return
img
elif
mode
==
1
:
return
np
.
flipud
(
np
.
rot90
(
img
))
elif
mode
==
2
:
return
np
.
flipud
(
img
)
elif
mode
==
3
:
return
np
.
rot90
(
img
,
k
=
3
)
elif
mode
==
4
:
return
np
.
flipud
(
np
.
rot90
(
img
,
k
=
2
))
elif
mode
==
5
:
return
np
.
rot90
(
img
)
elif
mode
==
6
:
return
np
.
rot90
(
img
,
k
=
2
)
elif
mode
==
7
:
return
np
.
flipud
(
np
.
rot90
(
img
,
k
=
3
))
def
augment_img_tensor4
(
img
,
mode
=
0
):
'''Kai Zhang (github: https://github.com/cszn)
'''
if
mode
==
0
:
return
img
elif
mode
==
1
:
return
img
.
rot90
(
1
,
[
2
,
3
]).
flip
([
2
])
elif
mode
==
2
:
return
img
.
flip
([
2
])
elif
mode
==
3
:
return
img
.
rot90
(
3
,
[
2
,
3
])
elif
mode
==
4
:
return
img
.
rot90
(
2
,
[
2
,
3
]).
flip
([
2
])
elif
mode
==
5
:
return
img
.
rot90
(
1
,
[
2
,
3
])
elif
mode
==
6
:
return
img
.
rot90
(
2
,
[
2
,
3
])
elif
mode
==
7
:
return
img
.
rot90
(
3
,
[
2
,
3
]).
flip
([
2
])
def
augment_img_tensor
(
img
,
mode
=
0
):
'''Kai Zhang (github: https://github.com/cszn)
'''
img_size
=
img
.
size
()
img_np
=
img
.
data
.
cpu
().
numpy
()
if
len
(
img_size
)
==
3
:
img_np
=
np
.
transpose
(
img_np
,
(
1
,
2
,
0
))
elif
len
(
img_size
)
==
4
:
img_np
=
np
.
transpose
(
img_np
,
(
2
,
3
,
1
,
0
))
img_np
=
augment_img
(
img_np
,
mode
=
mode
)
img_tensor
=
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img_np
))
if
len
(
img_size
)
==
3
:
img_tensor
=
img_tensor
.
permute
(
2
,
0
,
1
)
elif
len
(
img_size
)
==
4
:
img_tensor
=
img_tensor
.
permute
(
3
,
2
,
0
,
1
)
return
img_tensor
.
type_as
(
img
)
def
augment_img_np3
(
img
,
mode
=
0
):
if
mode
==
0
:
return
img
elif
mode
==
1
:
return
img
.
transpose
(
1
,
0
,
2
)
elif
mode
==
2
:
return
img
[::
-
1
,
:,
:]
elif
mode
==
3
:
img
=
img
[::
-
1
,
:,
:]
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
elif
mode
==
4
:
return
img
[:,
::
-
1
,
:]
elif
mode
==
5
:
img
=
img
[:,
::
-
1
,
:]
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
elif
mode
==
6
:
img
=
img
[:,
::
-
1
,
:]
img
=
img
[::
-
1
,
:,
:]
return
img
elif
mode
==
7
:
img
=
img
[:,
::
-
1
,
:]
img
=
img
[::
-
1
,
:,
:]
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
def
augment_imgs
(
img_list
,
hflip
=
True
,
rot
=
True
):
# horizontal flip OR rotate
hflip
=
hflip
and
random
.
random
()
<
0.5
vflip
=
rot
and
random
.
random
()
<
0.5
rot90
=
rot
and
random
.
random
()
<
0.5
def
_augment
(
img
):
if
hflip
:
img
=
img
[:,
::
-
1
,
:]
if
vflip
:
img
=
img
[::
-
1
,
:,
:]
if
rot90
:
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
return
[
_augment
(
img
)
for
img
in
img_list
]
'''
# --------------------------------------------
# modcrop and shave
# --------------------------------------------
'''
def
modcrop
(
img_in
,
scale
):
# img_in: Numpy, HWC or HW
img
=
np
.
copy
(
img_in
)
if
img
.
ndim
==
2
:
H
,
W
=
img
.
shape
H_r
,
W_r
=
H
%
scale
,
W
%
scale
img
=
img
[:
H
-
H_r
,
:
W
-
W_r
]
elif
img
.
ndim
==
3
:
H
,
W
,
C
=
img
.
shape
H_r
,
W_r
=
H
%
scale
,
W
%
scale
img
=
img
[:
H
-
H_r
,
:
W
-
W_r
,
:]
else
:
raise
ValueError
(
'Wrong img ndim: [{:d}].'
.
format
(
img
.
ndim
))
return
img
def
shave
(
img_in
,
border
=
0
):
# img_in: Numpy, HWC or HW
img
=
np
.
copy
(
img_in
)
h
,
w
=
img
.
shape
[:
2
]
img
=
img
[
border
:
h
-
border
,
border
:
w
-
border
]
return
img
'''
# --------------------------------------------
# image processing process on numpy image
# channel_convert(in_c, tar_type, img_list):
# rgb2ycbcr(img, only_y=True):
# bgr2ycbcr(img, only_y=True):
# ycbcr2rgb(img):
# --------------------------------------------
'''
def
rgb2ycbcr
(
img
,
only_y
=
True
):
'''same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type
=
img
.
dtype
img
.
astype
(
np
.
float32
)
if
in_img_type
!=
np
.
uint8
:
img
*=
255.
# convert
if
only_y
:
rlt
=
np
.
dot
(
img
,
[
65.481
,
128.553
,
24.966
])
/
255.0
+
16.0
else
:
rlt
=
np
.
matmul
(
img
,
[[
65.481
,
-
37.797
,
112.0
],
[
128.553
,
-
74.203
,
-
93.786
],
[
24.966
,
112.0
,
-
18.214
]])
/
255.0
+
[
16
,
128
,
128
]
if
in_img_type
==
np
.
uint8
:
rlt
=
rlt
.
round
()
else
:
rlt
/=
255.
return
rlt
.
astype
(
in_img_type
)
def
ycbcr2rgb
(
img
):
'''same as matlab ycbcr2rgb
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type
=
img
.
dtype
img
.
astype
(
np
.
float32
)
if
in_img_type
!=
np
.
uint8
:
img
*=
255.
# convert
rlt
=
np
.
matmul
(
img
,
[[
0.00456621
,
0.00456621
,
0.00456621
],
[
0
,
-
0.00153632
,
0.00791071
],
[
0.00625893
,
-
0.00318811
,
0
]])
*
255.0
+
[
-
222.921
,
135.576
,
-
276.836
]
if
in_img_type
==
np
.
uint8
:
rlt
=
rlt
.
round
()
else
:
rlt
/=
255.
return
rlt
.
astype
(
in_img_type
)
def
bgr2ycbcr
(
img
,
only_y
=
True
):
'''bgr version of rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type
=
img
.
dtype
img
.
astype
(
np
.
float32
)
if
in_img_type
!=
np
.
uint8
:
img
*=
255.
# convert
if
only_y
:
rlt
=
np
.
dot
(
img
,
[
24.966
,
128.553
,
65.481
])
/
255.0
+
16.0
else
:
rlt
=
np
.
matmul
(
img
,
[[
24.966
,
112.0
,
-
18.214
],
[
128.553
,
-
74.203
,
-
93.786
],
[
65.481
,
-
37.797
,
112.0
]])
/
255.0
+
[
16
,
128
,
128
]
if
in_img_type
==
np
.
uint8
:
rlt
=
rlt
.
round
()
else
:
rlt
/=
255.
return
rlt
.
astype
(
in_img_type
)
def
channel_convert
(
in_c
,
tar_type
,
img_list
):
# conversion among BGR, gray and y
if
in_c
==
3
and
tar_type
==
'gray'
:
# BGR to gray
gray_list
=
[
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
for
img
in
img_list
]
return
[
np
.
expand_dims
(
img
,
axis
=
2
)
for
img
in
gray_list
]
elif
in_c
==
3
and
tar_type
==
'y'
:
# BGR to y
y_list
=
[
bgr2ycbcr
(
img
,
only_y
=
True
)
for
img
in
img_list
]
return
[
np
.
expand_dims
(
img
,
axis
=
2
)
for
img
in
y_list
]
elif
in_c
==
1
and
tar_type
==
'RGB'
:
# gray/y to BGR
return
[
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
for
img
in
img_list
]
else
:
return
img_list
'''
# --------------------------------------------
# metric, PSNR and SSIM
# --------------------------------------------
'''
# --------------------------------------------
# PSNR
# --------------------------------------------
def
calculate_psnr
(
img1
,
img2
,
border
=
0
):
# img1 and img2 have range [0, 255]
#img1 = img1.squeeze()
#img2 = img2.squeeze()
if
not
img1
.
shape
==
img2
.
shape
:
raise
ValueError
(
'Input images must have the same dimensions.'
)
h
,
w
=
img1
.
shape
[:
2
]
img1
=
img1
[
border
:
h
-
border
,
border
:
w
-
border
]
img2
=
img2
[
border
:
h
-
border
,
border
:
w
-
border
]
img1
=
img1
.
astype
(
np
.
float64
)
img2
=
img2
.
astype
(
np
.
float64
)
mse
=
np
.
mean
((
img1
-
img2
)
**
2
)
if
mse
==
0
:
return
float
(
'inf'
)
return
20
*
math
.
log10
(
255.0
/
math
.
sqrt
(
mse
))
# --------------------------------------------
# SSIM
# --------------------------------------------
def
calculate_ssim
(
img1
,
img2
,
border
=
0
):
'''calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
#img1 = img1.squeeze()
#img2 = img2.squeeze()
if
not
img1
.
shape
==
img2
.
shape
:
raise
ValueError
(
'Input images must have the same dimensions.'
)
h
,
w
=
img1
.
shape
[:
2
]
img1
=
img1
[
border
:
h
-
border
,
border
:
w
-
border
]
img2
=
img2
[
border
:
h
-
border
,
border
:
w
-
border
]
if
img1
.
ndim
==
2
:
return
ssim
(
img1
,
img2
)
elif
img1
.
ndim
==
3
:
if
img1
.
shape
[
2
]
==
3
:
ssims
=
[]
for
i
in
range
(
3
):
ssims
.
append
(
ssim
(
img1
[:,:,
i
],
img2
[:,:,
i
]))
return
np
.
array
(
ssims
).
mean
()
elif
img1
.
shape
[
2
]
==
1
:
return
ssim
(
np
.
squeeze
(
img1
),
np
.
squeeze
(
img2
))
else
:
raise
ValueError
(
'Wrong input image dimensions.'
)
def
ssim
(
img1
,
img2
):
C1
=
(
0.01
*
255
)
**
2
C2
=
(
0.03
*
255
)
**
2
img1
=
img1
.
astype
(
np
.
float64
)
img2
=
img2
.
astype
(
np
.
float64
)
kernel
=
cv2
.
getGaussianKernel
(
11
,
1.5
)
window
=
np
.
outer
(
kernel
,
kernel
.
transpose
())
mu1
=
cv2
.
filter2D
(
img1
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
# valid
mu2
=
cv2
.
filter2D
(
img2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
mu1_sq
=
mu1
**
2
mu2_sq
=
mu2
**
2
mu1_mu2
=
mu1
*
mu2
sigma1_sq
=
cv2
.
filter2D
(
img1
**
2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu1_sq
sigma2_sq
=
cv2
.
filter2D
(
img2
**
2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu2_sq
sigma12
=
cv2
.
filter2D
(
img1
*
img2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu1_mu2
ssim_map
=
((
2
*
mu1_mu2
+
C1
)
*
(
2
*
sigma12
+
C2
))
/
((
mu1_sq
+
mu2_sq
+
C1
)
*
(
sigma1_sq
+
sigma2_sq
+
C2
))
return
ssim_map
.
mean
()
'''
# --------------------------------------------
# matlab's bicubic imresize (numpy and torch) [0, 1]
# --------------------------------------------
'''
# matlab 'imresize' function, now only support 'bicubic'
def
cubic
(
x
):
absx
=
torch
.
abs
(
x
)
absx2
=
absx
**
2
absx3
=
absx
**
3
return
(
1.5
*
absx3
-
2.5
*
absx2
+
1
)
*
((
absx
<=
1
).
type_as
(
absx
))
+
\
(
-
0.5
*
absx3
+
2.5
*
absx2
-
4
*
absx
+
2
)
*
(((
absx
>
1
)
*
(
absx
<=
2
)).
type_as
(
absx
))
def
calculate_weights_indices
(
in_length
,
out_length
,
scale
,
kernel
,
kernel_width
,
antialiasing
):
if
(
scale
<
1
)
and
(
antialiasing
):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width
=
kernel_width
/
scale
# Output-space coordinates
x
=
torch
.
linspace
(
1
,
out_length
,
out_length
)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5+scale in output
# space maps to 1.5 in input space.
u
=
x
/
scale
+
0.5
*
(
1
-
1
/
scale
)
# What is the left-most pixel that can be involved in the computation?
left
=
torch
.
floor
(
u
-
kernel_width
/
2
)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
P
=
math
.
ceil
(
kernel_width
)
+
2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices
=
left
.
view
(
out_length
,
1
).
expand
(
out_length
,
P
)
+
torch
.
linspace
(
0
,
P
-
1
,
P
).
view
(
1
,
P
).
expand
(
out_length
,
P
)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center
=
u
.
view
(
out_length
,
1
).
expand
(
out_length
,
P
)
-
indices
# apply cubic kernel
if
(
scale
<
1
)
and
(
antialiasing
):
weights
=
scale
*
cubic
(
distance_to_center
*
scale
)
else
:
weights
=
cubic
(
distance_to_center
)
# Normalize the weights matrix so that each row sums to 1.
weights_sum
=
torch
.
sum
(
weights
,
1
).
view
(
out_length
,
1
)
weights
=
weights
/
weights_sum
.
expand
(
out_length
,
P
)
# If a column in weights is all zero, get rid of it. only consider the first and last column.
weights_zero_tmp
=
torch
.
sum
((
weights
==
0
),
0
)
if
not
math
.
isclose
(
weights_zero_tmp
[
0
],
0
,
rel_tol
=
1e-6
):
indices
=
indices
.
narrow
(
1
,
1
,
P
-
2
)
weights
=
weights
.
narrow
(
1
,
1
,
P
-
2
)
if
not
math
.
isclose
(
weights_zero_tmp
[
-
1
],
0
,
rel_tol
=
1e-6
):
indices
=
indices
.
narrow
(
1
,
0
,
P
-
2
)
weights
=
weights
.
narrow
(
1
,
0
,
P
-
2
)
weights
=
weights
.
contiguous
()
indices
=
indices
.
contiguous
()
sym_len_s
=
-
indices
.
min
()
+
1
sym_len_e
=
indices
.
max
()
-
in_length
indices
=
indices
+
sym_len_s
-
1
return
weights
,
indices
,
int
(
sym_len_s
),
int
(
sym_len_e
)
# --------------------------------------------
# imresize for tensor image [0, 1]
# --------------------------------------------
def
imresize
(
img
,
scale
,
antialiasing
=
True
):
# Now the scale should be the same for H and W
# input: img: pytorch tensor, CHW or HW [0,1]
# output: CHW or HW [0,1] w/o round
need_squeeze
=
True
if
img
.
dim
()
==
2
else
False
if
need_squeeze
:
img
.
unsqueeze_
(
0
)
in_C
,
in_H
,
in_W
=
img
.
size
()
out_C
,
out_H
,
out_W
=
in_C
,
math
.
ceil
(
in_H
*
scale
),
math
.
ceil
(
in_W
*
scale
)
kernel_width
=
4
kernel
=
'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H
,
indices_H
,
sym_len_Hs
,
sym_len_He
=
calculate_weights_indices
(
in_H
,
out_H
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
weights_W
,
indices_W
,
sym_len_Ws
,
sym_len_We
=
calculate_weights_indices
(
in_W
,
out_W
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
# process H dimension
# symmetric copying
img_aug
=
torch
.
FloatTensor
(
in_C
,
in_H
+
sym_len_Hs
+
sym_len_He
,
in_W
)
img_aug
.
narrow
(
1
,
sym_len_Hs
,
in_H
).
copy_
(
img
)
sym_patch
=
img
[:,
:
sym_len_Hs
,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
img_aug
.
narrow
(
1
,
0
,
sym_len_Hs
).
copy_
(
sym_patch_inv
)
sym_patch
=
img
[:,
-
sym_len_He
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
img_aug
.
narrow
(
1
,
sym_len_Hs
+
in_H
,
sym_len_He
).
copy_
(
sym_patch_inv
)
out_1
=
torch
.
FloatTensor
(
in_C
,
out_H
,
in_W
)
kernel_width
=
weights_H
.
size
(
1
)
for
i
in
range
(
out_H
):
idx
=
int
(
indices_H
[
i
][
0
])
for
j
in
range
(
out_C
):
out_1
[
j
,
i
,
:]
=
img_aug
[
j
,
idx
:
idx
+
kernel_width
,
:].
transpose
(
0
,
1
).
mv
(
weights_H
[
i
])
# process W dimension
# symmetric copying
out_1_aug
=
torch
.
FloatTensor
(
in_C
,
out_H
,
in_W
+
sym_len_Ws
+
sym_len_We
)
out_1_aug
.
narrow
(
2
,
sym_len_Ws
,
in_W
).
copy_
(
out_1
)
sym_patch
=
out_1
[:,
:,
:
sym_len_Ws
]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
2
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
2
,
inv_idx
)
out_1_aug
.
narrow
(
2
,
0
,
sym_len_Ws
).
copy_
(
sym_patch_inv
)
sym_patch
=
out_1
[:,
:,
-
sym_len_We
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
2
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
2
,
inv_idx
)
out_1_aug
.
narrow
(
2
,
sym_len_Ws
+
in_W
,
sym_len_We
).
copy_
(
sym_patch_inv
)
out_2
=
torch
.
FloatTensor
(
in_C
,
out_H
,
out_W
)
kernel_width
=
weights_W
.
size
(
1
)
for
i
in
range
(
out_W
):
idx
=
int
(
indices_W
[
i
][
0
])
for
j
in
range
(
out_C
):
out_2
[
j
,
:,
i
]
=
out_1_aug
[
j
,
:,
idx
:
idx
+
kernel_width
].
mv
(
weights_W
[
i
])
if
need_squeeze
:
out_2
.
squeeze_
()
return
out_2
# --------------------------------------------
# imresize for numpy image [0, 1]
# --------------------------------------------
def
imresize_np
(
img
,
scale
,
antialiasing
=
True
):
# Now the scale should be the same for H and W
# input: img: Numpy, HWC or HW [0,1]
# output: HWC or HW [0,1] w/o round
img
=
torch
.
from_numpy
(
img
)
need_squeeze
=
True
if
img
.
dim
()
==
2
else
False
if
need_squeeze
:
img
.
unsqueeze_
(
2
)
in_H
,
in_W
,
in_C
=
img
.
size
()
out_C
,
out_H
,
out_W
=
in_C
,
math
.
ceil
(
in_H
*
scale
),
math
.
ceil
(
in_W
*
scale
)
kernel_width
=
4
kernel
=
'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H
,
indices_H
,
sym_len_Hs
,
sym_len_He
=
calculate_weights_indices
(
in_H
,
out_H
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
weights_W
,
indices_W
,
sym_len_Ws
,
sym_len_We
=
calculate_weights_indices
(
in_W
,
out_W
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
# process H dimension
# symmetric copying
img_aug
=
torch
.
FloatTensor
(
in_H
+
sym_len_Hs
+
sym_len_He
,
in_W
,
in_C
)
img_aug
.
narrow
(
0
,
sym_len_Hs
,
in_H
).
copy_
(
img
)
sym_patch
=
img
[:
sym_len_Hs
,
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
0
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
0
,
inv_idx
)
img_aug
.
narrow
(
0
,
0
,
sym_len_Hs
).
copy_
(
sym_patch_inv
)
sym_patch
=
img
[
-
sym_len_He
:,
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
0
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
0
,
inv_idx
)
img_aug
.
narrow
(
0
,
sym_len_Hs
+
in_H
,
sym_len_He
).
copy_
(
sym_patch_inv
)
out_1
=
torch
.
FloatTensor
(
out_H
,
in_W
,
in_C
)
kernel_width
=
weights_H
.
size
(
1
)
for
i
in
range
(
out_H
):
idx
=
int
(
indices_H
[
i
][
0
])
for
j
in
range
(
out_C
):
out_1
[
i
,
:,
j
]
=
img_aug
[
idx
:
idx
+
kernel_width
,
:,
j
].
transpose
(
0
,
1
).
mv
(
weights_H
[
i
])
# process W dimension
# symmetric copying
out_1_aug
=
torch
.
FloatTensor
(
out_H
,
in_W
+
sym_len_Ws
+
sym_len_We
,
in_C
)
out_1_aug
.
narrow
(
1
,
sym_len_Ws
,
in_W
).
copy_
(
out_1
)
sym_patch
=
out_1
[:,
:
sym_len_Ws
,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
out_1_aug
.
narrow
(
1
,
0
,
sym_len_Ws
).
copy_
(
sym_patch_inv
)
sym_patch
=
out_1
[:,
-
sym_len_We
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
out_1_aug
.
narrow
(
1
,
sym_len_Ws
+
in_W
,
sym_len_We
).
copy_
(
sym_patch_inv
)
out_2
=
torch
.
FloatTensor
(
out_H
,
out_W
,
in_C
)
kernel_width
=
weights_W
.
size
(
1
)
for
i
in
range
(
out_W
):
idx
=
int
(
indices_W
[
i
][
0
])
for
j
in
range
(
out_C
):
out_2
[:,
i
,
j
]
=
out_1_aug
[:,
idx
:
idx
+
kernel_width
,
j
].
mv
(
weights_W
[
i
])
if
need_squeeze
:
out_2
.
squeeze_
()
return
out_2
.
numpy
()
if
__name__
==
'__main__'
:
print
(
'---'
)
# img = imread_uint('test.bmp', 3)
# img = uint2single(img)
# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
stable_diffusion/ldm/modules/losses/__init__.py
0 → 100644
View file @
9cfc6603
from
ldm.modules.losses.contperceptual
import
LPIPSWithDiscriminator
\ No newline at end of file
stable_diffusion/ldm/modules/losses/contperceptual.py
0 → 100644
View file @
9cfc6603
import
torch
import
torch.nn
as
nn
from
taming.modules.losses.vqperceptual
import
*
# TODO: taming dependency yes/no?
class
LPIPSWithDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
disc_start
,
logvar_init
=
0.0
,
kl_weight
=
1.0
,
pixelloss_weight
=
1.0
,
disc_num_layers
=
3
,
disc_in_channels
=
3
,
disc_factor
=
1.0
,
disc_weight
=
1.0
,
perceptual_weight
=
1.0
,
use_actnorm
=
False
,
disc_conditional
=
False
,
disc_loss
=
"hinge"
):
super
().
__init__
()
assert
disc_loss
in
[
"hinge"
,
"vanilla"
]
self
.
kl_weight
=
kl_weight
self
.
pixel_weight
=
pixelloss_weight
self
.
perceptual_loss
=
LPIPS
().
eval
()
self
.
perceptual_weight
=
perceptual_weight
# output log variance
self
.
logvar
=
nn
.
Parameter
(
torch
.
ones
(
size
=
())
*
logvar_init
)
self
.
discriminator
=
NLayerDiscriminator
(
input_nc
=
disc_in_channels
,
n_layers
=
disc_num_layers
,
use_actnorm
=
use_actnorm
).
apply
(
weights_init
)
self
.
discriminator_iter_start
=
disc_start
self
.
disc_loss
=
hinge_d_loss
if
disc_loss
==
"hinge"
else
vanilla_d_loss
self
.
disc_factor
=
disc_factor
self
.
discriminator_weight
=
disc_weight
self
.
disc_conditional
=
disc_conditional
def
calculate_adaptive_weight
(
self
,
nll_loss
,
g_loss
,
last_layer
=
None
):
if
last_layer
is
not
None
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
else
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
d_weight
=
torch
.
norm
(
nll_grads
)
/
(
torch
.
norm
(
g_grads
)
+
1e-4
)
d_weight
=
torch
.
clamp
(
d_weight
,
0.0
,
1e4
).
detach
()
d_weight
=
d_weight
*
self
.
discriminator_weight
return
d_weight
def
forward
(
self
,
inputs
,
reconstructions
,
posteriors
,
optimizer_idx
,
global_step
,
last_layer
=
None
,
cond
=
None
,
split
=
"train"
,
weights
=
None
):
rec_loss
=
torch
.
abs
(
inputs
.
contiguous
()
-
reconstructions
.
contiguous
())
if
self
.
perceptual_weight
>
0
:
p_loss
=
self
.
perceptual_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
())
rec_loss
=
rec_loss
+
self
.
perceptual_weight
*
p_loss
nll_loss
=
rec_loss
/
torch
.
exp
(
self
.
logvar
)
+
self
.
logvar
weighted_nll_loss
=
nll_loss
if
weights
is
not
None
:
weighted_nll_loss
=
weights
*
nll_loss
weighted_nll_loss
=
torch
.
sum
(
weighted_nll_loss
)
/
weighted_nll_loss
.
shape
[
0
]
nll_loss
=
torch
.
sum
(
nll_loss
)
/
nll_loss
.
shape
[
0
]
kl_loss
=
posteriors
.
kl
()
kl_loss
=
torch
.
sum
(
kl_loss
)
/
kl_loss
.
shape
[
0
]
# now the GAN part
if
optimizer_idx
==
0
:
# generator update
if
cond
is
None
:
assert
not
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
())
else
:
assert
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
(),
cond
),
dim
=
1
))
g_loss
=
-
torch
.
mean
(
logits_fake
)
if
self
.
disc_factor
>
0.0
:
try
:
d_weight
=
self
.
calculate_adaptive_weight
(
nll_loss
,
g_loss
,
last_layer
=
last_layer
)
except
RuntimeError
:
assert
not
self
.
training
d_weight
=
torch
.
tensor
(
0.0
)
else
:
d_weight
=
torch
.
tensor
(
0.0
)
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
loss
=
weighted_nll_loss
+
self
.
kl_weight
*
kl_loss
+
d_weight
*
disc_factor
*
g_loss
log
=
{
"{}/total_loss"
.
format
(
split
):
loss
.
clone
().
detach
().
mean
(),
"{}/logvar"
.
format
(
split
):
self
.
logvar
.
detach
(),
"{}/kl_loss"
.
format
(
split
):
kl_loss
.
detach
().
mean
(),
"{}/nll_loss"
.
format
(
split
):
nll_loss
.
detach
().
mean
(),
"{}/rec_loss"
.
format
(
split
):
rec_loss
.
detach
().
mean
(),
"{}/d_weight"
.
format
(
split
):
d_weight
.
detach
(),
"{}/disc_factor"
.
format
(
split
):
torch
.
tensor
(
disc_factor
),
"{}/g_loss"
.
format
(
split
):
g_loss
.
detach
().
mean
(),
}
return
loss
,
log
if
optimizer_idx
==
1
:
# second pass for discriminator update
if
cond
is
None
:
logits_real
=
self
.
discriminator
(
inputs
.
contiguous
().
detach
())
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
().
detach
())
else
:
logits_real
=
self
.
discriminator
(
torch
.
cat
((
inputs
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
d_loss
=
disc_factor
*
self
.
disc_loss
(
logits_real
,
logits_fake
)
log
=
{
"{}/disc_loss"
.
format
(
split
):
d_loss
.
clone
().
detach
().
mean
(),
"{}/logits_real"
.
format
(
split
):
logits_real
.
detach
().
mean
(),
"{}/logits_fake"
.
format
(
split
):
logits_fake
.
detach
().
mean
()
}
return
d_loss
,
log
stable_diffusion/ldm/modules/losses/vqperceptual.py
0 → 100644
View file @
9cfc6603
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
einops
import
repeat
from
taming.modules.discriminator.model
import
NLayerDiscriminator
,
weights_init
from
taming.modules.losses.lpips
import
LPIPS
from
taming.modules.losses.vqperceptual
import
hinge_d_loss
,
vanilla_d_loss
def
hinge_d_loss_with_exemplar_weights
(
logits_real
,
logits_fake
,
weights
):
assert
weights
.
shape
[
0
]
==
logits_real
.
shape
[
0
]
==
logits_fake
.
shape
[
0
]
loss_real
=
torch
.
mean
(
F
.
relu
(
1.
-
logits_real
),
dim
=
[
1
,
2
,
3
])
loss_fake
=
torch
.
mean
(
F
.
relu
(
1.
+
logits_fake
),
dim
=
[
1
,
2
,
3
])
loss_real
=
(
weights
*
loss_real
).
sum
()
/
weights
.
sum
()
loss_fake
=
(
weights
*
loss_fake
).
sum
()
/
weights
.
sum
()
d_loss
=
0.5
*
(
loss_real
+
loss_fake
)
return
d_loss
def
adopt_weight
(
weight
,
global_step
,
threshold
=
0
,
value
=
0.
):
if
global_step
<
threshold
:
weight
=
value
return
weight
def
measure_perplexity
(
predicted_indices
,
n_embed
):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings
=
F
.
one_hot
(
predicted_indices
,
n_embed
).
float
().
reshape
(
-
1
,
n_embed
)
avg_probs
=
encodings
.
mean
(
0
)
perplexity
=
(
-
(
avg_probs
*
torch
.
log
(
avg_probs
+
1e-10
)).
sum
()).
exp
()
cluster_use
=
torch
.
sum
(
avg_probs
>
0
)
return
perplexity
,
cluster_use
def
l1
(
x
,
y
):
return
torch
.
abs
(
x
-
y
)
def
l2
(
x
,
y
):
return
torch
.
pow
((
x
-
y
),
2
)
class
VQLPIPSWithDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
disc_start
,
codebook_weight
=
1.0
,
pixelloss_weight
=
1.0
,
disc_num_layers
=
3
,
disc_in_channels
=
3
,
disc_factor
=
1.0
,
disc_weight
=
1.0
,
perceptual_weight
=
1.0
,
use_actnorm
=
False
,
disc_conditional
=
False
,
disc_ndf
=
64
,
disc_loss
=
"hinge"
,
n_classes
=
None
,
perceptual_loss
=
"lpips"
,
pixel_loss
=
"l1"
):
super
().
__init__
()
assert
disc_loss
in
[
"hinge"
,
"vanilla"
]
assert
perceptual_loss
in
[
"lpips"
,
"clips"
,
"dists"
]
assert
pixel_loss
in
[
"l1"
,
"l2"
]
self
.
codebook_weight
=
codebook_weight
self
.
pixel_weight
=
pixelloss_weight
if
perceptual_loss
==
"lpips"
:
print
(
f
"
{
self
.
__class__
.
__name__
}
: Running with LPIPS."
)
self
.
perceptual_loss
=
LPIPS
().
eval
()
else
:
raise
ValueError
(
f
"Unknown perceptual loss: >>
{
perceptual_loss
}
<<"
)
self
.
perceptual_weight
=
perceptual_weight
if
pixel_loss
==
"l1"
:
self
.
pixel_loss
=
l1
else
:
self
.
pixel_loss
=
l2
self
.
discriminator
=
NLayerDiscriminator
(
input_nc
=
disc_in_channels
,
n_layers
=
disc_num_layers
,
use_actnorm
=
use_actnorm
,
ndf
=
disc_ndf
).
apply
(
weights_init
)
self
.
discriminator_iter_start
=
disc_start
if
disc_loss
==
"hinge"
:
self
.
disc_loss
=
hinge_d_loss
elif
disc_loss
==
"vanilla"
:
self
.
disc_loss
=
vanilla_d_loss
else
:
raise
ValueError
(
f
"Unknown GAN loss '
{
disc_loss
}
'."
)
print
(
f
"VQLPIPSWithDiscriminator running with
{
disc_loss
}
loss."
)
self
.
disc_factor
=
disc_factor
self
.
discriminator_weight
=
disc_weight
self
.
disc_conditional
=
disc_conditional
self
.
n_classes
=
n_classes
def
calculate_adaptive_weight
(
self
,
nll_loss
,
g_loss
,
last_layer
=
None
):
if
last_layer
is
not
None
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
else
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
d_weight
=
torch
.
norm
(
nll_grads
)
/
(
torch
.
norm
(
g_grads
)
+
1e-4
)
d_weight
=
torch
.
clamp
(
d_weight
,
0.0
,
1e4
).
detach
()
d_weight
=
d_weight
*
self
.
discriminator_weight
return
d_weight
def
forward
(
self
,
codebook_loss
,
inputs
,
reconstructions
,
optimizer_idx
,
global_step
,
last_layer
=
None
,
cond
=
None
,
split
=
"train"
,
predicted_indices
=
None
):
if
not
exists
(
codebook_loss
):
codebook_loss
=
torch
.
tensor
([
0.
]).
to
(
inputs
.
device
)
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss
=
self
.
pixel_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
())
if
self
.
perceptual_weight
>
0
:
p_loss
=
self
.
perceptual_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
())
rec_loss
=
rec_loss
+
self
.
perceptual_weight
*
p_loss
else
:
p_loss
=
torch
.
tensor
([
0.0
])
nll_loss
=
rec_loss
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss
=
torch
.
mean
(
nll_loss
)
# now the GAN part
if
optimizer_idx
==
0
:
# generator update
if
cond
is
None
:
assert
not
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
())
else
:
assert
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
(),
cond
),
dim
=
1
))
g_loss
=
-
torch
.
mean
(
logits_fake
)
try
:
d_weight
=
self
.
calculate_adaptive_weight
(
nll_loss
,
g_loss
,
last_layer
=
last_layer
)
except
RuntimeError
:
assert
not
self
.
training
d_weight
=
torch
.
tensor
(
0.0
)
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
loss
=
nll_loss
+
d_weight
*
disc_factor
*
g_loss
+
self
.
codebook_weight
*
codebook_loss
.
mean
()
log
=
{
"{}/total_loss"
.
format
(
split
):
loss
.
clone
().
detach
().
mean
(),
"{}/quant_loss"
.
format
(
split
):
codebook_loss
.
detach
().
mean
(),
"{}/nll_loss"
.
format
(
split
):
nll_loss
.
detach
().
mean
(),
"{}/rec_loss"
.
format
(
split
):
rec_loss
.
detach
().
mean
(),
"{}/p_loss"
.
format
(
split
):
p_loss
.
detach
().
mean
(),
"{}/d_weight"
.
format
(
split
):
d_weight
.
detach
(),
"{}/disc_factor"
.
format
(
split
):
torch
.
tensor
(
disc_factor
),
"{}/g_loss"
.
format
(
split
):
g_loss
.
detach
().
mean
(),
}
if
predicted_indices
is
not
None
:
assert
self
.
n_classes
is
not
None
with
torch
.
no_grad
():
perplexity
,
cluster_usage
=
measure_perplexity
(
predicted_indices
,
self
.
n_classes
)
log
[
f
"
{
split
}
/perplexity"
]
=
perplexity
log
[
f
"
{
split
}
/cluster_usage"
]
=
cluster_usage
return
loss
,
log
if
optimizer_idx
==
1
:
# second pass for discriminator update
if
cond
is
None
:
logits_real
=
self
.
discriminator
(
inputs
.
contiguous
().
detach
())
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
().
detach
())
else
:
logits_real
=
self
.
discriminator
(
torch
.
cat
((
inputs
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
d_loss
=
disc_factor
*
self
.
disc_loss
(
logits_real
,
logits_fake
)
log
=
{
"{}/disc_loss"
.
format
(
split
):
d_loss
.
clone
().
detach
().
mean
(),
"{}/logits_real"
.
format
(
split
):
logits_real
.
detach
().
mean
(),
"{}/logits_fake"
.
format
(
split
):
logits_fake
.
detach
().
mean
()
}
return
d_loss
,
log
stable_diffusion/ldm/modules/x_transformer.py
0 → 100644
View file @
9cfc6603
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
import
torch
from
torch
import
nn
,
einsum
import
torch.nn.functional
as
F
from
functools
import
partial
from
inspect
import
isfunction
from
collections
import
namedtuple
from
einops
import
rearrange
,
repeat
,
reduce
# constants
DEFAULT_DIM_HEAD
=
64
Intermediates
=
namedtuple
(
'Intermediates'
,
[
'pre_softmax_attn'
,
'post_softmax_attn'
])
LayerIntermediates
=
namedtuple
(
'Intermediates'
,
[
'hiddens'
,
'attn_intermediates'
])
class
AbsolutePositionalEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_seq_len
):
super
().
__init__
()
self
.
emb
=
nn
.
Embedding
(
max_seq_len
,
dim
)
self
.
init_
()
def
init_
(
self
):
nn
.
init
.
normal_
(
self
.
emb
.
weight
,
std
=
0.02
)
def
forward
(
self
,
x
):
n
=
torch
.
arange
(
x
.
shape
[
1
],
device
=
x
.
device
)
return
self
.
emb
(
n
)[
None
,
:,
:]
class
FixedPositionalEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
inv_freq
=
1.
/
(
10000
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
def
forward
(
self
,
x
,
seq_dim
=
1
,
offset
=
0
):
t
=
torch
.
arange
(
x
.
shape
[
seq_dim
],
device
=
x
.
device
).
type_as
(
self
.
inv_freq
)
+
offset
sinusoid_inp
=
torch
.
einsum
(
'i , j -> i j'
,
t
,
self
.
inv_freq
)
emb
=
torch
.
cat
((
sinusoid_inp
.
sin
(),
sinusoid_inp
.
cos
()),
dim
=-
1
)
return
emb
[
None
,
:,
:]
# helpers
def
exists
(
val
):
return
val
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
def
always
(
val
):
def
inner
(
*
args
,
**
kwargs
):
return
val
return
inner
def
not_equals
(
val
):
def
inner
(
x
):
return
x
!=
val
return
inner
def
equals
(
val
):
def
inner
(
x
):
return
x
==
val
return
inner
def
max_neg_value
(
tensor
):
return
-
torch
.
finfo
(
tensor
.
dtype
).
max
# keyword argument helpers
def
pick_and_pop
(
keys
,
d
):
values
=
list
(
map
(
lambda
key
:
d
.
pop
(
key
),
keys
))
return
dict
(
zip
(
keys
,
values
))
def
group_dict_by_key
(
cond
,
d
):
return_val
=
[
dict
(),
dict
()]
for
key
in
d
.
keys
():
match
=
bool
(
cond
(
key
))
ind
=
int
(
not
match
)
return_val
[
ind
][
key
]
=
d
[
key
]
return
(
*
return_val
,)
def
string_begins_with
(
prefix
,
str
):
return
str
.
startswith
(
prefix
)
def
group_by_key_prefix
(
prefix
,
d
):
return
group_dict_by_key
(
partial
(
string_begins_with
,
prefix
),
d
)
def
groupby_prefix_and_trim
(
prefix
,
d
):
kwargs_with_prefix
,
kwargs
=
group_dict_by_key
(
partial
(
string_begins_with
,
prefix
),
d
)
kwargs_without_prefix
=
dict
(
map
(
lambda
x
:
(
x
[
0
][
len
(
prefix
):],
x
[
1
]),
tuple
(
kwargs_with_prefix
.
items
())))
return
kwargs_without_prefix
,
kwargs
# classes
class
Scale
(
nn
.
Module
):
def
__init__
(
self
,
value
,
fn
):
super
().
__init__
()
self
.
value
=
value
self
.
fn
=
fn
def
forward
(
self
,
x
,
**
kwargs
):
x
,
*
rest
=
self
.
fn
(
x
,
**
kwargs
)
return
(
x
*
self
.
value
,
*
rest
)
class
Rezero
(
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
().
__init__
()
self
.
fn
=
fn
self
.
g
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
x
,
**
kwargs
):
x
,
*
rest
=
self
.
fn
(
x
,
**
kwargs
)
return
(
x
*
self
.
g
,
*
rest
)
class
ScaleNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
eps
=
1e-5
):
super
().
__init__
()
self
.
scale
=
dim
**
-
0.5
self
.
eps
=
eps
self
.
g
=
nn
.
Parameter
(
torch
.
ones
(
1
))
def
forward
(
self
,
x
):
norm
=
torch
.
norm
(
x
,
dim
=-
1
,
keepdim
=
True
)
*
self
.
scale
return
x
/
norm
.
clamp
(
min
=
self
.
eps
)
*
self
.
g
class
RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
eps
=
1e-8
):
super
().
__init__
()
self
.
scale
=
dim
**
-
0.5
self
.
eps
=
eps
self
.
g
=
nn
.
Parameter
(
torch
.
ones
(
dim
))
def
forward
(
self
,
x
):
norm
=
torch
.
norm
(
x
,
dim
=-
1
,
keepdim
=
True
)
*
self
.
scale
return
x
/
norm
.
clamp
(
min
=
self
.
eps
)
*
self
.
g
class
Residual
(
nn
.
Module
):
def
forward
(
self
,
x
,
residual
):
return
x
+
residual
class
GRUGating
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
gru
=
nn
.
GRUCell
(
dim
,
dim
)
def
forward
(
self
,
x
,
residual
):
gated_output
=
self
.
gru
(
rearrange
(
x
,
'b n d -> (b n) d'
),
rearrange
(
residual
,
'b n d -> (b n) d'
)
)
return
gated_output
.
reshape_as
(
x
)
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
):
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
()
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
)
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
# attention.
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_head
=
DEFAULT_DIM_HEAD
,
heads
=
8
,
causal
=
False
,
mask
=
None
,
talking_heads
=
False
,
sparse_topk
=
None
,
use_entmax15
=
False
,
num_mem_kv
=
0
,
dropout
=
0.
,
on_attn
=
False
):
super
().
__init__
()
if
use_entmax15
:
raise
NotImplementedError
(
"Check out entmax activation instead of softmax activation!"
)
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
causal
=
causal
self
.
mask
=
mask
inner_dim
=
dim_head
*
heads
self
.
to_q
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
# talking heads
self
.
talking_heads
=
talking_heads
if
talking_heads
:
self
.
pre_softmax_proj
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
heads
))
self
.
post_softmax_proj
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
heads
))
# explicit topk sparse attention
self
.
sparse_topk
=
sparse_topk
# entmax
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
self
.
attn_fn
=
F
.
softmax
# add memory key / values
self
.
num_mem_kv
=
num_mem_kv
if
num_mem_kv
>
0
:
self
.
mem_k
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
num_mem_kv
,
dim_head
))
self
.
mem_v
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
num_mem_kv
,
dim_head
))
# attention on attention
self
.
attn_on_attn
=
on_attn
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
dim
*
2
),
nn
.
GLU
())
if
on_attn
else
nn
.
Linear
(
inner_dim
,
dim
)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
context_mask
=
None
,
rel_pos
=
None
,
sinusoidal_emb
=
None
,
prev_attn
=
None
,
mem
=
None
):
b
,
n
,
_
,
h
,
talking_heads
,
device
=
*
x
.
shape
,
self
.
heads
,
self
.
talking_heads
,
x
.
device
kv_input
=
default
(
context
,
x
)
q_input
=
x
k_input
=
kv_input
v_input
=
kv_input
if
exists
(
mem
):
k_input
=
torch
.
cat
((
mem
,
k_input
),
dim
=-
2
)
v_input
=
torch
.
cat
((
mem
,
v_input
),
dim
=-
2
)
if
exists
(
sinusoidal_emb
):
# in shortformer, the query would start at a position offset depending on the past cached memory
offset
=
k_input
.
shape
[
-
2
]
-
q_input
.
shape
[
-
2
]
q_input
=
q_input
+
sinusoidal_emb
(
q_input
,
offset
=
offset
)
k_input
=
k_input
+
sinusoidal_emb
(
k_input
)
q
=
self
.
to_q
(
q_input
)
k
=
self
.
to_k
(
k_input
)
v
=
self
.
to_v
(
v_input
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b n (h d) -> b h n d'
,
h
=
h
),
(
q
,
k
,
v
))
input_mask
=
None
if
any
(
map
(
exists
,
(
mask
,
context_mask
))):
q_mask
=
default
(
mask
,
lambda
:
torch
.
ones
((
b
,
n
),
device
=
device
).
bool
())
k_mask
=
q_mask
if
not
exists
(
context
)
else
context_mask
k_mask
=
default
(
k_mask
,
lambda
:
torch
.
ones
((
b
,
k
.
shape
[
-
2
]),
device
=
device
).
bool
())
q_mask
=
rearrange
(
q_mask
,
'b i -> b () i ()'
)
k_mask
=
rearrange
(
k_mask
,
'b j -> b () () j'
)
input_mask
=
q_mask
*
k_mask
if
self
.
num_mem_kv
>
0
:
mem_k
,
mem_v
=
map
(
lambda
t
:
repeat
(
t
,
'h n d -> b h n d'
,
b
=
b
),
(
self
.
mem_k
,
self
.
mem_v
))
k
=
torch
.
cat
((
mem_k
,
k
),
dim
=-
2
)
v
=
torch
.
cat
((
mem_v
,
v
),
dim
=-
2
)
if
exists
(
input_mask
):
input_mask
=
F
.
pad
(
input_mask
,
(
self
.
num_mem_kv
,
0
),
value
=
True
)
dots
=
einsum
(
'b h i d, b h j d -> b h i j'
,
q
,
k
)
*
self
.
scale
mask_value
=
max_neg_value
(
dots
)
if
exists
(
prev_attn
):
dots
=
dots
+
prev_attn
pre_softmax_attn
=
dots
if
talking_heads
:
dots
=
einsum
(
'b h i j, h k -> b k i j'
,
dots
,
self
.
pre_softmax_proj
).
contiguous
()
if
exists
(
rel_pos
):
dots
=
rel_pos
(
dots
)
if
exists
(
input_mask
):
dots
.
masked_fill_
(
~
input_mask
,
mask_value
)
del
input_mask
if
self
.
causal
:
i
,
j
=
dots
.
shape
[
-
2
:]
r
=
torch
.
arange
(
i
,
device
=
device
)
mask
=
rearrange
(
r
,
'i -> () () i ()'
)
<
rearrange
(
r
,
'j -> () () () j'
)
mask
=
F
.
pad
(
mask
,
(
j
-
i
,
0
),
value
=
False
)
dots
.
masked_fill_
(
mask
,
mask_value
)
del
mask
if
exists
(
self
.
sparse_topk
)
and
self
.
sparse_topk
<
dots
.
shape
[
-
1
]:
top
,
_
=
dots
.
topk
(
self
.
sparse_topk
,
dim
=-
1
)
vk
=
top
[...,
-
1
].
unsqueeze
(
-
1
).
expand_as
(
dots
)
mask
=
dots
<
vk
dots
.
masked_fill_
(
mask
,
mask_value
)
del
mask
attn
=
self
.
attn_fn
(
dots
,
dim
=-
1
)
post_softmax_attn
=
attn
attn
=
self
.
dropout
(
attn
)
if
talking_heads
:
attn
=
einsum
(
'b h i j, h k -> b k i j'
,
attn
,
self
.
post_softmax_proj
).
contiguous
()
out
=
einsum
(
'b h i j, b h j d -> b h i d'
,
attn
,
v
)
out
=
rearrange
(
out
,
'b h n d -> b n (h d)'
)
intermediates
=
Intermediates
(
pre_softmax_attn
=
pre_softmax_attn
,
post_softmax_attn
=
post_softmax_attn
)
return
self
.
to_out
(
out
),
intermediates
class
AttentionLayers
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
depth
,
heads
=
8
,
causal
=
False
,
cross_attend
=
False
,
only_cross
=
False
,
use_scalenorm
=
False
,
use_rmsnorm
=
False
,
use_rezero
=
False
,
rel_pos_num_buckets
=
32
,
rel_pos_max_distance
=
128
,
position_infused_attn
=
False
,
custom_layers
=
None
,
sandwich_coef
=
None
,
par_ratio
=
None
,
residual_attn
=
False
,
cross_residual_attn
=
False
,
macaron
=
False
,
pre_norm
=
True
,
gate_residual
=
False
,
**
kwargs
):
super
().
__init__
()
ff_kwargs
,
kwargs
=
groupby_prefix_and_trim
(
'ff_'
,
kwargs
)
attn_kwargs
,
_
=
groupby_prefix_and_trim
(
'attn_'
,
kwargs
)
dim_head
=
attn_kwargs
.
get
(
'dim_head'
,
DEFAULT_DIM_HEAD
)
self
.
dim
=
dim
self
.
depth
=
depth
self
.
layers
=
nn
.
ModuleList
([])
self
.
has_pos_emb
=
position_infused_attn
self
.
pia_pos_emb
=
FixedPositionalEmbedding
(
dim
)
if
position_infused_attn
else
None
self
.
rotary_pos_emb
=
always
(
None
)
assert
rel_pos_num_buckets
<=
rel_pos_max_distance
,
'number of relative position buckets must be less than the relative position max distance'
self
.
rel_pos
=
None
self
.
pre_norm
=
pre_norm
self
.
residual_attn
=
residual_attn
self
.
cross_residual_attn
=
cross_residual_attn
norm_class
=
ScaleNorm
if
use_scalenorm
else
nn
.
LayerNorm
norm_class
=
RMSNorm
if
use_rmsnorm
else
norm_class
norm_fn
=
partial
(
norm_class
,
dim
)
norm_fn
=
nn
.
Identity
if
use_rezero
else
norm_fn
branch_fn
=
Rezero
if
use_rezero
else
None
if
cross_attend
and
not
only_cross
:
default_block
=
(
'a'
,
'c'
,
'f'
)
elif
cross_attend
and
only_cross
:
default_block
=
(
'c'
,
'f'
)
else
:
default_block
=
(
'a'
,
'f'
)
if
macaron
:
default_block
=
(
'f'
,)
+
default_block
if
exists
(
custom_layers
):
layer_types
=
custom_layers
elif
exists
(
par_ratio
):
par_depth
=
depth
*
len
(
default_block
)
assert
1
<
par_ratio
<=
par_depth
,
'par ratio out of range'
default_block
=
tuple
(
filter
(
not_equals
(
'f'
),
default_block
))
par_attn
=
par_depth
//
par_ratio
depth_cut
=
par_depth
*
2
//
3
# 2 / 3 attention layer cutoff suggested by PAR paper
par_width
=
(
depth_cut
+
depth_cut
//
par_attn
)
//
par_attn
assert
len
(
default_block
)
<=
par_width
,
'default block is too large for par_ratio'
par_block
=
default_block
+
(
'f'
,)
*
(
par_width
-
len
(
default_block
))
par_head
=
par_block
*
par_attn
layer_types
=
par_head
+
(
'f'
,)
*
(
par_depth
-
len
(
par_head
))
elif
exists
(
sandwich_coef
):
assert
sandwich_coef
>
0
and
sandwich_coef
<=
depth
,
'sandwich coefficient should be less than the depth'
layer_types
=
(
'a'
,)
*
sandwich_coef
+
default_block
*
(
depth
-
sandwich_coef
)
+
(
'f'
,)
*
sandwich_coef
else
:
layer_types
=
default_block
*
depth
self
.
layer_types
=
layer_types
self
.
num_attn_layers
=
len
(
list
(
filter
(
equals
(
'a'
),
layer_types
)))
for
layer_type
in
self
.
layer_types
:
if
layer_type
==
'a'
:
layer
=
Attention
(
dim
,
heads
=
heads
,
causal
=
causal
,
**
attn_kwargs
)
elif
layer_type
==
'c'
:
layer
=
Attention
(
dim
,
heads
=
heads
,
**
attn_kwargs
)
elif
layer_type
==
'f'
:
layer
=
FeedForward
(
dim
,
**
ff_kwargs
)
layer
=
layer
if
not
macaron
else
Scale
(
0.5
,
layer
)
else
:
raise
Exception
(
f
'invalid layer type
{
layer_type
}
'
)
if
isinstance
(
layer
,
Attention
)
and
exists
(
branch_fn
):
layer
=
branch_fn
(
layer
)
if
gate_residual
:
residual_fn
=
GRUGating
(
dim
)
else
:
residual_fn
=
Residual
()
self
.
layers
.
append
(
nn
.
ModuleList
([
norm_fn
(),
layer
,
residual_fn
]))
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
context_mask
=
None
,
mems
=
None
,
return_hiddens
=
False
):
hiddens
=
[]
intermediates
=
[]
prev_attn
=
None
prev_cross_attn
=
None
mems
=
mems
.
copy
()
if
exists
(
mems
)
else
[
None
]
*
self
.
num_attn_layers
for
ind
,
(
layer_type
,
(
norm
,
block
,
residual_fn
))
in
enumerate
(
zip
(
self
.
layer_types
,
self
.
layers
)):
is_last
=
ind
==
(
len
(
self
.
layers
)
-
1
)
if
layer_type
==
'a'
:
hiddens
.
append
(
x
)
layer_mem
=
mems
.
pop
(
0
)
residual
=
x
if
self
.
pre_norm
:
x
=
norm
(
x
)
if
layer_type
==
'a'
:
out
,
inter
=
block
(
x
,
mask
=
mask
,
sinusoidal_emb
=
self
.
pia_pos_emb
,
rel_pos
=
self
.
rel_pos
,
prev_attn
=
prev_attn
,
mem
=
layer_mem
)
elif
layer_type
==
'c'
:
out
,
inter
=
block
(
x
,
context
=
context
,
mask
=
mask
,
context_mask
=
context_mask
,
prev_attn
=
prev_cross_attn
)
elif
layer_type
==
'f'
:
out
=
block
(
x
)
x
=
residual_fn
(
out
,
residual
)
if
layer_type
in
(
'a'
,
'c'
):
intermediates
.
append
(
inter
)
if
layer_type
==
'a'
and
self
.
residual_attn
:
prev_attn
=
inter
.
pre_softmax_attn
elif
layer_type
==
'c'
and
self
.
cross_residual_attn
:
prev_cross_attn
=
inter
.
pre_softmax_attn
if
not
self
.
pre_norm
and
not
is_last
:
x
=
norm
(
x
)
if
return_hiddens
:
intermediates
=
LayerIntermediates
(
hiddens
=
hiddens
,
attn_intermediates
=
intermediates
)
return
x
,
intermediates
return
x
class
Encoder
(
AttentionLayers
):
def
__init__
(
self
,
**
kwargs
):
assert
'causal'
not
in
kwargs
,
'cannot set causality on encoder'
super
().
__init__
(
causal
=
False
,
**
kwargs
)
class
TransformerWrapper
(
nn
.
Module
):
def
__init__
(
self
,
*
,
num_tokens
,
max_seq_len
,
attn_layers
,
emb_dim
=
None
,
max_mem_len
=
0.
,
emb_dropout
=
0.
,
num_memory_tokens
=
None
,
tie_embedding
=
False
,
use_pos_emb
=
True
):
super
().
__init__
()
assert
isinstance
(
attn_layers
,
AttentionLayers
),
'attention layers must be one of Encoder or Decoder'
dim
=
attn_layers
.
dim
emb_dim
=
default
(
emb_dim
,
dim
)
self
.
max_seq_len
=
max_seq_len
self
.
max_mem_len
=
max_mem_len
self
.
num_tokens
=
num_tokens
self
.
token_emb
=
nn
.
Embedding
(
num_tokens
,
emb_dim
)
self
.
pos_emb
=
AbsolutePositionalEmbedding
(
emb_dim
,
max_seq_len
)
if
(
use_pos_emb
and
not
attn_layers
.
has_pos_emb
)
else
always
(
0
)
self
.
emb_dropout
=
nn
.
Dropout
(
emb_dropout
)
self
.
project_emb
=
nn
.
Linear
(
emb_dim
,
dim
)
if
emb_dim
!=
dim
else
nn
.
Identity
()
self
.
attn_layers
=
attn_layers
self
.
norm
=
nn
.
LayerNorm
(
dim
)
self
.
init_
()
self
.
to_logits
=
nn
.
Linear
(
dim
,
num_tokens
)
if
not
tie_embedding
else
lambda
t
:
t
@
self
.
token_emb
.
weight
.
t
()
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens
=
default
(
num_memory_tokens
,
0
)
self
.
num_memory_tokens
=
num_memory_tokens
if
num_memory_tokens
>
0
:
self
.
memory_tokens
=
nn
.
Parameter
(
torch
.
randn
(
num_memory_tokens
,
dim
))
# let funnel encoder know number of memory tokens, if specified
if
hasattr
(
attn_layers
,
'num_memory_tokens'
):
attn_layers
.
num_memory_tokens
=
num_memory_tokens
def
init_
(
self
):
nn
.
init
.
normal_
(
self
.
token_emb
.
weight
,
std
=
0.02
)
def
forward
(
self
,
x
,
return_embeddings
=
False
,
mask
=
None
,
return_mems
=
False
,
return_attn
=
False
,
mems
=
None
,
**
kwargs
):
b
,
n
,
device
,
num_mem
=
*
x
.
shape
,
x
.
device
,
self
.
num_memory_tokens
x
=
self
.
token_emb
(
x
)
x
+=
self
.
pos_emb
(
x
)
x
=
self
.
emb_dropout
(
x
)
x
=
self
.
project_emb
(
x
)
if
num_mem
>
0
:
mem
=
repeat
(
self
.
memory_tokens
,
'n d -> b n d'
,
b
=
b
)
x
=
torch
.
cat
((
mem
,
x
),
dim
=
1
)
# auto-handle masking after appending memory tokens
if
exists
(
mask
):
mask
=
F
.
pad
(
mask
,
(
num_mem
,
0
),
value
=
True
)
x
,
intermediates
=
self
.
attn_layers
(
x
,
mask
=
mask
,
mems
=
mems
,
return_hiddens
=
True
,
**
kwargs
)
x
=
self
.
norm
(
x
)
mem
,
x
=
x
[:,
:
num_mem
],
x
[:,
num_mem
:]
out
=
self
.
to_logits
(
x
)
if
not
return_embeddings
else
x
if
return_mems
:
hiddens
=
intermediates
.
hiddens
new_mems
=
list
(
map
(
lambda
pair
:
torch
.
cat
(
pair
,
dim
=-
2
),
zip
(
mems
,
hiddens
)))
if
exists
(
mems
)
else
hiddens
new_mems
=
list
(
map
(
lambda
t
:
t
[...,
-
self
.
max_mem_len
:,
:].
detach
(),
new_mems
))
return
out
,
new_mems
if
return_attn
:
attn_maps
=
list
(
map
(
lambda
t
:
t
.
post_softmax_attn
,
intermediates
.
attn_intermediates
))
return
out
,
attn_maps
return
out
stable_diffusion/ldm/util.py
0 → 100644
View file @
9cfc6603
import
importlib
import
torch
import
numpy
as
np
from
collections
import
abc
from
einops
import
rearrange
from
functools
import
partial
import
multiprocessing
as
mp
from
threading
import
Thread
from
queue
import
Queue
from
inspect
import
isfunction
from
PIL
import
Image
,
ImageDraw
,
ImageFont
def
log_txt_as_img
(
wh
,
xc
,
size
=
10
):
# wh a tuple of (width, height)
# xc a list of captions to plot
b
=
len
(
xc
)
txts
=
list
()
for
bi
in
range
(
b
):
txt
=
Image
.
new
(
"RGB"
,
wh
,
color
=
"white"
)
draw
=
ImageDraw
.
Draw
(
txt
)
font
=
ImageFont
.
truetype
(
'data/DejaVuSans.ttf'
,
size
=
size
)
nc
=
int
(
40
*
(
wh
[
0
]
/
256
))
lines
=
"
\n
"
.
join
(
xc
[
bi
][
start
:
start
+
nc
]
for
start
in
range
(
0
,
len
(
xc
[
bi
]),
nc
))
try
:
draw
.
text
((
0
,
0
),
lines
,
fill
=
"black"
,
font
=
font
)
except
UnicodeEncodeError
:
print
(
"Cant encode string for logging. Skipping."
)
txt
=
np
.
array
(
txt
).
transpose
(
2
,
0
,
1
)
/
127.5
-
1.0
txts
.
append
(
txt
)
txts
=
np
.
stack
(
txts
)
txts
=
torch
.
tensor
(
txts
)
return
txts
def
ismap
(
x
):
if
not
isinstance
(
x
,
torch
.
Tensor
):
return
False
return
(
len
(
x
.
shape
)
==
4
)
and
(
x
.
shape
[
1
]
>
3
)
def
isimage
(
x
):
if
not
isinstance
(
x
,
torch
.
Tensor
):
return
False
return
(
len
(
x
.
shape
)
==
4
)
and
(
x
.
shape
[
1
]
==
3
or
x
.
shape
[
1
]
==
1
)
def
exists
(
x
):
return
x
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
def
mean_flat
(
tensor
):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return
tensor
.
mean
(
dim
=
list
(
range
(
1
,
len
(
tensor
.
shape
))))
def
count_params
(
model
,
verbose
=
False
):
total_params
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
if
verbose
:
print
(
f
"
{
model
.
__class__
.
__name__
}
has
{
total_params
*
1.e-6
:.
2
f
}
M params."
)
return
total_params
def
instantiate_from_config
(
config
):
if
not
"target"
in
config
:
if
config
==
'__is_first_stage__'
:
return
None
elif
config
==
"__is_unconditional__"
:
return
None
raise
KeyError
(
"Expected key `target` to instantiate."
)
return
get_obj_from_str
(
config
[
"target"
])(
**
config
.
get
(
"params"
,
dict
()))
def
get_obj_from_str
(
string
,
reload
=
False
):
module
,
cls
=
string
.
rsplit
(
"."
,
1
)
if
reload
:
module_imp
=
importlib
.
import_module
(
module
)
importlib
.
reload
(
module_imp
)
return
getattr
(
importlib
.
import_module
(
module
,
package
=
None
),
cls
)
def
_do_parallel_data_prefetch
(
func
,
Q
,
data
,
idx
,
idx_to_fn
=
False
):
# create dummy dataset instance
# run prefetching
if
idx_to_fn
:
res
=
func
(
data
,
worker_id
=
idx
)
else
:
res
=
func
(
data
)
Q
.
put
([
idx
,
res
])
Q
.
put
(
"Done"
)
def
parallel_data_prefetch
(
func
:
callable
,
data
,
n_proc
,
target_data_type
=
"ndarray"
,
cpu_intensive
=
True
,
use_worker_id
=
False
):
# if target_data_type not in ["ndarray", "list"]:
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if
isinstance
(
data
,
np
.
ndarray
)
and
target_data_type
==
"list"
:
raise
ValueError
(
"list expected but function got ndarray."
)
elif
isinstance
(
data
,
abc
.
Iterable
):
if
isinstance
(
data
,
dict
):
print
(
f
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data
=
list
(
data
.
values
())
if
target_data_type
==
"ndarray"
:
data
=
np
.
asarray
(
data
)
else
:
data
=
list
(
data
)
else
:
raise
TypeError
(
f
"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually
{
type
(
data
)
}
."
)
if
cpu_intensive
:
Q
=
mp
.
Queue
(
1000
)
proc
=
mp
.
Process
else
:
Q
=
Queue
(
1000
)
proc
=
Thread
# spawn processes
if
target_data_type
==
"ndarray"
:
arguments
=
[
[
func
,
Q
,
part
,
i
,
use_worker_id
]
for
i
,
part
in
enumerate
(
np
.
array_split
(
data
,
n_proc
))
]
else
:
step
=
(
int
(
len
(
data
)
/
n_proc
+
1
)
if
len
(
data
)
%
n_proc
!=
0
else
int
(
len
(
data
)
/
n_proc
)
)
arguments
=
[
[
func
,
Q
,
part
,
i
,
use_worker_id
]
for
i
,
part
in
enumerate
(
[
data
[
i
:
i
+
step
]
for
i
in
range
(
0
,
len
(
data
),
step
)]
)
]
processes
=
[]
for
i
in
range
(
n_proc
):
p
=
proc
(
target
=
_do_parallel_data_prefetch
,
args
=
arguments
[
i
])
processes
+=
[
p
]
# start processes
print
(
f
"Start prefetching..."
)
import
time
start
=
time
.
time
()
gather_res
=
[[]
for
_
in
range
(
n_proc
)]
try
:
for
p
in
processes
:
p
.
start
()
k
=
0
while
k
<
n_proc
:
# get result
res
=
Q
.
get
()
if
res
==
"Done"
:
k
+=
1
else
:
gather_res
[
res
[
0
]]
=
res
[
1
]
except
Exception
as
e
:
print
(
"Exception: "
,
e
)
for
p
in
processes
:
p
.
terminate
()
raise
e
finally
:
for
p
in
processes
:
p
.
join
()
print
(
f
"Prefetching complete. [
{
time
.
time
()
-
start
}
sec.]"
)
if
target_data_type
==
'ndarray'
:
if
not
isinstance
(
gather_res
[
0
],
np
.
ndarray
):
return
np
.
concatenate
([
np
.
asarray
(
r
)
for
r
in
gather_res
],
axis
=
0
)
# order outputs
return
np
.
concatenate
(
gather_res
,
axis
=
0
)
elif
target_data_type
==
'list'
:
out
=
[]
for
r
in
gather_res
:
out
.
extend
(
r
)
return
out
else
:
return
gather_res
stable_diffusion/main.py
0 → 100644
View file @
9cfc6603
import
argparse
,
os
,
sys
,
datetime
,
glob
,
importlib
,
csv
import
numpy
as
np
import
time
import
torch
import
torchvision
import
pytorch_lightning
as
pl
from
packaging
import
version
from
omegaconf
import
OmegaConf
from
torch.utils.data
import
random_split
,
DataLoader
,
Dataset
,
Subset
from
functools
import
partial
from
PIL
import
Image
from
pytorch_lightning
import
seed_everything
from
pytorch_lightning.trainer
import
Trainer
from
pytorch_lightning.callbacks
import
ModelCheckpoint
,
Callback
,
LearningRateMonitor
from
pytorch_lightning.utilities.distributed
import
rank_zero_only
from
pytorch_lightning.utilities
import
rank_zero_info
from
ldm.data.base
import
Txt2ImgIterableBaseDataset
from
ldm.util
import
instantiate_from_config
def
get_parser
(
**
parser_kwargs
):
def
str2bool
(
v
):
if
isinstance
(
v
,
bool
):
return
v
if
v
.
lower
()
in
(
"yes"
,
"true"
,
"t"
,
"y"
,
"1"
):
return
True
elif
v
.
lower
()
in
(
"no"
,
"false"
,
"f"
,
"n"
,
"0"
):
return
False
else
:
raise
argparse
.
ArgumentTypeError
(
"Boolean value expected."
)
parser
=
argparse
.
ArgumentParser
(
**
parser_kwargs
)
parser
.
add_argument
(
"-n"
,
"--name"
,
type
=
str
,
const
=
True
,
default
=
""
,
nargs
=
"?"
,
help
=
"postfix for logdir"
,
)
parser
.
add_argument
(
"-r"
,
"--resume"
,
type
=
str
,
const
=
True
,
default
=
""
,
nargs
=
"?"
,
help
=
"resume from logdir or checkpoint in logdir"
,
)
parser
.
add_argument
(
"-b"
,
"--base"
,
nargs
=
"*"
,
metavar
=
"base_config.yaml"
,
help
=
"paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`."
,
default
=
list
(),
)
parser
.
add_argument
(
"-t"
,
"--train"
,
type
=
str2bool
,
const
=
True
,
default
=
False
,
nargs
=
"?"
,
help
=
"train"
,
)
parser
.
add_argument
(
"--no-test"
,
type
=
str2bool
,
const
=
True
,
default
=
False
,
nargs
=
"?"
,
help
=
"disable test"
,
)
parser
.
add_argument
(
"-p"
,
"--project"
,
help
=
"name of new or path to existing project"
)
parser
.
add_argument
(
"-d"
,
"--debug"
,
type
=
str2bool
,
nargs
=
"?"
,
const
=
True
,
default
=
False
,
help
=
"enable post-mortem debugging"
,
)
parser
.
add_argument
(
"-s"
,
"--seed"
,
type
=
int
,
default
=
23
,
help
=
"seed for seed_everything"
,
)
parser
.
add_argument
(
"-f"
,
"--postfix"
,
type
=
str
,
default
=
""
,
help
=
"post-postfix for default name"
,
)
parser
.
add_argument
(
"-l"
,
"--logdir"
,
type
=
str
,
default
=
"logs"
,
help
=
"directory for logging dat shit"
,
)
parser
.
add_argument
(
"--scale_lr"
,
type
=
str2bool
,
nargs
=
"?"
,
const
=
True
,
default
=
True
,
help
=
"scale base-lr by ngpu * batch_size * n_accumulate"
,
)
return
parser
def
nondefault_trainer_args
(
opt
):
parser
=
argparse
.
ArgumentParser
()
parser
=
Trainer
.
add_argparse_args
(
parser
)
args
=
parser
.
parse_args
([])
return
sorted
(
k
for
k
in
vars
(
args
)
if
getattr
(
opt
,
k
)
!=
getattr
(
args
,
k
))
class
WrappedDataset
(
Dataset
):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def
__init__
(
self
,
dataset
):
self
.
data
=
dataset
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
idx
):
return
self
.
data
[
idx
]
def
worker_init_fn
(
_
):
worker_info
=
torch
.
utils
.
data
.
get_worker_info
()
dataset
=
worker_info
.
dataset
worker_id
=
worker_info
.
id
if
isinstance
(
dataset
,
Txt2ImgIterableBaseDataset
):
split_size
=
dataset
.
num_records
//
worker_info
.
num_workers
# reset num_records to the true number to retain reliable length information
dataset
.
sample_ids
=
dataset
.
valid_ids
[
worker_id
*
split_size
:(
worker_id
+
1
)
*
split_size
]
current_id
=
np
.
random
.
choice
(
len
(
np
.
random
.
get_state
()[
1
]),
1
)
return
np
.
random
.
seed
(
np
.
random
.
get_state
()[
1
][
current_id
]
+
worker_id
)
else
:
return
np
.
random
.
seed
(
np
.
random
.
get_state
()[
1
][
0
]
+
worker_id
)
class
DataModuleFromConfig
(
pl
.
LightningDataModule
):
def
__init__
(
self
,
batch_size
,
train
=
None
,
validation
=
None
,
test
=
None
,
predict
=
None
,
wrap
=
False
,
num_workers
=
None
,
shuffle_test_loader
=
False
,
use_worker_init_fn
=
False
,
shuffle_val_dataloader
=
False
):
super
().
__init__
()
self
.
batch_size
=
batch_size
self
.
dataset_configs
=
dict
()
self
.
num_workers
=
num_workers
if
num_workers
is
not
None
else
batch_size
*
2
self
.
use_worker_init_fn
=
use_worker_init_fn
if
train
is
not
None
:
self
.
dataset_configs
[
"train"
]
=
train
self
.
train_dataloader
=
self
.
_train_dataloader
if
validation
is
not
None
:
self
.
dataset_configs
[
"validation"
]
=
validation
self
.
val_dataloader
=
partial
(
self
.
_val_dataloader
,
shuffle
=
shuffle_val_dataloader
)
if
test
is
not
None
:
self
.
dataset_configs
[
"test"
]
=
test
self
.
test_dataloader
=
partial
(
self
.
_test_dataloader
,
shuffle
=
shuffle_test_loader
)
if
predict
is
not
None
:
self
.
dataset_configs
[
"predict"
]
=
predict
self
.
predict_dataloader
=
self
.
_predict_dataloader
self
.
wrap
=
wrap
def
prepare_data
(
self
):
for
data_cfg
in
self
.
dataset_configs
.
values
():
instantiate_from_config
(
data_cfg
)
def
setup
(
self
,
stage
=
None
):
self
.
datasets
=
dict
(
(
k
,
instantiate_from_config
(
self
.
dataset_configs
[
k
]))
for
k
in
self
.
dataset_configs
)
if
self
.
wrap
:
for
k
in
self
.
datasets
:
self
.
datasets
[
k
]
=
WrappedDataset
(
self
.
datasets
[
k
])
def
_train_dataloader
(
self
):
is_iterable_dataset
=
isinstance
(
self
.
datasets
[
'train'
],
Txt2ImgIterableBaseDataset
)
if
is_iterable_dataset
or
self
.
use_worker_init_fn
:
init_fn
=
worker_init_fn
else
:
init_fn
=
None
return
DataLoader
(
self
.
datasets
[
"train"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
shuffle
=
False
if
is_iterable_dataset
else
True
,
worker_init_fn
=
init_fn
)
def
_val_dataloader
(
self
,
shuffle
=
False
):
if
isinstance
(
self
.
datasets
[
'validation'
],
Txt2ImgIterableBaseDataset
)
or
self
.
use_worker_init_fn
:
init_fn
=
worker_init_fn
else
:
init_fn
=
None
return
DataLoader
(
self
.
datasets
[
"validation"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
worker_init_fn
=
init_fn
,
shuffle
=
shuffle
)
def
_test_dataloader
(
self
,
shuffle
=
False
):
is_iterable_dataset
=
isinstance
(
self
.
datasets
[
'train'
],
Txt2ImgIterableBaseDataset
)
if
is_iterable_dataset
or
self
.
use_worker_init_fn
:
init_fn
=
worker_init_fn
else
:
init_fn
=
None
# do not shuffle dataloader for iterable dataset
shuffle
=
shuffle
and
(
not
is_iterable_dataset
)
return
DataLoader
(
self
.
datasets
[
"test"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
worker_init_fn
=
init_fn
,
shuffle
=
shuffle
)
def
_predict_dataloader
(
self
,
shuffle
=
False
):
if
isinstance
(
self
.
datasets
[
'predict'
],
Txt2ImgIterableBaseDataset
)
or
self
.
use_worker_init_fn
:
init_fn
=
worker_init_fn
else
:
init_fn
=
None
return
DataLoader
(
self
.
datasets
[
"predict"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
worker_init_fn
=
init_fn
)
class
SetupCallback
(
Callback
):
def
__init__
(
self
,
resume
,
now
,
logdir
,
ckptdir
,
cfgdir
,
config
,
lightning_config
):
super
().
__init__
()
self
.
resume
=
resume
self
.
now
=
now
self
.
logdir
=
logdir
self
.
ckptdir
=
ckptdir
self
.
cfgdir
=
cfgdir
self
.
config
=
config
self
.
lightning_config
=
lightning_config
def
on_keyboard_interrupt
(
self
,
trainer
,
pl_module
):
if
trainer
.
global_rank
==
0
:
print
(
"Summoning checkpoint."
)
ckpt_path
=
os
.
path
.
join
(
self
.
ckptdir
,
"last.ckpt"
)
trainer
.
save_checkpoint
(
ckpt_path
)
def
on_pretrain_routine_start
(
self
,
trainer
,
pl_module
):
if
trainer
.
global_rank
==
0
:
# Create logdirs and save configs
os
.
makedirs
(
self
.
logdir
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
ckptdir
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
cfgdir
,
exist_ok
=
True
)
if
"callbacks"
in
self
.
lightning_config
:
if
'metrics_over_trainsteps_checkpoint'
in
self
.
lightning_config
[
'callbacks'
]:
os
.
makedirs
(
os
.
path
.
join
(
self
.
ckptdir
,
'trainstep_checkpoints'
),
exist_ok
=
True
)
print
(
"Project config"
)
print
(
OmegaConf
.
to_yaml
(
self
.
config
))
OmegaConf
.
save
(
self
.
config
,
os
.
path
.
join
(
self
.
cfgdir
,
"{}-project.yaml"
.
format
(
self
.
now
)))
print
(
"Lightning config"
)
print
(
OmegaConf
.
to_yaml
(
self
.
lightning_config
))
OmegaConf
.
save
(
OmegaConf
.
create
({
"lightning"
:
self
.
lightning_config
}),
os
.
path
.
join
(
self
.
cfgdir
,
"{}-lightning.yaml"
.
format
(
self
.
now
)))
else
:
# ModelCheckpoint callback created log directory --- remove it
if
not
self
.
resume
and
os
.
path
.
exists
(
self
.
logdir
):
dst
,
name
=
os
.
path
.
split
(
self
.
logdir
)
dst
=
os
.
path
.
join
(
dst
,
"child_runs"
,
name
)
os
.
makedirs
(
os
.
path
.
split
(
dst
)[
0
],
exist_ok
=
True
)
try
:
os
.
rename
(
self
.
logdir
,
dst
)
except
FileNotFoundError
:
pass
class
ImageLogger
(
Callback
):
def
__init__
(
self
,
batch_frequency
,
max_images
,
clamp
=
True
,
increase_log_steps
=
True
,
rescale
=
True
,
disabled
=
False
,
log_on_batch_idx
=
False
,
log_first_step
=
False
,
log_images_kwargs
=
None
):
super
().
__init__
()
self
.
rescale
=
rescale
self
.
batch_freq
=
batch_frequency
self
.
max_images
=
max_images
self
.
logger_log_images
=
{
pl
.
loggers
.
TestTubeLogger
:
self
.
_testtube
,
}
self
.
log_steps
=
[
2
**
n
for
n
in
range
(
int
(
np
.
log2
(
self
.
batch_freq
))
+
1
)]
if
not
increase_log_steps
:
self
.
log_steps
=
[
self
.
batch_freq
]
self
.
clamp
=
clamp
self
.
disabled
=
disabled
self
.
log_on_batch_idx
=
log_on_batch_idx
self
.
log_images_kwargs
=
log_images_kwargs
if
log_images_kwargs
else
{}
self
.
log_first_step
=
log_first_step
@
rank_zero_only
def
_testtube
(
self
,
pl_module
,
images
,
batch_idx
,
split
):
for
k
in
images
:
grid
=
torchvision
.
utils
.
make_grid
(
images
[
k
])
grid
=
(
grid
+
1.0
)
/
2.0
# -1,1 -> 0,1; c,h,w
tag
=
f
"
{
split
}
/
{
k
}
"
pl_module
.
logger
.
experiment
.
add_image
(
tag
,
grid
,
global_step
=
pl_module
.
global_step
)
@
rank_zero_only
def
log_local
(
self
,
save_dir
,
split
,
images
,
global_step
,
current_epoch
,
batch_idx
):
root
=
os
.
path
.
join
(
save_dir
,
"images"
,
split
)
for
k
in
images
:
grid
=
torchvision
.
utils
.
make_grid
(
images
[
k
],
nrow
=
4
)
if
self
.
rescale
:
grid
=
(
grid
+
1.0
)
/
2.0
# -1,1 -> 0,1; c,h,w
grid
=
grid
.
transpose
(
0
,
1
).
transpose
(
1
,
2
).
squeeze
(
-
1
)
grid
=
grid
.
numpy
()
grid
=
(
grid
*
255
).
astype
(
np
.
uint8
)
filename
=
"{}_gs-{:06}_e-{:06}_b-{:06}.png"
.
format
(
k
,
global_step
,
current_epoch
,
batch_idx
)
path
=
os
.
path
.
join
(
root
,
filename
)
os
.
makedirs
(
os
.
path
.
split
(
path
)[
0
],
exist_ok
=
True
)
Image
.
fromarray
(
grid
).
save
(
path
)
def
log_img
(
self
,
pl_module
,
batch
,
batch_idx
,
split
=
"train"
):
check_idx
=
batch_idx
if
self
.
log_on_batch_idx
else
pl_module
.
global_step
if
(
self
.
check_frequency
(
check_idx
)
and
# batch_idx % self.batch_freq == 0
hasattr
(
pl_module
,
"log_images"
)
and
callable
(
pl_module
.
log_images
)
and
self
.
max_images
>
0
):
logger
=
type
(
pl_module
.
logger
)
is_train
=
pl_module
.
training
if
is_train
:
pl_module
.
eval
()
with
torch
.
no_grad
():
images
=
pl_module
.
log_images
(
batch
,
split
=
split
,
**
self
.
log_images_kwargs
)
for
k
in
images
:
N
=
min
(
images
[
k
].
shape
[
0
],
self
.
max_images
)
images
[
k
]
=
images
[
k
][:
N
]
if
isinstance
(
images
[
k
],
torch
.
Tensor
):
images
[
k
]
=
images
[
k
].
detach
().
cpu
()
if
self
.
clamp
:
images
[
k
]
=
torch
.
clamp
(
images
[
k
],
-
1.
,
1.
)
self
.
log_local
(
pl_module
.
logger
.
save_dir
,
split
,
images
,
pl_module
.
global_step
,
pl_module
.
current_epoch
,
batch_idx
)
logger_log_images
=
self
.
logger_log_images
.
get
(
logger
,
lambda
*
args
,
**
kwargs
:
None
)
logger_log_images
(
pl_module
,
images
,
pl_module
.
global_step
,
split
)
if
is_train
:
pl_module
.
train
()
def
check_frequency
(
self
,
check_idx
):
if
((
check_idx
%
self
.
batch_freq
)
==
0
or
(
check_idx
in
self
.
log_steps
))
and
(
check_idx
>
0
or
self
.
log_first_step
):
try
:
self
.
log_steps
.
pop
(
0
)
except
IndexError
as
e
:
print
(
e
)
pass
return
True
return
False
def
on_train_batch_end
(
self
,
trainer
,
pl_module
,
outputs
,
batch
,
batch_idx
,
dataloader_idx
):
if
not
self
.
disabled
and
(
pl_module
.
global_step
>
0
or
self
.
log_first_step
):
self
.
log_img
(
pl_module
,
batch
,
batch_idx
,
split
=
"train"
)
def
on_validation_batch_end
(
self
,
trainer
,
pl_module
,
outputs
,
batch
,
batch_idx
,
dataloader_idx
):
if
not
self
.
disabled
and
pl_module
.
global_step
>
0
:
self
.
log_img
(
pl_module
,
batch
,
batch_idx
,
split
=
"val"
)
if
hasattr
(
pl_module
,
'calibrate_grad_norm'
):
if
(
pl_module
.
calibrate_grad_norm
and
batch_idx
%
25
==
0
)
and
batch_idx
>
0
:
self
.
log_gradients
(
trainer
,
pl_module
,
batch_idx
=
batch_idx
)
class
CUDACallback
(
Callback
):
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
def
on_train_epoch_start
(
self
,
trainer
,
pl_module
):
# Reset the memory use counter
torch
.
cuda
.
reset_peak_memory_stats
(
trainer
.
root_gpu
)
torch
.
cuda
.
synchronize
(
trainer
.
root_gpu
)
self
.
start_time
=
time
.
time
()
def
on_train_epoch_end
(
self
,
trainer
,
pl_module
,
outputs
):
torch
.
cuda
.
synchronize
(
trainer
.
root_gpu
)
max_memory
=
torch
.
cuda
.
max_memory_allocated
(
trainer
.
root_gpu
)
/
2
**
20
epoch_time
=
time
.
time
()
-
self
.
start_time
try
:
max_memory
=
trainer
.
training_type_plugin
.
reduce
(
max_memory
)
epoch_time
=
trainer
.
training_type_plugin
.
reduce
(
epoch_time
)
rank_zero_info
(
f
"Average Epoch time:
{
epoch_time
:.
2
f
}
seconds"
)
rank_zero_info
(
f
"Average Peak memory
{
max_memory
:.
2
f
}
MiB"
)
except
AttributeError
:
pass
if
__name__
==
"__main__"
:
# custom parser to specify config files, train, test and debug mode,
# postfix, resume.
# `--key value` arguments are interpreted as arguments to the trainer.
# `nested.key=value` arguments are interpreted as config parameters.
# configs are merged from left-to-right followed by command line parameters.
# model:
# base_learning_rate: float
# target: path to lightning module
# params:
# key: value
# data:
# target: main.DataModuleFromConfig
# params:
# batch_size: int
# wrap: bool
# train:
# target: path to train dataset
# params:
# key: value
# validation:
# target: path to validation dataset
# params:
# key: value
# test:
# target: path to test dataset
# params:
# key: value
# lightning: (optional, has sane defaults and can be specified on cmdline)
# trainer:
# additional arguments to trainer
# logger:
# logger to instantiate
# modelcheckpoint:
# modelcheckpoint to instantiate
# callbacks:
# callback1:
# target: importpath
# params:
# key: value
now
=
datetime
.
datetime
.
now
().
strftime
(
"%Y-%m-%dT%H-%M-%S"
)
# add cwd for convenience and to make classes in this file available when
# running as `python main.py`
# (in particular `main.DataModuleFromConfig`)
sys
.
path
.
append
(
os
.
getcwd
())
parser
=
get_parser
()
parser
=
Trainer
.
add_argparse_args
(
parser
)
opt
,
unknown
=
parser
.
parse_known_args
()
if
opt
.
name
and
opt
.
resume
:
raise
ValueError
(
"-n/--name and -r/--resume cannot be specified both."
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint"
)
if
opt
.
resume
:
if
not
os
.
path
.
exists
(
opt
.
resume
):
raise
ValueError
(
"Cannot find {}"
.
format
(
opt
.
resume
))
if
os
.
path
.
isfile
(
opt
.
resume
):
paths
=
opt
.
resume
.
split
(
"/"
)
# idx = len(paths)-paths[::-1].index("logs")+1
# logdir = "/".join(paths[:idx])
logdir
=
"/"
.
join
(
paths
[:
-
2
])
ckpt
=
opt
.
resume
else
:
assert
os
.
path
.
isdir
(
opt
.
resume
),
opt
.
resume
logdir
=
opt
.
resume
.
rstrip
(
"/"
)
ckpt
=
os
.
path
.
join
(
logdir
,
"checkpoints"
,
"last.ckpt"
)
opt
.
resume_from_checkpoint
=
ckpt
base_configs
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
logdir
,
"configs/*.yaml"
)))
opt
.
base
=
base_configs
+
opt
.
base
_tmp
=
logdir
.
split
(
"/"
)
nowname
=
_tmp
[
-
1
]
else
:
if
opt
.
name
:
name
=
"_"
+
opt
.
name
elif
opt
.
base
:
cfg_fname
=
os
.
path
.
split
(
opt
.
base
[
0
])[
-
1
]
cfg_name
=
os
.
path
.
splitext
(
cfg_fname
)[
0
]
name
=
"_"
+
cfg_name
else
:
name
=
""
nowname
=
now
+
name
+
opt
.
postfix
logdir
=
os
.
path
.
join
(
opt
.
logdir
,
nowname
)
ckptdir
=
os
.
path
.
join
(
logdir
,
"checkpoints"
)
cfgdir
=
os
.
path
.
join
(
logdir
,
"configs"
)
seed_everything
(
opt
.
seed
)
try
:
# init and save configs
configs
=
[
OmegaConf
.
load
(
cfg
)
for
cfg
in
opt
.
base
]
cli
=
OmegaConf
.
from_dotlist
(
unknown
)
config
=
OmegaConf
.
merge
(
*
configs
,
cli
)
lightning_config
=
config
.
pop
(
"lightning"
,
OmegaConf
.
create
())
# merge trainer cli with config
trainer_config
=
lightning_config
.
get
(
"trainer"
,
OmegaConf
.
create
())
# default to ddp
trainer_config
[
"accelerator"
]
=
"ddp"
for
k
in
nondefault_trainer_args
(
opt
):
trainer_config
[
k
]
=
getattr
(
opt
,
k
)
if
not
"gpus"
in
trainer_config
:
del
trainer_config
[
"accelerator"
]
cpu
=
True
else
:
gpuinfo
=
trainer_config
[
"gpus"
]
print
(
f
"Running on GPUs
{
gpuinfo
}
"
)
cpu
=
False
trainer_opt
=
argparse
.
Namespace
(
**
trainer_config
)
lightning_config
.
trainer
=
trainer_config
# model
model
=
instantiate_from_config
(
config
.
model
)
# trainer and callbacks
trainer_kwargs
=
dict
()
# default logger configs
default_logger_cfgs
=
{
"wandb"
:
{
"target"
:
"pytorch_lightning.loggers.WandbLogger"
,
"params"
:
{
"name"
:
nowname
,
"save_dir"
:
logdir
,
"offline"
:
opt
.
debug
,
"id"
:
nowname
,
}
},
"testtube"
:
{
"target"
:
"pytorch_lightning.loggers.TestTubeLogger"
,
"params"
:
{
"name"
:
"testtube"
,
"save_dir"
:
logdir
,
}
},
}
default_logger_cfg
=
default_logger_cfgs
[
"testtube"
]
if
"logger"
in
lightning_config
:
logger_cfg
=
lightning_config
.
logger
else
:
logger_cfg
=
OmegaConf
.
create
()
logger_cfg
=
OmegaConf
.
merge
(
default_logger_cfg
,
logger_cfg
)
trainer_kwargs
[
"logger"
]
=
instantiate_from_config
(
logger_cfg
)
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg
=
{
"target"
:
"pytorch_lightning.callbacks.ModelCheckpoint"
,
"params"
:
{
"dirpath"
:
ckptdir
,
"filename"
:
"{epoch:06}"
,
"verbose"
:
True
,
"save_last"
:
True
,
}
}
if
hasattr
(
model
,
"monitor"
):
print
(
f
"Monitoring
{
model
.
monitor
}
as checkpoint metric."
)
default_modelckpt_cfg
[
"params"
][
"monitor"
]
=
model
.
monitor
default_modelckpt_cfg
[
"params"
][
"save_top_k"
]
=
3
if
"modelcheckpoint"
in
lightning_config
:
modelckpt_cfg
=
lightning_config
.
modelcheckpoint
else
:
modelckpt_cfg
=
OmegaConf
.
create
()
modelckpt_cfg
=
OmegaConf
.
merge
(
default_modelckpt_cfg
,
modelckpt_cfg
)
print
(
f
"Merged modelckpt-cfg:
\n
{
modelckpt_cfg
}
"
)
if
version
.
parse
(
pl
.
__version__
)
<
version
.
parse
(
'1.4.0'
):
trainer_kwargs
[
"checkpoint_callback"
]
=
instantiate_from_config
(
modelckpt_cfg
)
# add callback which sets up log directory
default_callbacks_cfg
=
{
"setup_callback"
:
{
"target"
:
"main.SetupCallback"
,
"params"
:
{
"resume"
:
opt
.
resume
,
"now"
:
now
,
"logdir"
:
logdir
,
"ckptdir"
:
ckptdir
,
"cfgdir"
:
cfgdir
,
"config"
:
config
,
"lightning_config"
:
lightning_config
,
}
},
"image_logger"
:
{
"target"
:
"main.ImageLogger"
,
"params"
:
{
"batch_frequency"
:
750
,
"max_images"
:
4
,
"clamp"
:
True
}
},
"learning_rate_logger"
:
{
"target"
:
"main.LearningRateMonitor"
,
"params"
:
{
"logging_interval"
:
"step"
,
# "log_momentum": True
}
},
"cuda_callback"
:
{
"target"
:
"main.CUDACallback"
},
}
if
version
.
parse
(
pl
.
__version__
)
>=
version
.
parse
(
'1.4.0'
):
default_callbacks_cfg
.
update
({
'checkpoint_callback'
:
modelckpt_cfg
})
if
"callbacks"
in
lightning_config
:
callbacks_cfg
=
lightning_config
.
callbacks
else
:
callbacks_cfg
=
OmegaConf
.
create
()
if
'metrics_over_trainsteps_checkpoint'
in
callbacks_cfg
:
print
(
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.'
)
default_metrics_over_trainsteps_ckpt_dict
=
{
'metrics_over_trainsteps_checkpoint'
:
{
"target"
:
'pytorch_lightning.callbacks.ModelCheckpoint'
,
'params'
:
{
"dirpath"
:
os
.
path
.
join
(
ckptdir
,
'trainstep_checkpoints'
),
"filename"
:
"{epoch:06}-{step:09}"
,
"verbose"
:
True
,
'save_top_k'
:
-
1
,
'every_n_train_steps'
:
10000
,
'save_weights_only'
:
True
}
}
}
default_callbacks_cfg
.
update
(
default_metrics_over_trainsteps_ckpt_dict
)
callbacks_cfg
=
OmegaConf
.
merge
(
default_callbacks_cfg
,
callbacks_cfg
)
if
'ignore_keys_callback'
in
callbacks_cfg
and
hasattr
(
trainer_opt
,
'resume_from_checkpoint'
):
callbacks_cfg
.
ignore_keys_callback
.
params
[
'ckpt_path'
]
=
trainer_opt
.
resume_from_checkpoint
elif
'ignore_keys_callback'
in
callbacks_cfg
:
del
callbacks_cfg
[
'ignore_keys_callback'
]
trainer_kwargs
[
"callbacks"
]
=
[
instantiate_from_config
(
callbacks_cfg
[
k
])
for
k
in
callbacks_cfg
]
trainer
=
Trainer
.
from_argparse_args
(
trainer_opt
,
**
trainer_kwargs
)
trainer
.
logdir
=
logdir
###
# data
data
=
instantiate_from_config
(
config
.
data
)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
data
.
prepare_data
()
data
.
setup
()
print
(
"#### Data #####"
)
for
k
in
data
.
datasets
:
print
(
f
"
{
k
}
,
{
data
.
datasets
[
k
].
__class__
.
__name__
}
,
{
len
(
data
.
datasets
[
k
])
}
"
)
# configure learning rate
bs
,
base_lr
=
config
.
data
.
params
.
batch_size
,
config
.
model
.
base_learning_rate
if
not
cpu
:
ngpu
=
len
(
lightning_config
.
trainer
.
gpus
.
strip
(
","
).
split
(
','
))
else
:
ngpu
=
1
if
'accumulate_grad_batches'
in
lightning_config
.
trainer
:
accumulate_grad_batches
=
lightning_config
.
trainer
.
accumulate_grad_batches
else
:
accumulate_grad_batches
=
1
print
(
f
"accumulate_grad_batches =
{
accumulate_grad_batches
}
"
)
lightning_config
.
trainer
.
accumulate_grad_batches
=
accumulate_grad_batches
if
opt
.
scale_lr
:
model
.
learning_rate
=
accumulate_grad_batches
*
ngpu
*
bs
*
base_lr
print
(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
.
format
(
model
.
learning_rate
,
accumulate_grad_batches
,
ngpu
,
bs
,
base_lr
))
else
:
model
.
learning_rate
=
base_lr
print
(
"++++ NOT USING LR SCALING ++++"
)
print
(
f
"Setting learning rate to
{
model
.
learning_rate
:.
2
e
}
"
)
# allow checkpointing via USR1
def
melk
(
*
args
,
**
kwargs
):
# run all checkpoint hooks
if
trainer
.
global_rank
==
0
:
print
(
"Summoning checkpoint."
)
ckpt_path
=
os
.
path
.
join
(
ckptdir
,
"last.ckpt"
)
trainer
.
save_checkpoint
(
ckpt_path
)
def
divein
(
*
args
,
**
kwargs
):
if
trainer
.
global_rank
==
0
:
import
pudb
;
pudb
.
set_trace
()
import
signal
signal
.
signal
(
signal
.
SIGUSR1
,
melk
)
signal
.
signal
(
signal
.
SIGUSR2
,
divein
)
# run
if
opt
.
train
:
try
:
trainer
.
fit
(
model
,
data
)
except
Exception
:
melk
()
raise
if
not
opt
.
no_test
and
not
trainer
.
interrupted
:
trainer
.
test
(
model
,
data
)
except
Exception
:
if
opt
.
debug
and
trainer
.
global_rank
==
0
:
try
:
import
pudb
as
debugger
except
ImportError
:
import
pdb
as
debugger
debugger
.
post_mortem
()
raise
finally
:
# move newly created debug project to debug_runs
if
opt
.
debug
and
not
opt
.
resume
and
trainer
.
global_rank
==
0
:
dst
,
name
=
os
.
path
.
split
(
logdir
)
dst
=
os
.
path
.
join
(
dst
,
"debug_runs"
,
name
)
os
.
makedirs
(
os
.
path
.
split
(
dst
)[
0
],
exist_ok
=
True
)
os
.
rename
(
logdir
,
dst
)
try
:
if
trainer
.
global_rank
==
0
:
print
(
trainer
.
profiler
.
summary
())
except
:
pass
stable_diffusion/models/first_stage_models/kl-f16/config.yaml
0 → 100644
View file @
9cfc6603
model
:
base_learning_rate
:
4.5e-06
target
:
ldm.models.autoencoder.AutoencoderKL
params
:
monitor
:
val/rec_loss
embed_dim
:
16
lossconfig
:
target
:
ldm.modules.losses.LPIPSWithDiscriminator
params
:
disc_start
:
50001
kl_weight
:
1.0e-06
disc_weight
:
0.5
ddconfig
:
double_z
:
true
z_channels
:
16
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
1
-
2
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
-
16
dropout
:
0.0
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
6
wrap
:
true
train
:
target
:
ldm.data.openimages.FullOpenImagesTrain
params
:
size
:
384
crop_size
:
256
validation
:
target
:
ldm.data.openimages.FullOpenImagesValidation
params
:
size
:
384
crop_size
:
256
stable_diffusion/models/first_stage_models/kl-f32/config.yaml
0 → 100644
View file @
9cfc6603
model
:
base_learning_rate
:
4.5e-06
target
:
ldm.models.autoencoder.AutoencoderKL
params
:
monitor
:
val/rec_loss
embed_dim
:
64
lossconfig
:
target
:
ldm.modules.losses.LPIPSWithDiscriminator
params
:
disc_start
:
50001
kl_weight
:
1.0e-06
disc_weight
:
0.5
ddconfig
:
double_z
:
true
z_channels
:
64
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
1
-
2
-
2
-
4
-
4
num_res_blocks
:
2
attn_resolutions
:
-
16
-
8
dropout
:
0.0
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
6
wrap
:
true
train
:
target
:
ldm.data.openimages.FullOpenImagesTrain
params
:
size
:
384
crop_size
:
256
validation
:
target
:
ldm.data.openimages.FullOpenImagesValidation
params
:
size
:
384
crop_size
:
256
stable_diffusion/models/first_stage_models/kl-f4/config.yaml
0 → 100644
View file @
9cfc6603
model
:
base_learning_rate
:
4.5e-06
target
:
ldm.models.autoencoder.AutoencoderKL
params
:
monitor
:
val/rec_loss
embed_dim
:
3
lossconfig
:
target
:
ldm.modules.losses.LPIPSWithDiscriminator
params
:
disc_start
:
50001
kl_weight
:
1.0e-06
disc_weight
:
0.5
ddconfig
:
double_z
:
true
z_channels
:
3
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
[]
dropout
:
0.0
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
10
wrap
:
true
train
:
target
:
ldm.data.openimages.FullOpenImagesTrain
params
:
size
:
384
crop_size
:
256
validation
:
target
:
ldm.data.openimages.FullOpenImagesValidation
params
:
size
:
384
crop_size
:
256
stable_diffusion/models/first_stage_models/kl-f8/config.yaml
0 → 100644
View file @
9cfc6603
model
:
base_learning_rate
:
4.5e-06
target
:
ldm.models.autoencoder.AutoencoderKL
params
:
monitor
:
val/rec_loss
embed_dim
:
4
lossconfig
:
target
:
ldm.modules.losses.LPIPSWithDiscriminator
params
:
disc_start
:
50001
kl_weight
:
1.0e-06
disc_weight
:
0.5
ddconfig
:
double_z
:
true
z_channels
:
4
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
4
-
4
num_res_blocks
:
2
attn_resolutions
:
[]
dropout
:
0.0
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
4
wrap
:
true
train
:
target
:
ldm.data.openimages.FullOpenImagesTrain
params
:
size
:
384
crop_size
:
256
validation
:
target
:
ldm.data.openimages.FullOpenImagesValidation
params
:
size
:
384
crop_size
:
256
stable_diffusion/models/first_stage_models/vq-f16/config.yaml
0 → 100644
View file @
9cfc6603
model
:
base_learning_rate
:
4.5e-06
target
:
ldm.models.autoencoder.VQModel
params
:
embed_dim
:
8
n_embed
:
16384
ddconfig
:
double_z
:
false
z_channels
:
8
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
1
-
2
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
-
16
dropout
:
0.0
lossconfig
:
target
:
taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params
:
disc_conditional
:
false
disc_in_channels
:
3
disc_start
:
250001
disc_weight
:
0.75
disc_num_layers
:
2
codebook_weight
:
1.0
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
14
num_workers
:
20
wrap
:
true
train
:
target
:
ldm.data.openimages.FullOpenImagesTrain
params
:
size
:
384
crop_size
:
256
validation
:
target
:
ldm.data.openimages.FullOpenImagesValidation
params
:
size
:
384
crop_size
:
256
stable_diffusion/models/first_stage_models/vq-f4-noattn/config.yaml
0 → 100644
View file @
9cfc6603
model
:
base_learning_rate
:
4.5e-06
target
:
ldm.models.autoencoder.VQModel
params
:
embed_dim
:
3
n_embed
:
8192
monitor
:
val/rec_loss
ddconfig
:
attn_type
:
none
double_z
:
false
z_channels
:
3
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
[]
dropout
:
0.0
lossconfig
:
target
:
taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params
:
disc_conditional
:
false
disc_in_channels
:
3
disc_start
:
11
disc_weight
:
0.75
codebook_weight
:
1.0
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
8
num_workers
:
12
wrap
:
true
train
:
target
:
ldm.data.openimages.FullOpenImagesTrain
params
:
crop_size
:
256
validation
:
target
:
ldm.data.openimages.FullOpenImagesValidation
params
:
crop_size
:
256
Prev
1
…
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment