Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e532679c
Commit
e532679c
authored
Jan 10, 2023
by
oahzxl
Browse files
Merge branch 'main' of
https://github.com/oahzxl/ColossalAI
into chunk
parents
c1492e50
7d5640b9
Changes
461
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4250 additions
and
0 deletions
+4250
-0
examples/images/diffusion/ldm/modules/distributions/__init__.py
...es/images/diffusion/ldm/modules/distributions/__init__.py
+0
-0
examples/images/diffusion/ldm/modules/distributions/distributions.py
...ages/diffusion/ldm/modules/distributions/distributions.py
+92
-0
examples/images/diffusion/ldm/modules/ema.py
examples/images/diffusion/ldm/modules/ema.py
+80
-0
examples/images/diffusion/ldm/modules/encoders/__init__.py
examples/images/diffusion/ldm/modules/encoders/__init__.py
+0
-0
examples/images/diffusion/ldm/modules/encoders/modules.py
examples/images/diffusion/ldm/modules/encoders/modules.py
+213
-0
examples/images/diffusion/ldm/modules/image_degradation/__init__.py
...mages/diffusion/ldm/modules/image_degradation/__init__.py
+2
-0
examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py
.../images/diffusion/ldm/modules/image_degradation/bsrgan.py
+730
-0
examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py
...s/diffusion/ldm/modules/image_degradation/bsrgan_light.py
+651
-0
examples/images/diffusion/ldm/modules/image_degradation/utils/test.png
...es/diffusion/ldm/modules/image_degradation/utils/test.png
+0
-0
examples/images/diffusion/ldm/modules/image_degradation/utils_image.py
...es/diffusion/ldm/modules/image_degradation/utils_image.py
+916
-0
examples/images/diffusion/ldm/modules/midas/__init__.py
examples/images/diffusion/ldm/modules/midas/__init__.py
+0
-0
examples/images/diffusion/ldm/modules/midas/api.py
examples/images/diffusion/ldm/modules/midas/api.py
+170
-0
examples/images/diffusion/ldm/modules/midas/midas/__init__.py
...ples/images/diffusion/ldm/modules/midas/midas/__init__.py
+0
-0
examples/images/diffusion/ldm/modules/midas/midas/base_model.py
...es/images/diffusion/ldm/modules/midas/midas/base_model.py
+16
-0
examples/images/diffusion/ldm/modules/midas/midas/blocks.py
examples/images/diffusion/ldm/modules/midas/midas/blocks.py
+342
-0
examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py
...les/images/diffusion/ldm/modules/midas/midas/dpt_depth.py
+109
-0
examples/images/diffusion/ldm/modules/midas/midas/midas_net.py
...les/images/diffusion/ldm/modules/midas/midas/midas_net.py
+76
-0
examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py
...ges/diffusion/ldm/modules/midas/midas/midas_net_custom.py
+128
-0
examples/images/diffusion/ldm/modules/midas/midas/transforms.py
...es/images/diffusion/ldm/modules/midas/midas/transforms.py
+234
-0
examples/images/diffusion/ldm/modules/midas/midas/vit.py
examples/images/diffusion/ldm/modules/midas/midas/vit.py
+491
-0
No files found.
Too many changes to show.
To preserve performance only
461 of 461+
files are displayed.
Plain diff
Email patch
examples/images/diffusion/ldm/modules/distributions/__init__.py
0 → 100644
View file @
e532679c
examples/images/diffusion/ldm/modules/distributions/distributions.py
0 → 100644
View file @
e532679c
import
torch
import
numpy
as
np
class
AbstractDistribution
:
def
sample
(
self
):
raise
NotImplementedError
()
def
mode
(
self
):
raise
NotImplementedError
()
class
DiracDistribution
(
AbstractDistribution
):
def
__init__
(
self
,
value
):
self
.
value
=
value
def
sample
(
self
):
return
self
.
value
def
mode
(
self
):
return
self
.
value
class
DiagonalGaussianDistribution
(
object
):
def
__init__
(
self
,
parameters
,
deterministic
=
False
):
self
.
parameters
=
parameters
self
.
mean
,
self
.
logvar
=
torch
.
chunk
(
parameters
,
2
,
dim
=
1
)
self
.
logvar
=
torch
.
clamp
(
self
.
logvar
,
-
30.0
,
20.0
)
self
.
deterministic
=
deterministic
self
.
std
=
torch
.
exp
(
0.5
*
self
.
logvar
)
self
.
var
=
torch
.
exp
(
self
.
logvar
)
if
self
.
deterministic
:
self
.
var
=
self
.
std
=
torch
.
zeros_like
(
self
.
mean
).
to
(
device
=
self
.
parameters
.
device
)
def
sample
(
self
):
x
=
self
.
mean
+
self
.
std
*
torch
.
randn
(
self
.
mean
.
shape
).
to
(
device
=
self
.
parameters
.
device
)
return
x
def
kl
(
self
,
other
=
None
):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
])
else
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
else
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
])
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
def
mode
(
self
):
return
self
.
mean
def
normal_kl
(
mean1
,
logvar1
,
mean2
,
logvar2
):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor
=
None
for
obj
in
(
mean1
,
logvar1
,
mean2
,
logvar2
):
if
isinstance
(
obj
,
torch
.
Tensor
):
tensor
=
obj
break
assert
tensor
is
not
None
,
"at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1
,
logvar2
=
[
x
if
isinstance
(
x
,
torch
.
Tensor
)
else
torch
.
tensor
(
x
).
to
(
tensor
)
for
x
in
(
logvar1
,
logvar2
)
]
return
0.5
*
(
-
1.0
+
logvar2
-
logvar1
+
torch
.
exp
(
logvar1
-
logvar2
)
+
((
mean1
-
mean2
)
**
2
)
*
torch
.
exp
(
-
logvar2
)
)
examples/images/diffusion/ldm/modules/ema.py
0 → 100644
View file @
e532679c
import
torch
from
torch
import
nn
class
LitEma
(
nn
.
Module
):
def
__init__
(
self
,
model
,
decay
=
0.9999
,
use_num_upates
=
True
):
super
().
__init__
()
if
decay
<
0.0
or
decay
>
1.0
:
raise
ValueError
(
'Decay must be between 0 and 1'
)
self
.
m_name2s_name
=
{}
self
.
register_buffer
(
'decay'
,
torch
.
tensor
(
decay
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
'num_updates'
,
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
use_num_upates
else
torch
.
tensor
(
-
1
,
dtype
=
torch
.
int
))
for
name
,
p
in
model
.
named_parameters
():
if
p
.
requires_grad
:
# remove as '.'-character is not allowed in buffers
s_name
=
name
.
replace
(
'.'
,
''
)
self
.
m_name2s_name
.
update
({
name
:
s_name
})
self
.
register_buffer
(
s_name
,
p
.
clone
().
detach
().
data
)
self
.
collected_params
=
[]
def
reset_num_updates
(
self
):
del
self
.
num_updates
self
.
register_buffer
(
'num_updates'
,
torch
.
tensor
(
0
,
dtype
=
torch
.
int
))
def
forward
(
self
,
model
):
decay
=
self
.
decay
if
self
.
num_updates
>=
0
:
self
.
num_updates
+=
1
decay
=
min
(
self
.
decay
,
(
1
+
self
.
num_updates
)
/
(
10
+
self
.
num_updates
))
one_minus_decay
=
1.0
-
decay
with
torch
.
no_grad
():
m_param
=
dict
(
model
.
named_parameters
())
shadow_params
=
dict
(
self
.
named_buffers
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
sname
=
self
.
m_name2s_name
[
key
]
shadow_params
[
sname
]
=
shadow_params
[
sname
].
type_as
(
m_param
[
key
])
shadow_params
[
sname
].
sub_
(
one_minus_decay
*
(
shadow_params
[
sname
]
-
m_param
[
key
]))
else
:
assert
not
key
in
self
.
m_name2s_name
def
copy_to
(
self
,
model
):
m_param
=
dict
(
model
.
named_parameters
())
shadow_params
=
dict
(
self
.
named_buffers
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
m_param
[
key
].
data
.
copy_
(
shadow_params
[
self
.
m_name2s_name
[
key
]].
data
)
else
:
assert
not
key
in
self
.
m_name2s_name
def
store
(
self
,
parameters
):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self
.
collected_params
=
[
param
.
clone
()
for
param
in
parameters
]
def
restore
(
self
,
parameters
):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for
c_param
,
param
in
zip
(
self
.
collected_params
,
parameters
):
param
.
data
.
copy_
(
c_param
.
data
)
examples/images/diffusion/ldm/modules/encoders/__init__.py
0 → 100644
View file @
e532679c
examples/images/diffusion/ldm/modules/encoders/modules.py
0 → 100644
View file @
e532679c
import
torch
import
torch.nn
as
nn
from
torch.utils.checkpoint
import
checkpoint
from
transformers
import
T5Tokenizer
,
T5EncoderModel
,
CLIPTokenizer
,
CLIPTextModel
import
open_clip
from
ldm.util
import
default
,
count_params
class
AbstractEncoder
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
encode
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
class
IdentityEncoder
(
AbstractEncoder
):
def
encode
(
self
,
x
):
return
x
class
ClassEmbedder
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
n_classes
=
1000
,
key
=
'class'
,
ucg_rate
=
0.1
):
super
().
__init__
()
self
.
key
=
key
self
.
embedding
=
nn
.
Embedding
(
n_classes
,
embed_dim
)
self
.
n_classes
=
n_classes
self
.
ucg_rate
=
ucg_rate
def
forward
(
self
,
batch
,
key
=
None
,
disable_dropout
=
False
):
if
key
is
None
:
key
=
self
.
key
# this is for use in crossattn
c
=
batch
[
key
][:,
None
]
if
self
.
ucg_rate
>
0.
and
not
disable_dropout
:
mask
=
1.
-
torch
.
bernoulli
(
torch
.
ones_like
(
c
)
*
self
.
ucg_rate
)
c
=
mask
*
c
+
(
1
-
mask
)
*
torch
.
ones_like
(
c
)
*
(
self
.
n_classes
-
1
)
c
=
c
.
long
()
c
=
self
.
embedding
(
c
)
return
c
def
get_unconditional_conditioning
(
self
,
bs
,
device
=
"cuda"
):
uc_class
=
self
.
n_classes
-
1
# 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc
=
torch
.
ones
((
bs
,),
device
=
device
)
*
uc_class
uc
=
{
self
.
key
:
uc
}
return
uc
def
disabled_train
(
self
,
mode
=
True
):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return
self
class
FrozenT5Embedder
(
AbstractEncoder
):
"""Uses the T5 transformer encoder for text"""
def
__init__
(
self
,
version
=
"google/t5-v1_1-large"
,
device
=
"cuda"
,
max_length
=
77
,
freeze
=
True
):
# others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super
().
__init__
()
self
.
tokenizer
=
T5Tokenizer
.
from_pretrained
(
version
)
self
.
transformer
=
T5EncoderModel
.
from_pretrained
(
version
)
self
.
device
=
device
self
.
max_length
=
max_length
# TODO: typical value?
if
freeze
:
self
.
freeze
()
def
freeze
(
self
):
self
.
transformer
=
self
.
transformer
.
eval
()
#self.train = disabled_train
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
text
):
batch_encoding
=
self
.
tokenizer
(
text
,
truncation
=
True
,
max_length
=
self
.
max_length
,
return_length
=
True
,
return_overflowing_tokens
=
False
,
padding
=
"max_length"
,
return_tensors
=
"pt"
)
tokens
=
batch_encoding
[
"input_ids"
].
to
(
self
.
device
)
outputs
=
self
.
transformer
(
input_ids
=
tokens
)
z
=
outputs
.
last_hidden_state
return
z
def
encode
(
self
,
text
):
return
self
(
text
)
class
FrozenCLIPEmbedder
(
AbstractEncoder
):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS
=
[
"last"
,
"pooled"
,
"hidden"
]
def
__init__
(
self
,
version
=
"openai/clip-vit-large-patch14"
,
device
=
"cuda"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"last"
,
layer_idx
=
None
):
# clip-vit-base-patch32
super
().
__init__
()
assert
layer
in
self
.
LAYERS
self
.
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
version
)
self
.
transformer
=
CLIPTextModel
.
from_pretrained
(
version
)
self
.
device
=
device
self
.
max_length
=
max_length
if
freeze
:
self
.
freeze
()
self
.
layer
=
layer
self
.
layer_idx
=
layer_idx
if
layer
==
"hidden"
:
assert
layer_idx
is
not
None
assert
0
<=
abs
(
layer_idx
)
<=
12
def
freeze
(
self
):
self
.
transformer
=
self
.
transformer
.
eval
()
#self.train = disabled_train
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
text
):
batch_encoding
=
self
.
tokenizer
(
text
,
truncation
=
True
,
max_length
=
self
.
max_length
,
return_length
=
True
,
return_overflowing_tokens
=
False
,
padding
=
"max_length"
,
return_tensors
=
"pt"
)
tokens
=
batch_encoding
[
"input_ids"
].
to
(
self
.
device
)
outputs
=
self
.
transformer
(
input_ids
=
tokens
,
output_hidden_states
=
self
.
layer
==
"hidden"
)
if
self
.
layer
==
"last"
:
z
=
outputs
.
last_hidden_state
elif
self
.
layer
==
"pooled"
:
z
=
outputs
.
pooler_output
[:,
None
,
:]
else
:
z
=
outputs
.
hidden_states
[
self
.
layer_idx
]
return
z
def
encode
(
self
,
text
):
return
self
(
text
)
class
FrozenOpenCLIPEmbedder
(
AbstractEncoder
):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS
=
[
#"pooled",
"last"
,
"penultimate"
]
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
version
=
"laion2b_s32b_b79k"
,
device
=
"cuda"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"last"
):
super
().
__init__
()
assert
layer
in
self
.
LAYERS
model
,
_
,
_
=
open_clip
.
create_model_and_transforms
(
arch
,
device
=
torch
.
device
(
'cpu'
),
pretrained
=
version
)
del
model
.
visual
self
.
model
=
model
self
.
device
=
device
self
.
max_length
=
max_length
if
freeze
:
self
.
freeze
()
self
.
layer
=
layer
if
self
.
layer
==
"last"
:
self
.
layer_idx
=
0
elif
self
.
layer
==
"penultimate"
:
self
.
layer_idx
=
1
else
:
raise
NotImplementedError
()
def
freeze
(
self
):
self
.
model
=
self
.
model
.
eval
()
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
text
):
tokens
=
open_clip
.
tokenize
(
text
)
z
=
self
.
encode_with_transformer
(
tokens
.
to
(
self
.
device
))
return
z
def
encode_with_transformer
(
self
,
text
):
x
=
self
.
model
.
token_embedding
(
text
)
# [batch_size, n_ctx, d_model]
x
=
x
+
self
.
model
.
positional_embedding
x
=
x
.
permute
(
1
,
0
,
2
)
# NLD -> LND
x
=
self
.
text_transformer_forward
(
x
,
attn_mask
=
self
.
model
.
attn_mask
)
x
=
x
.
permute
(
1
,
0
,
2
)
# LND -> NLD
x
=
self
.
model
.
ln_final
(
x
)
return
x
def
text_transformer_forward
(
self
,
x
:
torch
.
Tensor
,
attn_mask
=
None
):
for
i
,
r
in
enumerate
(
self
.
model
.
transformer
.
resblocks
):
if
i
==
len
(
self
.
model
.
transformer
.
resblocks
)
-
self
.
layer_idx
:
break
if
self
.
model
.
transformer
.
grad_checkpointing
and
not
torch
.
jit
.
is_scripting
():
x
=
checkpoint
(
r
,
x
,
attn_mask
)
else
:
x
=
r
(
x
,
attn_mask
=
attn_mask
)
return
x
def
encode
(
self
,
text
):
return
self
(
text
)
class
FrozenCLIPT5Encoder
(
AbstractEncoder
):
def
__init__
(
self
,
clip_version
=
"openai/clip-vit-large-patch14"
,
t5_version
=
"google/t5-v1_1-xl"
,
device
=
"cuda"
,
clip_max_length
=
77
,
t5_max_length
=
77
):
super
().
__init__
()
self
.
clip_encoder
=
FrozenCLIPEmbedder
(
clip_version
,
device
,
max_length
=
clip_max_length
)
self
.
t5_encoder
=
FrozenT5Embedder
(
t5_version
,
device
,
max_length
=
t5_max_length
)
print
(
f
"
{
self
.
clip_encoder
.
__class__
.
__name__
}
has
{
count_params
(
self
.
clip_encoder
)
*
1.e-6
:.
2
f
}
M parameters, "
f
"
{
self
.
t5_encoder
.
__class__
.
__name__
}
comes with
{
count_params
(
self
.
t5_encoder
)
*
1.e-6
:.
2
f
}
M params."
)
def
encode
(
self
,
text
):
return
self
(
text
)
def
forward
(
self
,
text
):
clip_z
=
self
.
clip_encoder
.
encode
(
text
)
t5_z
=
self
.
t5_encoder
.
encode
(
text
)
return
[
clip_z
,
t5_z
]
examples/images/diffusion/ldm/modules/image_degradation/__init__.py
0 → 100644
View file @
e532679c
from
ldm.modules.image_degradation.bsrgan
import
degradation_bsrgan_variant
as
degradation_fn_bsr
from
ldm.modules.image_degradation.bsrgan_light
import
degradation_bsrgan_variant
as
degradation_fn_bsr_light
examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py
0 → 100644
View file @
e532679c
# -*- coding: utf-8 -*-
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
import
numpy
as
np
import
cv2
import
torch
from
functools
import
partial
import
random
from
scipy
import
ndimage
import
scipy
import
scipy.stats
as
ss
from
scipy.interpolate
import
interp2d
from
scipy.linalg
import
orth
import
albumentations
import
ldm.modules.image_degradation.utils_image
as
util
def
modcrop_np
(
img
,
sf
):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w
,
h
=
img
.
shape
[:
2
]
im
=
np
.
copy
(
img
)
return
im
[:
w
-
w
%
sf
,
:
h
-
h
%
sf
,
...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def
analytic_kernel
(
k
):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size
=
k
.
shape
[
0
]
# Calculate the big kernels size
big_k
=
np
.
zeros
((
3
*
k_size
-
2
,
3
*
k_size
-
2
))
# Loop over the small kernel to fill the big one
for
r
in
range
(
k_size
):
for
c
in
range
(
k_size
):
big_k
[
2
*
r
:
2
*
r
+
k_size
,
2
*
c
:
2
*
c
+
k_size
]
+=
k
[
r
,
c
]
*
k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop
=
k_size
//
2
cropped_big_k
=
big_k
[
crop
:
-
crop
,
crop
:
-
crop
]
# Normalize to 1
return
cropped_big_k
/
cropped_big_k
.
sum
()
def
anisotropic_Gaussian
(
ksize
=
15
,
theta
=
np
.
pi
,
l1
=
6
,
l2
=
6
):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v
=
np
.
dot
(
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]]),
np
.
array
([
1.
,
0.
]))
V
=
np
.
array
([[
v
[
0
],
v
[
1
]],
[
v
[
1
],
-
v
[
0
]]])
D
=
np
.
array
([[
l1
,
0
],
[
0
,
l2
]])
Sigma
=
np
.
dot
(
np
.
dot
(
V
,
D
),
np
.
linalg
.
inv
(
V
))
k
=
gm_blur_kernel
(
mean
=
[
0
,
0
],
cov
=
Sigma
,
size
=
ksize
)
return
k
def
gm_blur_kernel
(
mean
,
cov
,
size
=
15
):
center
=
size
/
2.0
+
0.5
k
=
np
.
zeros
([
size
,
size
])
for
y
in
range
(
size
):
for
x
in
range
(
size
):
cy
=
y
-
center
+
1
cx
=
x
-
center
+
1
k
[
y
,
x
]
=
ss
.
multivariate_normal
.
pdf
([
cx
,
cy
],
mean
=
mean
,
cov
=
cov
)
k
=
k
/
np
.
sum
(
k
)
return
k
def
shift_pixel
(
x
,
sf
,
upper_left
=
True
):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h
,
w
=
x
.
shape
[:
2
]
shift
=
(
sf
-
1
)
*
0.5
xv
,
yv
=
np
.
arange
(
0
,
w
,
1.0
),
np
.
arange
(
0
,
h
,
1.0
)
if
upper_left
:
x1
=
xv
+
shift
y1
=
yv
+
shift
else
:
x1
=
xv
-
shift
y1
=
yv
-
shift
x1
=
np
.
clip
(
x1
,
0
,
w
-
1
)
y1
=
np
.
clip
(
y1
,
0
,
h
-
1
)
if
x
.
ndim
==
2
:
x
=
interp2d
(
xv
,
yv
,
x
)(
x1
,
y1
)
if
x
.
ndim
==
3
:
for
i
in
range
(
x
.
shape
[
-
1
]):
x
[:,
:,
i
]
=
interp2d
(
xv
,
yv
,
x
[:,
:,
i
])(
x1
,
y1
)
return
x
def
blur
(
x
,
k
):
'''
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
n
,
c
=
x
.
shape
[:
2
]
p1
,
p2
=
(
k
.
shape
[
-
2
]
-
1
)
//
2
,
(
k
.
shape
[
-
1
]
-
1
)
//
2
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
=
(
p1
,
p2
,
p1
,
p2
),
mode
=
'replicate'
)
k
=
k
.
repeat
(
1
,
c
,
1
,
1
)
k
=
k
.
view
(
-
1
,
1
,
k
.
shape
[
2
],
k
.
shape
[
3
])
x
=
x
.
view
(
1
,
-
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
x
=
torch
.
nn
.
functional
.
conv2d
(
x
,
k
,
bias
=
None
,
stride
=
1
,
padding
=
0
,
groups
=
n
*
c
)
x
=
x
.
view
(
n
,
c
,
x
.
shape
[
2
],
x
.
shape
[
3
])
return
x
def
gen_kernel
(
k_size
=
np
.
array
([
15
,
15
]),
scale_factor
=
np
.
array
([
4
,
4
]),
min_var
=
0.6
,
max_var
=
10.
,
noise_level
=
0
):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
lambda_2
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
theta
=
np
.
random
.
rand
()
*
np
.
pi
# random theta
noise
=
-
noise_level
+
np
.
random
.
rand
(
*
k_size
)
*
noise_level
*
2
# Set COV matrix using Lambdas and Theta
LAMBDA
=
np
.
diag
([
lambda_1
,
lambda_2
])
Q
=
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]])
SIGMA
=
Q
@
LAMBDA
@
Q
.
T
INV_SIGMA
=
np
.
linalg
.
inv
(
SIGMA
)[
None
,
None
,
:,
:]
# Set expectation position (shifting kernel for aligned image)
MU
=
k_size
//
2
-
0.5
*
(
scale_factor
-
1
)
# - 0.5 * (scale_factor - k_size % 2)
MU
=
MU
[
None
,
None
,
:,
None
]
# Create meshgrid for Gaussian
[
X
,
Y
]
=
np
.
meshgrid
(
range
(
k_size
[
0
]),
range
(
k_size
[
1
]))
Z
=
np
.
stack
([
X
,
Y
],
2
)[:,
:,
:,
None
]
# Calcualte Gaussian for every pixel of the kernel
ZZ
=
Z
-
MU
ZZ_t
=
ZZ
.
transpose
(
0
,
1
,
3
,
2
)
raw_kernel
=
np
.
exp
(
-
0.5
*
np
.
squeeze
(
ZZ_t
@
INV_SIGMA
@
ZZ
))
*
(
1
+
noise
)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel
=
raw_kernel
/
np
.
sum
(
raw_kernel
)
return
kernel
def
fspecial_gaussian
(
hsize
,
sigma
):
hsize
=
[
hsize
,
hsize
]
siz
=
[(
hsize
[
0
]
-
1.0
)
/
2.0
,
(
hsize
[
1
]
-
1.0
)
/
2.0
]
std
=
sigma
[
x
,
y
]
=
np
.
meshgrid
(
np
.
arange
(
-
siz
[
1
],
siz
[
1
]
+
1
),
np
.
arange
(
-
siz
[
0
],
siz
[
0
]
+
1
))
arg
=
-
(
x
*
x
+
y
*
y
)
/
(
2
*
std
*
std
)
h
=
np
.
exp
(
arg
)
h
[
h
<
scipy
.
finfo
(
float
).
eps
*
h
.
max
()]
=
0
sumh
=
h
.
sum
()
if
sumh
!=
0
:
h
=
h
/
sumh
return
h
def
fspecial_laplacian
(
alpha
):
alpha
=
max
([
0
,
min
([
alpha
,
1
])])
h1
=
alpha
/
(
alpha
+
1
)
h2
=
(
1
-
alpha
)
/
(
alpha
+
1
)
h
=
[[
h1
,
h2
,
h1
],
[
h2
,
-
4
/
(
alpha
+
1
),
h2
],
[
h1
,
h2
,
h1
]]
h
=
np
.
array
(
h
)
return
h
def
fspecial
(
filter_type
,
*
args
,
**
kwargs
):
'''
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
if
filter_type
==
'gaussian'
:
return
fspecial_gaussian
(
*
args
,
**
kwargs
)
if
filter_type
==
'laplacian'
:
return
fspecial_laplacian
(
*
args
,
**
kwargs
)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def
bicubic_degradation
(
x
,
sf
=
3
):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x
=
util
.
imresize_np
(
x
,
scale
=
1
/
sf
)
return
x
def
srmd_degradation
(
x
,
k
,
sf
=
3
):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# 'nearest' | 'mirror'
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
return
x
def
dpsr_degradation
(
x
,
k
,
sf
=
3
):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
return
x
def
classical_degradation
(
x
,
k
,
sf
=
3
):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st
=
0
return
x
[
st
::
sf
,
st
::
sf
,
...]
def
add_sharpening
(
img
,
weight
=
0.5
,
radius
=
50
,
threshold
=
10
):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if
radius
%
2
==
0
:
radius
+=
1
blur
=
cv2
.
GaussianBlur
(
img
,
(
radius
,
radius
),
0
)
residual
=
img
-
blur
mask
=
np
.
abs
(
residual
)
*
255
>
threshold
mask
=
mask
.
astype
(
'float32'
)
soft_mask
=
cv2
.
GaussianBlur
(
mask
,
(
radius
,
radius
),
0
)
K
=
img
+
weight
*
residual
K
=
np
.
clip
(
K
,
0
,
1
)
return
soft_mask
*
K
+
(
1
-
soft_mask
)
*
img
def
add_blur
(
img
,
sf
=
4
):
wd2
=
4.0
+
sf
wd
=
2.0
+
0.2
*
sf
if
random
.
random
()
<
0.5
:
l1
=
wd2
*
random
.
random
()
l2
=
wd2
*
random
.
random
()
k
=
anisotropic_Gaussian
(
ksize
=
2
*
random
.
randint
(
2
,
11
)
+
3
,
theta
=
random
.
random
()
*
np
.
pi
,
l1
=
l1
,
l2
=
l2
)
else
:
k
=
fspecial
(
'gaussian'
,
2
*
random
.
randint
(
2
,
11
)
+
3
,
wd
*
random
.
random
())
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'mirror'
)
return
img
def
add_resize
(
img
,
sf
=
4
):
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.8
:
# up
sf1
=
random
.
uniform
(
1
,
2
)
elif
rnum
<
0.7
:
# down
sf1
=
random
.
uniform
(
0.5
/
sf
,
1
)
else
:
sf1
=
1.0
img
=
cv2
.
resize
(
img
,
(
int
(
sf1
*
img
.
shape
[
1
]),
int
(
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.6
:
# add color Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
# add grayscale Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
# add noise
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
=
img
+
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_speckle_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
rnum
=
random
.
random
()
if
rnum
>
0.6
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
+=
img
*
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_Poisson_noise
(
img
):
img
=
np
.
clip
((
img
*
255.0
).
round
(),
0
,
255
)
/
255.
vals
=
10
**
(
2
*
random
.
random
()
+
2.0
)
# [2, 4]
if
random
.
random
()
<
0.5
:
img
=
np
.
random
.
poisson
(
img
*
vals
).
astype
(
np
.
float32
)
/
vals
else
:
img_gray
=
np
.
dot
(
img
[...,
:
3
],
[
0.299
,
0.587
,
0.114
])
img_gray
=
np
.
clip
((
img_gray
*
255.0
).
round
(),
0
,
255
)
/
255.
noise_gray
=
np
.
random
.
poisson
(
img_gray
*
vals
).
astype
(
np
.
float32
)
/
vals
-
img_gray
img
+=
noise_gray
[:,
:,
np
.
newaxis
]
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_JPEG_noise
(
img
):
quality_factor
=
random
.
randint
(
30
,
95
)
img
=
cv2
.
cvtColor
(
util
.
single2uint
(
img
),
cv2
.
COLOR_RGB2BGR
)
result
,
encimg
=
cv2
.
imencode
(
'.jpg'
,
img
,
[
int
(
cv2
.
IMWRITE_JPEG_QUALITY
),
quality_factor
])
img
=
cv2
.
imdecode
(
encimg
,
1
)
img
=
cv2
.
cvtColor
(
util
.
uint2single
(
img
),
cv2
.
COLOR_BGR2RGB
)
return
img
def
random_crop
(
lq
,
hq
,
sf
=
4
,
lq_patchsize
=
64
):
h
,
w
=
lq
.
shape
[:
2
]
rnd_h
=
random
.
randint
(
0
,
h
-
lq_patchsize
)
rnd_w
=
random
.
randint
(
0
,
w
-
lq_patchsize
)
lq
=
lq
[
rnd_h
:
rnd_h
+
lq_patchsize
,
rnd_w
:
rnd_w
+
lq_patchsize
,
:]
rnd_h_H
,
rnd_w_H
=
int
(
rnd_h
*
sf
),
int
(
rnd_w
*
sf
)
hq
=
hq
[
rnd_h_H
:
rnd_h_H
+
lq_patchsize
*
sf
,
rnd_w_H
:
rnd_w_H
+
lq_patchsize
*
sf
,
:]
return
lq
,
hq
def
degradation_bsrgan
(
img
,
sf
=
4
,
lq_patchsize
=
72
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
img
.
shape
[:
2
]
img
=
img
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
img
.
shape
[:
2
]
if
h
<
lq_patchsize
*
sf
or
w
<
lq_patchsize
*
sf
:
raise
ValueError
(
f
'img size (
{
h1
}
X
{
w1
}
) is too small!'
)
hq
=
img
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
2
*
img
.
shape
[
1
]),
int
(
1
/
2
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
img
=
util
.
imresize_np
(
img
,
1
/
2
,
True
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
1
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
2
:
a
,
b
=
img
.
shape
[
1
],
img
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.75
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf1
*
img
.
shape
[
1
]),
int
(
1
/
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
img
=
img
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
img
=
add_JPEG_noise
(
img
)
elif
i
==
6
:
# add processed camera sensor noise
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
# add final JPEG compression noise
img
=
add_JPEG_noise
(
img
)
# random crop
img
,
hq
=
random_crop
(
img
,
hq
,
sf_ori
,
lq_patchsize
)
return
img
,
hq
# todo no isp_model?
def
degradation_bsrgan_variant
(
image
,
sf
=
4
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image
=
util
.
uint2single
(
image
)
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
image
.
shape
[:
2
]
image
=
image
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
image
.
shape
[:
2
]
hq
=
image
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
2
*
image
.
shape
[
1
]),
int
(
1
/
2
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
image
=
util
.
imresize_np
(
image
,
1
/
2
,
True
)
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
image
=
add_blur
(
image
,
sf
=
sf
)
elif
i
==
1
:
image
=
add_blur
(
image
,
sf
=
sf
)
elif
i
==
2
:
a
,
b
=
image
.
shape
[
1
],
image
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.75
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf1
*
image
.
shape
[
1
]),
int
(
1
/
sf1
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
image
=
ndimage
.
filters
.
convolve
(
image
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
image
=
image
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
image
=
add_Gaussian_noise
(
image
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
image
=
add_JPEG_noise
(
image
)
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image
=
add_JPEG_noise
(
image
)
image
=
util
.
single2uint
(
image
)
example
=
{
"image"
:
image
}
return
example
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
def
degradation_bsrgan_plus
(
img
,
sf
=
4
,
shuffle_prob
=
0.5
,
use_sharp
=
True
,
lq_patchsize
=
64
,
isp_model
=
None
):
"""
This is an extended degradation model by combining
the degradation models of BSRGAN and Real-ESRGAN
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
use_shuffle: the degradation shuffle
use_sharp: sharpening the img
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
h1
,
w1
=
img
.
shape
[:
2
]
img
=
img
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
img
.
shape
[:
2
]
if
h
<
lq_patchsize
*
sf
or
w
<
lq_patchsize
*
sf
:
raise
ValueError
(
f
'img size (
{
h1
}
X
{
w1
}
) is too small!'
)
if
use_sharp
:
img
=
add_sharpening
(
img
)
hq
=
img
.
copy
()
if
random
.
random
()
<
shuffle_prob
:
shuffle_order
=
random
.
sample
(
range
(
13
),
13
)
else
:
shuffle_order
=
list
(
range
(
13
))
# local shuffle for noise, JPEG is always the last one
shuffle_order
[
2
:
6
]
=
random
.
sample
(
shuffle_order
[
2
:
6
],
len
(
range
(
2
,
6
)))
shuffle_order
[
9
:
13
]
=
random
.
sample
(
shuffle_order
[
9
:
13
],
len
(
range
(
9
,
13
)))
poisson_prob
,
speckle_prob
,
isp_prob
=
0.1
,
0.1
,
0.1
for
i
in
shuffle_order
:
if
i
==
0
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
1
:
img
=
add_resize
(
img
,
sf
=
sf
)
elif
i
==
2
:
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
3
:
if
random
.
random
()
<
poisson_prob
:
img
=
add_Poisson_noise
(
img
)
elif
i
==
4
:
if
random
.
random
()
<
speckle_prob
:
img
=
add_speckle_noise
(
img
)
elif
i
==
5
:
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
elif
i
==
6
:
img
=
add_JPEG_noise
(
img
)
elif
i
==
7
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
8
:
img
=
add_resize
(
img
,
sf
=
sf
)
elif
i
==
9
:
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
10
:
if
random
.
random
()
<
poisson_prob
:
img
=
add_Poisson_noise
(
img
)
elif
i
==
11
:
if
random
.
random
()
<
speckle_prob
:
img
=
add_speckle_noise
(
img
)
elif
i
==
12
:
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
else
:
print
(
'check the shuffle!'
)
# resize to desired size
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf
*
hq
.
shape
[
1
]),
int
(
1
/
sf
*
hq
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
# add final JPEG compression noise
img
=
add_JPEG_noise
(
img
)
# random crop
img
,
hq
=
random_crop
(
img
,
hq
,
sf
,
lq_patchsize
)
return
img
,
hq
if
__name__
==
'__main__'
:
print
(
"hey"
)
img
=
util
.
imread_uint
(
'utils/test.png'
,
3
)
print
(
img
)
img
=
util
.
uint2single
(
img
)
print
(
img
)
img
=
img
[:
448
,
:
448
]
h
=
img
.
shape
[
0
]
//
4
print
(
"resizing to"
,
h
)
sf
=
4
deg_fn
=
partial
(
degradation_bsrgan_variant
,
sf
=
sf
)
for
i
in
range
(
20
):
print
(
i
)
img_lq
=
deg_fn
(
img
)
print
(
img_lq
)
img_lq_bicubic
=
albumentations
.
SmallestMaxSize
(
max_size
=
h
,
interpolation
=
cv2
.
INTER_CUBIC
)(
image
=
img
)[
"image"
]
print
(
img_lq
.
shape
)
print
(
"bicubic"
,
img_lq_bicubic
.
shape
)
print
(
img_hq
.
shape
)
lq_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
lq_bicubic_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq_bicubic
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
img_concat
=
np
.
concatenate
([
lq_bicubic_nearest
,
lq_nearest
,
util
.
single2uint
(
img_hq
)],
axis
=
1
)
util
.
imsave
(
img_concat
,
str
(
i
)
+
'.png'
)
examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py
0 → 100644
View file @
e532679c
# -*- coding: utf-8 -*-
import
numpy
as
np
import
cv2
import
torch
from
functools
import
partial
import
random
from
scipy
import
ndimage
import
scipy
import
scipy.stats
as
ss
from
scipy.interpolate
import
interp2d
from
scipy.linalg
import
orth
import
albumentations
import
ldm.modules.image_degradation.utils_image
as
util
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
def
modcrop_np
(
img
,
sf
):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w
,
h
=
img
.
shape
[:
2
]
im
=
np
.
copy
(
img
)
return
im
[:
w
-
w
%
sf
,
:
h
-
h
%
sf
,
...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def
analytic_kernel
(
k
):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size
=
k
.
shape
[
0
]
# Calculate the big kernels size
big_k
=
np
.
zeros
((
3
*
k_size
-
2
,
3
*
k_size
-
2
))
# Loop over the small kernel to fill the big one
for
r
in
range
(
k_size
):
for
c
in
range
(
k_size
):
big_k
[
2
*
r
:
2
*
r
+
k_size
,
2
*
c
:
2
*
c
+
k_size
]
+=
k
[
r
,
c
]
*
k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop
=
k_size
//
2
cropped_big_k
=
big_k
[
crop
:
-
crop
,
crop
:
-
crop
]
# Normalize to 1
return
cropped_big_k
/
cropped_big_k
.
sum
()
def
anisotropic_Gaussian
(
ksize
=
15
,
theta
=
np
.
pi
,
l1
=
6
,
l2
=
6
):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v
=
np
.
dot
(
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]]),
np
.
array
([
1.
,
0.
]))
V
=
np
.
array
([[
v
[
0
],
v
[
1
]],
[
v
[
1
],
-
v
[
0
]]])
D
=
np
.
array
([[
l1
,
0
],
[
0
,
l2
]])
Sigma
=
np
.
dot
(
np
.
dot
(
V
,
D
),
np
.
linalg
.
inv
(
V
))
k
=
gm_blur_kernel
(
mean
=
[
0
,
0
],
cov
=
Sigma
,
size
=
ksize
)
return
k
def
gm_blur_kernel
(
mean
,
cov
,
size
=
15
):
center
=
size
/
2.0
+
0.5
k
=
np
.
zeros
([
size
,
size
])
for
y
in
range
(
size
):
for
x
in
range
(
size
):
cy
=
y
-
center
+
1
cx
=
x
-
center
+
1
k
[
y
,
x
]
=
ss
.
multivariate_normal
.
pdf
([
cx
,
cy
],
mean
=
mean
,
cov
=
cov
)
k
=
k
/
np
.
sum
(
k
)
return
k
def
shift_pixel
(
x
,
sf
,
upper_left
=
True
):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h
,
w
=
x
.
shape
[:
2
]
shift
=
(
sf
-
1
)
*
0.5
xv
,
yv
=
np
.
arange
(
0
,
w
,
1.0
),
np
.
arange
(
0
,
h
,
1.0
)
if
upper_left
:
x1
=
xv
+
shift
y1
=
yv
+
shift
else
:
x1
=
xv
-
shift
y1
=
yv
-
shift
x1
=
np
.
clip
(
x1
,
0
,
w
-
1
)
y1
=
np
.
clip
(
y1
,
0
,
h
-
1
)
if
x
.
ndim
==
2
:
x
=
interp2d
(
xv
,
yv
,
x
)(
x1
,
y1
)
if
x
.
ndim
==
3
:
for
i
in
range
(
x
.
shape
[
-
1
]):
x
[:,
:,
i
]
=
interp2d
(
xv
,
yv
,
x
[:,
:,
i
])(
x1
,
y1
)
return
x
def
blur
(
x
,
k
):
'''
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
n
,
c
=
x
.
shape
[:
2
]
p1
,
p2
=
(
k
.
shape
[
-
2
]
-
1
)
//
2
,
(
k
.
shape
[
-
1
]
-
1
)
//
2
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
=
(
p1
,
p2
,
p1
,
p2
),
mode
=
'replicate'
)
k
=
k
.
repeat
(
1
,
c
,
1
,
1
)
k
=
k
.
view
(
-
1
,
1
,
k
.
shape
[
2
],
k
.
shape
[
3
])
x
=
x
.
view
(
1
,
-
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
x
=
torch
.
nn
.
functional
.
conv2d
(
x
,
k
,
bias
=
None
,
stride
=
1
,
padding
=
0
,
groups
=
n
*
c
)
x
=
x
.
view
(
n
,
c
,
x
.
shape
[
2
],
x
.
shape
[
3
])
return
x
def
gen_kernel
(
k_size
=
np
.
array
([
15
,
15
]),
scale_factor
=
np
.
array
([
4
,
4
]),
min_var
=
0.6
,
max_var
=
10.
,
noise_level
=
0
):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
lambda_2
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
theta
=
np
.
random
.
rand
()
*
np
.
pi
# random theta
noise
=
-
noise_level
+
np
.
random
.
rand
(
*
k_size
)
*
noise_level
*
2
# Set COV matrix using Lambdas and Theta
LAMBDA
=
np
.
diag
([
lambda_1
,
lambda_2
])
Q
=
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]])
SIGMA
=
Q
@
LAMBDA
@
Q
.
T
INV_SIGMA
=
np
.
linalg
.
inv
(
SIGMA
)[
None
,
None
,
:,
:]
# Set expectation position (shifting kernel for aligned image)
MU
=
k_size
//
2
-
0.5
*
(
scale_factor
-
1
)
# - 0.5 * (scale_factor - k_size % 2)
MU
=
MU
[
None
,
None
,
:,
None
]
# Create meshgrid for Gaussian
[
X
,
Y
]
=
np
.
meshgrid
(
range
(
k_size
[
0
]),
range
(
k_size
[
1
]))
Z
=
np
.
stack
([
X
,
Y
],
2
)[:,
:,
:,
None
]
# Calcualte Gaussian for every pixel of the kernel
ZZ
=
Z
-
MU
ZZ_t
=
ZZ
.
transpose
(
0
,
1
,
3
,
2
)
raw_kernel
=
np
.
exp
(
-
0.5
*
np
.
squeeze
(
ZZ_t
@
INV_SIGMA
@
ZZ
))
*
(
1
+
noise
)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel
=
raw_kernel
/
np
.
sum
(
raw_kernel
)
return
kernel
def
fspecial_gaussian
(
hsize
,
sigma
):
hsize
=
[
hsize
,
hsize
]
siz
=
[(
hsize
[
0
]
-
1.0
)
/
2.0
,
(
hsize
[
1
]
-
1.0
)
/
2.0
]
std
=
sigma
[
x
,
y
]
=
np
.
meshgrid
(
np
.
arange
(
-
siz
[
1
],
siz
[
1
]
+
1
),
np
.
arange
(
-
siz
[
0
],
siz
[
0
]
+
1
))
arg
=
-
(
x
*
x
+
y
*
y
)
/
(
2
*
std
*
std
)
h
=
np
.
exp
(
arg
)
h
[
h
<
scipy
.
finfo
(
float
).
eps
*
h
.
max
()]
=
0
sumh
=
h
.
sum
()
if
sumh
!=
0
:
h
=
h
/
sumh
return
h
def
fspecial_laplacian
(
alpha
):
alpha
=
max
([
0
,
min
([
alpha
,
1
])])
h1
=
alpha
/
(
alpha
+
1
)
h2
=
(
1
-
alpha
)
/
(
alpha
+
1
)
h
=
[[
h1
,
h2
,
h1
],
[
h2
,
-
4
/
(
alpha
+
1
),
h2
],
[
h1
,
h2
,
h1
]]
h
=
np
.
array
(
h
)
return
h
def
fspecial
(
filter_type
,
*
args
,
**
kwargs
):
'''
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
if
filter_type
==
'gaussian'
:
return
fspecial_gaussian
(
*
args
,
**
kwargs
)
if
filter_type
==
'laplacian'
:
return
fspecial_laplacian
(
*
args
,
**
kwargs
)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def
bicubic_degradation
(
x
,
sf
=
3
):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x
=
util
.
imresize_np
(
x
,
scale
=
1
/
sf
)
return
x
def
srmd_degradation
(
x
,
k
,
sf
=
3
):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x
=
ndimage
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# 'nearest' | 'mirror'
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
return
x
def
dpsr_degradation
(
x
,
k
,
sf
=
3
):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
x
=
ndimage
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
return
x
def
classical_degradation
(
x
,
k
,
sf
=
3
):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x
=
ndimage
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st
=
0
return
x
[
st
::
sf
,
st
::
sf
,
...]
def
add_sharpening
(
img
,
weight
=
0.5
,
radius
=
50
,
threshold
=
10
):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if
radius
%
2
==
0
:
radius
+=
1
blur
=
cv2
.
GaussianBlur
(
img
,
(
radius
,
radius
),
0
)
residual
=
img
-
blur
mask
=
np
.
abs
(
residual
)
*
255
>
threshold
mask
=
mask
.
astype
(
'float32'
)
soft_mask
=
cv2
.
GaussianBlur
(
mask
,
(
radius
,
radius
),
0
)
K
=
img
+
weight
*
residual
K
=
np
.
clip
(
K
,
0
,
1
)
return
soft_mask
*
K
+
(
1
-
soft_mask
)
*
img
def
add_blur
(
img
,
sf
=
4
):
wd2
=
4.0
+
sf
wd
=
2.0
+
0.2
*
sf
wd2
=
wd2
/
4
wd
=
wd
/
4
if
random
.
random
()
<
0.5
:
l1
=
wd2
*
random
.
random
()
l2
=
wd2
*
random
.
random
()
k
=
anisotropic_Gaussian
(
ksize
=
random
.
randint
(
2
,
11
)
+
3
,
theta
=
random
.
random
()
*
np
.
pi
,
l1
=
l1
,
l2
=
l2
)
else
:
k
=
fspecial
(
'gaussian'
,
random
.
randint
(
2
,
4
)
+
3
,
wd
*
random
.
random
())
img
=
ndimage
.
convolve
(
img
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'mirror'
)
return
img
def
add_resize
(
img
,
sf
=
4
):
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.8
:
# up
sf1
=
random
.
uniform
(
1
,
2
)
elif
rnum
<
0.7
:
# down
sf1
=
random
.
uniform
(
0.5
/
sf
,
1
)
else
:
sf1
=
1.0
img
=
cv2
.
resize
(
img
,
(
int
(
sf1
*
img
.
shape
[
1
]),
int
(
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.6
:
# add color Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
# add grayscale Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
# add noise
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
=
img
+
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_speckle_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
rnum
=
random
.
random
()
if
rnum
>
0.6
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
+=
img
*
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_Poisson_noise
(
img
):
img
=
np
.
clip
((
img
*
255.0
).
round
(),
0
,
255
)
/
255.
vals
=
10
**
(
2
*
random
.
random
()
+
2.0
)
# [2, 4]
if
random
.
random
()
<
0.5
:
img
=
np
.
random
.
poisson
(
img
*
vals
).
astype
(
np
.
float32
)
/
vals
else
:
img_gray
=
np
.
dot
(
img
[...,
:
3
],
[
0.299
,
0.587
,
0.114
])
img_gray
=
np
.
clip
((
img_gray
*
255.0
).
round
(),
0
,
255
)
/
255.
noise_gray
=
np
.
random
.
poisson
(
img_gray
*
vals
).
astype
(
np
.
float32
)
/
vals
-
img_gray
img
+=
noise_gray
[:,
:,
np
.
newaxis
]
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_JPEG_noise
(
img
):
quality_factor
=
random
.
randint
(
80
,
95
)
img
=
cv2
.
cvtColor
(
util
.
single2uint
(
img
),
cv2
.
COLOR_RGB2BGR
)
result
,
encimg
=
cv2
.
imencode
(
'.jpg'
,
img
,
[
int
(
cv2
.
IMWRITE_JPEG_QUALITY
),
quality_factor
])
img
=
cv2
.
imdecode
(
encimg
,
1
)
img
=
cv2
.
cvtColor
(
util
.
uint2single
(
img
),
cv2
.
COLOR_BGR2RGB
)
return
img
def
random_crop
(
lq
,
hq
,
sf
=
4
,
lq_patchsize
=
64
):
h
,
w
=
lq
.
shape
[:
2
]
rnd_h
=
random
.
randint
(
0
,
h
-
lq_patchsize
)
rnd_w
=
random
.
randint
(
0
,
w
-
lq_patchsize
)
lq
=
lq
[
rnd_h
:
rnd_h
+
lq_patchsize
,
rnd_w
:
rnd_w
+
lq_patchsize
,
:]
rnd_h_H
,
rnd_w_H
=
int
(
rnd_h
*
sf
),
int
(
rnd_w
*
sf
)
hq
=
hq
[
rnd_h_H
:
rnd_h_H
+
lq_patchsize
*
sf
,
rnd_w_H
:
rnd_w_H
+
lq_patchsize
*
sf
,
:]
return
lq
,
hq
def
degradation_bsrgan
(
img
,
sf
=
4
,
lq_patchsize
=
72
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
img
.
shape
[:
2
]
img
=
img
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
img
.
shape
[:
2
]
if
h
<
lq_patchsize
*
sf
or
w
<
lq_patchsize
*
sf
:
raise
ValueError
(
f
'img size (
{
h1
}
X
{
w1
}
) is too small!'
)
hq
=
img
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
2
*
img
.
shape
[
1
]),
int
(
1
/
2
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
img
=
util
.
imresize_np
(
img
,
1
/
2
,
True
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
1
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
2
:
a
,
b
=
img
.
shape
[
1
],
img
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.75
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf1
*
img
.
shape
[
1
]),
int
(
1
/
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
img
=
ndimage
.
convolve
(
img
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
img
=
img
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
8
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
img
=
add_JPEG_noise
(
img
)
elif
i
==
6
:
# add processed camera sensor noise
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
# add final JPEG compression noise
img
=
add_JPEG_noise
(
img
)
# random crop
img
,
hq
=
random_crop
(
img
,
hq
,
sf_ori
,
lq_patchsize
)
return
img
,
hq
# todo no isp_model?
def
degradation_bsrgan_variant
(
image
,
sf
=
4
,
isp_model
=
None
,
up
=
False
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image
=
util
.
uint2single
(
image
)
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
image
.
shape
[:
2
]
image
=
image
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
image
.
shape
[:
2
]
hq
=
image
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
2
*
image
.
shape
[
1
]),
int
(
1
/
2
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
image
=
util
.
imresize_np
(
image
,
1
/
2
,
True
)
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
image
=
add_blur
(
image
,
sf
=
sf
)
# elif i == 1:
# image = add_blur(image, sf=sf)
if
i
==
0
:
pass
elif
i
==
2
:
a
,
b
=
image
.
shape
[
1
],
image
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.8
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf1
*
image
.
shape
[
1
]),
int
(
1
/
sf1
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
image
=
ndimage
.
convolve
(
image
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
image
=
image
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
image
=
add_Gaussian_noise
(
image
,
noise_level1
=
1
,
noise_level2
=
2
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
image
=
add_JPEG_noise
(
image
)
#
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image
=
add_JPEG_noise
(
image
)
image
=
util
.
single2uint
(
image
)
if
up
:
image
=
cv2
.
resize
(
image
,
(
w1
,
h1
),
interpolation
=
cv2
.
INTER_CUBIC
)
# todo: random, as above? want to condition on it then
example
=
{
"image"
:
image
}
return
example
if
__name__
==
'__main__'
:
print
(
"hey"
)
img
=
util
.
imread_uint
(
'utils/test.png'
,
3
)
img
=
img
[:
448
,
:
448
]
h
=
img
.
shape
[
0
]
//
4
print
(
"resizing to"
,
h
)
sf
=
4
deg_fn
=
partial
(
degradation_bsrgan_variant
,
sf
=
sf
)
for
i
in
range
(
20
):
print
(
i
)
img_hq
=
img
img_lq
=
deg_fn
(
img
)[
"image"
]
img_hq
,
img_lq
=
util
.
uint2single
(
img_hq
),
util
.
uint2single
(
img_lq
)
print
(
img_lq
)
img_lq_bicubic
=
albumentations
.
SmallestMaxSize
(
max_size
=
h
,
interpolation
=
cv2
.
INTER_CUBIC
)(
image
=
img_hq
)[
"image"
]
print
(
img_lq
.
shape
)
print
(
"bicubic"
,
img_lq_bicubic
.
shape
)
print
(
img_hq
.
shape
)
lq_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
lq_bicubic_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq_bicubic
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
img_concat
=
np
.
concatenate
([
lq_bicubic_nearest
,
lq_nearest
,
util
.
single2uint
(
img_hq
)],
axis
=
1
)
util
.
imsave
(
img_concat
,
str
(
i
)
+
'.png'
)
examples/images/diffusion/ldm/modules/image_degradation/utils/test.png
0 → 100644
View file @
e532679c
431 KB
examples/images/diffusion/ldm/modules/image_degradation/utils_image.py
0 → 100644
View file @
e532679c
import
os
import
math
import
random
import
numpy
as
np
import
torch
import
cv2
from
torchvision.utils
import
make_grid
from
datetime
import
datetime
#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
os
.
environ
[
"KMP_DUPLICATE_LIB_OK"
]
=
"TRUE"
'''
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# --------------------------------------------
# https://github.com/twhui/SRGAN-pyTorch
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
IMG_EXTENSIONS
=
[
'.jpg'
,
'.JPG'
,
'.jpeg'
,
'.JPEG'
,
'.png'
,
'.PNG'
,
'.ppm'
,
'.PPM'
,
'.bmp'
,
'.BMP'
,
'.tif'
]
def
is_image_file
(
filename
):
return
any
(
filename
.
endswith
(
extension
)
for
extension
in
IMG_EXTENSIONS
)
def
get_timestamp
():
return
datetime
.
now
().
strftime
(
'%y%m%d-%H%M%S'
)
def
imshow
(
x
,
title
=
None
,
cbar
=
False
,
figsize
=
None
):
plt
.
figure
(
figsize
=
figsize
)
plt
.
imshow
(
np
.
squeeze
(
x
),
interpolation
=
'nearest'
,
cmap
=
'gray'
)
if
title
:
plt
.
title
(
title
)
if
cbar
:
plt
.
colorbar
()
plt
.
show
()
def
surf
(
Z
,
cmap
=
'rainbow'
,
figsize
=
None
):
plt
.
figure
(
figsize
=
figsize
)
ax3
=
plt
.
axes
(
projection
=
'3d'
)
w
,
h
=
Z
.
shape
[:
2
]
xx
=
np
.
arange
(
0
,
w
,
1
)
yy
=
np
.
arange
(
0
,
h
,
1
)
X
,
Y
=
np
.
meshgrid
(
xx
,
yy
)
ax3
.
plot_surface
(
X
,
Y
,
Z
,
cmap
=
cmap
)
#ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
plt
.
show
()
'''
# --------------------------------------------
# get image pathes
# --------------------------------------------
'''
def
get_image_paths
(
dataroot
):
paths
=
None
# return None if dataroot is None
if
dataroot
is
not
None
:
paths
=
sorted
(
_get_paths_from_images
(
dataroot
))
return
paths
def
_get_paths_from_images
(
path
):
assert
os
.
path
.
isdir
(
path
),
'{:s} is not a valid directory'
.
format
(
path
)
images
=
[]
for
dirpath
,
_
,
fnames
in
sorted
(
os
.
walk
(
path
)):
for
fname
in
sorted
(
fnames
):
if
is_image_file
(
fname
):
img_path
=
os
.
path
.
join
(
dirpath
,
fname
)
images
.
append
(
img_path
)
assert
images
,
'{:s} has no valid image file'
.
format
(
path
)
return
images
'''
# --------------------------------------------
# split large images into small images
# --------------------------------------------
'''
def
patches_from_image
(
img
,
p_size
=
512
,
p_overlap
=
64
,
p_max
=
800
):
w
,
h
=
img
.
shape
[:
2
]
patches
=
[]
if
w
>
p_max
and
h
>
p_max
:
w1
=
list
(
np
.
arange
(
0
,
w
-
p_size
,
p_size
-
p_overlap
,
dtype
=
np
.
int
))
h1
=
list
(
np
.
arange
(
0
,
h
-
p_size
,
p_size
-
p_overlap
,
dtype
=
np
.
int
))
w1
.
append
(
w
-
p_size
)
h1
.
append
(
h
-
p_size
)
# print(w1)
# print(h1)
for
i
in
w1
:
for
j
in
h1
:
patches
.
append
(
img
[
i
:
i
+
p_size
,
j
:
j
+
p_size
,:])
else
:
patches
.
append
(
img
)
return
patches
def
imssave
(
imgs
,
img_path
):
"""
imgs: list, N images of size WxHxC
"""
img_name
,
ext
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))
for
i
,
img
in
enumerate
(
imgs
):
if
img
.
ndim
==
3
:
img
=
img
[:,
:,
[
2
,
1
,
0
]]
new_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
img_path
),
img_name
+
str
(
'_s{:04d}'
.
format
(
i
))
+
'.png'
)
cv2
.
imwrite
(
new_path
,
img
)
def
split_imageset
(
original_dataroot
,
taget_dataroot
,
n_channels
=
3
,
p_size
=
800
,
p_overlap
=
96
,
p_max
=
1000
):
"""
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
will be splitted.
Args:
original_dataroot:
taget_dataroot:
p_size: size of small images
p_overlap: patch size in training is a good choice
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
"""
paths
=
get_image_paths
(
original_dataroot
)
for
img_path
in
paths
:
# img_name, ext = os.path.splitext(os.path.basename(img_path))
img
=
imread_uint
(
img_path
,
n_channels
=
n_channels
)
patches
=
patches_from_image
(
img
,
p_size
,
p_overlap
,
p_max
)
imssave
(
patches
,
os
.
path
.
join
(
taget_dataroot
,
os
.
path
.
basename
(
img_path
)))
#if original_dataroot == taget_dataroot:
#del img_path
'''
# --------------------------------------------
# makedir
# --------------------------------------------
'''
def
mkdir
(
path
):
if
not
os
.
path
.
exists
(
path
):
os
.
makedirs
(
path
)
def
mkdirs
(
paths
):
if
isinstance
(
paths
,
str
):
mkdir
(
paths
)
else
:
for
path
in
paths
:
mkdir
(
path
)
def
mkdir_and_rename
(
path
):
if
os
.
path
.
exists
(
path
):
new_name
=
path
+
'_archived_'
+
get_timestamp
()
print
(
'Path already exists. Rename it to [{:s}]'
.
format
(
new_name
))
os
.
rename
(
path
,
new_name
)
os
.
makedirs
(
path
)
'''
# --------------------------------------------
# read image from path
# opencv is fast, but read BGR numpy image
# --------------------------------------------
'''
# --------------------------------------------
# get uint8 image of size HxWxn_channles (RGB)
# --------------------------------------------
def
imread_uint
(
path
,
n_channels
=
3
):
# input: path
# output: HxWx3(RGB or GGG), or HxWx1 (G)
if
n_channels
==
1
:
img
=
cv2
.
imread
(
path
,
0
)
# cv2.IMREAD_GRAYSCALE
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
# HxWx1
elif
n_channels
==
3
:
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
# BGR or G
if
img
.
ndim
==
2
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2RGB
)
# GGG
else
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
# RGB
return
img
# --------------------------------------------
# matlab's imwrite
# --------------------------------------------
def
imsave
(
img
,
img_path
):
img
=
np
.
squeeze
(
img
)
if
img
.
ndim
==
3
:
img
=
img
[:,
:,
[
2
,
1
,
0
]]
cv2
.
imwrite
(
img_path
,
img
)
def
imwrite
(
img
,
img_path
):
img
=
np
.
squeeze
(
img
)
if
img
.
ndim
==
3
:
img
=
img
[:,
:,
[
2
,
1
,
0
]]
cv2
.
imwrite
(
img_path
,
img
)
# --------------------------------------------
# get single image of size HxWxn_channles (BGR)
# --------------------------------------------
def
read_img
(
path
):
# read image by cv2
# return: Numpy float32, HWC, BGR, [0,1]
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
# cv2.IMREAD_GRAYSCALE
img
=
img
.
astype
(
np
.
float32
)
/
255.
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
# some images have 4 channels
if
img
.
shape
[
2
]
>
3
:
img
=
img
[:,
:,
:
3
]
return
img
'''
# --------------------------------------------
# image format conversion
# --------------------------------------------
# numpy(single) <---> numpy(unit)
# numpy(single) <---> tensor
# numpy(unit) <---> tensor
# --------------------------------------------
'''
# --------------------------------------------
# numpy(single) [0, 1] <---> numpy(unit)
# --------------------------------------------
def
uint2single
(
img
):
return
np
.
float32
(
img
/
255.
)
def
single2uint
(
img
):
return
np
.
uint8
((
img
.
clip
(
0
,
1
)
*
255.
).
round
())
def
uint162single
(
img
):
return
np
.
float32
(
img
/
65535.
)
def
single2uint16
(
img
):
return
np
.
uint16
((
img
.
clip
(
0
,
1
)
*
65535.
).
round
())
# --------------------------------------------
# numpy(unit) (HxWxC or HxW) <---> tensor
# --------------------------------------------
# convert uint to 4-dimensional torch tensor
def
uint2tensor4
(
img
):
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
().
div
(
255.
).
unsqueeze
(
0
)
# convert uint to 3-dimensional torch tensor
def
uint2tensor3
(
img
):
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
().
div
(
255.
)
# convert 2/3/4-dimensional torch tensor to uint
def
tensor2uint
(
img
):
img
=
img
.
data
.
squeeze
().
float
().
clamp_
(
0
,
1
).
cpu
().
numpy
()
if
img
.
ndim
==
3
:
img
=
np
.
transpose
(
img
,
(
1
,
2
,
0
))
return
np
.
uint8
((
img
*
255.0
).
round
())
# --------------------------------------------
# numpy(single) (HxWxC) <---> tensor
# --------------------------------------------
# convert single (HxWxC) to 3-dimensional torch tensor
def
single2tensor3
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
()
# convert single (HxWxC) to 4-dimensional torch tensor
def
single2tensor4
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
().
unsqueeze
(
0
)
# convert torch tensor to single
def
tensor2single
(
img
):
img
=
img
.
data
.
squeeze
().
float
().
cpu
().
numpy
()
if
img
.
ndim
==
3
:
img
=
np
.
transpose
(
img
,
(
1
,
2
,
0
))
return
img
# convert torch tensor to single
def
tensor2single3
(
img
):
img
=
img
.
data
.
squeeze
().
float
().
cpu
().
numpy
()
if
img
.
ndim
==
3
:
img
=
np
.
transpose
(
img
,
(
1
,
2
,
0
))
elif
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
return
img
def
single2tensor5
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
,
3
).
float
().
unsqueeze
(
0
)
def
single32tensor5
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
float
().
unsqueeze
(
0
).
unsqueeze
(
0
)
def
single42tensor4
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
,
3
).
float
()
# from skimage.io import imread, imsave
def
tensor2img
(
tensor
,
out_type
=
np
.
uint8
,
min_max
=
(
0
,
1
)):
'''
Converts a torch Tensor into an image Numpy array of BGR channel order
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
'''
tensor
=
tensor
.
squeeze
().
float
().
cpu
().
clamp_
(
*
min_max
)
# squeeze first, then clamp
tensor
=
(
tensor
-
min_max
[
0
])
/
(
min_max
[
1
]
-
min_max
[
0
])
# to range [0,1]
n_dim
=
tensor
.
dim
()
if
n_dim
==
4
:
n_img
=
len
(
tensor
)
img_np
=
make_grid
(
tensor
,
nrow
=
int
(
math
.
sqrt
(
n_img
)),
normalize
=
False
).
numpy
()
img_np
=
np
.
transpose
(
img_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
# HWC, BGR
elif
n_dim
==
3
:
img_np
=
tensor
.
numpy
()
img_np
=
np
.
transpose
(
img_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
# HWC, BGR
elif
n_dim
==
2
:
img_np
=
tensor
.
numpy
()
else
:
raise
TypeError
(
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'
.
format
(
n_dim
))
if
out_type
==
np
.
uint8
:
img_np
=
(
img_np
*
255.0
).
round
()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return
img_np
.
astype
(
out_type
)
'''
# --------------------------------------------
# Augmentation, flipe and/or rotate
# --------------------------------------------
# The following two are enough.
# (1) augmet_img: numpy image of WxHxC or WxH
# (2) augment_img_tensor4: tensor image 1xCxWxH
# --------------------------------------------
'''
def
augment_img
(
img
,
mode
=
0
):
'''Kai Zhang (github: https://github.com/cszn)
'''
if
mode
==
0
:
return
img
elif
mode
==
1
:
return
np
.
flipud
(
np
.
rot90
(
img
))
elif
mode
==
2
:
return
np
.
flipud
(
img
)
elif
mode
==
3
:
return
np
.
rot90
(
img
,
k
=
3
)
elif
mode
==
4
:
return
np
.
flipud
(
np
.
rot90
(
img
,
k
=
2
))
elif
mode
==
5
:
return
np
.
rot90
(
img
)
elif
mode
==
6
:
return
np
.
rot90
(
img
,
k
=
2
)
elif
mode
==
7
:
return
np
.
flipud
(
np
.
rot90
(
img
,
k
=
3
))
def
augment_img_tensor4
(
img
,
mode
=
0
):
'''Kai Zhang (github: https://github.com/cszn)
'''
if
mode
==
0
:
return
img
elif
mode
==
1
:
return
img
.
rot90
(
1
,
[
2
,
3
]).
flip
([
2
])
elif
mode
==
2
:
return
img
.
flip
([
2
])
elif
mode
==
3
:
return
img
.
rot90
(
3
,
[
2
,
3
])
elif
mode
==
4
:
return
img
.
rot90
(
2
,
[
2
,
3
]).
flip
([
2
])
elif
mode
==
5
:
return
img
.
rot90
(
1
,
[
2
,
3
])
elif
mode
==
6
:
return
img
.
rot90
(
2
,
[
2
,
3
])
elif
mode
==
7
:
return
img
.
rot90
(
3
,
[
2
,
3
]).
flip
([
2
])
def
augment_img_tensor
(
img
,
mode
=
0
):
'''Kai Zhang (github: https://github.com/cszn)
'''
img_size
=
img
.
size
()
img_np
=
img
.
data
.
cpu
().
numpy
()
if
len
(
img_size
)
==
3
:
img_np
=
np
.
transpose
(
img_np
,
(
1
,
2
,
0
))
elif
len
(
img_size
)
==
4
:
img_np
=
np
.
transpose
(
img_np
,
(
2
,
3
,
1
,
0
))
img_np
=
augment_img
(
img_np
,
mode
=
mode
)
img_tensor
=
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img_np
))
if
len
(
img_size
)
==
3
:
img_tensor
=
img_tensor
.
permute
(
2
,
0
,
1
)
elif
len
(
img_size
)
==
4
:
img_tensor
=
img_tensor
.
permute
(
3
,
2
,
0
,
1
)
return
img_tensor
.
type_as
(
img
)
def
augment_img_np3
(
img
,
mode
=
0
):
if
mode
==
0
:
return
img
elif
mode
==
1
:
return
img
.
transpose
(
1
,
0
,
2
)
elif
mode
==
2
:
return
img
[::
-
1
,
:,
:]
elif
mode
==
3
:
img
=
img
[::
-
1
,
:,
:]
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
elif
mode
==
4
:
return
img
[:,
::
-
1
,
:]
elif
mode
==
5
:
img
=
img
[:,
::
-
1
,
:]
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
elif
mode
==
6
:
img
=
img
[:,
::
-
1
,
:]
img
=
img
[::
-
1
,
:,
:]
return
img
elif
mode
==
7
:
img
=
img
[:,
::
-
1
,
:]
img
=
img
[::
-
1
,
:,
:]
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
def
augment_imgs
(
img_list
,
hflip
=
True
,
rot
=
True
):
# horizontal flip OR rotate
hflip
=
hflip
and
random
.
random
()
<
0.5
vflip
=
rot
and
random
.
random
()
<
0.5
rot90
=
rot
and
random
.
random
()
<
0.5
def
_augment
(
img
):
if
hflip
:
img
=
img
[:,
::
-
1
,
:]
if
vflip
:
img
=
img
[::
-
1
,
:,
:]
if
rot90
:
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
return
[
_augment
(
img
)
for
img
in
img_list
]
'''
# --------------------------------------------
# modcrop and shave
# --------------------------------------------
'''
def
modcrop
(
img_in
,
scale
):
# img_in: Numpy, HWC or HW
img
=
np
.
copy
(
img_in
)
if
img
.
ndim
==
2
:
H
,
W
=
img
.
shape
H_r
,
W_r
=
H
%
scale
,
W
%
scale
img
=
img
[:
H
-
H_r
,
:
W
-
W_r
]
elif
img
.
ndim
==
3
:
H
,
W
,
C
=
img
.
shape
H_r
,
W_r
=
H
%
scale
,
W
%
scale
img
=
img
[:
H
-
H_r
,
:
W
-
W_r
,
:]
else
:
raise
ValueError
(
'Wrong img ndim: [{:d}].'
.
format
(
img
.
ndim
))
return
img
def
shave
(
img_in
,
border
=
0
):
# img_in: Numpy, HWC or HW
img
=
np
.
copy
(
img_in
)
h
,
w
=
img
.
shape
[:
2
]
img
=
img
[
border
:
h
-
border
,
border
:
w
-
border
]
return
img
'''
# --------------------------------------------
# image processing process on numpy image
# channel_convert(in_c, tar_type, img_list):
# rgb2ycbcr(img, only_y=True):
# bgr2ycbcr(img, only_y=True):
# ycbcr2rgb(img):
# --------------------------------------------
'''
def
rgb2ycbcr
(
img
,
only_y
=
True
):
'''same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type
=
img
.
dtype
img
.
astype
(
np
.
float32
)
if
in_img_type
!=
np
.
uint8
:
img
*=
255.
# convert
if
only_y
:
rlt
=
np
.
dot
(
img
,
[
65.481
,
128.553
,
24.966
])
/
255.0
+
16.0
else
:
rlt
=
np
.
matmul
(
img
,
[[
65.481
,
-
37.797
,
112.0
],
[
128.553
,
-
74.203
,
-
93.786
],
[
24.966
,
112.0
,
-
18.214
]])
/
255.0
+
[
16
,
128
,
128
]
if
in_img_type
==
np
.
uint8
:
rlt
=
rlt
.
round
()
else
:
rlt
/=
255.
return
rlt
.
astype
(
in_img_type
)
def
ycbcr2rgb
(
img
):
'''same as matlab ycbcr2rgb
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type
=
img
.
dtype
img
.
astype
(
np
.
float32
)
if
in_img_type
!=
np
.
uint8
:
img
*=
255.
# convert
rlt
=
np
.
matmul
(
img
,
[[
0.00456621
,
0.00456621
,
0.00456621
],
[
0
,
-
0.00153632
,
0.00791071
],
[
0.00625893
,
-
0.00318811
,
0
]])
*
255.0
+
[
-
222.921
,
135.576
,
-
276.836
]
if
in_img_type
==
np
.
uint8
:
rlt
=
rlt
.
round
()
else
:
rlt
/=
255.
return
rlt
.
astype
(
in_img_type
)
def
bgr2ycbcr
(
img
,
only_y
=
True
):
'''bgr version of rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type
=
img
.
dtype
img
.
astype
(
np
.
float32
)
if
in_img_type
!=
np
.
uint8
:
img
*=
255.
# convert
if
only_y
:
rlt
=
np
.
dot
(
img
,
[
24.966
,
128.553
,
65.481
])
/
255.0
+
16.0
else
:
rlt
=
np
.
matmul
(
img
,
[[
24.966
,
112.0
,
-
18.214
],
[
128.553
,
-
74.203
,
-
93.786
],
[
65.481
,
-
37.797
,
112.0
]])
/
255.0
+
[
16
,
128
,
128
]
if
in_img_type
==
np
.
uint8
:
rlt
=
rlt
.
round
()
else
:
rlt
/=
255.
return
rlt
.
astype
(
in_img_type
)
def
channel_convert
(
in_c
,
tar_type
,
img_list
):
# conversion among BGR, gray and y
if
in_c
==
3
and
tar_type
==
'gray'
:
# BGR to gray
gray_list
=
[
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
for
img
in
img_list
]
return
[
np
.
expand_dims
(
img
,
axis
=
2
)
for
img
in
gray_list
]
elif
in_c
==
3
and
tar_type
==
'y'
:
# BGR to y
y_list
=
[
bgr2ycbcr
(
img
,
only_y
=
True
)
for
img
in
img_list
]
return
[
np
.
expand_dims
(
img
,
axis
=
2
)
for
img
in
y_list
]
elif
in_c
==
1
and
tar_type
==
'RGB'
:
# gray/y to BGR
return
[
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
for
img
in
img_list
]
else
:
return
img_list
'''
# --------------------------------------------
# metric, PSNR and SSIM
# --------------------------------------------
'''
# --------------------------------------------
# PSNR
# --------------------------------------------
def
calculate_psnr
(
img1
,
img2
,
border
=
0
):
# img1 and img2 have range [0, 255]
#img1 = img1.squeeze()
#img2 = img2.squeeze()
if
not
img1
.
shape
==
img2
.
shape
:
raise
ValueError
(
'Input images must have the same dimensions.'
)
h
,
w
=
img1
.
shape
[:
2
]
img1
=
img1
[
border
:
h
-
border
,
border
:
w
-
border
]
img2
=
img2
[
border
:
h
-
border
,
border
:
w
-
border
]
img1
=
img1
.
astype
(
np
.
float64
)
img2
=
img2
.
astype
(
np
.
float64
)
mse
=
np
.
mean
((
img1
-
img2
)
**
2
)
if
mse
==
0
:
return
float
(
'inf'
)
return
20
*
math
.
log10
(
255.0
/
math
.
sqrt
(
mse
))
# --------------------------------------------
# SSIM
# --------------------------------------------
def
calculate_ssim
(
img1
,
img2
,
border
=
0
):
'''calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
#img1 = img1.squeeze()
#img2 = img2.squeeze()
if
not
img1
.
shape
==
img2
.
shape
:
raise
ValueError
(
'Input images must have the same dimensions.'
)
h
,
w
=
img1
.
shape
[:
2
]
img1
=
img1
[
border
:
h
-
border
,
border
:
w
-
border
]
img2
=
img2
[
border
:
h
-
border
,
border
:
w
-
border
]
if
img1
.
ndim
==
2
:
return
ssim
(
img1
,
img2
)
elif
img1
.
ndim
==
3
:
if
img1
.
shape
[
2
]
==
3
:
ssims
=
[]
for
i
in
range
(
3
):
ssims
.
append
(
ssim
(
img1
[:,:,
i
],
img2
[:,:,
i
]))
return
np
.
array
(
ssims
).
mean
()
elif
img1
.
shape
[
2
]
==
1
:
return
ssim
(
np
.
squeeze
(
img1
),
np
.
squeeze
(
img2
))
else
:
raise
ValueError
(
'Wrong input image dimensions.'
)
def
ssim
(
img1
,
img2
):
C1
=
(
0.01
*
255
)
**
2
C2
=
(
0.03
*
255
)
**
2
img1
=
img1
.
astype
(
np
.
float64
)
img2
=
img2
.
astype
(
np
.
float64
)
kernel
=
cv2
.
getGaussianKernel
(
11
,
1.5
)
window
=
np
.
outer
(
kernel
,
kernel
.
transpose
())
mu1
=
cv2
.
filter2D
(
img1
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
# valid
mu2
=
cv2
.
filter2D
(
img2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
mu1_sq
=
mu1
**
2
mu2_sq
=
mu2
**
2
mu1_mu2
=
mu1
*
mu2
sigma1_sq
=
cv2
.
filter2D
(
img1
**
2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu1_sq
sigma2_sq
=
cv2
.
filter2D
(
img2
**
2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu2_sq
sigma12
=
cv2
.
filter2D
(
img1
*
img2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu1_mu2
ssim_map
=
((
2
*
mu1_mu2
+
C1
)
*
(
2
*
sigma12
+
C2
))
/
((
mu1_sq
+
mu2_sq
+
C1
)
*
(
sigma1_sq
+
sigma2_sq
+
C2
))
return
ssim_map
.
mean
()
'''
# --------------------------------------------
# matlab's bicubic imresize (numpy and torch) [0, 1]
# --------------------------------------------
'''
# matlab 'imresize' function, now only support 'bicubic'
def
cubic
(
x
):
absx
=
torch
.
abs
(
x
)
absx2
=
absx
**
2
absx3
=
absx
**
3
return
(
1.5
*
absx3
-
2.5
*
absx2
+
1
)
*
((
absx
<=
1
).
type_as
(
absx
))
+
\
(
-
0.5
*
absx3
+
2.5
*
absx2
-
4
*
absx
+
2
)
*
(((
absx
>
1
)
*
(
absx
<=
2
)).
type_as
(
absx
))
def
calculate_weights_indices
(
in_length
,
out_length
,
scale
,
kernel
,
kernel_width
,
antialiasing
):
if
(
scale
<
1
)
and
(
antialiasing
):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width
=
kernel_width
/
scale
# Output-space coordinates
x
=
torch
.
linspace
(
1
,
out_length
,
out_length
)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5+scale in output
# space maps to 1.5 in input space.
u
=
x
/
scale
+
0.5
*
(
1
-
1
/
scale
)
# What is the left-most pixel that can be involved in the computation?
left
=
torch
.
floor
(
u
-
kernel_width
/
2
)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
P
=
math
.
ceil
(
kernel_width
)
+
2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices
=
left
.
view
(
out_length
,
1
).
expand
(
out_length
,
P
)
+
torch
.
linspace
(
0
,
P
-
1
,
P
).
view
(
1
,
P
).
expand
(
out_length
,
P
)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center
=
u
.
view
(
out_length
,
1
).
expand
(
out_length
,
P
)
-
indices
# apply cubic kernel
if
(
scale
<
1
)
and
(
antialiasing
):
weights
=
scale
*
cubic
(
distance_to_center
*
scale
)
else
:
weights
=
cubic
(
distance_to_center
)
# Normalize the weights matrix so that each row sums to 1.
weights_sum
=
torch
.
sum
(
weights
,
1
).
view
(
out_length
,
1
)
weights
=
weights
/
weights_sum
.
expand
(
out_length
,
P
)
# If a column in weights is all zero, get rid of it. only consider the first and last column.
weights_zero_tmp
=
torch
.
sum
((
weights
==
0
),
0
)
if
not
math
.
isclose
(
weights_zero_tmp
[
0
],
0
,
rel_tol
=
1e-6
):
indices
=
indices
.
narrow
(
1
,
1
,
P
-
2
)
weights
=
weights
.
narrow
(
1
,
1
,
P
-
2
)
if
not
math
.
isclose
(
weights_zero_tmp
[
-
1
],
0
,
rel_tol
=
1e-6
):
indices
=
indices
.
narrow
(
1
,
0
,
P
-
2
)
weights
=
weights
.
narrow
(
1
,
0
,
P
-
2
)
weights
=
weights
.
contiguous
()
indices
=
indices
.
contiguous
()
sym_len_s
=
-
indices
.
min
()
+
1
sym_len_e
=
indices
.
max
()
-
in_length
indices
=
indices
+
sym_len_s
-
1
return
weights
,
indices
,
int
(
sym_len_s
),
int
(
sym_len_e
)
# --------------------------------------------
# imresize for tensor image [0, 1]
# --------------------------------------------
def
imresize
(
img
,
scale
,
antialiasing
=
True
):
# Now the scale should be the same for H and W
# input: img: pytorch tensor, CHW or HW [0,1]
# output: CHW or HW [0,1] w/o round
need_squeeze
=
True
if
img
.
dim
()
==
2
else
False
if
need_squeeze
:
img
.
unsqueeze_
(
0
)
in_C
,
in_H
,
in_W
=
img
.
size
()
out_C
,
out_H
,
out_W
=
in_C
,
math
.
ceil
(
in_H
*
scale
),
math
.
ceil
(
in_W
*
scale
)
kernel_width
=
4
kernel
=
'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H
,
indices_H
,
sym_len_Hs
,
sym_len_He
=
calculate_weights_indices
(
in_H
,
out_H
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
weights_W
,
indices_W
,
sym_len_Ws
,
sym_len_We
=
calculate_weights_indices
(
in_W
,
out_W
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
# process H dimension
# symmetric copying
img_aug
=
torch
.
FloatTensor
(
in_C
,
in_H
+
sym_len_Hs
+
sym_len_He
,
in_W
)
img_aug
.
narrow
(
1
,
sym_len_Hs
,
in_H
).
copy_
(
img
)
sym_patch
=
img
[:,
:
sym_len_Hs
,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
img_aug
.
narrow
(
1
,
0
,
sym_len_Hs
).
copy_
(
sym_patch_inv
)
sym_patch
=
img
[:,
-
sym_len_He
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
img_aug
.
narrow
(
1
,
sym_len_Hs
+
in_H
,
sym_len_He
).
copy_
(
sym_patch_inv
)
out_1
=
torch
.
FloatTensor
(
in_C
,
out_H
,
in_W
)
kernel_width
=
weights_H
.
size
(
1
)
for
i
in
range
(
out_H
):
idx
=
int
(
indices_H
[
i
][
0
])
for
j
in
range
(
out_C
):
out_1
[
j
,
i
,
:]
=
img_aug
[
j
,
idx
:
idx
+
kernel_width
,
:].
transpose
(
0
,
1
).
mv
(
weights_H
[
i
])
# process W dimension
# symmetric copying
out_1_aug
=
torch
.
FloatTensor
(
in_C
,
out_H
,
in_W
+
sym_len_Ws
+
sym_len_We
)
out_1_aug
.
narrow
(
2
,
sym_len_Ws
,
in_W
).
copy_
(
out_1
)
sym_patch
=
out_1
[:,
:,
:
sym_len_Ws
]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
2
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
2
,
inv_idx
)
out_1_aug
.
narrow
(
2
,
0
,
sym_len_Ws
).
copy_
(
sym_patch_inv
)
sym_patch
=
out_1
[:,
:,
-
sym_len_We
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
2
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
2
,
inv_idx
)
out_1_aug
.
narrow
(
2
,
sym_len_Ws
+
in_W
,
sym_len_We
).
copy_
(
sym_patch_inv
)
out_2
=
torch
.
FloatTensor
(
in_C
,
out_H
,
out_W
)
kernel_width
=
weights_W
.
size
(
1
)
for
i
in
range
(
out_W
):
idx
=
int
(
indices_W
[
i
][
0
])
for
j
in
range
(
out_C
):
out_2
[
j
,
:,
i
]
=
out_1_aug
[
j
,
:,
idx
:
idx
+
kernel_width
].
mv
(
weights_W
[
i
])
if
need_squeeze
:
out_2
.
squeeze_
()
return
out_2
# --------------------------------------------
# imresize for numpy image [0, 1]
# --------------------------------------------
def
imresize_np
(
img
,
scale
,
antialiasing
=
True
):
# Now the scale should be the same for H and W
# input: img: Numpy, HWC or HW [0,1]
# output: HWC or HW [0,1] w/o round
img
=
torch
.
from_numpy
(
img
)
need_squeeze
=
True
if
img
.
dim
()
==
2
else
False
if
need_squeeze
:
img
.
unsqueeze_
(
2
)
in_H
,
in_W
,
in_C
=
img
.
size
()
out_C
,
out_H
,
out_W
=
in_C
,
math
.
ceil
(
in_H
*
scale
),
math
.
ceil
(
in_W
*
scale
)
kernel_width
=
4
kernel
=
'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H
,
indices_H
,
sym_len_Hs
,
sym_len_He
=
calculate_weights_indices
(
in_H
,
out_H
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
weights_W
,
indices_W
,
sym_len_Ws
,
sym_len_We
=
calculate_weights_indices
(
in_W
,
out_W
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
# process H dimension
# symmetric copying
img_aug
=
torch
.
FloatTensor
(
in_H
+
sym_len_Hs
+
sym_len_He
,
in_W
,
in_C
)
img_aug
.
narrow
(
0
,
sym_len_Hs
,
in_H
).
copy_
(
img
)
sym_patch
=
img
[:
sym_len_Hs
,
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
0
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
0
,
inv_idx
)
img_aug
.
narrow
(
0
,
0
,
sym_len_Hs
).
copy_
(
sym_patch_inv
)
sym_patch
=
img
[
-
sym_len_He
:,
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
0
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
0
,
inv_idx
)
img_aug
.
narrow
(
0
,
sym_len_Hs
+
in_H
,
sym_len_He
).
copy_
(
sym_patch_inv
)
out_1
=
torch
.
FloatTensor
(
out_H
,
in_W
,
in_C
)
kernel_width
=
weights_H
.
size
(
1
)
for
i
in
range
(
out_H
):
idx
=
int
(
indices_H
[
i
][
0
])
for
j
in
range
(
out_C
):
out_1
[
i
,
:,
j
]
=
img_aug
[
idx
:
idx
+
kernel_width
,
:,
j
].
transpose
(
0
,
1
).
mv
(
weights_H
[
i
])
# process W dimension
# symmetric copying
out_1_aug
=
torch
.
FloatTensor
(
out_H
,
in_W
+
sym_len_Ws
+
sym_len_We
,
in_C
)
out_1_aug
.
narrow
(
1
,
sym_len_Ws
,
in_W
).
copy_
(
out_1
)
sym_patch
=
out_1
[:,
:
sym_len_Ws
,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
out_1_aug
.
narrow
(
1
,
0
,
sym_len_Ws
).
copy_
(
sym_patch_inv
)
sym_patch
=
out_1
[:,
-
sym_len_We
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
out_1_aug
.
narrow
(
1
,
sym_len_Ws
+
in_W
,
sym_len_We
).
copy_
(
sym_patch_inv
)
out_2
=
torch
.
FloatTensor
(
out_H
,
out_W
,
in_C
)
kernel_width
=
weights_W
.
size
(
1
)
for
i
in
range
(
out_W
):
idx
=
int
(
indices_W
[
i
][
0
])
for
j
in
range
(
out_C
):
out_2
[:,
i
,
j
]
=
out_1_aug
[:,
idx
:
idx
+
kernel_width
,
j
].
mv
(
weights_W
[
i
])
if
need_squeeze
:
out_2
.
squeeze_
()
return
out_2
.
numpy
()
if
__name__
==
'__main__'
:
print
(
'---'
)
# img = imread_uint('test.bmp', 3)
# img = uint2single(img)
# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
examples/images/diffusion/ldm/modules/midas/__init__.py
0 → 100644
View file @
e532679c
examples/images/diffusion/ldm/modules/midas/api.py
0 → 100644
View file @
e532679c
# based on https://github.com/isl-org/MiDaS
import
cv2
import
torch
import
torch.nn
as
nn
from
torchvision.transforms
import
Compose
from
ldm.modules.midas.midas.dpt_depth
import
DPTDepthModel
from
ldm.modules.midas.midas.midas_net
import
MidasNet
from
ldm.modules.midas.midas.midas_net_custom
import
MidasNet_small
from
ldm.modules.midas.midas.transforms
import
Resize
,
NormalizeImage
,
PrepareForNet
ISL_PATHS
=
{
"dpt_large"
:
"midas_models/dpt_large-midas-2f21e586.pt"
,
"dpt_hybrid"
:
"midas_models/dpt_hybrid-midas-501f0c75.pt"
,
"midas_v21"
:
""
,
"midas_v21_small"
:
""
,
}
def
disabled_train
(
self
,
mode
=
True
):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return
self
def
load_midas_transform
(
model_type
):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load transform only
if
model_type
==
"dpt_large"
:
# DPT-Large
net_w
,
net_h
=
384
,
384
resize_mode
=
"minimal"
normalization
=
NormalizeImage
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
elif
model_type
==
"dpt_hybrid"
:
# DPT-Hybrid
net_w
,
net_h
=
384
,
384
resize_mode
=
"minimal"
normalization
=
NormalizeImage
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
elif
model_type
==
"midas_v21"
:
net_w
,
net_h
=
384
,
384
resize_mode
=
"upper_bound"
normalization
=
NormalizeImage
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
elif
model_type
==
"midas_v21_small"
:
net_w
,
net_h
=
256
,
256
resize_mode
=
"upper_bound"
normalization
=
NormalizeImage
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
else
:
assert
False
,
f
"model_type '
{
model_type
}
' not implemented, use: --model_type large"
transform
=
Compose
(
[
Resize
(
net_w
,
net_h
,
resize_target
=
None
,
keep_aspect_ratio
=
True
,
ensure_multiple_of
=
32
,
resize_method
=
resize_mode
,
image_interpolation_method
=
cv2
.
INTER_CUBIC
,
),
normalization
,
PrepareForNet
(),
]
)
return
transform
def
load_model
(
model_type
):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load network
model_path
=
ISL_PATHS
[
model_type
]
if
model_type
==
"dpt_large"
:
# DPT-Large
model
=
DPTDepthModel
(
path
=
model_path
,
backbone
=
"vitl16_384"
,
non_negative
=
True
,
)
net_w
,
net_h
=
384
,
384
resize_mode
=
"minimal"
normalization
=
NormalizeImage
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
elif
model_type
==
"dpt_hybrid"
:
# DPT-Hybrid
model
=
DPTDepthModel
(
path
=
model_path
,
backbone
=
"vitb_rn50_384"
,
non_negative
=
True
,
)
net_w
,
net_h
=
384
,
384
resize_mode
=
"minimal"
normalization
=
NormalizeImage
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
elif
model_type
==
"midas_v21"
:
model
=
MidasNet
(
model_path
,
non_negative
=
True
)
net_w
,
net_h
=
384
,
384
resize_mode
=
"upper_bound"
normalization
=
NormalizeImage
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]
)
elif
model_type
==
"midas_v21_small"
:
model
=
MidasNet_small
(
model_path
,
features
=
64
,
backbone
=
"efficientnet_lite3"
,
exportable
=
True
,
non_negative
=
True
,
blocks
=
{
'expand'
:
True
})
net_w
,
net_h
=
256
,
256
resize_mode
=
"upper_bound"
normalization
=
NormalizeImage
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]
)
else
:
print
(
f
"model_type '
{
model_type
}
' not implemented, use: --model_type large"
)
assert
False
transform
=
Compose
(
[
Resize
(
net_w
,
net_h
,
resize_target
=
None
,
keep_aspect_ratio
=
True
,
ensure_multiple_of
=
32
,
resize_method
=
resize_mode
,
image_interpolation_method
=
cv2
.
INTER_CUBIC
,
),
normalization
,
PrepareForNet
(),
]
)
return
model
.
eval
(),
transform
class
MiDaSInference
(
nn
.
Module
):
MODEL_TYPES_TORCH_HUB
=
[
"DPT_Large"
,
"DPT_Hybrid"
,
"MiDaS_small"
]
MODEL_TYPES_ISL
=
[
"dpt_large"
,
"dpt_hybrid"
,
"midas_v21"
,
"midas_v21_small"
,
]
def
__init__
(
self
,
model_type
):
super
().
__init__
()
assert
(
model_type
in
self
.
MODEL_TYPES_ISL
)
model
,
_
=
load_model
(
model_type
)
self
.
model
=
model
self
.
model
.
train
=
disabled_train
def
forward
(
self
,
x
):
# x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
# NOTE: we expect that the correct transform has been called during dataloading.
with
torch
.
no_grad
():
prediction
=
self
.
model
(
x
)
prediction
=
torch
.
nn
.
functional
.
interpolate
(
prediction
.
unsqueeze
(
1
),
size
=
x
.
shape
[
2
:],
mode
=
"bicubic"
,
align_corners
=
False
,
)
assert
prediction
.
shape
==
(
x
.
shape
[
0
],
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
return
prediction
examples/images/diffusion/ldm/modules/midas/midas/__init__.py
0 → 100644
View file @
e532679c
examples/images/diffusion/ldm/modules/midas/midas/base_model.py
0 → 100644
View file @
e532679c
import
torch
class
BaseModel
(
torch
.
nn
.
Module
):
def
load
(
self
,
path
):
"""Load model from file.
Args:
path (str): file path
"""
parameters
=
torch
.
load
(
path
,
map_location
=
torch
.
device
(
'cpu'
))
if
"optimizer"
in
parameters
:
parameters
=
parameters
[
"model"
]
self
.
load_state_dict
(
parameters
)
examples/images/diffusion/ldm/modules/midas/midas/blocks.py
0 → 100644
View file @
e532679c
import
torch
import
torch.nn
as
nn
from
.vit
import
(
_make_pretrained_vitb_rn50_384
,
_make_pretrained_vitl16_384
,
_make_pretrained_vitb16_384
,
forward_vit
,
)
def
_make_encoder
(
backbone
,
features
,
use_pretrained
,
groups
=
1
,
expand
=
False
,
exportable
=
True
,
hooks
=
None
,
use_vit_only
=
False
,
use_readout
=
"ignore"
,):
if
backbone
==
"vitl16_384"
:
pretrained
=
_make_pretrained_vitl16_384
(
use_pretrained
,
hooks
=
hooks
,
use_readout
=
use_readout
)
scratch
=
_make_scratch
(
[
256
,
512
,
1024
,
1024
],
features
,
groups
=
groups
,
expand
=
expand
)
# ViT-L/16 - 85.0% Top1 (backbone)
elif
backbone
==
"vitb_rn50_384"
:
pretrained
=
_make_pretrained_vitb_rn50_384
(
use_pretrained
,
hooks
=
hooks
,
use_vit_only
=
use_vit_only
,
use_readout
=
use_readout
,
)
scratch
=
_make_scratch
(
[
256
,
512
,
768
,
768
],
features
,
groups
=
groups
,
expand
=
expand
)
# ViT-H/16 - 85.0% Top1 (backbone)
elif
backbone
==
"vitb16_384"
:
pretrained
=
_make_pretrained_vitb16_384
(
use_pretrained
,
hooks
=
hooks
,
use_readout
=
use_readout
)
scratch
=
_make_scratch
(
[
96
,
192
,
384
,
768
],
features
,
groups
=
groups
,
expand
=
expand
)
# ViT-B/16 - 84.6% Top1 (backbone)
elif
backbone
==
"resnext101_wsl"
:
pretrained
=
_make_pretrained_resnext101_wsl
(
use_pretrained
)
scratch
=
_make_scratch
([
256
,
512
,
1024
,
2048
],
features
,
groups
=
groups
,
expand
=
expand
)
# efficientnet_lite3
elif
backbone
==
"efficientnet_lite3"
:
pretrained
=
_make_pretrained_efficientnet_lite3
(
use_pretrained
,
exportable
=
exportable
)
scratch
=
_make_scratch
([
32
,
48
,
136
,
384
],
features
,
groups
=
groups
,
expand
=
expand
)
# efficientnet_lite3
else
:
print
(
f
"Backbone '
{
backbone
}
' not implemented"
)
assert
False
return
pretrained
,
scratch
def
_make_scratch
(
in_shape
,
out_shape
,
groups
=
1
,
expand
=
False
):
scratch
=
nn
.
Module
()
out_shape1
=
out_shape
out_shape2
=
out_shape
out_shape3
=
out_shape
out_shape4
=
out_shape
if
expand
==
True
:
out_shape1
=
out_shape
out_shape2
=
out_shape
*
2
out_shape3
=
out_shape
*
4
out_shape4
=
out_shape
*
8
scratch
.
layer1_rn
=
nn
.
Conv2d
(
in_shape
[
0
],
out_shape1
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
,
groups
=
groups
)
scratch
.
layer2_rn
=
nn
.
Conv2d
(
in_shape
[
1
],
out_shape2
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
,
groups
=
groups
)
scratch
.
layer3_rn
=
nn
.
Conv2d
(
in_shape
[
2
],
out_shape3
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
,
groups
=
groups
)
scratch
.
layer4_rn
=
nn
.
Conv2d
(
in_shape
[
3
],
out_shape4
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
,
groups
=
groups
)
return
scratch
def
_make_pretrained_efficientnet_lite3
(
use_pretrained
,
exportable
=
False
):
efficientnet
=
torch
.
hub
.
load
(
"rwightman/gen-efficientnet-pytorch"
,
"tf_efficientnet_lite3"
,
pretrained
=
use_pretrained
,
exportable
=
exportable
)
return
_make_efficientnet_backbone
(
efficientnet
)
def
_make_efficientnet_backbone
(
effnet
):
pretrained
=
nn
.
Module
()
pretrained
.
layer1
=
nn
.
Sequential
(
effnet
.
conv_stem
,
effnet
.
bn1
,
effnet
.
act1
,
*
effnet
.
blocks
[
0
:
2
]
)
pretrained
.
layer2
=
nn
.
Sequential
(
*
effnet
.
blocks
[
2
:
3
])
pretrained
.
layer3
=
nn
.
Sequential
(
*
effnet
.
blocks
[
3
:
5
])
pretrained
.
layer4
=
nn
.
Sequential
(
*
effnet
.
blocks
[
5
:
9
])
return
pretrained
def
_make_resnet_backbone
(
resnet
):
pretrained
=
nn
.
Module
()
pretrained
.
layer1
=
nn
.
Sequential
(
resnet
.
conv1
,
resnet
.
bn1
,
resnet
.
relu
,
resnet
.
maxpool
,
resnet
.
layer1
)
pretrained
.
layer2
=
resnet
.
layer2
pretrained
.
layer3
=
resnet
.
layer3
pretrained
.
layer4
=
resnet
.
layer4
return
pretrained
def
_make_pretrained_resnext101_wsl
(
use_pretrained
):
resnet
=
torch
.
hub
.
load
(
"facebookresearch/WSL-Images"
,
"resnext101_32x8d_wsl"
)
return
_make_resnet_backbone
(
resnet
)
class
Interpolate
(
nn
.
Module
):
"""Interpolation module.
"""
def
__init__
(
self
,
scale_factor
,
mode
,
align_corners
=
False
):
"""Init.
Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super
(
Interpolate
,
self
).
__init__
()
self
.
interp
=
nn
.
functional
.
interpolate
self
.
scale_factor
=
scale_factor
self
.
mode
=
mode
self
.
align_corners
=
align_corners
def
forward
(
self
,
x
):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: interpolated data
"""
x
=
self
.
interp
(
x
,
scale_factor
=
self
.
scale_factor
,
mode
=
self
.
mode
,
align_corners
=
self
.
align_corners
)
return
x
class
ResidualConvUnit
(
nn
.
Module
):
"""Residual convolution module.
"""
def
__init__
(
self
,
features
):
"""Init.
Args:
features (int): number of features
"""
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
features
,
features
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
)
self
.
conv2
=
nn
.
Conv2d
(
features
,
features
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
def
forward
(
self
,
x
):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out
=
self
.
relu
(
x
)
out
=
self
.
conv1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
return
out
+
x
class
FeatureFusionBlock
(
nn
.
Module
):
"""Feature fusion block.
"""
def
__init__
(
self
,
features
):
"""Init.
Args:
features (int): number of features
"""
super
(
FeatureFusionBlock
,
self
).
__init__
()
self
.
resConfUnit1
=
ResidualConvUnit
(
features
)
self
.
resConfUnit2
=
ResidualConvUnit
(
features
)
def
forward
(
self
,
*
xs
):
"""Forward pass.
Returns:
tensor: output
"""
output
=
xs
[
0
]
if
len
(
xs
)
==
2
:
output
+=
self
.
resConfUnit1
(
xs
[
1
])
output
=
self
.
resConfUnit2
(
output
)
output
=
nn
.
functional
.
interpolate
(
output
,
scale_factor
=
2
,
mode
=
"bilinear"
,
align_corners
=
True
)
return
output
class
ResidualConvUnit_custom
(
nn
.
Module
):
"""Residual convolution module.
"""
def
__init__
(
self
,
features
,
activation
,
bn
):
"""Init.
Args:
features (int): number of features
"""
super
().
__init__
()
self
.
bn
=
bn
self
.
groups
=
1
self
.
conv1
=
nn
.
Conv2d
(
features
,
features
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
groups
=
self
.
groups
)
self
.
conv2
=
nn
.
Conv2d
(
features
,
features
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
groups
=
self
.
groups
)
if
self
.
bn
==
True
:
self
.
bn1
=
nn
.
BatchNorm2d
(
features
)
self
.
bn2
=
nn
.
BatchNorm2d
(
features
)
self
.
activation
=
activation
self
.
skip_add
=
nn
.
quantized
.
FloatFunctional
()
def
forward
(
self
,
x
):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out
=
self
.
activation
(
x
)
out
=
self
.
conv1
(
out
)
if
self
.
bn
==
True
:
out
=
self
.
bn1
(
out
)
out
=
self
.
activation
(
out
)
out
=
self
.
conv2
(
out
)
if
self
.
bn
==
True
:
out
=
self
.
bn2
(
out
)
if
self
.
groups
>
1
:
out
=
self
.
conv_merge
(
out
)
return
self
.
skip_add
.
add
(
out
,
x
)
# return out + x
class
FeatureFusionBlock_custom
(
nn
.
Module
):
"""Feature fusion block.
"""
def
__init__
(
self
,
features
,
activation
,
deconv
=
False
,
bn
=
False
,
expand
=
False
,
align_corners
=
True
):
"""Init.
Args:
features (int): number of features
"""
super
(
FeatureFusionBlock_custom
,
self
).
__init__
()
self
.
deconv
=
deconv
self
.
align_corners
=
align_corners
self
.
groups
=
1
self
.
expand
=
expand
out_features
=
features
if
self
.
expand
==
True
:
out_features
=
features
//
2
self
.
out_conv
=
nn
.
Conv2d
(
features
,
out_features
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
True
,
groups
=
1
)
self
.
resConfUnit1
=
ResidualConvUnit_custom
(
features
,
activation
,
bn
)
self
.
resConfUnit2
=
ResidualConvUnit_custom
(
features
,
activation
,
bn
)
self
.
skip_add
=
nn
.
quantized
.
FloatFunctional
()
def
forward
(
self
,
*
xs
):
"""Forward pass.
Returns:
tensor: output
"""
output
=
xs
[
0
]
if
len
(
xs
)
==
2
:
res
=
self
.
resConfUnit1
(
xs
[
1
])
output
=
self
.
skip_add
.
add
(
output
,
res
)
# output += res
output
=
self
.
resConfUnit2
(
output
)
output
=
nn
.
functional
.
interpolate
(
output
,
scale_factor
=
2
,
mode
=
"bilinear"
,
align_corners
=
self
.
align_corners
)
output
=
self
.
out_conv
(
output
)
return
output
examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py
0 → 100644
View file @
e532679c
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.base_model
import
BaseModel
from
.blocks
import
(
FeatureFusionBlock
,
FeatureFusionBlock_custom
,
Interpolate
,
_make_encoder
,
forward_vit
,
)
def
_make_fusion_block
(
features
,
use_bn
):
return
FeatureFusionBlock_custom
(
features
,
nn
.
ReLU
(
False
),
deconv
=
False
,
bn
=
use_bn
,
expand
=
False
,
align_corners
=
True
,
)
class
DPT
(
BaseModel
):
def
__init__
(
self
,
head
,
features
=
256
,
backbone
=
"vitb_rn50_384"
,
readout
=
"project"
,
channels_last
=
False
,
use_bn
=
False
,
):
super
(
DPT
,
self
).
__init__
()
self
.
channels_last
=
channels_last
hooks
=
{
"vitb_rn50_384"
:
[
0
,
1
,
8
,
11
],
"vitb16_384"
:
[
2
,
5
,
8
,
11
],
"vitl16_384"
:
[
5
,
11
,
17
,
23
],
}
# Instantiate backbone and reassemble blocks
self
.
pretrained
,
self
.
scratch
=
_make_encoder
(
backbone
,
features
,
False
,
# Set to true of you want to train from scratch, uses ImageNet weights
groups
=
1
,
expand
=
False
,
exportable
=
False
,
hooks
=
hooks
[
backbone
],
use_readout
=
readout
,
)
self
.
scratch
.
refinenet1
=
_make_fusion_block
(
features
,
use_bn
)
self
.
scratch
.
refinenet2
=
_make_fusion_block
(
features
,
use_bn
)
self
.
scratch
.
refinenet3
=
_make_fusion_block
(
features
,
use_bn
)
self
.
scratch
.
refinenet4
=
_make_fusion_block
(
features
,
use_bn
)
self
.
scratch
.
output_conv
=
head
def
forward
(
self
,
x
):
if
self
.
channels_last
==
True
:
x
.
contiguous
(
memory_format
=
torch
.
channels_last
)
layer_1
,
layer_2
,
layer_3
,
layer_4
=
forward_vit
(
self
.
pretrained
,
x
)
layer_1_rn
=
self
.
scratch
.
layer1_rn
(
layer_1
)
layer_2_rn
=
self
.
scratch
.
layer2_rn
(
layer_2
)
layer_3_rn
=
self
.
scratch
.
layer3_rn
(
layer_3
)
layer_4_rn
=
self
.
scratch
.
layer4_rn
(
layer_4
)
path_4
=
self
.
scratch
.
refinenet4
(
layer_4_rn
)
path_3
=
self
.
scratch
.
refinenet3
(
path_4
,
layer_3_rn
)
path_2
=
self
.
scratch
.
refinenet2
(
path_3
,
layer_2_rn
)
path_1
=
self
.
scratch
.
refinenet1
(
path_2
,
layer_1_rn
)
out
=
self
.
scratch
.
output_conv
(
path_1
)
return
out
class
DPTDepthModel
(
DPT
):
def
__init__
(
self
,
path
=
None
,
non_negative
=
True
,
**
kwargs
):
features
=
kwargs
[
"features"
]
if
"features"
in
kwargs
else
256
head
=
nn
.
Sequential
(
nn
.
Conv2d
(
features
,
features
//
2
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
Interpolate
(
scale_factor
=
2
,
mode
=
"bilinear"
,
align_corners
=
True
),
nn
.
Conv2d
(
features
//
2
,
32
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
32
,
1
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
),
nn
.
ReLU
(
True
)
if
non_negative
else
nn
.
Identity
(),
nn
.
Identity
(),
)
super
().
__init__
(
head
,
**
kwargs
)
if
path
is
not
None
:
self
.
load
(
path
)
def
forward
(
self
,
x
):
return
super
().
forward
(
x
).
squeeze
(
dim
=
1
)
examples/images/diffusion/ldm/modules/midas/midas/midas_net.py
0 → 100644
View file @
e532679c
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import
torch
import
torch.nn
as
nn
from
.base_model
import
BaseModel
from
.blocks
import
FeatureFusionBlock
,
Interpolate
,
_make_encoder
class
MidasNet
(
BaseModel
):
"""Network for monocular depth estimation.
"""
def
__init__
(
self
,
path
=
None
,
features
=
256
,
non_negative
=
True
):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print
(
"Loading weights: "
,
path
)
super
(
MidasNet
,
self
).
__init__
()
use_pretrained
=
False
if
path
is
None
else
True
self
.
pretrained
,
self
.
scratch
=
_make_encoder
(
backbone
=
"resnext101_wsl"
,
features
=
features
,
use_pretrained
=
use_pretrained
)
self
.
scratch
.
refinenet4
=
FeatureFusionBlock
(
features
)
self
.
scratch
.
refinenet3
=
FeatureFusionBlock
(
features
)
self
.
scratch
.
refinenet2
=
FeatureFusionBlock
(
features
)
self
.
scratch
.
refinenet1
=
FeatureFusionBlock
(
features
)
self
.
scratch
.
output_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
features
,
128
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
Interpolate
(
scale_factor
=
2
,
mode
=
"bilinear"
),
nn
.
Conv2d
(
128
,
32
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
32
,
1
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
),
nn
.
ReLU
(
True
)
if
non_negative
else
nn
.
Identity
(),
)
if
path
:
self
.
load
(
path
)
def
forward
(
self
,
x
):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
layer_1
=
self
.
pretrained
.
layer1
(
x
)
layer_2
=
self
.
pretrained
.
layer2
(
layer_1
)
layer_3
=
self
.
pretrained
.
layer3
(
layer_2
)
layer_4
=
self
.
pretrained
.
layer4
(
layer_3
)
layer_1_rn
=
self
.
scratch
.
layer1_rn
(
layer_1
)
layer_2_rn
=
self
.
scratch
.
layer2_rn
(
layer_2
)
layer_3_rn
=
self
.
scratch
.
layer3_rn
(
layer_3
)
layer_4_rn
=
self
.
scratch
.
layer4_rn
(
layer_4
)
path_4
=
self
.
scratch
.
refinenet4
(
layer_4_rn
)
path_3
=
self
.
scratch
.
refinenet3
(
path_4
,
layer_3_rn
)
path_2
=
self
.
scratch
.
refinenet2
(
path_3
,
layer_2_rn
)
path_1
=
self
.
scratch
.
refinenet1
(
path_2
,
layer_1_rn
)
out
=
self
.
scratch
.
output_conv
(
path_1
)
return
torch
.
squeeze
(
out
,
dim
=
1
)
examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py
0 → 100644
View file @
e532679c
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import
torch
import
torch.nn
as
nn
from
.base_model
import
BaseModel
from
.blocks
import
FeatureFusionBlock
,
FeatureFusionBlock_custom
,
Interpolate
,
_make_encoder
class
MidasNet_small
(
BaseModel
):
"""Network for monocular depth estimation.
"""
def
__init__
(
self
,
path
=
None
,
features
=
64
,
backbone
=
"efficientnet_lite3"
,
non_negative
=
True
,
exportable
=
True
,
channels_last
=
False
,
align_corners
=
True
,
blocks
=
{
'expand'
:
True
}):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print
(
"Loading weights: "
,
path
)
super
(
MidasNet_small
,
self
).
__init__
()
use_pretrained
=
False
if
path
else
True
self
.
channels_last
=
channels_last
self
.
blocks
=
blocks
self
.
backbone
=
backbone
self
.
groups
=
1
features1
=
features
features2
=
features
features3
=
features
features4
=
features
self
.
expand
=
False
if
"expand"
in
self
.
blocks
and
self
.
blocks
[
'expand'
]
==
True
:
self
.
expand
=
True
features1
=
features
features2
=
features
*
2
features3
=
features
*
4
features4
=
features
*
8
self
.
pretrained
,
self
.
scratch
=
_make_encoder
(
self
.
backbone
,
features
,
use_pretrained
,
groups
=
self
.
groups
,
expand
=
self
.
expand
,
exportable
=
exportable
)
self
.
scratch
.
activation
=
nn
.
ReLU
(
False
)
self
.
scratch
.
refinenet4
=
FeatureFusionBlock_custom
(
features4
,
self
.
scratch
.
activation
,
deconv
=
False
,
bn
=
False
,
expand
=
self
.
expand
,
align_corners
=
align_corners
)
self
.
scratch
.
refinenet3
=
FeatureFusionBlock_custom
(
features3
,
self
.
scratch
.
activation
,
deconv
=
False
,
bn
=
False
,
expand
=
self
.
expand
,
align_corners
=
align_corners
)
self
.
scratch
.
refinenet2
=
FeatureFusionBlock_custom
(
features2
,
self
.
scratch
.
activation
,
deconv
=
False
,
bn
=
False
,
expand
=
self
.
expand
,
align_corners
=
align_corners
)
self
.
scratch
.
refinenet1
=
FeatureFusionBlock_custom
(
features1
,
self
.
scratch
.
activation
,
deconv
=
False
,
bn
=
False
,
align_corners
=
align_corners
)
self
.
scratch
.
output_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
features
,
features
//
2
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
groups
=
self
.
groups
),
Interpolate
(
scale_factor
=
2
,
mode
=
"bilinear"
),
nn
.
Conv2d
(
features
//
2
,
32
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
self
.
scratch
.
activation
,
nn
.
Conv2d
(
32
,
1
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
),
nn
.
ReLU
(
True
)
if
non_negative
else
nn
.
Identity
(),
nn
.
Identity
(),
)
if
path
:
self
.
load
(
path
)
def
forward
(
self
,
x
):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
if
self
.
channels_last
==
True
:
print
(
"self.channels_last = "
,
self
.
channels_last
)
x
.
contiguous
(
memory_format
=
torch
.
channels_last
)
layer_1
=
self
.
pretrained
.
layer1
(
x
)
layer_2
=
self
.
pretrained
.
layer2
(
layer_1
)
layer_3
=
self
.
pretrained
.
layer3
(
layer_2
)
layer_4
=
self
.
pretrained
.
layer4
(
layer_3
)
layer_1_rn
=
self
.
scratch
.
layer1_rn
(
layer_1
)
layer_2_rn
=
self
.
scratch
.
layer2_rn
(
layer_2
)
layer_3_rn
=
self
.
scratch
.
layer3_rn
(
layer_3
)
layer_4_rn
=
self
.
scratch
.
layer4_rn
(
layer_4
)
path_4
=
self
.
scratch
.
refinenet4
(
layer_4_rn
)
path_3
=
self
.
scratch
.
refinenet3
(
path_4
,
layer_3_rn
)
path_2
=
self
.
scratch
.
refinenet2
(
path_3
,
layer_2_rn
)
path_1
=
self
.
scratch
.
refinenet1
(
path_2
,
layer_1_rn
)
out
=
self
.
scratch
.
output_conv
(
path_1
)
return
torch
.
squeeze
(
out
,
dim
=
1
)
def
fuse_model
(
m
):
prev_previous_type
=
nn
.
Identity
()
prev_previous_name
=
''
previous_type
=
nn
.
Identity
()
previous_name
=
''
for
name
,
module
in
m
.
named_modules
():
if
prev_previous_type
==
nn
.
Conv2d
and
previous_type
==
nn
.
BatchNorm2d
and
type
(
module
)
==
nn
.
ReLU
:
# print("FUSED ", prev_previous_name, previous_name, name)
torch
.
quantization
.
fuse_modules
(
m
,
[
prev_previous_name
,
previous_name
,
name
],
inplace
=
True
)
elif
prev_previous_type
==
nn
.
Conv2d
and
previous_type
==
nn
.
BatchNorm2d
:
# print("FUSED ", prev_previous_name, previous_name)
torch
.
quantization
.
fuse_modules
(
m
,
[
prev_previous_name
,
previous_name
],
inplace
=
True
)
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
# print("FUSED ", previous_name, name)
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
prev_previous_type
=
previous_type
prev_previous_name
=
previous_name
previous_type
=
type
(
module
)
previous_name
=
name
\ No newline at end of file
examples/images/diffusion/ldm/modules/midas/midas/transforms.py
0 → 100644
View file @
e532679c
import
numpy
as
np
import
cv2
import
math
def
apply_min_size
(
sample
,
size
,
image_interpolation_method
=
cv2
.
INTER_AREA
):
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
Args:
sample (dict): sample
size (tuple): image size
Returns:
tuple: new size
"""
shape
=
list
(
sample
[
"disparity"
].
shape
)
if
shape
[
0
]
>=
size
[
0
]
and
shape
[
1
]
>=
size
[
1
]:
return
sample
scale
=
[
0
,
0
]
scale
[
0
]
=
size
[
0
]
/
shape
[
0
]
scale
[
1
]
=
size
[
1
]
/
shape
[
1
]
scale
=
max
(
scale
)
shape
[
0
]
=
math
.
ceil
(
scale
*
shape
[
0
])
shape
[
1
]
=
math
.
ceil
(
scale
*
shape
[
1
])
# resize
sample
[
"image"
]
=
cv2
.
resize
(
sample
[
"image"
],
tuple
(
shape
[::
-
1
]),
interpolation
=
image_interpolation_method
)
sample
[
"disparity"
]
=
cv2
.
resize
(
sample
[
"disparity"
],
tuple
(
shape
[::
-
1
]),
interpolation
=
cv2
.
INTER_NEAREST
)
sample
[
"mask"
]
=
cv2
.
resize
(
sample
[
"mask"
].
astype
(
np
.
float32
),
tuple
(
shape
[::
-
1
]),
interpolation
=
cv2
.
INTER_NEAREST
,
)
sample
[
"mask"
]
=
sample
[
"mask"
].
astype
(
bool
)
return
tuple
(
shape
)
class
Resize
(
object
):
"""Resize sample to given size (width, height).
"""
def
__init__
(
self
,
width
,
height
,
resize_target
=
True
,
keep_aspect_ratio
=
False
,
ensure_multiple_of
=
1
,
resize_method
=
"lower_bound"
,
image_interpolation_method
=
cv2
.
INTER_AREA
,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self
.
__width
=
width
self
.
__height
=
height
self
.
__resize_target
=
resize_target
self
.
__keep_aspect_ratio
=
keep_aspect_ratio
self
.
__multiple_of
=
ensure_multiple_of
self
.
__resize_method
=
resize_method
self
.
__image_interpolation_method
=
image_interpolation_method
def
constrain_to_multiple_of
(
self
,
x
,
min_val
=
0
,
max_val
=
None
):
y
=
(
np
.
round
(
x
/
self
.
__multiple_of
)
*
self
.
__multiple_of
).
astype
(
int
)
if
max_val
is
not
None
and
y
>
max_val
:
y
=
(
np
.
floor
(
x
/
self
.
__multiple_of
)
*
self
.
__multiple_of
).
astype
(
int
)
if
y
<
min_val
:
y
=
(
np
.
ceil
(
x
/
self
.
__multiple_of
)
*
self
.
__multiple_of
).
astype
(
int
)
return
y
def
get_size
(
self
,
width
,
height
):
# determine new height and width
scale_height
=
self
.
__height
/
height
scale_width
=
self
.
__width
/
width
if
self
.
__keep_aspect_ratio
:
if
self
.
__resize_method
==
"lower_bound"
:
# scale such that output size is lower bound
if
scale_width
>
scale_height
:
# fit width
scale_height
=
scale_width
else
:
# fit height
scale_width
=
scale_height
elif
self
.
__resize_method
==
"upper_bound"
:
# scale such that output size is upper bound
if
scale_width
<
scale_height
:
# fit width
scale_height
=
scale_width
else
:
# fit height
scale_width
=
scale_height
elif
self
.
__resize_method
==
"minimal"
:
# scale as least as possbile
if
abs
(
1
-
scale_width
)
<
abs
(
1
-
scale_height
):
# fit width
scale_height
=
scale_width
else
:
# fit height
scale_width
=
scale_height
else
:
raise
ValueError
(
f
"resize_method
{
self
.
__resize_method
}
not implemented"
)
if
self
.
__resize_method
==
"lower_bound"
:
new_height
=
self
.
constrain_to_multiple_of
(
scale_height
*
height
,
min_val
=
self
.
__height
)
new_width
=
self
.
constrain_to_multiple_of
(
scale_width
*
width
,
min_val
=
self
.
__width
)
elif
self
.
__resize_method
==
"upper_bound"
:
new_height
=
self
.
constrain_to_multiple_of
(
scale_height
*
height
,
max_val
=
self
.
__height
)
new_width
=
self
.
constrain_to_multiple_of
(
scale_width
*
width
,
max_val
=
self
.
__width
)
elif
self
.
__resize_method
==
"minimal"
:
new_height
=
self
.
constrain_to_multiple_of
(
scale_height
*
height
)
new_width
=
self
.
constrain_to_multiple_of
(
scale_width
*
width
)
else
:
raise
ValueError
(
f
"resize_method
{
self
.
__resize_method
}
not implemented"
)
return
(
new_width
,
new_height
)
def
__call__
(
self
,
sample
):
width
,
height
=
self
.
get_size
(
sample
[
"image"
].
shape
[
1
],
sample
[
"image"
].
shape
[
0
]
)
# resize sample
sample
[
"image"
]
=
cv2
.
resize
(
sample
[
"image"
],
(
width
,
height
),
interpolation
=
self
.
__image_interpolation_method
,
)
if
self
.
__resize_target
:
if
"disparity"
in
sample
:
sample
[
"disparity"
]
=
cv2
.
resize
(
sample
[
"disparity"
],
(
width
,
height
),
interpolation
=
cv2
.
INTER_NEAREST
,
)
if
"depth"
in
sample
:
sample
[
"depth"
]
=
cv2
.
resize
(
sample
[
"depth"
],
(
width
,
height
),
interpolation
=
cv2
.
INTER_NEAREST
)
sample
[
"mask"
]
=
cv2
.
resize
(
sample
[
"mask"
].
astype
(
np
.
float32
),
(
width
,
height
),
interpolation
=
cv2
.
INTER_NEAREST
,
)
sample
[
"mask"
]
=
sample
[
"mask"
].
astype
(
bool
)
return
sample
class
NormalizeImage
(
object
):
"""Normlize image by given mean and std.
"""
def
__init__
(
self
,
mean
,
std
):
self
.
__mean
=
mean
self
.
__std
=
std
def
__call__
(
self
,
sample
):
sample
[
"image"
]
=
(
sample
[
"image"
]
-
self
.
__mean
)
/
self
.
__std
return
sample
class
PrepareForNet
(
object
):
"""Prepare sample for usage as network input.
"""
def
__init__
(
self
):
pass
def
__call__
(
self
,
sample
):
image
=
np
.
transpose
(
sample
[
"image"
],
(
2
,
0
,
1
))
sample
[
"image"
]
=
np
.
ascontiguousarray
(
image
).
astype
(
np
.
float32
)
if
"mask"
in
sample
:
sample
[
"mask"
]
=
sample
[
"mask"
].
astype
(
np
.
float32
)
sample
[
"mask"
]
=
np
.
ascontiguousarray
(
sample
[
"mask"
])
if
"disparity"
in
sample
:
disparity
=
sample
[
"disparity"
].
astype
(
np
.
float32
)
sample
[
"disparity"
]
=
np
.
ascontiguousarray
(
disparity
)
if
"depth"
in
sample
:
depth
=
sample
[
"depth"
].
astype
(
np
.
float32
)
sample
[
"depth"
]
=
np
.
ascontiguousarray
(
depth
)
return
sample
examples/images/diffusion/ldm/modules/midas/midas/vit.py
0 → 100644
View file @
e532679c
import
torch
import
torch.nn
as
nn
import
timm
import
types
import
math
import
torch.nn.functional
as
F
class
Slice
(
nn
.
Module
):
def
__init__
(
self
,
start_index
=
1
):
super
(
Slice
,
self
).
__init__
()
self
.
start_index
=
start_index
def
forward
(
self
,
x
):
return
x
[:,
self
.
start_index
:]
class
AddReadout
(
nn
.
Module
):
def
__init__
(
self
,
start_index
=
1
):
super
(
AddReadout
,
self
).
__init__
()
self
.
start_index
=
start_index
def
forward
(
self
,
x
):
if
self
.
start_index
==
2
:
readout
=
(
x
[:,
0
]
+
x
[:,
1
])
/
2
else
:
readout
=
x
[:,
0
]
return
x
[:,
self
.
start_index
:]
+
readout
.
unsqueeze
(
1
)
class
ProjectReadout
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
start_index
=
1
):
super
(
ProjectReadout
,
self
).
__init__
()
self
.
start_index
=
start_index
self
.
project
=
nn
.
Sequential
(
nn
.
Linear
(
2
*
in_features
,
in_features
),
nn
.
GELU
())
def
forward
(
self
,
x
):
readout
=
x
[:,
0
].
unsqueeze
(
1
).
expand_as
(
x
[:,
self
.
start_index
:])
features
=
torch
.
cat
((
x
[:,
self
.
start_index
:],
readout
),
-
1
)
return
self
.
project
(
features
)
class
Transpose
(
nn
.
Module
):
def
__init__
(
self
,
dim0
,
dim1
):
super
(
Transpose
,
self
).
__init__
()
self
.
dim0
=
dim0
self
.
dim1
=
dim1
def
forward
(
self
,
x
):
x
=
x
.
transpose
(
self
.
dim0
,
self
.
dim1
)
return
x
def
forward_vit
(
pretrained
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
glob
=
pretrained
.
model
.
forward_flex
(
x
)
layer_1
=
pretrained
.
activations
[
"1"
]
layer_2
=
pretrained
.
activations
[
"2"
]
layer_3
=
pretrained
.
activations
[
"3"
]
layer_4
=
pretrained
.
activations
[
"4"
]
layer_1
=
pretrained
.
act_postprocess1
[
0
:
2
](
layer_1
)
layer_2
=
pretrained
.
act_postprocess2
[
0
:
2
](
layer_2
)
layer_3
=
pretrained
.
act_postprocess3
[
0
:
2
](
layer_3
)
layer_4
=
pretrained
.
act_postprocess4
[
0
:
2
](
layer_4
)
unflatten
=
nn
.
Sequential
(
nn
.
Unflatten
(
2
,
torch
.
Size
(
[
h
//
pretrained
.
model
.
patch_size
[
1
],
w
//
pretrained
.
model
.
patch_size
[
0
],
]
),
)
)
if
layer_1
.
ndim
==
3
:
layer_1
=
unflatten
(
layer_1
)
if
layer_2
.
ndim
==
3
:
layer_2
=
unflatten
(
layer_2
)
if
layer_3
.
ndim
==
3
:
layer_3
=
unflatten
(
layer_3
)
if
layer_4
.
ndim
==
3
:
layer_4
=
unflatten
(
layer_4
)
layer_1
=
pretrained
.
act_postprocess1
[
3
:
len
(
pretrained
.
act_postprocess1
)](
layer_1
)
layer_2
=
pretrained
.
act_postprocess2
[
3
:
len
(
pretrained
.
act_postprocess2
)](
layer_2
)
layer_3
=
pretrained
.
act_postprocess3
[
3
:
len
(
pretrained
.
act_postprocess3
)](
layer_3
)
layer_4
=
pretrained
.
act_postprocess4
[
3
:
len
(
pretrained
.
act_postprocess4
)](
layer_4
)
return
layer_1
,
layer_2
,
layer_3
,
layer_4
def
_resize_pos_embed
(
self
,
posemb
,
gs_h
,
gs_w
):
posemb_tok
,
posemb_grid
=
(
posemb
[:,
:
self
.
start_index
],
posemb
[
0
,
self
.
start_index
:],
)
gs_old
=
int
(
math
.
sqrt
(
len
(
posemb_grid
)))
posemb_grid
=
posemb_grid
.
reshape
(
1
,
gs_old
,
gs_old
,
-
1
).
permute
(
0
,
3
,
1
,
2
)
posemb_grid
=
F
.
interpolate
(
posemb_grid
,
size
=
(
gs_h
,
gs_w
),
mode
=
"bilinear"
)
posemb_grid
=
posemb_grid
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
1
,
gs_h
*
gs_w
,
-
1
)
posemb
=
torch
.
cat
([
posemb_tok
,
posemb_grid
],
dim
=
1
)
return
posemb
def
forward_flex
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
pos_embed
=
self
.
_resize_pos_embed
(
self
.
pos_embed
,
h
//
self
.
patch_size
[
1
],
w
//
self
.
patch_size
[
0
]
)
B
=
x
.
shape
[
0
]
if
hasattr
(
self
.
patch_embed
,
"backbone"
):
x
=
self
.
patch_embed
.
backbone
(
x
)
if
isinstance
(
x
,
(
list
,
tuple
)):
x
=
x
[
-
1
]
# last feature if backbone outputs list/tuple of features
x
=
self
.
patch_embed
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
if
getattr
(
self
,
"dist_token"
,
None
)
is
not
None
:
cls_tokens
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
# stole cls_tokens impl from Phil Wang, thanks
dist_token
=
self
.
dist_token
.
expand
(
B
,
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_tokens
,
dist_token
,
x
),
dim
=
1
)
else
:
cls_tokens
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
# stole cls_tokens impl from Phil Wang, thanks
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
x
=
x
+
pos_embed
x
=
self
.
pos_drop
(
x
)
for
blk
in
self
.
blocks
:
x
=
blk
(
x
)
x
=
self
.
norm
(
x
)
return
x
activations
=
{}
def
get_activation
(
name
):
def
hook
(
model
,
input
,
output
):
activations
[
name
]
=
output
return
hook
def
get_readout_oper
(
vit_features
,
features
,
use_readout
,
start_index
=
1
):
if
use_readout
==
"ignore"
:
readout_oper
=
[
Slice
(
start_index
)]
*
len
(
features
)
elif
use_readout
==
"add"
:
readout_oper
=
[
AddReadout
(
start_index
)]
*
len
(
features
)
elif
use_readout
==
"project"
:
readout_oper
=
[
ProjectReadout
(
vit_features
,
start_index
)
for
out_feat
in
features
]
else
:
assert
(
False
),
"wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
return
readout_oper
def
_make_vit_b16_backbone
(
model
,
features
=
[
96
,
192
,
384
,
768
],
size
=
[
384
,
384
],
hooks
=
[
2
,
5
,
8
,
11
],
vit_features
=
768
,
use_readout
=
"ignore"
,
start_index
=
1
,
):
pretrained
=
nn
.
Module
()
pretrained
.
model
=
model
pretrained
.
model
.
blocks
[
hooks
[
0
]].
register_forward_hook
(
get_activation
(
"1"
))
pretrained
.
model
.
blocks
[
hooks
[
1
]].
register_forward_hook
(
get_activation
(
"2"
))
pretrained
.
model
.
blocks
[
hooks
[
2
]].
register_forward_hook
(
get_activation
(
"3"
))
pretrained
.
model
.
blocks
[
hooks
[
3
]].
register_forward_hook
(
get_activation
(
"4"
))
pretrained
.
activations
=
activations
readout_oper
=
get_readout_oper
(
vit_features
,
features
,
use_readout
,
start_index
)
# 32, 48, 136, 384
pretrained
.
act_postprocess1
=
nn
.
Sequential
(
readout_oper
[
0
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
0
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
nn
.
ConvTranspose2d
(
in_channels
=
features
[
0
],
out_channels
=
features
[
0
],
kernel_size
=
4
,
stride
=
4
,
padding
=
0
,
bias
=
True
,
dilation
=
1
,
groups
=
1
,
),
)
pretrained
.
act_postprocess2
=
nn
.
Sequential
(
readout_oper
[
1
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
1
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
nn
.
ConvTranspose2d
(
in_channels
=
features
[
1
],
out_channels
=
features
[
1
],
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
bias
=
True
,
dilation
=
1
,
groups
=
1
,
),
)
pretrained
.
act_postprocess3
=
nn
.
Sequential
(
readout_oper
[
2
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
2
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
)
pretrained
.
act_postprocess4
=
nn
.
Sequential
(
readout_oper
[
3
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
3
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
nn
.
Conv2d
(
in_channels
=
features
[
3
],
out_channels
=
features
[
3
],
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
),
)
pretrained
.
model
.
start_index
=
start_index
pretrained
.
model
.
patch_size
=
[
16
,
16
]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained
.
model
.
forward_flex
=
types
.
MethodType
(
forward_flex
,
pretrained
.
model
)
pretrained
.
model
.
_resize_pos_embed
=
types
.
MethodType
(
_resize_pos_embed
,
pretrained
.
model
)
return
pretrained
def
_make_pretrained_vitl16_384
(
pretrained
,
use_readout
=
"ignore"
,
hooks
=
None
):
model
=
timm
.
create_model
(
"vit_large_patch16_384"
,
pretrained
=
pretrained
)
hooks
=
[
5
,
11
,
17
,
23
]
if
hooks
==
None
else
hooks
return
_make_vit_b16_backbone
(
model
,
features
=
[
256
,
512
,
1024
,
1024
],
hooks
=
hooks
,
vit_features
=
1024
,
use_readout
=
use_readout
,
)
def
_make_pretrained_vitb16_384
(
pretrained
,
use_readout
=
"ignore"
,
hooks
=
None
):
model
=
timm
.
create_model
(
"vit_base_patch16_384"
,
pretrained
=
pretrained
)
hooks
=
[
2
,
5
,
8
,
11
]
if
hooks
==
None
else
hooks
return
_make_vit_b16_backbone
(
model
,
features
=
[
96
,
192
,
384
,
768
],
hooks
=
hooks
,
use_readout
=
use_readout
)
def
_make_pretrained_deitb16_384
(
pretrained
,
use_readout
=
"ignore"
,
hooks
=
None
):
model
=
timm
.
create_model
(
"vit_deit_base_patch16_384"
,
pretrained
=
pretrained
)
hooks
=
[
2
,
5
,
8
,
11
]
if
hooks
==
None
else
hooks
return
_make_vit_b16_backbone
(
model
,
features
=
[
96
,
192
,
384
,
768
],
hooks
=
hooks
,
use_readout
=
use_readout
)
def
_make_pretrained_deitb16_distil_384
(
pretrained
,
use_readout
=
"ignore"
,
hooks
=
None
):
model
=
timm
.
create_model
(
"vit_deit_base_distilled_patch16_384"
,
pretrained
=
pretrained
)
hooks
=
[
2
,
5
,
8
,
11
]
if
hooks
==
None
else
hooks
return
_make_vit_b16_backbone
(
model
,
features
=
[
96
,
192
,
384
,
768
],
hooks
=
hooks
,
use_readout
=
use_readout
,
start_index
=
2
,
)
def
_make_vit_b_rn50_backbone
(
model
,
features
=
[
256
,
512
,
768
,
768
],
size
=
[
384
,
384
],
hooks
=
[
0
,
1
,
8
,
11
],
vit_features
=
768
,
use_vit_only
=
False
,
use_readout
=
"ignore"
,
start_index
=
1
,
):
pretrained
=
nn
.
Module
()
pretrained
.
model
=
model
if
use_vit_only
==
True
:
pretrained
.
model
.
blocks
[
hooks
[
0
]].
register_forward_hook
(
get_activation
(
"1"
))
pretrained
.
model
.
blocks
[
hooks
[
1
]].
register_forward_hook
(
get_activation
(
"2"
))
else
:
pretrained
.
model
.
patch_embed
.
backbone
.
stages
[
0
].
register_forward_hook
(
get_activation
(
"1"
)
)
pretrained
.
model
.
patch_embed
.
backbone
.
stages
[
1
].
register_forward_hook
(
get_activation
(
"2"
)
)
pretrained
.
model
.
blocks
[
hooks
[
2
]].
register_forward_hook
(
get_activation
(
"3"
))
pretrained
.
model
.
blocks
[
hooks
[
3
]].
register_forward_hook
(
get_activation
(
"4"
))
pretrained
.
activations
=
activations
readout_oper
=
get_readout_oper
(
vit_features
,
features
,
use_readout
,
start_index
)
if
use_vit_only
==
True
:
pretrained
.
act_postprocess1
=
nn
.
Sequential
(
readout_oper
[
0
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
0
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
nn
.
ConvTranspose2d
(
in_channels
=
features
[
0
],
out_channels
=
features
[
0
],
kernel_size
=
4
,
stride
=
4
,
padding
=
0
,
bias
=
True
,
dilation
=
1
,
groups
=
1
,
),
)
pretrained
.
act_postprocess2
=
nn
.
Sequential
(
readout_oper
[
1
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
1
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
nn
.
ConvTranspose2d
(
in_channels
=
features
[
1
],
out_channels
=
features
[
1
],
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
bias
=
True
,
dilation
=
1
,
groups
=
1
,
),
)
else
:
pretrained
.
act_postprocess1
=
nn
.
Sequential
(
nn
.
Identity
(),
nn
.
Identity
(),
nn
.
Identity
()
)
pretrained
.
act_postprocess2
=
nn
.
Sequential
(
nn
.
Identity
(),
nn
.
Identity
(),
nn
.
Identity
()
)
pretrained
.
act_postprocess3
=
nn
.
Sequential
(
readout_oper
[
2
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
2
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
)
pretrained
.
act_postprocess4
=
nn
.
Sequential
(
readout_oper
[
3
],
Transpose
(
1
,
2
),
nn
.
Unflatten
(
2
,
torch
.
Size
([
size
[
0
]
//
16
,
size
[
1
]
//
16
])),
nn
.
Conv2d
(
in_channels
=
vit_features
,
out_channels
=
features
[
3
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
),
nn
.
Conv2d
(
in_channels
=
features
[
3
],
out_channels
=
features
[
3
],
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
),
)
pretrained
.
model
.
start_index
=
start_index
pretrained
.
model
.
patch_size
=
[
16
,
16
]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained
.
model
.
forward_flex
=
types
.
MethodType
(
forward_flex
,
pretrained
.
model
)
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained
.
model
.
_resize_pos_embed
=
types
.
MethodType
(
_resize_pos_embed
,
pretrained
.
model
)
return
pretrained
def
_make_pretrained_vitb_rn50_384
(
pretrained
,
use_readout
=
"ignore"
,
hooks
=
None
,
use_vit_only
=
False
):
model
=
timm
.
create_model
(
"vit_base_resnet50_384"
,
pretrained
=
pretrained
)
hooks
=
[
0
,
1
,
8
,
11
]
if
hooks
==
None
else
hooks
return
_make_vit_b_rn50_backbone
(
model
,
features
=
[
256
,
512
,
768
,
768
],
size
=
[
384
,
384
],
hooks
=
hooks
,
use_vit_only
=
use_vit_only
,
use_readout
=
use_readout
,
)
Prev
1
…
15
16
17
18
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment