Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
stablediffusion_v2.1_pytorch
Commits
4007efdd
Commit
4007efdd
authored
May 12, 2024
by
lijian6
Browse files
Initial commit
parents
Pipeline
#994
canceled with stages
Changes
138
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5014 additions
and
0 deletions
+5014
-0
ldm/modules/diffusionmodules/openaimodel.py
ldm/modules/diffusionmodules/openaimodel.py
+807
-0
ldm/modules/diffusionmodules/upscaling.py
ldm/modules/diffusionmodules/upscaling.py
+81
-0
ldm/modules/diffusionmodules/util.py
ldm/modules/diffusionmodules/util.py
+278
-0
ldm/modules/distributions/__init__.py
ldm/modules/distributions/__init__.py
+0
-0
ldm/modules/distributions/distributions.py
ldm/modules/distributions/distributions.py
+92
-0
ldm/modules/ema.py
ldm/modules/ema.py
+80
-0
ldm/modules/encoders/__init__.py
ldm/modules/encoders/__init__.py
+0
-0
ldm/modules/encoders/modules.py
ldm/modules/encoders/modules.py
+350
-0
ldm/modules/image_degradation/__init__.py
ldm/modules/image_degradation/__init__.py
+2
-0
ldm/modules/image_degradation/bsrgan.py
ldm/modules/image_degradation/bsrgan.py
+730
-0
ldm/modules/image_degradation/bsrgan_light.py
ldm/modules/image_degradation/bsrgan_light.py
+651
-0
ldm/modules/image_degradation/utils/test.png
ldm/modules/image_degradation/utils/test.png
+0
-0
ldm/modules/image_degradation/utils_image.py
ldm/modules/image_degradation/utils_image.py
+917
-0
ldm/modules/karlo/__init__.py
ldm/modules/karlo/__init__.py
+0
-0
ldm/modules/karlo/diffusers_pipeline.py
ldm/modules/karlo/diffusers_pipeline.py
+513
-0
ldm/modules/karlo/kakao/__init__.py
ldm/modules/karlo/kakao/__init__.py
+0
-0
ldm/modules/karlo/kakao/models/__init__.py
ldm/modules/karlo/kakao/models/__init__.py
+0
-0
ldm/modules/karlo/kakao/models/clip.py
ldm/modules/karlo/kakao/models/clip.py
+182
-0
ldm/modules/karlo/kakao/models/decoder_model.py
ldm/modules/karlo/kakao/models/decoder_model.py
+193
-0
ldm/modules/karlo/kakao/models/prior_model.py
ldm/modules/karlo/kakao/models/prior_model.py
+138
-0
No files found.
ldm/modules/diffusionmodules/openaimodel.py
0 → 100644
View file @
4007efdd
from
abc
import
abstractmethod
import
math
import
numpy
as
np
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
ldm.modules.diffusionmodules.util
import
(
checkpoint
,
conv_nd
,
linear
,
avg_pool_nd
,
zero_module
,
normalization
,
timestep_embedding
,
)
from
ldm.modules.attention
import
SpatialTransformer
from
ldm.util
import
exists
# dummy replace
def
convert_module_to_f16
(
x
):
pass
def
convert_module_to_f32
(
x
):
pass
## go
class
AttentionPool2d
(
nn
.
Module
):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def
__init__
(
self
,
spacial_dim
:
int
,
embed_dim
:
int
,
num_heads_channels
:
int
,
output_dim
:
int
=
None
,
):
super
().
__init__
()
self
.
positional_embedding
=
nn
.
Parameter
(
th
.
randn
(
embed_dim
,
spacial_dim
**
2
+
1
)
/
embed_dim
**
0.5
)
self
.
qkv_proj
=
conv_nd
(
1
,
embed_dim
,
3
*
embed_dim
,
1
)
self
.
c_proj
=
conv_nd
(
1
,
embed_dim
,
output_dim
or
embed_dim
,
1
)
self
.
num_heads
=
embed_dim
//
num_heads_channels
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
def
forward
(
self
,
x
):
b
,
c
,
*
_spatial
=
x
.
shape
x
=
x
.
reshape
(
b
,
c
,
-
1
)
# NC(HW)
x
=
th
.
cat
([
x
.
mean
(
dim
=-
1
,
keepdim
=
True
),
x
],
dim
=-
1
)
# NC(HW+1)
x
=
x
+
self
.
positional_embedding
[
None
,
:,
:].
to
(
x
.
dtype
)
# NC(HW+1)
x
=
self
.
qkv_proj
(
x
)
x
=
self
.
attention
(
x
)
x
=
self
.
c_proj
(
x
)
return
x
[:,
:,
0
]
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@
abstractmethod
def
forward
(
self
,
x
,
emb
):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def
forward
(
self
,
x
,
emb
,
context
=
None
):
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
SpatialTransformer
):
x
=
layer
(
x
,
context
)
else
:
x
=
layer
(
x
)
return
x
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
TransposedUpsample
(
nn
.
Module
):
'Learned 2x upsampling without padding'
def
__init__
(
self
,
channels
,
out_channels
=
None
,
ks
=
5
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
up
=
nn
.
ConvTranspose2d
(
self
.
channels
,
self
.
out_channels
,
kernel_size
=
ks
,
stride
=
2
)
def
forward
(
self
,
x
):
return
self
.
up
(
x
)
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
op
(
x
)
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
),
nn
.
SiLU
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)
),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
forward
(
self
,
x
,
emb
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return
checkpoint
(
self
.
_forward
,
(
x
,
emb
),
self
.
parameters
(),
self
.
use_checkpoint
)
def
_forward
(
self
,
x
,
emb
):
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
th
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
AttentionBlock
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
use_new_attention_order
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
if
use_new_attention_order
:
# split qkv before split heads
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
else
:
# split heads before split qkv
self
.
attention
=
QKVAttentionLegacy
(
self
.
num_heads
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
def
forward
(
self
,
x
):
return
checkpoint
(
self
.
_forward
,
(
x
,),
self
.
parameters
(),
True
)
# TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
#return pt_checkpoint(self._forward, x) # pytorch
def
_forward
(
self
,
x
):
b
,
c
,
*
spatial
=
x
.
shape
x
=
x
.
reshape
(
b
,
c
,
-
1
)
qkv
=
self
.
qkv
(
self
.
norm
(
x
))
h
=
self
.
attention
(
qkv
)
h
=
self
.
proj_out
(
h
)
return
(
x
+
h
).
reshape
(
b
,
c
,
*
spatial
)
def
count_flops_attn
(
model
,
_x
,
y
):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b
,
c
,
*
spatial
=
y
[
0
].
shape
num_spatial
=
int
(
np
.
prod
(
spatial
))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops
=
2
*
b
*
(
num_spatial
**
2
)
*
c
model
.
total_ops
+=
th
.
DoubleTensor
([
matmul_ops
])
class
QKVAttentionLegacy
(
nn
.
Module
):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
th
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
th
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
th
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
class
QKVAttention
(
nn
.
Module
):
"""
A module which performs QKV attention and splits in a different order.
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
th
.
einsum
(
"bct,bcs->bts"
,
(
q
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
(
k
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
)
# More stable with f16 than dividing afterwards
weight
=
th
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
th
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
.
reshape
(
bs
*
self
.
n_heads
,
ch
,
length
))
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
class
Timestep
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
t
):
return
timestep_embedding
(
t
,
self
.
dim
)
class
UNetModel
(
nn
.
Module
):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def
__init__
(
self
,
image_size
,
in_channels
,
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
num_classes
=
None
,
use_checkpoint
=
False
,
use_fp16
=
False
,
use_bf16
=
False
,
num_heads
=-
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
use_new_attention_order
=
False
,
use_spatial_transformer
=
False
,
# custom transformer support
transformer_depth
=
1
,
# custom transformer support
context_dim
=
None
,
# custom transformer support
n_embed
=
None
,
# custom support for prediction of discrete ids into codebook of first stage vq model
legacy
=
True
,
disable_self_attentions
=
None
,
num_attention_blocks
=
None
,
disable_middle_self_attn
=
False
,
use_linear_in_transformer
=
False
,
adm_in_channels
=
None
,
):
super
().
__init__
()
if
use_spatial_transformer
:
assert
context_dim
is
not
None
,
'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if
context_dim
is
not
None
:
assert
use_spatial_transformer
,
'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from
omegaconf.listconfig
import
ListConfig
if
type
(
context_dim
)
==
ListConfig
:
context_dim
=
list
(
context_dim
)
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
if
num_heads
==
-
1
:
assert
num_head_channels
!=
-
1
,
'Either num_heads or num_head_channels has to be set'
if
num_head_channels
==
-
1
:
assert
num_heads
!=
-
1
,
'Either num_heads or num_head_channels has to be set'
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
out_channels
=
out_channels
if
isinstance
(
num_res_blocks
,
int
):
self
.
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
else
:
if
len
(
num_res_blocks
)
!=
len
(
channel_mult
):
raise
ValueError
(
"provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult"
)
self
.
num_res_blocks
=
num_res_blocks
if
disable_self_attentions
is
not
None
:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert
len
(
disable_self_attentions
)
==
len
(
channel_mult
)
if
num_attention_blocks
is
not
None
:
assert
len
(
num_attention_blocks
)
==
len
(
self
.
num_res_blocks
)
assert
all
(
map
(
lambda
i
:
self
.
num_res_blocks
[
i
]
>=
num_attention_blocks
[
i
],
range
(
len
(
num_attention_blocks
))))
print
(
f
"Constructor of UNetModel received num_attention_blocks=
{
num_attention_blocks
}
. "
f
"This option has LESS priority than attention_resolutions
{
attention_resolutions
}
, "
f
"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f
"attention will still not be set."
)
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
num_classes
=
num_classes
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
th
.
float16
if
use_fp16
else
th
.
float32
self
.
dtype
=
th
.
bfloat16
if
use_bf16
else
self
.
dtype
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
self
.
predict_codebook_ids
=
n_embed
is
not
None
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
if
self
.
num_classes
is
not
None
:
if
isinstance
(
self
.
num_classes
,
int
):
self
.
label_emb
=
nn
.
Embedding
(
num_classes
,
time_embed_dim
)
elif
self
.
num_classes
==
"continuous"
:
print
(
"setting up linear c_adm embedding layer"
)
self
.
label_emb
=
nn
.
Linear
(
1
,
time_embed_dim
)
elif
self
.
num_classes
==
"sequential"
:
assert
adm_in_channels
is
not
None
self
.
label_emb
=
nn
.
Sequential
(
nn
.
Sequential
(
linear
(
adm_in_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
)
else
:
raise
ValueError
()
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
)
)
]
)
self
.
_feature_size
=
model_channels
input_block_chans
=
[
model_channels
]
ch
=
model_channels
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
nr
in
range
(
self
.
num_res_blocks
[
level
]):
layers
=
[
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
#num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
if
exists
(
disable_self_attentions
):
disabled_sa
=
disable_self_attentions
[
level
]
else
:
disabled_sa
=
False
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
ds
*=
2
self
.
_feature_size
+=
ch
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
#num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
),
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
)
self
.
_feature_size
+=
ch
self
.
output_blocks
=
nn
.
ModuleList
([])
for
level
,
mult
in
list
(
enumerate
(
channel_mult
))[::
-
1
]:
for
i
in
range
(
self
.
num_res_blocks
[
level
]
+
1
):
ich
=
input_block_chans
.
pop
()
layers
=
[
ResBlock
(
ch
+
ich
,
time_embed_dim
,
dropout
,
out_channels
=
model_channels
*
mult
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
]
ch
=
model_channels
*
mult
if
ds
in
attention_resolutions
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
#num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
if
exists
(
disable_self_attentions
):
disabled_sa
=
disable_self_attentions
[
level
]
else
:
disabled_sa
=
False
if
not
exists
(
num_attention_blocks
)
or
i
<
num_attention_blocks
[
level
]:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads_upsample
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
)
)
if
level
and
i
==
self
.
num_res_blocks
[
level
]:
out_ch
=
ch
layers
.
append
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
up
=
True
,
)
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
),
nn
.
SiLU
(),
zero_module
(
conv_nd
(
dims
,
model_channels
,
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
predict_codebook_ids
:
self
.
id_predictor
=
nn
.
Sequential
(
normalization
(
ch
),
conv_nd
(
dims
,
model_channels
,
n_embed
,
1
),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
def
convert_to_fp16
(
self
):
"""
Convert the torso of the model to float16.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f16
)
self
.
middle_block
.
apply
(
convert_module_to_f16
)
self
.
output_blocks
.
apply
(
convert_module_to_f16
)
def
convert_to_fp32
(
self
):
"""
Convert the torso of the model to float32.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
**
kwargs
):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert
(
y
is
not
None
)
==
(
self
.
num_classes
is
not
None
),
"must specify y if and only if the model is class-conditional"
hs
=
[]
t_emb
=
timestep_embedding
(
timesteps
,
self
.
model_channels
,
repeat_only
=
False
)
emb
=
self
.
time_embed
(
t_emb
)
if
self
.
num_classes
is
not
None
:
assert
y
.
shape
[
0
]
==
x
.
shape
[
0
]
emb
=
emb
+
self
.
label_emb
(
y
)
h
=
x
.
type
(
self
.
dtype
)
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
,
context
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
,
context
)
for
module
in
self
.
output_blocks
:
h
=
th
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
,
context
)
h
=
h
.
type
(
x
.
dtype
)
if
self
.
predict_codebook_ids
:
return
self
.
id_predictor
(
h
)
else
:
return
self
.
out
(
h
)
ldm/modules/diffusionmodules/upscaling.py
0 → 100644
View file @
4007efdd
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
functools
import
partial
from
ldm.modules.diffusionmodules.util
import
extract_into_tensor
,
make_beta_schedule
from
ldm.util
import
default
class
AbstractLowScaleModel
(
nn
.
Module
):
# for concatenating a downsampled image to the latent representation
def
__init__
(
self
,
noise_schedule_config
=
None
):
super
(
AbstractLowScaleModel
,
self
).
__init__
()
if
noise_schedule_config
is
not
None
:
self
.
register_schedule
(
**
noise_schedule_config
)
def
register_schedule
(
self
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
linear_start
=
1e-4
,
linear_end
=
2e-2
,
cosine_s
=
8e-3
):
betas
=
make_beta_schedule
(
beta_schedule
,
timesteps
,
linear_start
=
linear_start
,
linear_end
=
linear_end
,
cosine_s
=
cosine_s
)
alphas
=
1.
-
betas
alphas_cumprod
=
np
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod_prev
=
np
.
append
(
1.
,
alphas_cumprod
[:
-
1
])
timesteps
,
=
betas
.
shape
self
.
num_timesteps
=
int
(
timesteps
)
self
.
linear_start
=
linear_start
self
.
linear_end
=
linear_end
assert
alphas_cumprod
.
shape
[
0
]
==
self
.
num_timesteps
,
'alphas have to be defined for each timestep'
to_torch
=
partial
(
torch
.
tensor
,
dtype
=
torch
.
float32
)
self
.
register_buffer
(
'betas'
,
to_torch
(
betas
))
self
.
register_buffer
(
'alphas_cumprod'
,
to_torch
(
alphas_cumprod
))
self
.
register_buffer
(
'alphas_cumprod_prev'
,
to_torch
(
alphas_cumprod_prev
))
# calculations for diffusion q(x_t | x_{t-1}) and others
self
.
register_buffer
(
'sqrt_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
alphas_cumprod
)))
self
.
register_buffer
(
'sqrt_one_minus_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
-
alphas_cumprod
)))
self
.
register_buffer
(
'log_one_minus_alphas_cumprod'
,
to_torch
(
np
.
log
(
1.
-
alphas_cumprod
)))
self
.
register_buffer
(
'sqrt_recip_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
/
alphas_cumprod
)))
self
.
register_buffer
(
'sqrt_recipm1_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
/
alphas_cumprod
-
1
)))
def
q_sample
(
self
,
x_start
,
t
,
noise
=
None
):
noise
=
default
(
noise
,
lambda
:
torch
.
randn_like
(
x_start
))
return
(
extract_into_tensor
(
self
.
sqrt_alphas_cumprod
,
t
,
x_start
.
shape
)
*
x_start
+
extract_into_tensor
(
self
.
sqrt_one_minus_alphas_cumprod
,
t
,
x_start
.
shape
)
*
noise
)
def
forward
(
self
,
x
):
return
x
,
None
def
decode
(
self
,
x
):
return
x
class
SimpleImageConcat
(
AbstractLowScaleModel
):
# no noise level conditioning
def
__init__
(
self
):
super
(
SimpleImageConcat
,
self
).
__init__
(
noise_schedule_config
=
None
)
self
.
max_noise_level
=
0
def
forward
(
self
,
x
):
# fix to constant noise level
return
x
,
torch
.
zeros
(
x
.
shape
[
0
],
device
=
x
.
device
).
long
()
class
ImageConcatWithNoiseAugmentation
(
AbstractLowScaleModel
):
def
__init__
(
self
,
noise_schedule_config
,
max_noise_level
=
1000
,
to_cuda
=
False
):
super
().
__init__
(
noise_schedule_config
=
noise_schedule_config
)
self
.
max_noise_level
=
max_noise_level
def
forward
(
self
,
x
,
noise_level
=
None
):
if
noise_level
is
None
:
noise_level
=
torch
.
randint
(
0
,
self
.
max_noise_level
,
(
x
.
shape
[
0
],),
device
=
x
.
device
).
long
()
else
:
assert
isinstance
(
noise_level
,
torch
.
Tensor
)
z
=
self
.
q_sample
(
x
,
noise_level
)
return
z
,
noise_level
ldm/modules/diffusionmodules/util.py
0 → 100644
View file @
4007efdd
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!
import
os
import
math
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
einops
import
repeat
from
ldm.util
import
instantiate_from_config
def
make_beta_schedule
(
schedule
,
n_timestep
,
linear_start
=
1e-4
,
linear_end
=
2e-2
,
cosine_s
=
8e-3
):
if
schedule
==
"linear"
:
betas
=
(
torch
.
linspace
(
linear_start
**
0.5
,
linear_end
**
0.5
,
n_timestep
,
dtype
=
torch
.
float64
)
**
2
)
elif
schedule
==
"cosine"
:
timesteps
=
(
torch
.
arange
(
n_timestep
+
1
,
dtype
=
torch
.
float64
)
/
n_timestep
+
cosine_s
)
alphas
=
timesteps
/
(
1
+
cosine_s
)
*
np
.
pi
/
2
alphas
=
torch
.
cos
(
alphas
).
pow
(
2
)
alphas
=
alphas
/
alphas
[
0
]
betas
=
1
-
alphas
[
1
:]
/
alphas
[:
-
1
]
betas
=
np
.
clip
(
betas
,
a_min
=
0
,
a_max
=
0.999
)
elif
schedule
==
"squaredcos_cap_v2"
:
# used for karlo prior
# return early
return
betas_for_alpha_bar
(
n_timestep
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
elif
schedule
==
"sqrt_linear"
:
betas
=
torch
.
linspace
(
linear_start
,
linear_end
,
n_timestep
,
dtype
=
torch
.
float64
)
elif
schedule
==
"sqrt"
:
betas
=
torch
.
linspace
(
linear_start
,
linear_end
,
n_timestep
,
dtype
=
torch
.
float64
)
**
0.5
else
:
raise
ValueError
(
f
"schedule '
{
schedule
}
' unknown."
)
return
betas
.
numpy
()
def
make_ddim_timesteps
(
ddim_discr_method
,
num_ddim_timesteps
,
num_ddpm_timesteps
,
verbose
=
True
):
if
ddim_discr_method
==
'uniform'
:
c
=
num_ddpm_timesteps
//
num_ddim_timesteps
ddim_timesteps
=
np
.
asarray
(
list
(
range
(
0
,
num_ddpm_timesteps
,
c
)))
elif
ddim_discr_method
==
'quad'
:
ddim_timesteps
=
((
np
.
linspace
(
0
,
np
.
sqrt
(
num_ddpm_timesteps
*
.
8
),
num_ddim_timesteps
))
**
2
).
astype
(
int
)
else
:
raise
NotImplementedError
(
f
'There is no ddim discretization method called "
{
ddim_discr_method
}
"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out
=
ddim_timesteps
+
1
if
verbose
:
print
(
f
'Selected timesteps for ddim sampler:
{
steps_out
}
'
)
return
steps_out
def
make_ddim_sampling_parameters
(
alphacums
,
ddim_timesteps
,
eta
,
verbose
=
True
):
# select alphas for computing the variance schedule
alphas
=
alphacums
[
ddim_timesteps
]
alphas_prev
=
np
.
asarray
([
alphacums
[
0
]]
+
alphacums
[
ddim_timesteps
[:
-
1
]].
tolist
())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas
=
eta
*
np
.
sqrt
((
1
-
alphas_prev
)
/
(
1
-
alphas
)
*
(
1
-
alphas
/
alphas_prev
))
if
verbose
:
print
(
f
'Selected alphas for ddim sampler: a_t:
{
alphas
}
; a_(t-1):
{
alphas_prev
}
'
)
print
(
f
'For the chosen value of eta, which is
{
eta
}
, '
f
'this results in the following sigma_t schedule for ddim sampler
{
sigmas
}
'
)
return
sigmas
,
alphas
,
alphas_prev
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
alpha_bar
,
max_beta
=
0.999
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
np
.
array
(
betas
)
def
extract_into_tensor
(
a
,
t
,
x_shape
):
b
,
*
_
=
t
.
shape
out
=
a
.
gather
(
-
1
,
t
)
return
out
.
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_shape
)
-
1
)))
def
checkpoint
(
func
,
inputs
,
params
,
flag
):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if
flag
:
args
=
tuple
(
inputs
)
+
tuple
(
params
)
return
CheckpointFunction
.
apply
(
func
,
len
(
inputs
),
*
args
)
else
:
return
func
(
*
inputs
)
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
run_function
,
length
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
input_tensors
=
list
(
args
[:
length
])
ctx
.
input_params
=
list
(
args
[
length
:])
ctx
.
gpu_autocast_kwargs
=
{
"enabled"
:
torch
.
is_autocast_enabled
(),
"dtype"
:
torch
.
get_autocast_gpu_dtype
(),
"cache_enabled"
:
torch
.
is_autocast_cache_enabled
()}
with
torch
.
no_grad
():
output_tensors
=
ctx
.
run_function
(
*
ctx
.
input_tensors
)
return
output_tensors
@
staticmethod
def
backward
(
ctx
,
*
output_grads
):
ctx
.
input_tensors
=
[
x
.
detach
().
requires_grad_
(
True
)
for
x
in
ctx
.
input_tensors
]
with
torch
.
enable_grad
(),
\
torch
.
cuda
.
amp
.
autocast
(
**
ctx
.
gpu_autocast_kwargs
):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies
=
[
x
.
view_as
(
x
)
for
x
in
ctx
.
input_tensors
]
output_tensors
=
ctx
.
run_function
(
*
shallow_copies
)
input_grads
=
torch
.
autograd
.
grad
(
output_tensors
,
ctx
.
input_tensors
+
ctx
.
input_params
,
output_grads
,
allow_unused
=
True
,
)
del
ctx
.
input_tensors
del
ctx
.
input_params
del
output_tensors
return
(
None
,
None
)
+
input_grads
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
,
repeat_only
=
False
):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if
not
repeat_only
:
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
else
:
embedding
=
repeat
(
timesteps
,
'b -> b d'
,
d
=
dim
)
return
embedding
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
def
scale_module
(
module
,
scale
):
"""
Scale the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
mul_
(
scale
)
return
module
def
mean_flat
(
tensor
):
"""
Take the mean over all non-batch dimensions.
"""
return
tensor
.
mean
(
dim
=
list
(
range
(
1
,
len
(
tensor
.
shape
))))
def
normalization
(
channels
):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return
GroupNorm32
(
32
,
channels
)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class
SiLU
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
*
torch
.
sigmoid
(
x
)
class
GroupNorm32
(
nn
.
GroupNorm
):
def
forward
(
self
,
x
):
return
super
().
forward
(
x
.
float
()).
type
(
x
.
dtype
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if
dims
==
1
:
return
nn
.
Conv1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
linear
(
*
args
,
**
kwargs
):
"""
Create a linear module.
"""
return
nn
.
Linear
(
*
args
,
**
kwargs
)
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if
dims
==
1
:
return
nn
.
AvgPool1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
AvgPool2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
class
HybridConditioner
(
nn
.
Module
):
def
__init__
(
self
,
c_concat_config
,
c_crossattn_config
):
super
().
__init__
()
self
.
concat_conditioner
=
instantiate_from_config
(
c_concat_config
)
self
.
crossattn_conditioner
=
instantiate_from_config
(
c_crossattn_config
)
def
forward
(
self
,
c_concat
,
c_crossattn
):
c_concat
=
self
.
concat_conditioner
(
c_concat
)
c_crossattn
=
self
.
crossattn_conditioner
(
c_crossattn
)
return
{
'c_concat'
:
[
c_concat
],
'c_crossattn'
:
[
c_crossattn
]}
def
noise_like
(
shape
,
device
,
repeat
=
False
):
repeat_noise
=
lambda
:
torch
.
randn
((
1
,
*
shape
[
1
:]),
device
=
device
).
repeat
(
shape
[
0
],
*
((
1
,)
*
(
len
(
shape
)
-
1
)))
noise
=
lambda
:
torch
.
randn
(
shape
,
device
=
device
)
return
repeat_noise
()
if
repeat
else
noise
()
ldm/modules/distributions/__init__.py
0 → 100644
View file @
4007efdd
ldm/modules/distributions/distributions.py
0 → 100644
View file @
4007efdd
import
torch
import
numpy
as
np
class
AbstractDistribution
:
def
sample
(
self
):
raise
NotImplementedError
()
def
mode
(
self
):
raise
NotImplementedError
()
class
DiracDistribution
(
AbstractDistribution
):
def
__init__
(
self
,
value
):
self
.
value
=
value
def
sample
(
self
):
return
self
.
value
def
mode
(
self
):
return
self
.
value
class
DiagonalGaussianDistribution
(
object
):
def
__init__
(
self
,
parameters
,
deterministic
=
False
):
self
.
parameters
=
parameters
self
.
mean
,
self
.
logvar
=
torch
.
chunk
(
parameters
,
2
,
dim
=
1
)
self
.
logvar
=
torch
.
clamp
(
self
.
logvar
,
-
30.0
,
20.0
)
self
.
deterministic
=
deterministic
self
.
std
=
torch
.
exp
(
0.5
*
self
.
logvar
)
self
.
var
=
torch
.
exp
(
self
.
logvar
)
if
self
.
deterministic
:
self
.
var
=
self
.
std
=
torch
.
zeros_like
(
self
.
mean
).
to
(
device
=
self
.
parameters
.
device
)
def
sample
(
self
):
x
=
self
.
mean
+
self
.
std
*
torch
.
randn
(
self
.
mean
.
shape
).
to
(
device
=
self
.
parameters
.
device
)
return
x
def
kl
(
self
,
other
=
None
):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
])
else
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
else
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
])
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
def
mode
(
self
):
return
self
.
mean
def
normal_kl
(
mean1
,
logvar1
,
mean2
,
logvar2
):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor
=
None
for
obj
in
(
mean1
,
logvar1
,
mean2
,
logvar2
):
if
isinstance
(
obj
,
torch
.
Tensor
):
tensor
=
obj
break
assert
tensor
is
not
None
,
"at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1
,
logvar2
=
[
x
if
isinstance
(
x
,
torch
.
Tensor
)
else
torch
.
tensor
(
x
).
to
(
tensor
)
for
x
in
(
logvar1
,
logvar2
)
]
return
0.5
*
(
-
1.0
+
logvar2
-
logvar1
+
torch
.
exp
(
logvar1
-
logvar2
)
+
((
mean1
-
mean2
)
**
2
)
*
torch
.
exp
(
-
logvar2
)
)
ldm/modules/ema.py
0 → 100644
View file @
4007efdd
import
torch
from
torch
import
nn
class
LitEma
(
nn
.
Module
):
def
__init__
(
self
,
model
,
decay
=
0.9999
,
use_num_upates
=
True
):
super
().
__init__
()
if
decay
<
0.0
or
decay
>
1.0
:
raise
ValueError
(
'Decay must be between 0 and 1'
)
self
.
m_name2s_name
=
{}
self
.
register_buffer
(
'decay'
,
torch
.
tensor
(
decay
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
'num_updates'
,
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
use_num_upates
else
torch
.
tensor
(
-
1
,
dtype
=
torch
.
int
))
for
name
,
p
in
model
.
named_parameters
():
if
p
.
requires_grad
:
# remove as '.'-character is not allowed in buffers
s_name
=
name
.
replace
(
'.'
,
''
)
self
.
m_name2s_name
.
update
({
name
:
s_name
})
self
.
register_buffer
(
s_name
,
p
.
clone
().
detach
().
data
)
self
.
collected_params
=
[]
def
reset_num_updates
(
self
):
del
self
.
num_updates
self
.
register_buffer
(
'num_updates'
,
torch
.
tensor
(
0
,
dtype
=
torch
.
int
))
def
forward
(
self
,
model
):
decay
=
self
.
decay
if
self
.
num_updates
>=
0
:
self
.
num_updates
+=
1
decay
=
min
(
self
.
decay
,
(
1
+
self
.
num_updates
)
/
(
10
+
self
.
num_updates
))
one_minus_decay
=
1.0
-
decay
with
torch
.
no_grad
():
m_param
=
dict
(
model
.
named_parameters
())
shadow_params
=
dict
(
self
.
named_buffers
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
sname
=
self
.
m_name2s_name
[
key
]
shadow_params
[
sname
]
=
shadow_params
[
sname
].
type_as
(
m_param
[
key
])
shadow_params
[
sname
].
sub_
(
one_minus_decay
*
(
shadow_params
[
sname
]
-
m_param
[
key
]))
else
:
assert
not
key
in
self
.
m_name2s_name
def
copy_to
(
self
,
model
):
m_param
=
dict
(
model
.
named_parameters
())
shadow_params
=
dict
(
self
.
named_buffers
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
m_param
[
key
].
data
.
copy_
(
shadow_params
[
self
.
m_name2s_name
[
key
]].
data
)
else
:
assert
not
key
in
self
.
m_name2s_name
def
store
(
self
,
parameters
):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self
.
collected_params
=
[
param
.
clone
()
for
param
in
parameters
]
def
restore
(
self
,
parameters
):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for
c_param
,
param
in
zip
(
self
.
collected_params
,
parameters
):
param
.
data
.
copy_
(
c_param
.
data
)
ldm/modules/encoders/__init__.py
0 → 100644
View file @
4007efdd
ldm/modules/encoders/modules.py
0 → 100644
View file @
4007efdd
import
torch
import
torch.nn
as
nn
import
kornia
from
torch.utils.checkpoint
import
checkpoint
from
transformers
import
T5Tokenizer
,
T5EncoderModel
,
CLIPTokenizer
,
CLIPTextModel
import
open_clip
from
ldm.util
import
default
,
count_params
,
autocast
class
AbstractEncoder
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
encode
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
class
IdentityEncoder
(
AbstractEncoder
):
def
encode
(
self
,
x
):
return
x
class
ClassEmbedder
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
n_classes
=
1000
,
key
=
'class'
,
ucg_rate
=
0.1
):
super
().
__init__
()
self
.
key
=
key
self
.
embedding
=
nn
.
Embedding
(
n_classes
,
embed_dim
)
self
.
n_classes
=
n_classes
self
.
ucg_rate
=
ucg_rate
def
forward
(
self
,
batch
,
key
=
None
,
disable_dropout
=
False
):
if
key
is
None
:
key
=
self
.
key
# this is for use in crossattn
c
=
batch
[
key
][:,
None
]
if
self
.
ucg_rate
>
0.
and
not
disable_dropout
:
mask
=
1.
-
torch
.
bernoulli
(
torch
.
ones_like
(
c
)
*
self
.
ucg_rate
)
c
=
mask
*
c
+
(
1
-
mask
)
*
torch
.
ones_like
(
c
)
*
(
self
.
n_classes
-
1
)
c
=
c
.
long
()
c
=
self
.
embedding
(
c
)
return
c
def
get_unconditional_conditioning
(
self
,
bs
,
device
=
"cuda"
):
uc_class
=
self
.
n_classes
-
1
# 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc
=
torch
.
ones
((
bs
,),
device
=
device
)
*
uc_class
uc
=
{
self
.
key
:
uc
}
return
uc
def
disabled_train
(
self
,
mode
=
True
):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return
self
class
FrozenT5Embedder
(
AbstractEncoder
):
"""Uses the T5 transformer encoder for text"""
def
__init__
(
self
,
version
=
"google/t5-v1_1-large"
,
device
=
"cuda"
,
max_length
=
77
,
freeze
=
True
):
# others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super
().
__init__
()
self
.
tokenizer
=
T5Tokenizer
.
from_pretrained
(
version
)
self
.
transformer
=
T5EncoderModel
.
from_pretrained
(
version
)
self
.
device
=
device
self
.
max_length
=
max_length
# TODO: typical value?
if
freeze
:
self
.
freeze
()
def
freeze
(
self
):
self
.
transformer
=
self
.
transformer
.
eval
()
# self.train = disabled_train
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
text
):
batch_encoding
=
self
.
tokenizer
(
text
,
truncation
=
True
,
max_length
=
self
.
max_length
,
return_length
=
True
,
return_overflowing_tokens
=
False
,
padding
=
"max_length"
,
return_tensors
=
"pt"
)
tokens
=
batch_encoding
[
"input_ids"
].
to
(
self
.
device
)
outputs
=
self
.
transformer
(
input_ids
=
tokens
)
z
=
outputs
.
last_hidden_state
return
z
def
encode
(
self
,
text
):
return
self
(
text
)
class
FrozenCLIPEmbedder
(
AbstractEncoder
):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS
=
[
"last"
,
"pooled"
,
"hidden"
]
def
__init__
(
self
,
version
=
"openai/clip-vit-large-patch14"
,
device
=
"cuda"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"last"
,
layer_idx
=
None
):
# clip-vit-base-patch32
super
().
__init__
()
assert
layer
in
self
.
LAYERS
self
.
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
version
)
self
.
transformer
=
CLIPTextModel
.
from_pretrained
(
version
)
self
.
device
=
device
self
.
max_length
=
max_length
if
freeze
:
self
.
freeze
()
self
.
layer
=
layer
self
.
layer_idx
=
layer_idx
if
layer
==
"hidden"
:
assert
layer_idx
is
not
None
assert
0
<=
abs
(
layer_idx
)
<=
12
def
freeze
(
self
):
self
.
transformer
=
self
.
transformer
.
eval
()
# self.train = disabled_train
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
text
):
batch_encoding
=
self
.
tokenizer
(
text
,
truncation
=
True
,
max_length
=
self
.
max_length
,
return_length
=
True
,
return_overflowing_tokens
=
False
,
padding
=
"max_length"
,
return_tensors
=
"pt"
)
tokens
=
batch_encoding
[
"input_ids"
].
to
(
self
.
device
)
outputs
=
self
.
transformer
(
input_ids
=
tokens
,
output_hidden_states
=
self
.
layer
==
"hidden"
)
if
self
.
layer
==
"last"
:
z
=
outputs
.
last_hidden_state
elif
self
.
layer
==
"pooled"
:
z
=
outputs
.
pooler_output
[:,
None
,
:]
else
:
z
=
outputs
.
hidden_states
[
self
.
layer_idx
]
return
z
def
encode
(
self
,
text
):
return
self
(
text
)
class
ClipImageEmbedder
(
nn
.
Module
):
def
__init__
(
self
,
model
,
jit
=
False
,
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
,
antialias
=
True
,
ucg_rate
=
0.
):
super
().
__init__
()
from
clip
import
load
as
load_clip
self
.
model
,
_
=
load_clip
(
name
=
model
,
device
=
device
,
jit
=
jit
)
self
.
antialias
=
antialias
self
.
register_buffer
(
'mean'
,
torch
.
Tensor
([
0.48145466
,
0.4578275
,
0.40821073
]),
persistent
=
False
)
self
.
register_buffer
(
'std'
,
torch
.
Tensor
([
0.26862954
,
0.26130258
,
0.27577711
]),
persistent
=
False
)
self
.
ucg_rate
=
ucg_rate
def
preprocess
(
self
,
x
):
# normalize to [0,1]
x
=
kornia
.
geometry
.
resize
(
x
,
(
224
,
224
),
interpolation
=
'bicubic'
,
align_corners
=
True
,
antialias
=
self
.
antialias
)
x
=
(
x
+
1.
)
/
2.
# re-normalize according to clip
x
=
kornia
.
enhance
.
normalize
(
x
,
self
.
mean
,
self
.
std
)
return
x
def
forward
(
self
,
x
,
no_dropout
=
False
):
# x is assumed to be in range [-1,1]
out
=
self
.
model
.
encode_image
(
self
.
preprocess
(
x
))
out
=
out
.
to
(
x
.
dtype
)
if
self
.
ucg_rate
>
0.
and
not
no_dropout
:
out
=
torch
.
bernoulli
((
1.
-
self
.
ucg_rate
)
*
torch
.
ones
(
out
.
shape
[
0
],
device
=
out
.
device
))[:,
None
]
*
out
return
out
class
FrozenOpenCLIPEmbedder
(
AbstractEncoder
):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS
=
[
# "pooled",
"last"
,
"penultimate"
]
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
version
=
"laion2b_s32b_b79k"
,
device
=
"cuda"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"last"
):
super
().
__init__
()
assert
layer
in
self
.
LAYERS
model
,
_
,
_
=
open_clip
.
create_model_and_transforms
(
arch
,
device
=
torch
.
device
(
'cpu'
),
pretrained
=
version
)
del
model
.
visual
self
.
model
=
model
self
.
device
=
device
self
.
max_length
=
max_length
if
freeze
:
self
.
freeze
()
self
.
layer
=
layer
if
self
.
layer
==
"last"
:
self
.
layer_idx
=
0
elif
self
.
layer
==
"penultimate"
:
self
.
layer_idx
=
1
else
:
raise
NotImplementedError
()
def
freeze
(
self
):
self
.
model
=
self
.
model
.
eval
()
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
text
):
tokens
=
open_clip
.
tokenize
(
text
)
z
=
self
.
encode_with_transformer
(
tokens
.
to
(
self
.
device
))
return
z
def
encode_with_transformer
(
self
,
text
):
x
=
self
.
model
.
token_embedding
(
text
)
# [batch_size, n_ctx, d_model]
x
=
x
+
self
.
model
.
positional_embedding
x
=
x
.
permute
(
1
,
0
,
2
)
# NLD -> LND
x
=
self
.
text_transformer_forward
(
x
,
attn_mask
=
self
.
model
.
attn_mask
)
x
=
x
.
permute
(
1
,
0
,
2
)
# LND -> NLD
x
=
self
.
model
.
ln_final
(
x
)
return
x
def
text_transformer_forward
(
self
,
x
:
torch
.
Tensor
,
attn_mask
=
None
):
for
i
,
r
in
enumerate
(
self
.
model
.
transformer
.
resblocks
):
if
i
==
len
(
self
.
model
.
transformer
.
resblocks
)
-
self
.
layer_idx
:
break
if
self
.
model
.
transformer
.
grad_checkpointing
and
not
torch
.
jit
.
is_scripting
():
x
=
checkpoint
(
r
,
x
,
attn_mask
)
else
:
x
=
r
(
x
,
attn_mask
=
attn_mask
)
return
x
def
encode
(
self
,
text
):
return
self
(
text
)
class
FrozenOpenCLIPImageEmbedder
(
AbstractEncoder
):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
version
=
"laion2b_s32b_b79k"
,
device
=
"cuda"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"pooled"
,
antialias
=
True
,
ucg_rate
=
0.
):
super
().
__init__
()
model
,
_
,
_
=
open_clip
.
create_model_and_transforms
(
arch
,
device
=
torch
.
device
(
'cpu'
),
pretrained
=
version
,
)
del
model
.
transformer
self
.
model
=
model
self
.
device
=
device
self
.
max_length
=
max_length
if
freeze
:
self
.
freeze
()
self
.
layer
=
layer
if
self
.
layer
==
"penultimate"
:
raise
NotImplementedError
()
self
.
layer_idx
=
1
self
.
antialias
=
antialias
self
.
register_buffer
(
'mean'
,
torch
.
Tensor
([
0.48145466
,
0.4578275
,
0.40821073
]),
persistent
=
False
)
self
.
register_buffer
(
'std'
,
torch
.
Tensor
([
0.26862954
,
0.26130258
,
0.27577711
]),
persistent
=
False
)
self
.
ucg_rate
=
ucg_rate
def
preprocess
(
self
,
x
):
# normalize to [0,1]
x
=
kornia
.
geometry
.
resize
(
x
,
(
224
,
224
),
interpolation
=
'bicubic'
,
align_corners
=
True
,
antialias
=
self
.
antialias
)
x
=
(
x
+
1.
)
/
2.
# renormalize according to clip
x
=
kornia
.
enhance
.
normalize
(
x
,
self
.
mean
,
self
.
std
)
return
x
def
freeze
(
self
):
self
.
model
=
self
.
model
.
eval
()
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
@
autocast
def
forward
(
self
,
image
,
no_dropout
=
False
):
z
=
self
.
encode_with_vision_transformer
(
image
)
if
self
.
ucg_rate
>
0.
and
not
no_dropout
:
z
=
torch
.
bernoulli
((
1.
-
self
.
ucg_rate
)
*
torch
.
ones
(
z
.
shape
[
0
],
device
=
z
.
device
))[:,
None
]
*
z
return
z
def
encode_with_vision_transformer
(
self
,
img
):
img
=
self
.
preprocess
(
img
)
x
=
self
.
model
.
visual
(
img
)
return
x
def
encode
(
self
,
text
):
return
self
(
text
)
class
FrozenCLIPT5Encoder
(
AbstractEncoder
):
def
__init__
(
self
,
clip_version
=
"openai/clip-vit-large-patch14"
,
t5_version
=
"google/t5-v1_1-xl"
,
device
=
"cuda"
,
clip_max_length
=
77
,
t5_max_length
=
77
):
super
().
__init__
()
self
.
clip_encoder
=
FrozenCLIPEmbedder
(
clip_version
,
device
,
max_length
=
clip_max_length
)
self
.
t5_encoder
=
FrozenT5Embedder
(
t5_version
,
device
,
max_length
=
t5_max_length
)
print
(
f
"
{
self
.
clip_encoder
.
__class__
.
__name__
}
has
{
count_params
(
self
.
clip_encoder
)
*
1.e-6
:.
2
f
}
M parameters, "
f
"
{
self
.
t5_encoder
.
__class__
.
__name__
}
comes with
{
count_params
(
self
.
t5_encoder
)
*
1.e-6
:.
2
f
}
M params."
)
def
encode
(
self
,
text
):
return
self
(
text
)
def
forward
(
self
,
text
):
clip_z
=
self
.
clip_encoder
.
encode
(
text
)
t5_z
=
self
.
t5_encoder
.
encode
(
text
)
return
[
clip_z
,
t5_z
]
from
ldm.modules.diffusionmodules.upscaling
import
ImageConcatWithNoiseAugmentation
from
ldm.modules.diffusionmodules.openaimodel
import
Timestep
class
CLIPEmbeddingNoiseAugmentation
(
ImageConcatWithNoiseAugmentation
):
def
__init__
(
self
,
*
args
,
clip_stats_path
=
None
,
timestep_dim
=
256
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
clip_stats_path
is
None
:
clip_mean
,
clip_std
=
torch
.
zeros
(
timestep_dim
),
torch
.
ones
(
timestep_dim
)
else
:
clip_mean
,
clip_std
=
torch
.
load
(
clip_stats_path
,
map_location
=
"cpu"
)
self
.
register_buffer
(
"data_mean"
,
clip_mean
[
None
,
:],
persistent
=
False
)
self
.
register_buffer
(
"data_std"
,
clip_std
[
None
,
:],
persistent
=
False
)
self
.
time_embed
=
Timestep
(
timestep_dim
)
def
scale
(
self
,
x
):
# re-normalize to centered mean and unit variance
x
=
(
x
-
self
.
data_mean
)
*
1.
/
self
.
data_std
return
x
def
unscale
(
self
,
x
):
# back to original data stats
x
=
(
x
*
self
.
data_std
)
+
self
.
data_mean
return
x
def
forward
(
self
,
x
,
noise_level
=
None
):
if
noise_level
is
None
:
noise_level
=
torch
.
randint
(
0
,
self
.
max_noise_level
,
(
x
.
shape
[
0
],),
device
=
x
.
device
).
long
()
else
:
assert
isinstance
(
noise_level
,
torch
.
Tensor
)
x
=
self
.
scale
(
x
)
z
=
self
.
q_sample
(
x
,
noise_level
)
z
=
self
.
unscale
(
z
)
noise_level
=
self
.
time_embed
(
noise_level
)
return
z
,
noise_level
ldm/modules/image_degradation/__init__.py
0 → 100644
View file @
4007efdd
from
ldm.modules.image_degradation.bsrgan
import
degradation_bsrgan_variant
as
degradation_fn_bsr
from
ldm.modules.image_degradation.bsrgan_light
import
degradation_bsrgan_variant
as
degradation_fn_bsr_light
ldm/modules/image_degradation/bsrgan.py
0 → 100644
View file @
4007efdd
# -*- coding: utf-8 -*-
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
import
numpy
as
np
import
cv2
import
torch
from
functools
import
partial
import
random
from
scipy
import
ndimage
import
scipy
import
scipy.stats
as
ss
from
scipy.interpolate
import
interp2d
from
scipy.linalg
import
orth
import
albumentations
import
ldm.modules.image_degradation.utils_image
as
util
def
modcrop_np
(
img
,
sf
):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w
,
h
=
img
.
shape
[:
2
]
im
=
np
.
copy
(
img
)
return
im
[:
w
-
w
%
sf
,
:
h
-
h
%
sf
,
...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def
analytic_kernel
(
k
):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size
=
k
.
shape
[
0
]
# Calculate the big kernels size
big_k
=
np
.
zeros
((
3
*
k_size
-
2
,
3
*
k_size
-
2
))
# Loop over the small kernel to fill the big one
for
r
in
range
(
k_size
):
for
c
in
range
(
k_size
):
big_k
[
2
*
r
:
2
*
r
+
k_size
,
2
*
c
:
2
*
c
+
k_size
]
+=
k
[
r
,
c
]
*
k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop
=
k_size
//
2
cropped_big_k
=
big_k
[
crop
:
-
crop
,
crop
:
-
crop
]
# Normalize to 1
return
cropped_big_k
/
cropped_big_k
.
sum
()
def
anisotropic_Gaussian
(
ksize
=
15
,
theta
=
np
.
pi
,
l1
=
6
,
l2
=
6
):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v
=
np
.
dot
(
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]]),
np
.
array
([
1.
,
0.
]))
V
=
np
.
array
([[
v
[
0
],
v
[
1
]],
[
v
[
1
],
-
v
[
0
]]])
D
=
np
.
array
([[
l1
,
0
],
[
0
,
l2
]])
Sigma
=
np
.
dot
(
np
.
dot
(
V
,
D
),
np
.
linalg
.
inv
(
V
))
k
=
gm_blur_kernel
(
mean
=
[
0
,
0
],
cov
=
Sigma
,
size
=
ksize
)
return
k
def
gm_blur_kernel
(
mean
,
cov
,
size
=
15
):
center
=
size
/
2.0
+
0.5
k
=
np
.
zeros
([
size
,
size
])
for
y
in
range
(
size
):
for
x
in
range
(
size
):
cy
=
y
-
center
+
1
cx
=
x
-
center
+
1
k
[
y
,
x
]
=
ss
.
multivariate_normal
.
pdf
([
cx
,
cy
],
mean
=
mean
,
cov
=
cov
)
k
=
k
/
np
.
sum
(
k
)
return
k
def
shift_pixel
(
x
,
sf
,
upper_left
=
True
):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h
,
w
=
x
.
shape
[:
2
]
shift
=
(
sf
-
1
)
*
0.5
xv
,
yv
=
np
.
arange
(
0
,
w
,
1.0
),
np
.
arange
(
0
,
h
,
1.0
)
if
upper_left
:
x1
=
xv
+
shift
y1
=
yv
+
shift
else
:
x1
=
xv
-
shift
y1
=
yv
-
shift
x1
=
np
.
clip
(
x1
,
0
,
w
-
1
)
y1
=
np
.
clip
(
y1
,
0
,
h
-
1
)
if
x
.
ndim
==
2
:
x
=
interp2d
(
xv
,
yv
,
x
)(
x1
,
y1
)
if
x
.
ndim
==
3
:
for
i
in
range
(
x
.
shape
[
-
1
]):
x
[:,
:,
i
]
=
interp2d
(
xv
,
yv
,
x
[:,
:,
i
])(
x1
,
y1
)
return
x
def
blur
(
x
,
k
):
'''
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
n
,
c
=
x
.
shape
[:
2
]
p1
,
p2
=
(
k
.
shape
[
-
2
]
-
1
)
//
2
,
(
k
.
shape
[
-
1
]
-
1
)
//
2
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
=
(
p1
,
p2
,
p1
,
p2
),
mode
=
'replicate'
)
k
=
k
.
repeat
(
1
,
c
,
1
,
1
)
k
=
k
.
view
(
-
1
,
1
,
k
.
shape
[
2
],
k
.
shape
[
3
])
x
=
x
.
view
(
1
,
-
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
x
=
torch
.
nn
.
functional
.
conv2d
(
x
,
k
,
bias
=
None
,
stride
=
1
,
padding
=
0
,
groups
=
n
*
c
)
x
=
x
.
view
(
n
,
c
,
x
.
shape
[
2
],
x
.
shape
[
3
])
return
x
def
gen_kernel
(
k_size
=
np
.
array
([
15
,
15
]),
scale_factor
=
np
.
array
([
4
,
4
]),
min_var
=
0.6
,
max_var
=
10.
,
noise_level
=
0
):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
lambda_2
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
theta
=
np
.
random
.
rand
()
*
np
.
pi
# random theta
noise
=
-
noise_level
+
np
.
random
.
rand
(
*
k_size
)
*
noise_level
*
2
# Set COV matrix using Lambdas and Theta
LAMBDA
=
np
.
diag
([
lambda_1
,
lambda_2
])
Q
=
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]])
SIGMA
=
Q
@
LAMBDA
@
Q
.
T
INV_SIGMA
=
np
.
linalg
.
inv
(
SIGMA
)[
None
,
None
,
:,
:]
# Set expectation position (shifting kernel for aligned image)
MU
=
k_size
//
2
-
0.5
*
(
scale_factor
-
1
)
# - 0.5 * (scale_factor - k_size % 2)
MU
=
MU
[
None
,
None
,
:,
None
]
# Create meshgrid for Gaussian
[
X
,
Y
]
=
np
.
meshgrid
(
range
(
k_size
[
0
]),
range
(
k_size
[
1
]))
Z
=
np
.
stack
([
X
,
Y
],
2
)[:,
:,
:,
None
]
# Calcualte Gaussian for every pixel of the kernel
ZZ
=
Z
-
MU
ZZ_t
=
ZZ
.
transpose
(
0
,
1
,
3
,
2
)
raw_kernel
=
np
.
exp
(
-
0.5
*
np
.
squeeze
(
ZZ_t
@
INV_SIGMA
@
ZZ
))
*
(
1
+
noise
)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel
=
raw_kernel
/
np
.
sum
(
raw_kernel
)
return
kernel
def
fspecial_gaussian
(
hsize
,
sigma
):
hsize
=
[
hsize
,
hsize
]
siz
=
[(
hsize
[
0
]
-
1.0
)
/
2.0
,
(
hsize
[
1
]
-
1.0
)
/
2.0
]
std
=
sigma
[
x
,
y
]
=
np
.
meshgrid
(
np
.
arange
(
-
siz
[
1
],
siz
[
1
]
+
1
),
np
.
arange
(
-
siz
[
0
],
siz
[
0
]
+
1
))
arg
=
-
(
x
*
x
+
y
*
y
)
/
(
2
*
std
*
std
)
h
=
np
.
exp
(
arg
)
h
[
h
<
scipy
.
finfo
(
float
).
eps
*
h
.
max
()]
=
0
sumh
=
h
.
sum
()
if
sumh
!=
0
:
h
=
h
/
sumh
return
h
def
fspecial_laplacian
(
alpha
):
alpha
=
max
([
0
,
min
([
alpha
,
1
])])
h1
=
alpha
/
(
alpha
+
1
)
h2
=
(
1
-
alpha
)
/
(
alpha
+
1
)
h
=
[[
h1
,
h2
,
h1
],
[
h2
,
-
4
/
(
alpha
+
1
),
h2
],
[
h1
,
h2
,
h1
]]
h
=
np
.
array
(
h
)
return
h
def
fspecial
(
filter_type
,
*
args
,
**
kwargs
):
'''
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
if
filter_type
==
'gaussian'
:
return
fspecial_gaussian
(
*
args
,
**
kwargs
)
if
filter_type
==
'laplacian'
:
return
fspecial_laplacian
(
*
args
,
**
kwargs
)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def
bicubic_degradation
(
x
,
sf
=
3
):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x
=
util
.
imresize_np
(
x
,
scale
=
1
/
sf
)
return
x
def
srmd_degradation
(
x
,
k
,
sf
=
3
):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# 'nearest' | 'mirror'
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
return
x
def
dpsr_degradation
(
x
,
k
,
sf
=
3
):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
return
x
def
classical_degradation
(
x
,
k
,
sf
=
3
):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st
=
0
return
x
[
st
::
sf
,
st
::
sf
,
...]
def
add_sharpening
(
img
,
weight
=
0.5
,
radius
=
50
,
threshold
=
10
):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if
radius
%
2
==
0
:
radius
+=
1
blur
=
cv2
.
GaussianBlur
(
img
,
(
radius
,
radius
),
0
)
residual
=
img
-
blur
mask
=
np
.
abs
(
residual
)
*
255
>
threshold
mask
=
mask
.
astype
(
'float32'
)
soft_mask
=
cv2
.
GaussianBlur
(
mask
,
(
radius
,
radius
),
0
)
K
=
img
+
weight
*
residual
K
=
np
.
clip
(
K
,
0
,
1
)
return
soft_mask
*
K
+
(
1
-
soft_mask
)
*
img
def
add_blur
(
img
,
sf
=
4
):
wd2
=
4.0
+
sf
wd
=
2.0
+
0.2
*
sf
if
random
.
random
()
<
0.5
:
l1
=
wd2
*
random
.
random
()
l2
=
wd2
*
random
.
random
()
k
=
anisotropic_Gaussian
(
ksize
=
2
*
random
.
randint
(
2
,
11
)
+
3
,
theta
=
random
.
random
()
*
np
.
pi
,
l1
=
l1
,
l2
=
l2
)
else
:
k
=
fspecial
(
'gaussian'
,
2
*
random
.
randint
(
2
,
11
)
+
3
,
wd
*
random
.
random
())
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'mirror'
)
return
img
def
add_resize
(
img
,
sf
=
4
):
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.8
:
# up
sf1
=
random
.
uniform
(
1
,
2
)
elif
rnum
<
0.7
:
# down
sf1
=
random
.
uniform
(
0.5
/
sf
,
1
)
else
:
sf1
=
1.0
img
=
cv2
.
resize
(
img
,
(
int
(
sf1
*
img
.
shape
[
1
]),
int
(
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.6
:
# add color Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
# add grayscale Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
# add noise
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
=
img
+
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_speckle_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
rnum
=
random
.
random
()
if
rnum
>
0.6
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
+=
img
*
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_Poisson_noise
(
img
):
img
=
np
.
clip
((
img
*
255.0
).
round
(),
0
,
255
)
/
255.
vals
=
10
**
(
2
*
random
.
random
()
+
2.0
)
# [2, 4]
if
random
.
random
()
<
0.5
:
img
=
np
.
random
.
poisson
(
img
*
vals
).
astype
(
np
.
float32
)
/
vals
else
:
img_gray
=
np
.
dot
(
img
[...,
:
3
],
[
0.299
,
0.587
,
0.114
])
img_gray
=
np
.
clip
((
img_gray
*
255.0
).
round
(),
0
,
255
)
/
255.
noise_gray
=
np
.
random
.
poisson
(
img_gray
*
vals
).
astype
(
np
.
float32
)
/
vals
-
img_gray
img
+=
noise_gray
[:,
:,
np
.
newaxis
]
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_JPEG_noise
(
img
):
quality_factor
=
random
.
randint
(
30
,
95
)
img
=
cv2
.
cvtColor
(
util
.
single2uint
(
img
),
cv2
.
COLOR_RGB2BGR
)
result
,
encimg
=
cv2
.
imencode
(
'.jpg'
,
img
,
[
int
(
cv2
.
IMWRITE_JPEG_QUALITY
),
quality_factor
])
img
=
cv2
.
imdecode
(
encimg
,
1
)
img
=
cv2
.
cvtColor
(
util
.
uint2single
(
img
),
cv2
.
COLOR_BGR2RGB
)
return
img
def
random_crop
(
lq
,
hq
,
sf
=
4
,
lq_patchsize
=
64
):
h
,
w
=
lq
.
shape
[:
2
]
rnd_h
=
random
.
randint
(
0
,
h
-
lq_patchsize
)
rnd_w
=
random
.
randint
(
0
,
w
-
lq_patchsize
)
lq
=
lq
[
rnd_h
:
rnd_h
+
lq_patchsize
,
rnd_w
:
rnd_w
+
lq_patchsize
,
:]
rnd_h_H
,
rnd_w_H
=
int
(
rnd_h
*
sf
),
int
(
rnd_w
*
sf
)
hq
=
hq
[
rnd_h_H
:
rnd_h_H
+
lq_patchsize
*
sf
,
rnd_w_H
:
rnd_w_H
+
lq_patchsize
*
sf
,
:]
return
lq
,
hq
def
degradation_bsrgan
(
img
,
sf
=
4
,
lq_patchsize
=
72
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
img
.
shape
[:
2
]
img
=
img
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
img
.
shape
[:
2
]
if
h
<
lq_patchsize
*
sf
or
w
<
lq_patchsize
*
sf
:
raise
ValueError
(
f
'img size (
{
h1
}
X
{
w1
}
) is too small!'
)
hq
=
img
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
2
*
img
.
shape
[
1
]),
int
(
1
/
2
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
img
=
util
.
imresize_np
(
img
,
1
/
2
,
True
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
1
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
2
:
a
,
b
=
img
.
shape
[
1
],
img
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.75
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf1
*
img
.
shape
[
1
]),
int
(
1
/
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
img
=
img
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
img
=
add_JPEG_noise
(
img
)
elif
i
==
6
:
# add processed camera sensor noise
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
# add final JPEG compression noise
img
=
add_JPEG_noise
(
img
)
# random crop
img
,
hq
=
random_crop
(
img
,
hq
,
sf_ori
,
lq_patchsize
)
return
img
,
hq
# todo no isp_model?
def
degradation_bsrgan_variant
(
image
,
sf
=
4
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image
=
util
.
uint2single
(
image
)
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
image
.
shape
[:
2
]
image
=
image
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
image
.
shape
[:
2
]
hq
=
image
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
2
*
image
.
shape
[
1
]),
int
(
1
/
2
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
image
=
util
.
imresize_np
(
image
,
1
/
2
,
True
)
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
image
=
add_blur
(
image
,
sf
=
sf
)
elif
i
==
1
:
image
=
add_blur
(
image
,
sf
=
sf
)
elif
i
==
2
:
a
,
b
=
image
.
shape
[
1
],
image
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.75
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf1
*
image
.
shape
[
1
]),
int
(
1
/
sf1
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
image
=
ndimage
.
filters
.
convolve
(
image
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
image
=
image
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
image
=
add_Gaussian_noise
(
image
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
image
=
add_JPEG_noise
(
image
)
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image
=
add_JPEG_noise
(
image
)
image
=
util
.
single2uint
(
image
)
example
=
{
"image"
:
image
}
return
example
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
def
degradation_bsrgan_plus
(
img
,
sf
=
4
,
shuffle_prob
=
0.5
,
use_sharp
=
True
,
lq_patchsize
=
64
,
isp_model
=
None
):
"""
This is an extended degradation model by combining
the degradation models of BSRGAN and Real-ESRGAN
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
use_shuffle: the degradation shuffle
use_sharp: sharpening the img
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
h1
,
w1
=
img
.
shape
[:
2
]
img
=
img
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
img
.
shape
[:
2
]
if
h
<
lq_patchsize
*
sf
or
w
<
lq_patchsize
*
sf
:
raise
ValueError
(
f
'img size (
{
h1
}
X
{
w1
}
) is too small!'
)
if
use_sharp
:
img
=
add_sharpening
(
img
)
hq
=
img
.
copy
()
if
random
.
random
()
<
shuffle_prob
:
shuffle_order
=
random
.
sample
(
range
(
13
),
13
)
else
:
shuffle_order
=
list
(
range
(
13
))
# local shuffle for noise, JPEG is always the last one
shuffle_order
[
2
:
6
]
=
random
.
sample
(
shuffle_order
[
2
:
6
],
len
(
range
(
2
,
6
)))
shuffle_order
[
9
:
13
]
=
random
.
sample
(
shuffle_order
[
9
:
13
],
len
(
range
(
9
,
13
)))
poisson_prob
,
speckle_prob
,
isp_prob
=
0.1
,
0.1
,
0.1
for
i
in
shuffle_order
:
if
i
==
0
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
1
:
img
=
add_resize
(
img
,
sf
=
sf
)
elif
i
==
2
:
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
3
:
if
random
.
random
()
<
poisson_prob
:
img
=
add_Poisson_noise
(
img
)
elif
i
==
4
:
if
random
.
random
()
<
speckle_prob
:
img
=
add_speckle_noise
(
img
)
elif
i
==
5
:
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
elif
i
==
6
:
img
=
add_JPEG_noise
(
img
)
elif
i
==
7
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
8
:
img
=
add_resize
(
img
,
sf
=
sf
)
elif
i
==
9
:
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
10
:
if
random
.
random
()
<
poisson_prob
:
img
=
add_Poisson_noise
(
img
)
elif
i
==
11
:
if
random
.
random
()
<
speckle_prob
:
img
=
add_speckle_noise
(
img
)
elif
i
==
12
:
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
else
:
print
(
'check the shuffle!'
)
# resize to desired size
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf
*
hq
.
shape
[
1
]),
int
(
1
/
sf
*
hq
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
# add final JPEG compression noise
img
=
add_JPEG_noise
(
img
)
# random crop
img
,
hq
=
random_crop
(
img
,
hq
,
sf
,
lq_patchsize
)
return
img
,
hq
if
__name__
==
'__main__'
:
print
(
"hey"
)
img
=
util
.
imread_uint
(
'utils/test.png'
,
3
)
print
(
img
)
img
=
util
.
uint2single
(
img
)
print
(
img
)
img
=
img
[:
448
,
:
448
]
h
=
img
.
shape
[
0
]
//
4
print
(
"resizing to"
,
h
)
sf
=
4
deg_fn
=
partial
(
degradation_bsrgan_variant
,
sf
=
sf
)
for
i
in
range
(
20
):
print
(
i
)
img_lq
=
deg_fn
(
img
)
print
(
img_lq
)
img_lq_bicubic
=
albumentations
.
SmallestMaxSize
(
max_size
=
h
,
interpolation
=
cv2
.
INTER_CUBIC
)(
image
=
img
)[
"image"
]
print
(
img_lq
.
shape
)
print
(
"bicubic"
,
img_lq_bicubic
.
shape
)
print
(
img_hq
.
shape
)
lq_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
lq_bicubic_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq_bicubic
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
img_concat
=
np
.
concatenate
([
lq_bicubic_nearest
,
lq_nearest
,
util
.
single2uint
(
img_hq
)],
axis
=
1
)
util
.
imsave
(
img_concat
,
str
(
i
)
+
'.png'
)
ldm/modules/image_degradation/bsrgan_light.py
0 → 100644
View file @
4007efdd
# -*- coding: utf-8 -*-
import
numpy
as
np
import
cv2
import
torch
from
functools
import
partial
import
random
from
scipy
import
ndimage
import
scipy
import
scipy.stats
as
ss
from
scipy.interpolate
import
interp2d
from
scipy.linalg
import
orth
import
albumentations
import
ldm.modules.image_degradation.utils_image
as
util
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
def
modcrop_np
(
img
,
sf
):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w
,
h
=
img
.
shape
[:
2
]
im
=
np
.
copy
(
img
)
return
im
[:
w
-
w
%
sf
,
:
h
-
h
%
sf
,
...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def
analytic_kernel
(
k
):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size
=
k
.
shape
[
0
]
# Calculate the big kernels size
big_k
=
np
.
zeros
((
3
*
k_size
-
2
,
3
*
k_size
-
2
))
# Loop over the small kernel to fill the big one
for
r
in
range
(
k_size
):
for
c
in
range
(
k_size
):
big_k
[
2
*
r
:
2
*
r
+
k_size
,
2
*
c
:
2
*
c
+
k_size
]
+=
k
[
r
,
c
]
*
k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop
=
k_size
//
2
cropped_big_k
=
big_k
[
crop
:
-
crop
,
crop
:
-
crop
]
# Normalize to 1
return
cropped_big_k
/
cropped_big_k
.
sum
()
def
anisotropic_Gaussian
(
ksize
=
15
,
theta
=
np
.
pi
,
l1
=
6
,
l2
=
6
):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v
=
np
.
dot
(
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]]),
np
.
array
([
1.
,
0.
]))
V
=
np
.
array
([[
v
[
0
],
v
[
1
]],
[
v
[
1
],
-
v
[
0
]]])
D
=
np
.
array
([[
l1
,
0
],
[
0
,
l2
]])
Sigma
=
np
.
dot
(
np
.
dot
(
V
,
D
),
np
.
linalg
.
inv
(
V
))
k
=
gm_blur_kernel
(
mean
=
[
0
,
0
],
cov
=
Sigma
,
size
=
ksize
)
return
k
def
gm_blur_kernel
(
mean
,
cov
,
size
=
15
):
center
=
size
/
2.0
+
0.5
k
=
np
.
zeros
([
size
,
size
])
for
y
in
range
(
size
):
for
x
in
range
(
size
):
cy
=
y
-
center
+
1
cx
=
x
-
center
+
1
k
[
y
,
x
]
=
ss
.
multivariate_normal
.
pdf
([
cx
,
cy
],
mean
=
mean
,
cov
=
cov
)
k
=
k
/
np
.
sum
(
k
)
return
k
def
shift_pixel
(
x
,
sf
,
upper_left
=
True
):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h
,
w
=
x
.
shape
[:
2
]
shift
=
(
sf
-
1
)
*
0.5
xv
,
yv
=
np
.
arange
(
0
,
w
,
1.0
),
np
.
arange
(
0
,
h
,
1.0
)
if
upper_left
:
x1
=
xv
+
shift
y1
=
yv
+
shift
else
:
x1
=
xv
-
shift
y1
=
yv
-
shift
x1
=
np
.
clip
(
x1
,
0
,
w
-
1
)
y1
=
np
.
clip
(
y1
,
0
,
h
-
1
)
if
x
.
ndim
==
2
:
x
=
interp2d
(
xv
,
yv
,
x
)(
x1
,
y1
)
if
x
.
ndim
==
3
:
for
i
in
range
(
x
.
shape
[
-
1
]):
x
[:,
:,
i
]
=
interp2d
(
xv
,
yv
,
x
[:,
:,
i
])(
x1
,
y1
)
return
x
def
blur
(
x
,
k
):
'''
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
n
,
c
=
x
.
shape
[:
2
]
p1
,
p2
=
(
k
.
shape
[
-
2
]
-
1
)
//
2
,
(
k
.
shape
[
-
1
]
-
1
)
//
2
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
=
(
p1
,
p2
,
p1
,
p2
),
mode
=
'replicate'
)
k
=
k
.
repeat
(
1
,
c
,
1
,
1
)
k
=
k
.
view
(
-
1
,
1
,
k
.
shape
[
2
],
k
.
shape
[
3
])
x
=
x
.
view
(
1
,
-
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
x
=
torch
.
nn
.
functional
.
conv2d
(
x
,
k
,
bias
=
None
,
stride
=
1
,
padding
=
0
,
groups
=
n
*
c
)
x
=
x
.
view
(
n
,
c
,
x
.
shape
[
2
],
x
.
shape
[
3
])
return
x
def
gen_kernel
(
k_size
=
np
.
array
([
15
,
15
]),
scale_factor
=
np
.
array
([
4
,
4
]),
min_var
=
0.6
,
max_var
=
10.
,
noise_level
=
0
):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
lambda_2
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
theta
=
np
.
random
.
rand
()
*
np
.
pi
# random theta
noise
=
-
noise_level
+
np
.
random
.
rand
(
*
k_size
)
*
noise_level
*
2
# Set COV matrix using Lambdas and Theta
LAMBDA
=
np
.
diag
([
lambda_1
,
lambda_2
])
Q
=
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]])
SIGMA
=
Q
@
LAMBDA
@
Q
.
T
INV_SIGMA
=
np
.
linalg
.
inv
(
SIGMA
)[
None
,
None
,
:,
:]
# Set expectation position (shifting kernel for aligned image)
MU
=
k_size
//
2
-
0.5
*
(
scale_factor
-
1
)
# - 0.5 * (scale_factor - k_size % 2)
MU
=
MU
[
None
,
None
,
:,
None
]
# Create meshgrid for Gaussian
[
X
,
Y
]
=
np
.
meshgrid
(
range
(
k_size
[
0
]),
range
(
k_size
[
1
]))
Z
=
np
.
stack
([
X
,
Y
],
2
)[:,
:,
:,
None
]
# Calcualte Gaussian for every pixel of the kernel
ZZ
=
Z
-
MU
ZZ_t
=
ZZ
.
transpose
(
0
,
1
,
3
,
2
)
raw_kernel
=
np
.
exp
(
-
0.5
*
np
.
squeeze
(
ZZ_t
@
INV_SIGMA
@
ZZ
))
*
(
1
+
noise
)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel
=
raw_kernel
/
np
.
sum
(
raw_kernel
)
return
kernel
def
fspecial_gaussian
(
hsize
,
sigma
):
hsize
=
[
hsize
,
hsize
]
siz
=
[(
hsize
[
0
]
-
1.0
)
/
2.0
,
(
hsize
[
1
]
-
1.0
)
/
2.0
]
std
=
sigma
[
x
,
y
]
=
np
.
meshgrid
(
np
.
arange
(
-
siz
[
1
],
siz
[
1
]
+
1
),
np
.
arange
(
-
siz
[
0
],
siz
[
0
]
+
1
))
arg
=
-
(
x
*
x
+
y
*
y
)
/
(
2
*
std
*
std
)
h
=
np
.
exp
(
arg
)
h
[
h
<
scipy
.
finfo
(
float
).
eps
*
h
.
max
()]
=
0
sumh
=
h
.
sum
()
if
sumh
!=
0
:
h
=
h
/
sumh
return
h
def
fspecial_laplacian
(
alpha
):
alpha
=
max
([
0
,
min
([
alpha
,
1
])])
h1
=
alpha
/
(
alpha
+
1
)
h2
=
(
1
-
alpha
)
/
(
alpha
+
1
)
h
=
[[
h1
,
h2
,
h1
],
[
h2
,
-
4
/
(
alpha
+
1
),
h2
],
[
h1
,
h2
,
h1
]]
h
=
np
.
array
(
h
)
return
h
def
fspecial
(
filter_type
,
*
args
,
**
kwargs
):
'''
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
if
filter_type
==
'gaussian'
:
return
fspecial_gaussian
(
*
args
,
**
kwargs
)
if
filter_type
==
'laplacian'
:
return
fspecial_laplacian
(
*
args
,
**
kwargs
)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def
bicubic_degradation
(
x
,
sf
=
3
):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x
=
util
.
imresize_np
(
x
,
scale
=
1
/
sf
)
return
x
def
srmd_degradation
(
x
,
k
,
sf
=
3
):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x
=
ndimage
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# 'nearest' | 'mirror'
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
return
x
def
dpsr_degradation
(
x
,
k
,
sf
=
3
):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
x
=
ndimage
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
return
x
def
classical_degradation
(
x
,
k
,
sf
=
3
):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x
=
ndimage
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st
=
0
return
x
[
st
::
sf
,
st
::
sf
,
...]
def
add_sharpening
(
img
,
weight
=
0.5
,
radius
=
50
,
threshold
=
10
):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if
radius
%
2
==
0
:
radius
+=
1
blur
=
cv2
.
GaussianBlur
(
img
,
(
radius
,
radius
),
0
)
residual
=
img
-
blur
mask
=
np
.
abs
(
residual
)
*
255
>
threshold
mask
=
mask
.
astype
(
'float32'
)
soft_mask
=
cv2
.
GaussianBlur
(
mask
,
(
radius
,
radius
),
0
)
K
=
img
+
weight
*
residual
K
=
np
.
clip
(
K
,
0
,
1
)
return
soft_mask
*
K
+
(
1
-
soft_mask
)
*
img
def
add_blur
(
img
,
sf
=
4
):
wd2
=
4.0
+
sf
wd
=
2.0
+
0.2
*
sf
wd2
=
wd2
/
4
wd
=
wd
/
4
if
random
.
random
()
<
0.5
:
l1
=
wd2
*
random
.
random
()
l2
=
wd2
*
random
.
random
()
k
=
anisotropic_Gaussian
(
ksize
=
random
.
randint
(
2
,
11
)
+
3
,
theta
=
random
.
random
()
*
np
.
pi
,
l1
=
l1
,
l2
=
l2
)
else
:
k
=
fspecial
(
'gaussian'
,
random
.
randint
(
2
,
4
)
+
3
,
wd
*
random
.
random
())
img
=
ndimage
.
convolve
(
img
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'mirror'
)
return
img
def
add_resize
(
img
,
sf
=
4
):
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.8
:
# up
sf1
=
random
.
uniform
(
1
,
2
)
elif
rnum
<
0.7
:
# down
sf1
=
random
.
uniform
(
0.5
/
sf
,
1
)
else
:
sf1
=
1.0
img
=
cv2
.
resize
(
img
,
(
int
(
sf1
*
img
.
shape
[
1
]),
int
(
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.6
:
# add color Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
# add grayscale Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
# add noise
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
=
img
+
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_speckle_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
rnum
=
random
.
random
()
if
rnum
>
0.6
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
+=
img
*
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_Poisson_noise
(
img
):
img
=
np
.
clip
((
img
*
255.0
).
round
(),
0
,
255
)
/
255.
vals
=
10
**
(
2
*
random
.
random
()
+
2.0
)
# [2, 4]
if
random
.
random
()
<
0.5
:
img
=
np
.
random
.
poisson
(
img
*
vals
).
astype
(
np
.
float32
)
/
vals
else
:
img_gray
=
np
.
dot
(
img
[...,
:
3
],
[
0.299
,
0.587
,
0.114
])
img_gray
=
np
.
clip
((
img_gray
*
255.0
).
round
(),
0
,
255
)
/
255.
noise_gray
=
np
.
random
.
poisson
(
img_gray
*
vals
).
astype
(
np
.
float32
)
/
vals
-
img_gray
img
+=
noise_gray
[:,
:,
np
.
newaxis
]
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_JPEG_noise
(
img
):
quality_factor
=
random
.
randint
(
80
,
95
)
img
=
cv2
.
cvtColor
(
util
.
single2uint
(
img
),
cv2
.
COLOR_RGB2BGR
)
result
,
encimg
=
cv2
.
imencode
(
'.jpg'
,
img
,
[
int
(
cv2
.
IMWRITE_JPEG_QUALITY
),
quality_factor
])
img
=
cv2
.
imdecode
(
encimg
,
1
)
img
=
cv2
.
cvtColor
(
util
.
uint2single
(
img
),
cv2
.
COLOR_BGR2RGB
)
return
img
def
random_crop
(
lq
,
hq
,
sf
=
4
,
lq_patchsize
=
64
):
h
,
w
=
lq
.
shape
[:
2
]
rnd_h
=
random
.
randint
(
0
,
h
-
lq_patchsize
)
rnd_w
=
random
.
randint
(
0
,
w
-
lq_patchsize
)
lq
=
lq
[
rnd_h
:
rnd_h
+
lq_patchsize
,
rnd_w
:
rnd_w
+
lq_patchsize
,
:]
rnd_h_H
,
rnd_w_H
=
int
(
rnd_h
*
sf
),
int
(
rnd_w
*
sf
)
hq
=
hq
[
rnd_h_H
:
rnd_h_H
+
lq_patchsize
*
sf
,
rnd_w_H
:
rnd_w_H
+
lq_patchsize
*
sf
,
:]
return
lq
,
hq
def
degradation_bsrgan
(
img
,
sf
=
4
,
lq_patchsize
=
72
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
img
.
shape
[:
2
]
img
=
img
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
img
.
shape
[:
2
]
if
h
<
lq_patchsize
*
sf
or
w
<
lq_patchsize
*
sf
:
raise
ValueError
(
f
'img size (
{
h1
}
X
{
w1
}
) is too small!'
)
hq
=
img
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
2
*
img
.
shape
[
1
]),
int
(
1
/
2
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
img
=
util
.
imresize_np
(
img
,
1
/
2
,
True
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
1
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
2
:
a
,
b
=
img
.
shape
[
1
],
img
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.75
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf1
*
img
.
shape
[
1
]),
int
(
1
/
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
img
=
ndimage
.
convolve
(
img
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
img
=
img
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
8
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
img
=
add_JPEG_noise
(
img
)
elif
i
==
6
:
# add processed camera sensor noise
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
# add final JPEG compression noise
img
=
add_JPEG_noise
(
img
)
# random crop
img
,
hq
=
random_crop
(
img
,
hq
,
sf_ori
,
lq_patchsize
)
return
img
,
hq
# todo no isp_model?
def
degradation_bsrgan_variant
(
image
,
sf
=
4
,
isp_model
=
None
,
up
=
False
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image
=
util
.
uint2single
(
image
)
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
image
.
shape
[:
2
]
image
=
image
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
image
.
shape
[:
2
]
hq
=
image
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
2
*
image
.
shape
[
1
]),
int
(
1
/
2
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
image
=
util
.
imresize_np
(
image
,
1
/
2
,
True
)
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
image
=
add_blur
(
image
,
sf
=
sf
)
# elif i == 1:
# image = add_blur(image, sf=sf)
if
i
==
0
:
pass
elif
i
==
2
:
a
,
b
=
image
.
shape
[
1
],
image
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.8
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf1
*
image
.
shape
[
1
]),
int
(
1
/
sf1
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
image
=
ndimage
.
convolve
(
image
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
image
=
image
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
image
=
add_Gaussian_noise
(
image
,
noise_level1
=
1
,
noise_level2
=
2
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
image
=
add_JPEG_noise
(
image
)
#
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image
=
add_JPEG_noise
(
image
)
image
=
util
.
single2uint
(
image
)
if
up
:
image
=
cv2
.
resize
(
image
,
(
w1
,
h1
),
interpolation
=
cv2
.
INTER_CUBIC
)
# todo: random, as above? want to condition on it then
example
=
{
"image"
:
image
}
return
example
if
__name__
==
'__main__'
:
print
(
"hey"
)
img
=
util
.
imread_uint
(
'utils/test.png'
,
3
)
img
=
img
[:
448
,
:
448
]
h
=
img
.
shape
[
0
]
//
4
print
(
"resizing to"
,
h
)
sf
=
4
deg_fn
=
partial
(
degradation_bsrgan_variant
,
sf
=
sf
)
for
i
in
range
(
20
):
print
(
i
)
img_hq
=
img
img_lq
=
deg_fn
(
img
)[
"image"
]
img_hq
,
img_lq
=
util
.
uint2single
(
img_hq
),
util
.
uint2single
(
img_lq
)
print
(
img_lq
)
img_lq_bicubic
=
albumentations
.
SmallestMaxSize
(
max_size
=
h
,
interpolation
=
cv2
.
INTER_CUBIC
)(
image
=
img_hq
)[
"image"
]
print
(
img_lq
.
shape
)
print
(
"bicubic"
,
img_lq_bicubic
.
shape
)
print
(
img_hq
.
shape
)
lq_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
lq_bicubic_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq_bicubic
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
img_concat
=
np
.
concatenate
([
lq_bicubic_nearest
,
lq_nearest
,
util
.
single2uint
(
img_hq
)],
axis
=
1
)
util
.
imsave
(
img_concat
,
str
(
i
)
+
'.png'
)
ldm/modules/image_degradation/utils/test.png
0 → 100644
View file @
4007efdd
431 KB
ldm/modules/image_degradation/utils_image.py
0 → 100644
View file @
4007efdd
import
os
import
math
import
random
import
numpy
as
np
import
torch
import
cv2
from
torchvision.utils
import
make_grid
from
datetime
import
datetime
#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
os
.
environ
[
"KMP_DUPLICATE_LIB_OK"
]
=
"TRUE"
'''
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# --------------------------------------------
# https://github.com/twhui/SRGAN-pyTorch
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
IMG_EXTENSIONS
=
[
'.jpg'
,
'.JPG'
,
'.jpeg'
,
'.JPEG'
,
'.png'
,
'.PNG'
,
'.ppm'
,
'.PPM'
,
'.bmp'
,
'.BMP'
,
'.tif'
]
def
is_image_file
(
filename
):
return
any
(
filename
.
endswith
(
extension
)
for
extension
in
IMG_EXTENSIONS
)
def
get_timestamp
():
return
datetime
.
now
().
strftime
(
'%y%m%d-%H%M%S'
)
def
imshow
(
x
,
title
=
None
,
cbar
=
False
,
figsize
=
None
):
plt
.
figure
(
figsize
=
figsize
)
plt
.
imshow
(
np
.
squeeze
(
x
),
interpolation
=
'nearest'
,
cmap
=
'gray'
)
if
title
:
plt
.
title
(
title
)
if
cbar
:
plt
.
colorbar
()
plt
.
show
()
def
surf
(
Z
,
cmap
=
'rainbow'
,
figsize
=
None
):
plt
.
figure
(
figsize
=
figsize
)
ax3
=
plt
.
axes
(
projection
=
'3d'
)
w
,
h
=
Z
.
shape
[:
2
]
xx
=
np
.
arange
(
0
,
w
,
1
)
yy
=
np
.
arange
(
0
,
h
,
1
)
X
,
Y
=
np
.
meshgrid
(
xx
,
yy
)
ax3
.
plot_surface
(
X
,
Y
,
Z
,
cmap
=
cmap
)
#ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
plt
.
show
()
'''
# --------------------------------------------
# get image pathes
# --------------------------------------------
'''
def
get_image_paths
(
dataroot
):
paths
=
None
# return None if dataroot is None
if
dataroot
is
not
None
:
paths
=
sorted
(
_get_paths_from_images
(
dataroot
))
return
paths
def
_get_paths_from_images
(
path
):
assert
os
.
path
.
isdir
(
path
),
'{:s} is not a valid directory'
.
format
(
path
)
images
=
[]
for
dirpath
,
_
,
fnames
in
sorted
(
os
.
walk
(
path
)):
for
fname
in
sorted
(
fnames
):
if
is_image_file
(
fname
):
img_path
=
os
.
path
.
join
(
dirpath
,
fname
)
images
.
append
(
img_path
)
assert
images
,
'{:s} has no valid image file'
.
format
(
path
)
return
images
'''
# --------------------------------------------
# split large images into small images
# --------------------------------------------
'''
def
patches_from_image
(
img
,
p_size
=
512
,
p_overlap
=
64
,
p_max
=
800
):
w
,
h
=
img
.
shape
[:
2
]
patches
=
[]
if
w
>
p_max
and
h
>
p_max
:
w1
=
list
(
np
.
arange
(
0
,
w
-
p_size
,
p_size
-
p_overlap
,
dtype
=
np
.
int
))
h1
=
list
(
np
.
arange
(
0
,
h
-
p_size
,
p_size
-
p_overlap
,
dtype
=
np
.
int
))
w1
.
append
(
w
-
p_size
)
h1
.
append
(
h
-
p_size
)
# print(w1)
# print(h1)
for
i
in
w1
:
for
j
in
h1
:
patches
.
append
(
img
[
i
:
i
+
p_size
,
j
:
j
+
p_size
,:])
else
:
patches
.
append
(
img
)
return
patches
def
imssave
(
imgs
,
img_path
):
"""
imgs: list, N images of size WxHxC
"""
img_name
,
ext
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))
for
i
,
img
in
enumerate
(
imgs
):
if
img
.
ndim
==
3
:
img
=
img
[:,
:,
[
2
,
1
,
0
]]
new_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
img_path
),
img_name
+
str
(
'_s{:04d}'
.
format
(
i
))
+
'.png'
)
cv2
.
imwrite
(
new_path
,
img
)
def
split_imageset
(
original_dataroot
,
taget_dataroot
,
n_channels
=
3
,
p_size
=
800
,
p_overlap
=
96
,
p_max
=
1000
):
"""
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
will be splitted.
Args:
original_dataroot:
taget_dataroot:
p_size: size of small images
p_overlap: patch size in training is a good choice
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
"""
paths
=
get_image_paths
(
original_dataroot
)
for
img_path
in
paths
:
# img_name, ext = os.path.splitext(os.path.basename(img_path))
img
=
imread_uint
(
img_path
,
n_channels
=
n_channels
)
patches
=
patches_from_image
(
img
,
p_size
,
p_overlap
,
p_max
)
imssave
(
patches
,
os
.
path
.
join
(
taget_dataroot
,
os
.
path
.
basename
(
img_path
)))
#if original_dataroot == taget_dataroot:
#del img_path
'''
# --------------------------------------------
# makedir
# --------------------------------------------
'''
def
mkdir
(
path
):
if
not
os
.
path
.
exists
(
path
):
os
.
makedirs
(
path
)
def
mkdirs
(
paths
):
if
isinstance
(
paths
,
str
):
mkdir
(
paths
)
else
:
for
path
in
paths
:
mkdir
(
path
)
def
mkdir_and_rename
(
path
):
if
os
.
path
.
exists
(
path
):
new_name
=
path
+
'_archived_'
+
get_timestamp
()
print
(
'Path already exists. Rename it to [{:s}]'
.
format
(
new_name
))
os
.
rename
(
path
,
new_name
)
os
.
makedirs
(
path
)
'''
# --------------------------------------------
# read image from path
# opencv is fast, but read BGR numpy image
# --------------------------------------------
'''
# --------------------------------------------
# get uint8 image of size HxWxn_channles (RGB)
# --------------------------------------------
def
imread_uint
(
path
,
n_channels
=
3
):
# input: path
# output: HxWx3(RGB or GGG), or HxWx1 (G)
if
n_channels
==
1
:
img
=
cv2
.
imread
(
path
,
0
)
# cv2.IMREAD_GRAYSCALE
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
# HxWx1
elif
n_channels
==
3
:
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
# BGR or G
if
img
.
ndim
==
2
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2RGB
)
# GGG
else
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
# RGB
return
img
# --------------------------------------------
# matlab's imwrite
# --------------------------------------------
def
imsave
(
img
,
img_path
):
img
=
np
.
squeeze
(
img
)
if
img
.
ndim
==
3
:
img
=
img
[:,
:,
[
2
,
1
,
0
]]
cv2
.
imwrite
(
img_path
,
img
)
def
imwrite
(
img
,
img_path
):
img
=
np
.
squeeze
(
img
)
if
img
.
ndim
==
3
:
img
=
img
[:,
:,
[
2
,
1
,
0
]]
cv2
.
imwrite
(
img_path
,
img
)
# --------------------------------------------
# get single image of size HxWxn_channles (BGR)
# --------------------------------------------
def
read_img
(
path
):
# read image by cv2
# return: Numpy float32, HWC, BGR, [0,1]
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
# cv2.IMREAD_GRAYSCALE
img
=
img
.
astype
(
np
.
float32
)
/
255.
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
# some images have 4 channels
if
img
.
shape
[
2
]
>
3
:
img
=
img
[:,
:,
:
3
]
return
img
'''
# --------------------------------------------
# image format conversion
# --------------------------------------------
# numpy(single) <---> numpy(unit)
# numpy(single) <---> tensor
# numpy(unit) <---> tensor
# --------------------------------------------
'''
# --------------------------------------------
# numpy(single) [0, 1] <---> numpy(unit)
# --------------------------------------------
def
uint2single
(
img
):
return
np
.
float32
(
img
/
255.
)
def
single2uint
(
img
):
return
np
.
uint8
((
img
.
clip
(
0
,
1
)
*
255.
).
round
())
def
uint162single
(
img
):
return
np
.
float32
(
img
/
65535.
)
def
single2uint16
(
img
):
return
np
.
uint16
((
img
.
clip
(
0
,
1
)
*
65535.
).
round
())
# --------------------------------------------
# numpy(unit) (HxWxC or HxW) <---> tensor
# --------------------------------------------
# convert uint to 4-dimensional torch tensor
def
uint2tensor4
(
img
):
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
().
div
(
255.
).
unsqueeze
(
0
)
# convert uint to 3-dimensional torch tensor
def
uint2tensor3
(
img
):
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
().
div
(
255.
)
# convert 2/3/4-dimensional torch tensor to uint
def
tensor2uint
(
img
):
img
=
img
.
data
.
squeeze
().
float
().
clamp_
(
0
,
1
).
cpu
().
numpy
()
if
img
.
ndim
==
3
:
img
=
np
.
transpose
(
img
,
(
1
,
2
,
0
))
return
np
.
uint8
((
img
*
255.0
).
round
())
# --------------------------------------------
# numpy(single) (HxWxC) <---> tensor
# --------------------------------------------
# convert single (HxWxC) to 3-dimensional torch tensor
def
single2tensor3
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
()
# convert single (HxWxC) to 4-dimensional torch tensor
def
single2tensor4
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
().
unsqueeze
(
0
)
# convert torch tensor to single
def
tensor2single
(
img
):
img
=
img
.
data
.
squeeze
().
float
().
cpu
().
numpy
()
if
img
.
ndim
==
3
:
img
=
np
.
transpose
(
img
,
(
1
,
2
,
0
))
return
img
# convert torch tensor to single
def
tensor2single3
(
img
):
img
=
img
.
data
.
squeeze
().
float
().
cpu
().
numpy
()
if
img
.
ndim
==
3
:
img
=
np
.
transpose
(
img
,
(
1
,
2
,
0
))
elif
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
return
img
def
single2tensor5
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
,
3
).
float
().
unsqueeze
(
0
)
def
single32tensor5
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
float
().
unsqueeze
(
0
).
unsqueeze
(
0
)
def
single42tensor4
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
,
3
).
float
()
# from skimage.io import imread, imsave
def
tensor2img
(
tensor
,
out_type
=
np
.
uint8
,
min_max
=
(
0
,
1
)):
'''
Converts a torch Tensor into an image Numpy array of BGR channel order
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
'''
tensor
=
tensor
.
squeeze
().
float
().
cpu
().
clamp_
(
*
min_max
)
# squeeze first, then clamp
tensor
=
(
tensor
-
min_max
[
0
])
/
(
min_max
[
1
]
-
min_max
[
0
])
# to range [0,1]
n_dim
=
tensor
.
dim
()
if
n_dim
==
4
:
n_img
=
len
(
tensor
)
img_np
=
make_grid
(
tensor
,
nrow
=
int
(
math
.
sqrt
(
n_img
)),
normalize
=
False
).
numpy
()
img_np
=
np
.
transpose
(
img_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
# HWC, BGR
elif
n_dim
==
3
:
img_np
=
tensor
.
numpy
()
img_np
=
np
.
transpose
(
img_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
# HWC, BGR
elif
n_dim
==
2
:
img_np
=
tensor
.
numpy
()
else
:
raise
TypeError
(
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'
.
format
(
n_dim
))
if
out_type
==
np
.
uint8
:
img_np
=
(
img_np
*
255.0
).
round
()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return
img_np
.
astype
(
out_type
)
'''
# --------------------------------------------
# Augmentation, flipe and/or rotate
# --------------------------------------------
# The following two are enough.
# (1) augmet_img: numpy image of WxHxC or WxH
# (2) augment_img_tensor4: tensor image 1xCxWxH
# --------------------------------------------
'''
def
augment_img
(
img
,
mode
=
0
):
'''Kai Zhang (github: https://github.com/cszn)
'''
if
mode
==
0
:
return
img
elif
mode
==
1
:
return
np
.
flipud
(
np
.
rot90
(
img
))
elif
mode
==
2
:
return
np
.
flipud
(
img
)
elif
mode
==
3
:
return
np
.
rot90
(
img
,
k
=
3
)
elif
mode
==
4
:
return
np
.
flipud
(
np
.
rot90
(
img
,
k
=
2
))
elif
mode
==
5
:
return
np
.
rot90
(
img
)
elif
mode
==
6
:
return
np
.
rot90
(
img
,
k
=
2
)
elif
mode
==
7
:
return
np
.
flipud
(
np
.
rot90
(
img
,
k
=
3
))
def
augment_img_tensor4
(
img
,
mode
=
0
):
'''Kai Zhang (github: https://github.com/cszn)
'''
if
mode
==
0
:
return
img
elif
mode
==
1
:
return
img
.
rot90
(
1
,
[
2
,
3
]).
flip
([
2
])
elif
mode
==
2
:
return
img
.
flip
([
2
])
elif
mode
==
3
:
return
img
.
rot90
(
3
,
[
2
,
3
])
elif
mode
==
4
:
return
img
.
rot90
(
2
,
[
2
,
3
]).
flip
([
2
])
elif
mode
==
5
:
return
img
.
rot90
(
1
,
[
2
,
3
])
elif
mode
==
6
:
return
img
.
rot90
(
2
,
[
2
,
3
])
elif
mode
==
7
:
return
img
.
rot90
(
3
,
[
2
,
3
]).
flip
([
2
])
def
augment_img_tensor
(
img
,
mode
=
0
):
'''Kai Zhang (github: https://github.com/cszn)
'''
img_size
=
img
.
size
()
img_np
=
img
.
data
.
cpu
().
numpy
()
if
len
(
img_size
)
==
3
:
img_np
=
np
.
transpose
(
img_np
,
(
1
,
2
,
0
))
elif
len
(
img_size
)
==
4
:
img_np
=
np
.
transpose
(
img_np
,
(
2
,
3
,
1
,
0
))
img_np
=
augment_img
(
img_np
,
mode
=
mode
)
img_tensor
=
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img_np
))
if
len
(
img_size
)
==
3
:
img_tensor
=
img_tensor
.
permute
(
2
,
0
,
1
)
elif
len
(
img_size
)
==
4
:
img_tensor
=
img_tensor
.
permute
(
3
,
2
,
0
,
1
)
return
img_tensor
.
type_as
(
img
)
def
augment_img_np3
(
img
,
mode
=
0
):
if
mode
==
0
:
return
img
elif
mode
==
1
:
return
img
.
transpose
(
1
,
0
,
2
)
elif
mode
==
2
:
return
img
[::
-
1
,
:,
:]
elif
mode
==
3
:
img
=
img
[::
-
1
,
:,
:]
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
elif
mode
==
4
:
return
img
[:,
::
-
1
,
:]
elif
mode
==
5
:
img
=
img
[:,
::
-
1
,
:]
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
elif
mode
==
6
:
img
=
img
[:,
::
-
1
,
:]
img
=
img
[::
-
1
,
:,
:]
return
img
elif
mode
==
7
:
img
=
img
[:,
::
-
1
,
:]
img
=
img
[::
-
1
,
:,
:]
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
def
augment_imgs
(
img_list
,
hflip
=
True
,
rot
=
True
):
# horizontal flip OR rotate
hflip
=
hflip
and
random
.
random
()
<
0.5
vflip
=
rot
and
random
.
random
()
<
0.5
rot90
=
rot
and
random
.
random
()
<
0.5
def
_augment
(
img
):
if
hflip
:
img
=
img
[:,
::
-
1
,
:]
if
vflip
:
img
=
img
[::
-
1
,
:,
:]
if
rot90
:
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
return
[
_augment
(
img
)
for
img
in
img_list
]
'''
# --------------------------------------------
# modcrop and shave
# --------------------------------------------
'''
def
modcrop
(
img_in
,
scale
):
# img_in: Numpy, HWC or HW
img
=
np
.
copy
(
img_in
)
if
img
.
ndim
==
2
:
H
,
W
=
img
.
shape
H_r
,
W_r
=
H
%
scale
,
W
%
scale
img
=
img
[:
H
-
H_r
,
:
W
-
W_r
]
elif
img
.
ndim
==
3
:
H
,
W
,
C
=
img
.
shape
H_r
,
W_r
=
H
%
scale
,
W
%
scale
img
=
img
[:
H
-
H_r
,
:
W
-
W_r
,
:]
else
:
raise
ValueError
(
'Wrong img ndim: [{:d}].'
.
format
(
img
.
ndim
))
return
img
def
shave
(
img_in
,
border
=
0
):
# img_in: Numpy, HWC or HW
img
=
np
.
copy
(
img_in
)
h
,
w
=
img
.
shape
[:
2
]
img
=
img
[
border
:
h
-
border
,
border
:
w
-
border
]
return
img
'''
# --------------------------------------------
# image processing process on numpy image
# channel_convert(in_c, tar_type, img_list):
# rgb2ycbcr(img, only_y=True):
# bgr2ycbcr(img, only_y=True):
# ycbcr2rgb(img):
# --------------------------------------------
'''
def
rgb2ycbcr
(
img
,
only_y
=
True
):
'''same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type
=
img
.
dtype
img
.
astype
(
np
.
float32
)
if
in_img_type
!=
np
.
uint8
:
img
*=
255.
# convert
if
only_y
:
rlt
=
np
.
dot
(
img
,
[
65.481
,
128.553
,
24.966
])
/
255.0
+
16.0
else
:
rlt
=
np
.
matmul
(
img
,
[[
65.481
,
-
37.797
,
112.0
],
[
128.553
,
-
74.203
,
-
93.786
],
[
24.966
,
112.0
,
-
18.214
]])
/
255.0
+
[
16
,
128
,
128
]
if
in_img_type
==
np
.
uint8
:
rlt
=
rlt
.
round
()
else
:
rlt
/=
255.
return
rlt
.
astype
(
in_img_type
)
def
ycbcr2rgb
(
img
):
'''same as matlab ycbcr2rgb
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type
=
img
.
dtype
img
.
astype
(
np
.
float32
)
if
in_img_type
!=
np
.
uint8
:
img
*=
255.
# convert
rlt
=
np
.
matmul
(
img
,
[[
0.00456621
,
0.00456621
,
0.00456621
],
[
0
,
-
0.00153632
,
0.00791071
],
[
0.00625893
,
-
0.00318811
,
0
]])
*
255.0
+
[
-
222.921
,
135.576
,
-
276.836
]
if
in_img_type
==
np
.
uint8
:
rlt
=
rlt
.
round
()
else
:
rlt
/=
255.
return
rlt
.
astype
(
in_img_type
)
def
bgr2ycbcr
(
img
,
only_y
=
True
):
'''bgr version of rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type
=
img
.
dtype
img
.
astype
(
np
.
float32
)
if
in_img_type
!=
np
.
uint8
:
img
*=
255.
# convert
if
only_y
:
rlt
=
np
.
dot
(
img
,
[
24.966
,
128.553
,
65.481
])
/
255.0
+
16.0
else
:
rlt
=
np
.
matmul
(
img
,
[[
24.966
,
112.0
,
-
18.214
],
[
128.553
,
-
74.203
,
-
93.786
],
[
65.481
,
-
37.797
,
112.0
]])
/
255.0
+
[
16
,
128
,
128
]
if
in_img_type
==
np
.
uint8
:
rlt
=
rlt
.
round
()
else
:
rlt
/=
255.
return
rlt
.
astype
(
in_img_type
)
def
channel_convert
(
in_c
,
tar_type
,
img_list
):
# conversion among BGR, gray and y
if
in_c
==
3
and
tar_type
==
'gray'
:
# BGR to gray
gray_list
=
[
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
for
img
in
img_list
]
return
[
np
.
expand_dims
(
img
,
axis
=
2
)
for
img
in
gray_list
]
elif
in_c
==
3
and
tar_type
==
'y'
:
# BGR to y
y_list
=
[
bgr2ycbcr
(
img
,
only_y
=
True
)
for
img
in
img_list
]
return
[
np
.
expand_dims
(
img
,
axis
=
2
)
for
img
in
y_list
]
elif
in_c
==
1
and
tar_type
==
'RGB'
:
# gray/y to BGR
return
[
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
for
img
in
img_list
]
else
:
return
img_list
'''
# --------------------------------------------
# metric, PSNR and SSIM
# --------------------------------------------
'''
# --------------------------------------------
# PSNR
# --------------------------------------------
def
calculate_psnr
(
img1
,
img2
,
border
=
0
):
# img1 and img2 have range [0, 255]
#img1 = img1.squeeze()
#img2 = img2.squeeze()
if
not
img1
.
shape
==
img2
.
shape
:
raise
ValueError
(
'Input images must have the same dimensions.'
)
h
,
w
=
img1
.
shape
[:
2
]
img1
=
img1
[
border
:
h
-
border
,
border
:
w
-
border
]
img2
=
img2
[
border
:
h
-
border
,
border
:
w
-
border
]
img1
=
img1
.
astype
(
np
.
float64
)
img2
=
img2
.
astype
(
np
.
float64
)
mse
=
np
.
mean
((
img1
-
img2
)
**
2
)
if
mse
==
0
:
return
float
(
'inf'
)
return
20
*
math
.
log10
(
255.0
/
math
.
sqrt
(
mse
))
# --------------------------------------------
# SSIM
# --------------------------------------------
def
calculate_ssim
(
img1
,
img2
,
border
=
0
):
'''calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
#img1 = img1.squeeze()
#img2 = img2.squeeze()
if
not
img1
.
shape
==
img2
.
shape
:
raise
ValueError
(
'Input images must have the same dimensions.'
)
h
,
w
=
img1
.
shape
[:
2
]
img1
=
img1
[
border
:
h
-
border
,
border
:
w
-
border
]
img2
=
img2
[
border
:
h
-
border
,
border
:
w
-
border
]
if
img1
.
ndim
==
2
:
return
ssim
(
img1
,
img2
)
elif
img1
.
ndim
==
3
:
if
img1
.
shape
[
2
]
==
3
:
ssims
=
[]
for
i
in
range
(
3
):
ssims
.
append
(
ssim
(
img1
[:,:,
i
],
img2
[:,:,
i
]))
return
np
.
array
(
ssims
).
mean
()
elif
img1
.
shape
[
2
]
==
1
:
return
ssim
(
np
.
squeeze
(
img1
),
np
.
squeeze
(
img2
))
else
:
raise
ValueError
(
'Wrong input image dimensions.'
)
def
ssim
(
img1
,
img2
):
C1
=
(
0.01
*
255
)
**
2
C2
=
(
0.03
*
255
)
**
2
img1
=
img1
.
astype
(
np
.
float64
)
img2
=
img2
.
astype
(
np
.
float64
)
kernel
=
cv2
.
getGaussianKernel
(
11
,
1.5
)
window
=
np
.
outer
(
kernel
,
kernel
.
transpose
())
mu1
=
cv2
.
filter2D
(
img1
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
# valid
mu2
=
cv2
.
filter2D
(
img2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
mu1_sq
=
mu1
**
2
mu2_sq
=
mu2
**
2
mu1_mu2
=
mu1
*
mu2
sigma1_sq
=
cv2
.
filter2D
(
img1
**
2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu1_sq
sigma2_sq
=
cv2
.
filter2D
(
img2
**
2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu2_sq
sigma12
=
cv2
.
filter2D
(
img1
*
img2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu1_mu2
ssim_map
=
((
2
*
mu1_mu2
+
C1
)
*
(
2
*
sigma12
+
C2
))
/
((
mu1_sq
+
mu2_sq
+
C1
)
*
(
sigma1_sq
+
sigma2_sq
+
C2
))
return
ssim_map
.
mean
()
'''
# --------------------------------------------
# matlab's bicubic imresize (numpy and torch) [0, 1]
# --------------------------------------------
'''
# matlab 'imresize' function, now only support 'bicubic'
def
cubic
(
x
):
absx
=
torch
.
abs
(
x
)
absx2
=
absx
**
2
absx3
=
absx
**
3
return
(
1.5
*
absx3
-
2.5
*
absx2
+
1
)
*
((
absx
<=
1
).
type_as
(
absx
))
+
\
(
-
0.5
*
absx3
+
2.5
*
absx2
-
4
*
absx
+
2
)
*
(((
absx
>
1
)
*
(
absx
<=
2
)).
type_as
(
absx
))
def
calculate_weights_indices
(
in_length
,
out_length
,
scale
,
kernel
,
kernel_width
,
antialiasing
):
if
(
scale
<
1
)
and
(
antialiasing
):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width
=
kernel_width
/
scale
# Output-space coordinates
x
=
torch
.
linspace
(
1
,
out_length
,
out_length
)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5+scale in output
# space maps to 1.5 in input space.
u
=
x
/
scale
+
0.5
*
(
1
-
1
/
scale
)
# What is the left-most pixel that can be involved in the computation?
left
=
torch
.
floor
(
u
-
kernel_width
/
2
)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
P
=
math
.
ceil
(
kernel_width
)
+
2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices
=
left
.
view
(
out_length
,
1
).
expand
(
out_length
,
P
)
+
torch
.
linspace
(
0
,
P
-
1
,
P
).
view
(
1
,
P
).
expand
(
out_length
,
P
)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center
=
u
.
view
(
out_length
,
1
).
expand
(
out_length
,
P
)
-
indices
# apply cubic kernel
if
(
scale
<
1
)
and
(
antialiasing
):
weights
=
scale
*
cubic
(
distance_to_center
*
scale
)
else
:
weights
=
cubic
(
distance_to_center
)
# Normalize the weights matrix so that each row sums to 1.
weights_sum
=
torch
.
sum
(
weights
,
1
).
view
(
out_length
,
1
)
weights
=
weights
/
weights_sum
.
expand
(
out_length
,
P
)
# If a column in weights is all zero, get rid of it. only consider the first and last column.
weights_zero_tmp
=
torch
.
sum
((
weights
==
0
),
0
)
if
not
math
.
isclose
(
weights_zero_tmp
[
0
],
0
,
rel_tol
=
1e-6
):
indices
=
indices
.
narrow
(
1
,
1
,
P
-
2
)
weights
=
weights
.
narrow
(
1
,
1
,
P
-
2
)
if
not
math
.
isclose
(
weights_zero_tmp
[
-
1
],
0
,
rel_tol
=
1e-6
):
indices
=
indices
.
narrow
(
1
,
0
,
P
-
2
)
weights
=
weights
.
narrow
(
1
,
0
,
P
-
2
)
weights
=
weights
.
contiguous
()
indices
=
indices
.
contiguous
()
sym_len_s
=
-
indices
.
min
()
+
1
sym_len_e
=
indices
.
max
()
-
in_length
indices
=
indices
+
sym_len_s
-
1
return
weights
,
indices
,
int
(
sym_len_s
),
int
(
sym_len_e
)
# --------------------------------------------
# imresize for tensor image [0, 1]
# --------------------------------------------
def
imresize
(
img
,
scale
,
antialiasing
=
True
):
# Now the scale should be the same for H and W
# input: img: pytorch tensor, CHW or HW [0,1]
# output: CHW or HW [0,1] w/o round
need_squeeze
=
True
if
img
.
dim
()
==
2
else
False
if
need_squeeze
:
img
.
unsqueeze_
(
0
)
in_C
,
in_H
,
in_W
=
img
.
size
()
out_C
,
out_H
,
out_W
=
in_C
,
math
.
ceil
(
in_H
*
scale
),
math
.
ceil
(
in_W
*
scale
)
kernel_width
=
4
kernel
=
'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H
,
indices_H
,
sym_len_Hs
,
sym_len_He
=
calculate_weights_indices
(
in_H
,
out_H
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
weights_W
,
indices_W
,
sym_len_Ws
,
sym_len_We
=
calculate_weights_indices
(
in_W
,
out_W
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
# process H dimension
# symmetric copying
img_aug
=
torch
.
FloatTensor
(
in_C
,
in_H
+
sym_len_Hs
+
sym_len_He
,
in_W
)
img_aug
.
narrow
(
1
,
sym_len_Hs
,
in_H
).
copy_
(
img
)
sym_patch
=
img
[:,
:
sym_len_Hs
,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
img_aug
.
narrow
(
1
,
0
,
sym_len_Hs
).
copy_
(
sym_patch_inv
)
sym_patch
=
img
[:,
-
sym_len_He
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
img_aug
.
narrow
(
1
,
sym_len_Hs
+
in_H
,
sym_len_He
).
copy_
(
sym_patch_inv
)
out_1
=
torch
.
FloatTensor
(
in_C
,
out_H
,
in_W
)
kernel_width
=
weights_H
.
size
(
1
)
for
i
in
range
(
out_H
):
idx
=
int
(
indices_H
[
i
][
0
])
for
j
in
range
(
out_C
):
out_1
[
j
,
i
,
:]
=
img_aug
[
j
,
idx
:
idx
+
kernel_width
,
:].
transpose
(
0
,
1
).
mv
(
weights_H
[
i
])
# process W dimension
# symmetric copying
out_1_aug
=
torch
.
FloatTensor
(
in_C
,
out_H
,
in_W
+
sym_len_Ws
+
sym_len_We
)
out_1_aug
.
narrow
(
2
,
sym_len_Ws
,
in_W
).
copy_
(
out_1
)
sym_patch
=
out_1
[:,
:,
:
sym_len_Ws
]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
2
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
2
,
inv_idx
)
out_1_aug
.
narrow
(
2
,
0
,
sym_len_Ws
).
copy_
(
sym_patch_inv
)
sym_patch
=
out_1
[:,
:,
-
sym_len_We
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
2
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
2
,
inv_idx
)
out_1_aug
.
narrow
(
2
,
sym_len_Ws
+
in_W
,
sym_len_We
).
copy_
(
sym_patch_inv
)
out_2
=
torch
.
FloatTensor
(
in_C
,
out_H
,
out_W
)
kernel_width
=
weights_W
.
size
(
1
)
for
i
in
range
(
out_W
):
idx
=
int
(
indices_W
[
i
][
0
])
for
j
in
range
(
out_C
):
out_2
[
j
,
:,
i
]
=
out_1_aug
[
j
,
:,
idx
:
idx
+
kernel_width
].
mv
(
weights_W
[
i
])
if
need_squeeze
:
out_2
.
squeeze_
()
return
out_2
# --------------------------------------------
# imresize for numpy image [0, 1]
# --------------------------------------------
def
imresize_np
(
img
,
scale
,
antialiasing
=
True
):
# Now the scale should be the same for H and W
# input: img: Numpy, HWC or HW [0,1]
# output: HWC or HW [0,1] w/o round
img
=
torch
.
from_numpy
(
img
)
need_squeeze
=
True
if
img
.
dim
()
==
2
else
False
if
need_squeeze
:
img
.
unsqueeze_
(
2
)
in_H
,
in_W
,
in_C
=
img
.
size
()
out_C
,
out_H
,
out_W
=
in_C
,
math
.
ceil
(
in_H
*
scale
),
math
.
ceil
(
in_W
*
scale
)
kernel_width
=
4
kernel
=
'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H
,
indices_H
,
sym_len_Hs
,
sym_len_He
=
calculate_weights_indices
(
in_H
,
out_H
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
weights_W
,
indices_W
,
sym_len_Ws
,
sym_len_We
=
calculate_weights_indices
(
in_W
,
out_W
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
# process H dimension
# symmetric copying
img_aug
=
torch
.
FloatTensor
(
in_H
+
sym_len_Hs
+
sym_len_He
,
in_W
,
in_C
)
img_aug
.
narrow
(
0
,
sym_len_Hs
,
in_H
).
copy_
(
img
)
sym_patch
=
img
[:
sym_len_Hs
,
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
0
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
0
,
inv_idx
)
img_aug
.
narrow
(
0
,
0
,
sym_len_Hs
).
copy_
(
sym_patch_inv
)
sym_patch
=
img
[
-
sym_len_He
:,
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
0
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
0
,
inv_idx
)
img_aug
.
narrow
(
0
,
sym_len_Hs
+
in_H
,
sym_len_He
).
copy_
(
sym_patch_inv
)
out_1
=
torch
.
FloatTensor
(
out_H
,
in_W
,
in_C
)
kernel_width
=
weights_H
.
size
(
1
)
for
i
in
range
(
out_H
):
idx
=
int
(
indices_H
[
i
][
0
])
for
j
in
range
(
out_C
):
out_1
[
i
,
:,
j
]
=
img_aug
[
idx
:
idx
+
kernel_width
,
:,
j
].
transpose
(
0
,
1
).
mv
(
weights_H
[
i
])
# process W dimension
# symmetric copying
out_1_aug
=
torch
.
FloatTensor
(
out_H
,
in_W
+
sym_len_Ws
+
sym_len_We
,
in_C
)
out_1_aug
.
narrow
(
1
,
sym_len_Ws
,
in_W
).
copy_
(
out_1
)
sym_patch
=
out_1
[:,
:
sym_len_Ws
,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
out_1_aug
.
narrow
(
1
,
0
,
sym_len_Ws
).
copy_
(
sym_patch_inv
)
sym_patch
=
out_1
[:,
-
sym_len_We
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
out_1_aug
.
narrow
(
1
,
sym_len_Ws
+
in_W
,
sym_len_We
).
copy_
(
sym_patch_inv
)
out_2
=
torch
.
FloatTensor
(
out_H
,
out_W
,
in_C
)
kernel_width
=
weights_W
.
size
(
1
)
for
i
in
range
(
out_W
):
idx
=
int
(
indices_W
[
i
][
0
])
for
j
in
range
(
out_C
):
out_2
[:,
i
,
j
]
=
out_1_aug
[:,
idx
:
idx
+
kernel_width
,
j
].
mv
(
weights_W
[
i
])
if
need_squeeze
:
out_2
.
squeeze_
()
return
out_2
.
numpy
()
if
__name__
==
'__main__'
:
print
(
'---'
)
# img = imread_uint('test.bmp', 3)
# img = uint2single(img)
# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
ldm/modules/karlo/__init__.py
0 → 100644
View file @
4007efdd
ldm/modules/karlo/diffusers_pipeline.py
0 → 100644
View file @
4007efdd
# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
inspect
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch.nn
import
functional
as
F
from
transformers
import
CLIPTextModelWithProjection
,
CLIPTokenizer
from
transformers.models.clip.modeling_clip
import
CLIPTextModelOutput
from
...models
import
PriorTransformer
,
UNet2DConditionModel
,
UNet2DModel
from
...pipelines
import
DiffusionPipeline
,
ImagePipelineOutput
from
...schedulers
import
UnCLIPScheduler
from
...utils
import
is_accelerate_available
,
logging
,
randn_tensor
from
.text_proj
import
UnCLIPTextProjModel
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
class
UnCLIPPipeline
(
DiffusionPipeline
):
"""
Pipeline for text-to-image generation using unCLIP
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
text_encoder ([`CLIPTextModelWithProjection`]):
Frozen text-encoder.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
prior ([`PriorTransformer`]):
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
text_proj ([`UnCLIPTextProjModel`]):
Utility class to prepare and combine the embeddings before they are passed to the decoder.
decoder ([`UNet2DConditionModel`]):
The decoder to invert the image embedding into an image.
super_res_first ([`UNet2DModel`]):
Super resolution unet. Used in all but the last step of the super resolution diffusion process.
super_res_last ([`UNet2DModel`]):
Super resolution unet. Used in the last step of the super resolution diffusion process.
prior_scheduler ([`UnCLIPScheduler`]):
Scheduler used in the prior denoising process. Just a modified DDPMScheduler.
decoder_scheduler ([`UnCLIPScheduler`]):
Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
super_res_scheduler ([`UnCLIPScheduler`]):
Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
"""
prior
:
PriorTransformer
decoder
:
UNet2DConditionModel
text_proj
:
UnCLIPTextProjModel
text_encoder
:
CLIPTextModelWithProjection
tokenizer
:
CLIPTokenizer
super_res_first
:
UNet2DModel
super_res_last
:
UNet2DModel
prior_scheduler
:
UnCLIPScheduler
decoder_scheduler
:
UnCLIPScheduler
super_res_scheduler
:
UnCLIPScheduler
def
__init__
(
self
,
prior
:
PriorTransformer
,
decoder
:
UNet2DConditionModel
,
text_encoder
:
CLIPTextModelWithProjection
,
tokenizer
:
CLIPTokenizer
,
text_proj
:
UnCLIPTextProjModel
,
super_res_first
:
UNet2DModel
,
super_res_last
:
UNet2DModel
,
prior_scheduler
:
UnCLIPScheduler
,
decoder_scheduler
:
UnCLIPScheduler
,
super_res_scheduler
:
UnCLIPScheduler
,
):
super
().
__init__
()
self
.
register_modules
(
prior
=
prior
,
decoder
=
decoder
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
text_proj
=
text_proj
,
super_res_first
=
super_res_first
,
super_res_last
=
super_res_last
,
prior_scheduler
=
prior_scheduler
,
decoder_scheduler
=
decoder_scheduler
,
super_res_scheduler
=
super_res_scheduler
,
)
def
prepare_latents
(
self
,
shape
,
dtype
,
device
,
generator
,
latents
,
scheduler
):
if
latents
is
None
:
latents
=
randn_tensor
(
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
dtype
)
else
:
if
latents
.
shape
!=
shape
:
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
shape
}
"
)
latents
=
latents
.
to
(
device
)
latents
=
latents
*
scheduler
.
init_noise_sigma
return
latents
def
_encode_prompt
(
self
,
prompt
,
device
,
num_images_per_prompt
,
do_classifier_free_guidance
,
text_model_output
:
Optional
[
Union
[
CLIPTextModelOutput
,
Tuple
]]
=
None
,
text_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
text_model_output
is
None
:
batch_size
=
len
(
prompt
)
if
isinstance
(
prompt
,
list
)
else
1
# get prompt text embeddings
text_inputs
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
return_tensors
=
"pt"
,
)
text_input_ids
=
text_inputs
.
input_ids
text_mask
=
text_inputs
.
attention_mask
.
bool
().
to
(
device
)
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input_ids
[:,
self
.
tokenizer
.
model_max_length
:])
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_encoder_output
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))
text_embeddings
=
text_encoder_output
.
text_embeds
text_encoder_hidden_states
=
text_encoder_output
.
last_hidden_state
else
:
batch_size
=
text_model_output
[
0
].
shape
[
0
]
text_embeddings
,
text_encoder_hidden_states
=
text_model_output
[
0
],
text_model_output
[
1
]
text_mask
=
text_attention_mask
text_embeddings
=
text_embeddings
.
repeat_interleave
(
num_images_per_prompt
,
dim
=
0
)
text_encoder_hidden_states
=
text_encoder_hidden_states
.
repeat_interleave
(
num_images_per_prompt
,
dim
=
0
)
text_mask
=
text_mask
.
repeat_interleave
(
num_images_per_prompt
,
dim
=
0
)
if
do_classifier_free_guidance
:
uncond_tokens
=
[
""
]
*
batch_size
uncond_input
=
self
.
tokenizer
(
uncond_tokens
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
)
uncond_text_mask
=
uncond_input
.
attention_mask
.
bool
().
to
(
device
)
uncond_embeddings_text_encoder_output
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
device
))
uncond_embeddings
=
uncond_embeddings_text_encoder_output
.
text_embeds
uncond_text_encoder_hidden_states
=
uncond_embeddings_text_encoder_output
.
last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len
=
uncond_embeddings
.
shape
[
1
]
uncond_embeddings
=
uncond_embeddings
.
repeat
(
1
,
num_images_per_prompt
)
uncond_embeddings
=
uncond_embeddings
.
view
(
batch_size
*
num_images_per_prompt
,
seq_len
)
seq_len
=
uncond_text_encoder_hidden_states
.
shape
[
1
]
uncond_text_encoder_hidden_states
=
uncond_text_encoder_hidden_states
.
repeat
(
1
,
num_images_per_prompt
,
1
)
uncond_text_encoder_hidden_states
=
uncond_text_encoder_hidden_states
.
view
(
batch_size
*
num_images_per_prompt
,
seq_len
,
-
1
)
uncond_text_mask
=
uncond_text_mask
.
repeat_interleave
(
num_images_per_prompt
,
dim
=
0
)
# done duplicates
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings
=
torch
.
cat
([
uncond_embeddings
,
text_embeddings
])
text_encoder_hidden_states
=
torch
.
cat
([
uncond_text_encoder_hidden_states
,
text_encoder_hidden_states
])
text_mask
=
torch
.
cat
([
uncond_text_mask
,
text_mask
])
return
text_embeddings
,
text_encoder_hidden_states
,
text_mask
def
enable_sequential_cpu_offload
(
self
,
gpu_id
=
0
):
r
"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
when their specific submodule has its `forward` method called.
"""
if
is_accelerate_available
():
from
accelerate
import
cpu_offload
else
:
raise
ImportError
(
"Please install accelerate via `pip install accelerate`"
)
device
=
torch
.
device
(
f
"cuda:
{
gpu_id
}
"
)
# TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list
models
=
[
self
.
decoder
,
self
.
text_proj
,
self
.
text_encoder
,
self
.
super_res_first
,
self
.
super_res_last
,
]
for
cpu_offloaded_model
in
models
:
if
cpu_offloaded_model
is
not
None
:
cpu_offload
(
cpu_offloaded_model
,
device
)
@
property
def
_execution_device
(
self
):
r
"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if
self
.
device
!=
torch
.
device
(
"meta"
)
or
not
hasattr
(
self
.
decoder
,
"_hf_hook"
):
return
self
.
device
for
module
in
self
.
decoder
.
modules
():
if
(
hasattr
(
module
,
"_hf_hook"
)
and
hasattr
(
module
.
_hf_hook
,
"execution_device"
)
and
module
.
_hf_hook
.
execution_device
is
not
None
):
return
torch
.
device
(
module
.
_hf_hook
.
execution_device
)
return
self
.
device
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
num_images_per_prompt
:
int
=
1
,
prior_num_inference_steps
:
int
=
25
,
decoder_num_inference_steps
:
int
=
25
,
super_res_num_inference_steps
:
int
=
7
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
prior_latents
:
Optional
[
torch
.
FloatTensor
]
=
None
,
decoder_latents
:
Optional
[
torch
.
FloatTensor
]
=
None
,
super_res_latents
:
Optional
[
torch
.
FloatTensor
]
=
None
,
text_model_output
:
Optional
[
Union
[
CLIPTextModelOutput
,
Tuple
]]
=
None
,
text_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
prior_guidance_scale
:
float
=
4.0
,
decoder_guidance_scale
:
float
=
8.0
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation. This can only be left undefined if
`text_model_output` and `text_attention_mask` is passed.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prior_num_inference_steps (`int`, *optional*, defaults to 25):
The number of denoising steps for the prior. More denoising steps usually lead to a higher quality
image at the expense of slower inference.
decoder_num_inference_steps (`int`, *optional*, defaults to 25):
The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
image at the expense of slower inference.
super_res_num_inference_steps (`int`, *optional*, defaults to 7):
The number of denoising steps for super resolution. More denoising steps usually lead to a higher
quality image at the expense of slower inference.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*):
Pre-generated noisy latents to be used as inputs for the prior.
decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
Pre-generated noisy latents to be used as inputs for the decoder.
super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
Pre-generated noisy latents to be used as inputs for the decoder.
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
text_model_output (`CLIPTextModelOutput`, *optional*):
Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs
can be passed for tasks like text embedding interpolations. Make sure to also pass
`text_attention_mask` in this case. `prompt` can the be left to `None`.
text_attention_mask (`torch.Tensor`, *optional*):
Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention
masks are necessary when passing `text_model_output`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
"""
if
prompt
is
not
None
:
if
isinstance
(
prompt
,
str
):
batch_size
=
1
elif
isinstance
(
prompt
,
list
):
batch_size
=
len
(
prompt
)
else
:
raise
ValueError
(
f
"`prompt` has to be of type `str` or `list` but is
{
type
(
prompt
)
}
"
)
else
:
batch_size
=
text_model_output
[
0
].
shape
[
0
]
device
=
self
.
_execution_device
batch_size
=
batch_size
*
num_images_per_prompt
do_classifier_free_guidance
=
prior_guidance_scale
>
1.0
or
decoder_guidance_scale
>
1.0
text_embeddings
,
text_encoder_hidden_states
,
text_mask
=
self
.
_encode_prompt
(
prompt
,
device
,
num_images_per_prompt
,
do_classifier_free_guidance
,
text_model_output
,
text_attention_mask
)
# prior
self
.
prior_scheduler
.
set_timesteps
(
prior_num_inference_steps
,
device
=
device
)
prior_timesteps_tensor
=
self
.
prior_scheduler
.
timesteps
embedding_dim
=
self
.
prior
.
config
.
embedding_dim
prior_latents
=
self
.
prepare_latents
(
(
batch_size
,
embedding_dim
),
text_embeddings
.
dtype
,
device
,
generator
,
prior_latents
,
self
.
prior_scheduler
,
)
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
prior_timesteps_tensor
)):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
torch
.
cat
([
prior_latents
]
*
2
)
if
do_classifier_free_guidance
else
prior_latents
predicted_image_embedding
=
self
.
prior
(
latent_model_input
,
timestep
=
t
,
proj_embedding
=
text_embeddings
,
encoder_hidden_states
=
text_encoder_hidden_states
,
attention_mask
=
text_mask
,
).
predicted_image_embedding
if
do_classifier_free_guidance
:
predicted_image_embedding_uncond
,
predicted_image_embedding_text
=
predicted_image_embedding
.
chunk
(
2
)
predicted_image_embedding
=
predicted_image_embedding_uncond
+
prior_guidance_scale
*
(
predicted_image_embedding_text
-
predicted_image_embedding_uncond
)
if
i
+
1
==
prior_timesteps_tensor
.
shape
[
0
]:
prev_timestep
=
None
else
:
prev_timestep
=
prior_timesteps_tensor
[
i
+
1
]
prior_latents
=
self
.
prior_scheduler
.
step
(
predicted_image_embedding
,
timestep
=
t
,
sample
=
prior_latents
,
generator
=
generator
,
prev_timestep
=
prev_timestep
,
).
prev_sample
prior_latents
=
self
.
prior
.
post_process_latents
(
prior_latents
)
image_embeddings
=
prior_latents
# done prior
# decoder
text_encoder_hidden_states
,
additive_clip_time_embeddings
=
self
.
text_proj
(
image_embeddings
=
image_embeddings
,
text_embeddings
=
text_embeddings
,
text_encoder_hidden_states
=
text_encoder_hidden_states
,
do_classifier_free_guidance
=
do_classifier_free_guidance
,
)
decoder_text_mask
=
F
.
pad
(
text_mask
,
(
self
.
text_proj
.
clip_extra_context_tokens
,
0
),
value
=
1
)
self
.
decoder_scheduler
.
set_timesteps
(
decoder_num_inference_steps
,
device
=
device
)
decoder_timesteps_tensor
=
self
.
decoder_scheduler
.
timesteps
num_channels_latents
=
self
.
decoder
.
in_channels
height
=
self
.
decoder
.
sample_size
width
=
self
.
decoder
.
sample_size
decoder_latents
=
self
.
prepare_latents
(
(
batch_size
,
num_channels_latents
,
height
,
width
),
text_encoder_hidden_states
.
dtype
,
device
,
generator
,
decoder_latents
,
self
.
decoder_scheduler
,
)
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
decoder_timesteps_tensor
)):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
torch
.
cat
([
decoder_latents
]
*
2
)
if
do_classifier_free_guidance
else
decoder_latents
noise_pred
=
self
.
decoder
(
sample
=
latent_model_input
,
timestep
=
t
,
encoder_hidden_states
=
text_encoder_hidden_states
,
class_labels
=
additive_clip_time_embeddings
,
attention_mask
=
decoder_text_mask
,
).
sample
if
do_classifier_free_guidance
:
noise_pred_uncond
,
noise_pred_text
=
noise_pred
.
chunk
(
2
)
noise_pred_uncond
,
_
=
noise_pred_uncond
.
split
(
latent_model_input
.
shape
[
1
],
dim
=
1
)
noise_pred_text
,
predicted_variance
=
noise_pred_text
.
split
(
latent_model_input
.
shape
[
1
],
dim
=
1
)
noise_pred
=
noise_pred_uncond
+
decoder_guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
noise_pred
=
torch
.
cat
([
noise_pred
,
predicted_variance
],
dim
=
1
)
if
i
+
1
==
decoder_timesteps_tensor
.
shape
[
0
]:
prev_timestep
=
None
else
:
prev_timestep
=
decoder_timesteps_tensor
[
i
+
1
]
# compute the previous noisy sample x_t -> x_t-1
decoder_latents
=
self
.
decoder_scheduler
.
step
(
noise_pred
,
t
,
decoder_latents
,
prev_timestep
=
prev_timestep
,
generator
=
generator
).
prev_sample
decoder_latents
=
decoder_latents
.
clamp
(
-
1
,
1
)
image_small
=
decoder_latents
# done decoder
# super res
self
.
super_res_scheduler
.
set_timesteps
(
super_res_num_inference_steps
,
device
=
device
)
super_res_timesteps_tensor
=
self
.
super_res_scheduler
.
timesteps
channels
=
self
.
super_res_first
.
in_channels
//
2
height
=
self
.
super_res_first
.
sample_size
width
=
self
.
super_res_first
.
sample_size
super_res_latents
=
self
.
prepare_latents
(
(
batch_size
,
channels
,
height
,
width
),
image_small
.
dtype
,
device
,
generator
,
super_res_latents
,
self
.
super_res_scheduler
,
)
interpolate_antialias
=
{}
if
"antialias"
in
inspect
.
signature
(
F
.
interpolate
).
parameters
:
interpolate_antialias
[
"antialias"
]
=
True
image_upscaled
=
F
.
interpolate
(
image_small
,
size
=
[
height
,
width
],
mode
=
"bicubic"
,
align_corners
=
False
,
**
interpolate_antialias
)
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
super_res_timesteps_tensor
)):
# no classifier free guidance
if
i
==
super_res_timesteps_tensor
.
shape
[
0
]
-
1
:
unet
=
self
.
super_res_last
else
:
unet
=
self
.
super_res_first
latent_model_input
=
torch
.
cat
([
super_res_latents
,
image_upscaled
],
dim
=
1
)
noise_pred
=
unet
(
sample
=
latent_model_input
,
timestep
=
t
,
).
sample
if
i
+
1
==
super_res_timesteps_tensor
.
shape
[
0
]:
prev_timestep
=
None
else
:
prev_timestep
=
super_res_timesteps_tensor
[
i
+
1
]
# compute the previous noisy sample x_t -> x_t-1
super_res_latents
=
self
.
super_res_scheduler
.
step
(
noise_pred
,
t
,
super_res_latents
,
prev_timestep
=
prev_timestep
,
generator
=
generator
).
prev_sample
image
=
super_res_latents
# done super res
# post processing
image
=
image
*
0.5
+
0.5
image
=
image
.
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
float
().
numpy
()
if
output_type
==
"pil"
:
image
=
self
.
numpy_to_pil
(
image
)
if
not
return_dict
:
return
(
image
,)
return
ImagePipelineOutput
(
images
=
image
)
\ No newline at end of file
ldm/modules/karlo/kakao/__init__.py
0 → 100644
View file @
4007efdd
ldm/modules/karlo/kakao/models/__init__.py
0 → 100644
View file @
4007efdd
ldm/modules/karlo/kakao/models/clip.py
0 → 100644
View file @
4007efdd
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
# ------------------------------------------------------------------------------------
# Adapted from OpenAI's CLIP (https://github.com/openai/CLIP/)
# ------------------------------------------------------------------------------------
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
clip
from
clip.model
import
CLIP
,
convert_weights
from
clip.simple_tokenizer
import
SimpleTokenizer
,
default_bpe
"""===== Monkey-Patching original CLIP for JIT compile ====="""
class
LayerNorm
(
nn
.
LayerNorm
):
"""Subclass torch's LayerNorm to handle fp16."""
def
forward
(
self
,
x
:
torch
.
Tensor
):
orig_type
=
x
.
dtype
ret
=
F
.
layer_norm
(
x
.
type
(
torch
.
float32
),
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
,
)
return
ret
.
type
(
orig_type
)
clip
.
model
.
LayerNorm
=
LayerNorm
delattr
(
clip
.
model
.
CLIP
,
"forward"
)
"""===== End of Monkey-Patching ====="""
class
CustomizedCLIP
(
CLIP
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
@
torch
.
jit
.
export
def
encode_image
(
self
,
image
):
return
self
.
visual
(
image
)
@
torch
.
jit
.
export
def
encode_text
(
self
,
text
):
# re-define this function to return unpooled text features
x
=
self
.
token_embedding
(
text
).
type
(
self
.
dtype
)
# [batch_size, n_ctx, d_model]
x
=
x
+
self
.
positional_embedding
.
type
(
self
.
dtype
)
x
=
x
.
permute
(
1
,
0
,
2
)
# NLD -> LND
x
=
self
.
transformer
(
x
)
x
=
x
.
permute
(
1
,
0
,
2
)
# LND -> NLD
x
=
self
.
ln_final
(
x
).
type
(
self
.
dtype
)
x_seq
=
x
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x_out
=
x
[
torch
.
arange
(
x
.
shape
[
0
]),
text
.
argmax
(
dim
=-
1
)]
@
self
.
text_projection
return
x_out
,
x_seq
@
torch
.
jit
.
ignore
def
forward
(
self
,
image
,
text
):
super
().
forward
(
image
,
text
)
@
classmethod
def
load_from_checkpoint
(
cls
,
ckpt_path
:
str
):
state_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
).
state_dict
()
vit
=
"visual.proj"
in
state_dict
if
vit
:
vision_width
=
state_dict
[
"visual.conv1.weight"
].
shape
[
0
]
vision_layers
=
len
(
[
k
for
k
in
state_dict
.
keys
()
if
k
.
startswith
(
"visual."
)
and
k
.
endswith
(
".attn.in_proj_weight"
)
]
)
vision_patch_size
=
state_dict
[
"visual.conv1.weight"
].
shape
[
-
1
]
grid_size
=
round
(
(
state_dict
[
"visual.positional_embedding"
].
shape
[
0
]
-
1
)
**
0.5
)
image_resolution
=
vision_patch_size
*
grid_size
else
:
counts
:
list
=
[
len
(
set
(
k
.
split
(
"."
)[
2
]
for
k
in
state_dict
if
k
.
startswith
(
f
"visual.layer
{
b
}
"
)
)
)
for
b
in
[
1
,
2
,
3
,
4
]
]
vision_layers
=
tuple
(
counts
)
vision_width
=
state_dict
[
"visual.layer1.0.conv1.weight"
].
shape
[
0
]
output_width
=
round
(
(
state_dict
[
"visual.attnpool.positional_embedding"
].
shape
[
0
]
-
1
)
**
0.5
)
vision_patch_size
=
None
assert
(
output_width
**
2
+
1
==
state_dict
[
"visual.attnpool.positional_embedding"
].
shape
[
0
]
)
image_resolution
=
output_width
*
32
embed_dim
=
state_dict
[
"text_projection"
].
shape
[
1
]
context_length
=
state_dict
[
"positional_embedding"
].
shape
[
0
]
vocab_size
=
state_dict
[
"token_embedding.weight"
].
shape
[
0
]
transformer_width
=
state_dict
[
"ln_final.weight"
].
shape
[
0
]
transformer_heads
=
transformer_width
//
64
transformer_layers
=
len
(
set
(
k
.
split
(
"."
)[
2
]
for
k
in
state_dict
if
k
.
startswith
(
"transformer.resblocks"
)
)
)
model
=
cls
(
embed_dim
,
image_resolution
,
vision_layers
,
vision_width
,
vision_patch_size
,
context_length
,
vocab_size
,
transformer_width
,
transformer_heads
,
transformer_layers
,
)
for
key
in
[
"input_resolution"
,
"context_length"
,
"vocab_size"
]:
if
key
in
state_dict
:
del
state_dict
[
key
]
convert_weights
(
model
)
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
model
.
float
()
return
model
class
CustomizedTokenizer
(
SimpleTokenizer
):
def
__init__
(
self
):
super
().
__init__
(
bpe_path
=
default_bpe
())
self
.
sot_token
=
self
.
encoder
[
"<|startoftext|>"
]
self
.
eot_token
=
self
.
encoder
[
"<|endoftext|>"
]
def
padded_tokens_and_mask
(
self
,
texts
,
text_ctx
):
assert
isinstance
(
texts
,
list
)
and
all
(
isinstance
(
elem
,
str
)
for
elem
in
texts
),
"texts should be a list of strings"
all_tokens
=
[
[
self
.
sot_token
]
+
self
.
encode
(
text
)
+
[
self
.
eot_token
]
for
text
in
texts
]
mask
=
[
[
True
]
*
min
(
text_ctx
,
len
(
tokens
))
+
[
False
]
*
max
(
text_ctx
-
len
(
tokens
),
0
)
for
tokens
in
all_tokens
]
mask
=
torch
.
tensor
(
mask
,
dtype
=
torch
.
bool
)
result
=
torch
.
zeros
(
len
(
all_tokens
),
text_ctx
,
dtype
=
torch
.
int
)
for
i
,
tokens
in
enumerate
(
all_tokens
):
if
len
(
tokens
)
>
text_ctx
:
tokens
=
tokens
[:
text_ctx
]
tokens
[
-
1
]
=
self
.
eot_token
result
[
i
,
:
len
(
tokens
)]
=
torch
.
tensor
(
tokens
)
return
result
,
mask
ldm/modules/karlo/kakao/models/decoder_model.py
0 → 100644
View file @
4007efdd
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
import
copy
import
torch
from
ldm.modules.karlo.kakao.modules
import
create_gaussian_diffusion
from
ldm.modules.karlo.kakao.modules.unet
import
PLMImUNet
class
Text2ImProgressiveModel
(
torch
.
nn
.
Module
):
"""
A decoder that generates 64x64px images based on the text prompt.
:param config: yaml config to define the decoder.
:param tokenizer: tokenizer used in clip.
"""
def
__init__
(
self
,
config
,
tokenizer
,
):
super
().
__init__
()
self
.
_conf
=
config
self
.
_model_conf
=
config
.
model
.
hparams
self
.
_diffusion_kwargs
=
dict
(
steps
=
config
.
diffusion
.
steps
,
learn_sigma
=
config
.
diffusion
.
learn_sigma
,
sigma_small
=
config
.
diffusion
.
sigma_small
,
noise_schedule
=
config
.
diffusion
.
noise_schedule
,
use_kl
=
config
.
diffusion
.
use_kl
,
predict_xstart
=
config
.
diffusion
.
predict_xstart
,
rescale_learned_sigmas
=
config
.
diffusion
.
rescale_learned_sigmas
,
timestep_respacing
=
config
.
diffusion
.
timestep_respacing
,
)
self
.
_tokenizer
=
tokenizer
self
.
model
=
self
.
create_plm_dec_model
()
cf_token
,
cf_mask
=
self
.
set_cf_text_tensor
()
self
.
register_buffer
(
"cf_token"
,
cf_token
,
persistent
=
False
)
self
.
register_buffer
(
"cf_mask"
,
cf_mask
,
persistent
=
False
)
@
classmethod
def
load_from_checkpoint
(
cls
,
config
,
tokenizer
,
ckpt_path
,
strict
:
bool
=
True
):
ckpt
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
)[
"state_dict"
]
model
=
cls
(
config
,
tokenizer
)
model
.
load_state_dict
(
ckpt
,
strict
=
strict
)
return
model
def
create_plm_dec_model
(
self
):
image_size
=
self
.
_model_conf
.
image_size
if
self
.
_model_conf
.
channel_mult
==
""
:
if
image_size
==
256
:
channel_mult
=
(
1
,
1
,
2
,
2
,
4
,
4
)
elif
image_size
==
128
:
channel_mult
=
(
1
,
1
,
2
,
3
,
4
)
elif
image_size
==
64
:
channel_mult
=
(
1
,
2
,
3
,
4
)
else
:
raise
ValueError
(
f
"unsupported image size:
{
image_size
}
"
)
else
:
channel_mult
=
tuple
(
int
(
ch_mult
)
for
ch_mult
in
self
.
_model_conf
.
channel_mult
.
split
(
","
)
)
assert
2
**
(
len
(
channel_mult
)
+
2
)
==
image_size
attention_ds
=
[]
for
res
in
self
.
_model_conf
.
attention_resolutions
.
split
(
","
):
attention_ds
.
append
(
image_size
//
int
(
res
))
return
PLMImUNet
(
text_ctx
=
self
.
_model_conf
.
text_ctx
,
xf_width
=
self
.
_model_conf
.
xf_width
,
in_channels
=
3
,
model_channels
=
self
.
_model_conf
.
num_channels
,
out_channels
=
6
if
self
.
_model_conf
.
learn_sigma
else
3
,
num_res_blocks
=
self
.
_model_conf
.
num_res_blocks
,
attention_resolutions
=
tuple
(
attention_ds
),
dropout
=
self
.
_model_conf
.
dropout
,
channel_mult
=
channel_mult
,
num_heads
=
self
.
_model_conf
.
num_heads
,
num_head_channels
=
self
.
_model_conf
.
num_head_channels
,
num_heads_upsample
=
self
.
_model_conf
.
num_heads_upsample
,
use_scale_shift_norm
=
self
.
_model_conf
.
use_scale_shift_norm
,
resblock_updown
=
self
.
_model_conf
.
resblock_updown
,
clip_dim
=
self
.
_model_conf
.
clip_dim
,
clip_emb_mult
=
self
.
_model_conf
.
clip_emb_mult
,
clip_emb_type
=
self
.
_model_conf
.
clip_emb_type
,
clip_emb_drop
=
self
.
_model_conf
.
clip_emb_drop
,
)
def
set_cf_text_tensor
(
self
):
return
self
.
_tokenizer
.
padded_tokens_and_mask
([
""
],
self
.
model
.
text_ctx
)
def
get_sample_fn
(
self
,
timestep_respacing
):
use_ddim
=
timestep_respacing
.
startswith
((
"ddim"
,
"fast"
))
diffusion_kwargs
=
copy
.
deepcopy
(
self
.
_diffusion_kwargs
)
diffusion_kwargs
.
update
(
timestep_respacing
=
timestep_respacing
)
diffusion
=
create_gaussian_diffusion
(
**
diffusion_kwargs
)
sample_fn
=
(
diffusion
.
ddim_sample_loop_progressive
if
use_ddim
else
diffusion
.
p_sample_loop_progressive
)
return
sample_fn
def
forward
(
self
,
txt_feat
,
txt_feat_seq
,
tok
,
mask
,
img_feat
=
None
,
cf_guidance_scales
=
None
,
timestep_respacing
=
None
,
):
# cfg should be enabled in inference
assert
cf_guidance_scales
is
not
None
and
all
(
cf_guidance_scales
>
0.0
)
assert
img_feat
is
not
None
bsz
=
txt_feat
.
shape
[
0
]
img_sz
=
self
.
_model_conf
.
image_size
def
guided_model_fn
(
x_t
,
ts
,
**
kwargs
):
half
=
x_t
[:
len
(
x_t
)
//
2
]
combined
=
torch
.
cat
([
half
,
half
],
dim
=
0
)
model_out
=
self
.
model
(
combined
,
ts
,
**
kwargs
)
eps
,
rest
=
model_out
[:,
:
3
],
model_out
[:,
3
:]
cond_eps
,
uncond_eps
=
torch
.
split
(
eps
,
len
(
eps
)
//
2
,
dim
=
0
)
half_eps
=
uncond_eps
+
cf_guidance_scales
.
view
(
-
1
,
1
,
1
,
1
)
*
(
cond_eps
-
uncond_eps
)
eps
=
torch
.
cat
([
half_eps
,
half_eps
],
dim
=
0
)
return
torch
.
cat
([
eps
,
rest
],
dim
=
1
)
cf_feat
=
self
.
model
.
cf_param
.
unsqueeze
(
0
)
cf_feat
=
cf_feat
.
expand
(
bsz
//
2
,
-
1
)
feat
=
torch
.
cat
([
img_feat
,
cf_feat
.
to
(
txt_feat
.
device
)],
dim
=
0
)
cond
=
{
"y"
:
feat
,
"txt_feat"
:
txt_feat
,
"txt_feat_seq"
:
txt_feat_seq
,
"mask"
:
mask
,
}
sample_fn
=
self
.
get_sample_fn
(
timestep_respacing
)
sample_outputs
=
sample_fn
(
guided_model_fn
,
(
bsz
,
3
,
img_sz
,
img_sz
),
noise
=
None
,
device
=
txt_feat
.
device
,
clip_denoised
=
True
,
model_kwargs
=
cond
,
)
for
out
in
sample_outputs
:
sample
=
out
[
"sample"
]
yield
sample
if
cf_guidance_scales
is
None
else
sample
[
:
sample
.
shape
[
0
]
//
2
]
class
Text2ImModel
(
Text2ImProgressiveModel
):
def
forward
(
self
,
txt_feat
,
txt_feat_seq
,
tok
,
mask
,
img_feat
=
None
,
cf_guidance_scales
=
None
,
timestep_respacing
=
None
,
):
last_out
=
None
for
out
in
super
().
forward
(
txt_feat
,
txt_feat_seq
,
tok
,
mask
,
img_feat
,
cf_guidance_scales
,
timestep_respacing
,
):
last_out
=
out
return
last_out
ldm/modules/karlo/kakao/models/prior_model.py
0 → 100644
View file @
4007efdd
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
import
copy
import
torch
from
ldm.modules.karlo.kakao.modules
import
create_gaussian_diffusion
from
ldm.modules.karlo.kakao.modules.xf
import
PriorTransformer
class
PriorDiffusionModel
(
torch
.
nn
.
Module
):
"""
A prior that generates clip image feature based on the text prompt.
:param config: yaml config to define the decoder.
:param tokenizer: tokenizer used in clip.
:param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance).
:param clip_std: std to noramlize the clip image feature (zero-mean, unit variance).
"""
def
__init__
(
self
,
config
,
tokenizer
,
clip_mean
,
clip_std
):
super
().
__init__
()
self
.
_conf
=
config
self
.
_model_conf
=
config
.
model
.
hparams
self
.
_diffusion_kwargs
=
dict
(
steps
=
config
.
diffusion
.
steps
,
learn_sigma
=
config
.
diffusion
.
learn_sigma
,
sigma_small
=
config
.
diffusion
.
sigma_small
,
noise_schedule
=
config
.
diffusion
.
noise_schedule
,
use_kl
=
config
.
diffusion
.
use_kl
,
predict_xstart
=
config
.
diffusion
.
predict_xstart
,
rescale_learned_sigmas
=
config
.
diffusion
.
rescale_learned_sigmas
,
timestep_respacing
=
config
.
diffusion
.
timestep_respacing
,
)
self
.
_tokenizer
=
tokenizer
self
.
register_buffer
(
"clip_mean"
,
clip_mean
[
None
,
:],
persistent
=
False
)
self
.
register_buffer
(
"clip_std"
,
clip_std
[
None
,
:],
persistent
=
False
)
causal_mask
=
self
.
get_causal_mask
()
self
.
register_buffer
(
"causal_mask"
,
causal_mask
,
persistent
=
False
)
self
.
model
=
PriorTransformer
(
text_ctx
=
self
.
_model_conf
.
text_ctx
,
xf_width
=
self
.
_model_conf
.
xf_width
,
xf_layers
=
self
.
_model_conf
.
xf_layers
,
xf_heads
=
self
.
_model_conf
.
xf_heads
,
xf_final_ln
=
self
.
_model_conf
.
xf_final_ln
,
clip_dim
=
self
.
_model_conf
.
clip_dim
,
)
cf_token
,
cf_mask
=
self
.
set_cf_text_tensor
()
self
.
register_buffer
(
"cf_token"
,
cf_token
,
persistent
=
False
)
self
.
register_buffer
(
"cf_mask"
,
cf_mask
,
persistent
=
False
)
@
classmethod
def
load_from_checkpoint
(
cls
,
config
,
tokenizer
,
clip_mean
,
clip_std
,
ckpt_path
,
strict
:
bool
=
True
):
ckpt
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
)[
"state_dict"
]
model
=
cls
(
config
,
tokenizer
,
clip_mean
,
clip_std
)
model
.
load_state_dict
(
ckpt
,
strict
=
strict
)
return
model
def
set_cf_text_tensor
(
self
):
return
self
.
_tokenizer
.
padded_tokens_and_mask
([
""
],
self
.
model
.
text_ctx
)
def
get_sample_fn
(
self
,
timestep_respacing
):
use_ddim
=
timestep_respacing
.
startswith
((
"ddim"
,
"fast"
))
diffusion_kwargs
=
copy
.
deepcopy
(
self
.
_diffusion_kwargs
)
diffusion_kwargs
.
update
(
timestep_respacing
=
timestep_respacing
)
diffusion
=
create_gaussian_diffusion
(
**
diffusion_kwargs
)
sample_fn
=
diffusion
.
ddim_sample_loop
if
use_ddim
else
diffusion
.
p_sample_loop
return
sample_fn
def
get_causal_mask
(
self
):
seq_len
=
self
.
_model_conf
.
text_ctx
+
4
mask
=
torch
.
empty
(
seq_len
,
seq_len
)
mask
.
fill_
(
float
(
"-inf"
))
mask
.
triu_
(
1
)
mask
=
mask
[
None
,
...]
return
mask
def
forward
(
self
,
txt_feat
,
txt_feat_seq
,
mask
,
cf_guidance_scales
=
None
,
timestep_respacing
=
None
,
denoised_fn
=
True
,
):
# cfg should be enabled in inference
assert
cf_guidance_scales
is
not
None
and
all
(
cf_guidance_scales
>
0.0
)
bsz_
=
txt_feat
.
shape
[
0
]
bsz
=
bsz_
//
2
def
guided_model_fn
(
x_t
,
ts
,
**
kwargs
):
half
=
x_t
[:
len
(
x_t
)
//
2
]
combined
=
torch
.
cat
([
half
,
half
],
dim
=
0
)
model_out
=
self
.
model
(
combined
,
ts
,
**
kwargs
)
eps
,
rest
=
(
model_out
[:,
:
int
(
x_t
.
shape
[
1
])],
model_out
[:,
int
(
x_t
.
shape
[
1
])
:],
)
cond_eps
,
uncond_eps
=
torch
.
split
(
eps
,
len
(
eps
)
//
2
,
dim
=
0
)
half_eps
=
uncond_eps
+
cf_guidance_scales
.
view
(
-
1
,
1
)
*
(
cond_eps
-
uncond_eps
)
eps
=
torch
.
cat
([
half_eps
,
half_eps
],
dim
=
0
)
return
torch
.
cat
([
eps
,
rest
],
dim
=
1
)
cond
=
{
"text_emb"
:
txt_feat
,
"text_enc"
:
txt_feat_seq
,
"mask"
:
mask
,
"causal_mask"
:
self
.
causal_mask
,
}
sample_fn
=
self
.
get_sample_fn
(
timestep_respacing
)
sample
=
sample_fn
(
guided_model_fn
,
(
bsz_
,
self
.
model
.
clip_dim
),
noise
=
None
,
device
=
txt_feat
.
device
,
clip_denoised
=
False
,
denoised_fn
=
lambda
x
:
torch
.
clamp
(
x
,
-
10
,
10
),
model_kwargs
=
cond
,
)
sample
=
(
sample
*
self
.
clip_std
)
+
self
.
clip_mean
return
sample
[:
bsz
]
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment