Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
Ruyi-Mini-7B
Commits
08a21d59
Commit
08a21d59
authored
Dec 27, 2024
by
chenpangpang
Browse files
feat: 初始提交
parent
1a6b26f1
Pipeline
#2165
failed with stages
in 0 seconds
Changes
95
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5443 additions
and
0 deletions
+5443
-0
Ruyi-Models/ruyi/vae/ldm/models/enc_dec_pytorch.py
Ruyi-Models/ruyi/vae/ldm/models/enc_dec_pytorch.py
+234
-0
Ruyi-Models/ruyi/vae/ldm/models/omnigen_casual3dcnn.py
Ruyi-Models/ruyi/vae/ldm/models/omnigen_casual3dcnn.py
+340
-0
Ruyi-Models/ruyi/vae/ldm/models/omnigen_enc_dec.py
Ruyi-Models/ruyi/vae/ldm/models/omnigen_enc_dec.py
+565
-0
Ruyi-Models/ruyi/vae/ldm/modules/diffusionmodules/__init__.py
...-Models/ruyi/vae/ldm/modules/diffusionmodules/__init__.py
+0
-0
Ruyi-Models/ruyi/vae/ldm/modules/diffusionmodules/model.py
Ruyi-Models/ruyi/vae/ldm/modules/diffusionmodules/model.py
+701
-0
Ruyi-Models/ruyi/vae/ldm/modules/diffusionmodules/util.py
Ruyi-Models/ruyi/vae/ldm/modules/diffusionmodules/util.py
+268
-0
Ruyi-Models/ruyi/vae/ldm/modules/distributions/__init__.py
Ruyi-Models/ruyi/vae/ldm/modules/distributions/__init__.py
+0
-0
Ruyi-Models/ruyi/vae/ldm/modules/distributions/distributions.py
...odels/ruyi/vae/ldm/modules/distributions/distributions.py
+92
-0
Ruyi-Models/ruyi/vae/ldm/modules/ema.py
Ruyi-Models/ruyi/vae/ldm/modules/ema.py
+115
-0
Ruyi-Models/ruyi/vae/ldm/modules/image_degradation/__init__.py
...Models/ruyi/vae/ldm/modules/image_degradation/__init__.py
+3
-0
Ruyi-Models/ruyi/vae/ldm/modules/image_degradation/bsrgan.py
Ruyi-Models/ruyi/vae/ldm/modules/image_degradation/bsrgan.py
+732
-0
Ruyi-Models/ruyi/vae/ldm/modules/image_degradation/bsrgan_light.py
...ls/ruyi/vae/ldm/modules/image_degradation/bsrgan_light.py
+650
-0
Ruyi-Models/ruyi/vae/ldm/modules/image_degradation/utils/test.png
...els/ruyi/vae/ldm/modules/image_degradation/utils/test.png
+0
-0
Ruyi-Models/ruyi/vae/ldm/modules/image_degradation/utils_image.py
...els/ruyi/vae/ldm/modules/image_degradation/utils_image.py
+920
-0
Ruyi-Models/ruyi/vae/ldm/modules/losses/__init__.py
Ruyi-Models/ruyi/vae/ldm/modules/losses/__init__.py
+1
-0
Ruyi-Models/ruyi/vae/ldm/modules/losses/contperceptual.py
Ruyi-Models/ruyi/vae/ldm/modules/losses/contperceptual.py
+148
-0
Ruyi-Models/ruyi/vae/ldm/modules/losses/vqperceptual.py
Ruyi-Models/ruyi/vae/ldm/modules/losses/vqperceptual.py
+168
-0
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/__init__.py
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/__init__.py
+0
-0
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/activations.py
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/activations.py
+27
-0
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/attention.py
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/attention.py
+479
-0
No files found.
Ruyi-Models/ruyi/vae/ldm/models/enc_dec_pytorch.py
0 → 100644
View file @
08a21d59
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
def
cast_tuple
(
t
,
length
=
1
):
return
t
if
isinstance
(
t
,
tuple
)
else
((
t
,)
*
length
)
def
divisible_by
(
num
,
den
):
return
(
num
%
den
)
==
0
def
is_odd
(
n
):
return
not
divisible_by
(
n
,
2
)
class
CausalConv3d
(
nn
.
Module
):
def
__init__
(
self
,
chan_in
,
chan_out
,
kernel_size
,
pad_mode
=
'constant'
,
**
kwargs
):
super
().
__init__
()
kernel_size
=
cast_tuple
(
kernel_size
,
3
)
time_kernel_size
,
height_kernel_size
,
width_kernel_size
=
kernel_size
assert
is_odd
(
height_kernel_size
)
and
is_odd
(
width_kernel_size
)
dilation
=
kwargs
.
pop
(
'dilation'
,
1
)
stride
=
kwargs
.
pop
(
'stride'
,
1
)
self
.
pad_mode
=
pad_mode
time_pad
=
dilation
*
(
time_kernel_size
-
1
)
+
(
1
-
stride
)
height_pad
=
height_kernel_size
//
2
width_pad
=
width_kernel_size
//
2
self
.
time_pad
=
time_pad
self
.
time_causal_padding
=
(
width_pad
,
width_pad
,
height_pad
,
height_pad
,
time_pad
,
0
)
stride
=
(
stride
,
1
,
1
)
dilation
=
(
dilation
,
1
,
1
)
self
.
conv
=
nn
.
Conv3d
(
chan_in
,
chan_out
,
kernel_size
,
stride
=
stride
,
dilation
=
dilation
,
**
kwargs
)
def
forward
(
self
,
x
):
x
=
F
.
pad
(
x
,
self
.
time_causal_padding
,
mode
=
'replicate'
)
return
self
.
conv
(
x
)
class
Swish
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
forward
(
self
,
x
):
return
x
*
F
.
sigmoid
(
x
)
class
ResBlockX
(
nn
.
Module
):
def
__init__
(
self
,
inchannel
)
->
None
:
super
().
__init__
()
self
.
conv
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
inchannel
),
Swish
(),
CausalConv3d
(
inchannel
,
inchannel
,
3
),
nn
.
GroupNorm
(
32
,
inchannel
),
Swish
(),
CausalConv3d
(
inchannel
,
inchannel
,
3
)
)
def
forward
(
self
,
x
):
return
x
+
self
.
conv
(
x
)
class
ResBlockXY
(
nn
.
Module
):
def
__init__
(
self
,
inchannel
,
outchannel
)
->
None
:
super
().
__init__
()
self
.
conv
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
inchannel
),
Swish
(),
CausalConv3d
(
inchannel
,
outchannel
,
3
),
nn
.
GroupNorm
(
32
,
outchannel
),
Swish
(),
CausalConv3d
(
outchannel
,
outchannel
,
3
)
)
self
.
conv_1
=
nn
.
Conv3d
(
inchannel
,
outchannel
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv_1
(
x
)
+
self
.
conv
(
x
)
class
PoolDown222
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
pool
=
nn
.
AvgPool3d
(
2
,
2
)
def
forward
(
self
,
x
):
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
0
,
1
,
0
),
'replicate'
)
return
self
.
pool
(
x
)
class
PoolDown122
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
pool
=
nn
.
AvgPool3d
((
1
,
2
,
2
),
(
1
,
2
,
2
))
def
forward
(
self
,
x
):
return
self
.
pool
(
x
)
class
Unpool222
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
up
=
nn
.
Upsample
(
scale_factor
=
2
,
mode
=
'nearest'
)
def
forward
(
self
,
x
):
x
=
self
.
up
(
x
)
return
x
[:,
:,
1
:]
class
Unpool122
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
up
=
nn
.
Upsample
(
scale_factor
=
(
1
,
2
,
2
),
mode
=
'nearest'
)
def
forward
(
self
,
x
):
x
=
self
.
up
(
x
)
return
x
class
ResBlockDown
(
nn
.
Module
):
def
__init__
(
self
,
inchannel
,
outchannel
)
->
None
:
super
().
__init__
()
self
.
blcok
=
nn
.
Sequential
(
CausalConv3d
(
inchannel
,
outchannel
,
3
),
nn
.
LeakyReLU
(
inplace
=
True
),
PoolDown222
(),
CausalConv3d
(
outchannel
,
outchannel
,
3
),
nn
.
LeakyReLU
(
inplace
=
True
)
)
self
.
res
=
nn
.
Sequential
(
PoolDown222
(),
nn
.
Conv3d
(
inchannel
,
outchannel
,
1
)
)
def
forward
(
self
,
x
):
return
self
.
res
(
x
)
+
self
.
blcok
(
x
)
class
Discriminator
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
block
=
nn
.
Sequential
(
CausalConv3d
(
3
,
64
,
3
),
nn
.
LeakyReLU
(
inplace
=
True
),
ResBlockDown
(
64
,
128
),
ResBlockDown
(
128
,
256
),
ResBlockDown
(
256
,
256
),
ResBlockDown
(
256
,
256
),
ResBlockDown
(
256
,
256
),
CausalConv3d
(
256
,
256
,
3
),
nn
.
LeakyReLU
(
inplace
=
True
),
nn
.
AdaptiveAvgPool3d
(
1
),
nn
.
Flatten
(),
nn
.
Linear
(
256
,
256
),
nn
.
LeakyReLU
(
inplace
=
True
),
nn
.
Linear
(
256
,
1
)
)
def
forward
(
self
,
x
):
if
x
.
ndim
==
4
:
x
=
x
.
unsqueeze
(
2
)
return
self
.
block
(
x
)
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
encoder
=
nn
.
Sequential
(
CausalConv3d
(
3
,
64
,
3
),
ResBlockX
(
64
),
ResBlockX
(
64
),
PoolDown222
(),
ResBlockXY
(
64
,
128
),
ResBlockX
(
128
),
PoolDown222
(),
ResBlockX
(
128
),
ResBlockX
(
128
),
PoolDown122
(),
ResBlockXY
(
128
,
256
),
ResBlockX
(
256
),
ResBlockX
(
256
),
ResBlockX
(
256
),
nn
.
GroupNorm
(
32
,
256
),
Swish
(),
nn
.
Conv3d
(
256
,
16
,
1
)
)
def
forward
(
self
,
x
):
return
self
.
encoder
(
x
)
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
decoder
=
nn
.
Sequential
(
CausalConv3d
(
8
,
256
,
3
),
ResBlockX
(
256
),
ResBlockX
(
256
),
ResBlockX
(
256
),
ResBlockX
(
256
),
Unpool122
(),
CausalConv3d
(
256
,
256
,
3
),
ResBlockXY
(
256
,
128
),
ResBlockX
(
128
),
Unpool222
(),
CausalConv3d
(
128
,
128
,
3
),
ResBlockX
(
128
),
ResBlockX
(
128
),
Unpool222
(),
CausalConv3d
(
128
,
128
,
3
),
ResBlockXY
(
128
,
64
),
ResBlockX
(
64
),
nn
.
GroupNorm
(
32
,
64
),
Swish
(),
CausalConv3d
(
64
,
64
,
3
)
)
self
.
conv_out
=
nn
.
Conv3d
(
64
,
3
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv_out
(
self
.
decoder
(
x
))
if
__name__
==
'__main__'
:
encoder
=
Encoder
()
decoder
=
Decoder
()
dis
=
Discriminator
()
x
=
torch
.
randn
((
1
,
3
,
1
,
64
,
64
))
embedding
=
encoder
(
x
)
y
=
decoder
(
embedding
)
tmp
=
torch
.
randn
((
1
,
4
,
1
,
64
,
64
))
print
(
'something mmm'
)
\ No newline at end of file
Ruyi-Models/ruyi/vae/ldm/models/omnigen_casual3dcnn.py
0 → 100644
View file @
08a21d59
import
itertools
from
dataclasses
import
dataclass
from
typing
import
Optional
import
numpy
as
np
import
pytorch_lightning
as
pl
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
..util
import
instantiate_from_config
from
.omnigen_enc_dec
import
Decoder
as
omnigen_Mag_Decoder
from
.omnigen_enc_dec
import
Encoder
as
omnigen_Mag_Encoder
class
DiagonalGaussianDistribution
:
def
__init__
(
self
,
mean
:
torch
.
Tensor
,
logvar
:
torch
.
Tensor
,
deterministic
:
bool
=
False
,
):
self
.
mean
=
mean
self
.
logvar
=
torch
.
clamp
(
logvar
,
-
30.0
,
20.0
)
self
.
deterministic
=
deterministic
if
deterministic
:
self
.
var
=
self
.
std
=
torch
.
zeros_like
(
self
.
mean
)
else
:
self
.
std
=
torch
.
exp
(
0.5
*
self
.
logvar
)
self
.
var
=
torch
.
exp
(
self
.
logvar
)
def
sample
(
self
,
generator
=
None
)
->
torch
.
FloatTensor
:
x
=
torch
.
randn
(
self
.
mean
.
shape
,
generator
=
generator
,
device
=
self
.
mean
.
device
,
dtype
=
self
.
mean
.
dtype
,
)
return
self
.
mean
+
self
.
std
*
x
def
mode
(
self
):
return
self
.
mean
def
kl
(
self
,
other
:
Optional
[
"DiagonalGaussianDistribution"
]
=
None
)
->
torch
.
Tensor
:
dims
=
list
(
range
(
1
,
self
.
mean
.
ndim
))
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
else
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
dims
,
)
else
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
dims
,
)
def
nll
(
self
,
sample
:
torch
.
Tensor
)
->
torch
.
Tensor
:
dims
=
list
(
range
(
1
,
self
.
mean
.
ndim
))
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
,
)
@
dataclass
class
EncoderOutput
:
latent_dist
:
DiagonalGaussianDistribution
@
dataclass
class
DecoderOutput
:
sample
:
torch
.
Tensor
def
str_eval
(
item
):
if
type
(
item
)
==
str
:
return
eval
(
item
)
else
:
return
item
class
AutoencoderKLMagvit_fromOmnigen
(
pl
.
LightningModule
):
def
__init__
(
self
,
in_channels
:
int
=
3
,
out_channels
:
int
=
3
,
ch
=
128
,
ch_mult
=
[
1
,
2
,
4
,
4
],
use_gc_blocks
=
None
,
down_block_types
:
tuple
=
None
,
up_block_types
:
tuple
=
None
,
mid_block_type
:
str
=
"MidBlock3D"
,
mid_block_use_attention
:
bool
=
True
,
mid_block_attention_type
:
str
=
"3d"
,
mid_block_num_attention_heads
:
int
=
1
,
layers_per_block
:
int
=
2
,
act_fn
:
str
=
"silu"
,
num_attention_heads
:
int
=
1
,
latent_channels
:
int
=
4
,
norm_num_groups
:
int
=
32
,
image_key
=
"image"
,
monitor
=
None
,
ckpt_path
=
None
,
lossconfig
=
None
,
slice_mag_vae
=
False
,
slice_compression_vae
=
False
,
cache_compression_vae
=
False
,
spatial_group_norm
=
False
,
mini_batch_encoder
=
9
,
mini_batch_decoder
=
3
,
train_decoder_only
=
False
,
train_encoder_only
=
False
,
):
super
().
__init__
()
self
.
image_key
=
image_key
down_block_types
=
str_eval
(
down_block_types
)
up_block_types
=
str_eval
(
up_block_types
)
self
.
encoder
=
omnigen_Mag_Encoder
(
in_channels
=
in_channels
,
out_channels
=
latent_channels
,
down_block_types
=
down_block_types
,
ch
=
ch
,
ch_mult
=
ch_mult
,
use_gc_blocks
=
use_gc_blocks
,
mid_block_type
=
mid_block_type
,
mid_block_use_attention
=
mid_block_use_attention
,
mid_block_attention_type
=
mid_block_attention_type
,
mid_block_num_attention_heads
=
mid_block_num_attention_heads
,
layers_per_block
=
layers_per_block
,
norm_num_groups
=
norm_num_groups
,
act_fn
=
act_fn
,
num_attention_heads
=
num_attention_heads
,
double_z
=
True
,
slice_mag_vae
=
slice_mag_vae
,
slice_compression_vae
=
slice_compression_vae
,
cache_compression_vae
=
cache_compression_vae
,
spatial_group_norm
=
spatial_group_norm
,
mini_batch_encoder
=
mini_batch_encoder
,
)
self
.
decoder
=
omnigen_Mag_Decoder
(
in_channels
=
latent_channels
,
out_channels
=
out_channels
,
up_block_types
=
up_block_types
,
ch
=
ch
,
ch_mult
=
ch_mult
,
use_gc_blocks
=
use_gc_blocks
,
mid_block_type
=
mid_block_type
,
mid_block_use_attention
=
mid_block_use_attention
,
mid_block_attention_type
=
mid_block_attention_type
,
mid_block_num_attention_heads
=
mid_block_num_attention_heads
,
layers_per_block
=
layers_per_block
,
norm_num_groups
=
norm_num_groups
,
act_fn
=
act_fn
,
num_attention_heads
=
num_attention_heads
,
slice_mag_vae
=
slice_mag_vae
,
slice_compression_vae
=
slice_compression_vae
,
cache_compression_vae
=
cache_compression_vae
,
spatial_group_norm
=
spatial_group_norm
,
mini_batch_decoder
=
mini_batch_decoder
,
)
self
.
quant_conv
=
nn
.
Conv3d
(
2
*
latent_channels
,
2
*
latent_channels
,
kernel_size
=
1
)
self
.
post_quant_conv
=
nn
.
Conv3d
(
latent_channels
,
latent_channels
,
kernel_size
=
1
)
self
.
mini_batch_encoder
=
mini_batch_encoder
self
.
mini_batch_decoder
=
mini_batch_decoder
self
.
train_decoder_only
=
train_decoder_only
self
.
train_encoder_only
=
train_encoder_only
if
train_decoder_only
:
self
.
encoder
.
requires_grad_
(
False
)
self
.
quant_conv
.
requires_grad_
(
False
)
if
train_encoder_only
:
self
.
decoder
.
requires_grad_
(
False
)
self
.
post_quant_conv
.
requires_grad_
(
False
)
if
monitor
is
not
None
:
self
.
monitor
=
monitor
if
ckpt_path
is
not
None
:
self
.
init_from_ckpt
(
ckpt_path
,
ignore_keys
=
"loss"
)
if
lossconfig
is
not
None
:
self
.
loss
=
instantiate_from_config
(
lossconfig
)
def
init_from_ckpt
(
self
,
path
,
ignore_keys
=
list
()):
if
path
.
endswith
(
"safetensors"
):
from
safetensors.torch
import
load_file
,
safe_open
sd
=
load_file
(
path
)
else
:
sd
=
torch
.
load
(
path
,
map_location
=
"cpu"
)
if
"state_dict"
in
list
(
sd
.
keys
()):
sd
=
sd
[
"state_dict"
]
keys
=
list
(
sd
.
keys
())
for
k
in
keys
:
for
ik
in
ignore_keys
:
if
k
.
startswith
(
ik
):
print
(
"Deleting key {} from state_dict."
.
format
(
k
))
del
sd
[
k
]
self
.
load_state_dict
(
sd
,
strict
=
False
)
# loss.item can be ignored successfully
print
(
f
"Restored from
{
path
}
"
)
def
encode
(
self
,
x
:
torch
.
Tensor
)
->
EncoderOutput
:
h
=
self
.
encoder
(
x
)
moments
:
torch
.
Tensor
=
self
.
quant_conv
(
h
)
mean
,
logvar
=
moments
.
chunk
(
2
,
dim
=
1
)
posterior
=
DiagonalGaussianDistribution
(
mean
,
logvar
)
# return EncoderOutput(latent_dist=posterior)
return
posterior
def
decode
(
self
,
z
:
torch
.
Tensor
)
->
DecoderOutput
:
z
=
self
.
post_quant_conv
(
z
)
decoded
=
self
.
decoder
(
z
)
# return DecoderOutput(sample=decoded)
return
decoded
def
forward
(
self
,
input
,
sample_posterior
=
True
):
if
input
.
ndim
==
4
:
input
=
input
.
unsqueeze
(
2
)
posterior
=
self
.
encode
(
input
)
if
sample_posterior
:
z
=
posterior
.
sample
()
else
:
z
=
posterior
.
mode
()
# print("stt latent shape", z.shape)
dec
=
self
.
decode
(
z
)
return
dec
,
posterior
def
get_input
(
self
,
batch
,
k
):
x
=
batch
[
k
]
if
x
.
ndim
==
5
:
x
=
x
.
permute
(
0
,
4
,
1
,
2
,
3
).
to
(
memory_format
=
torch
.
contiguous_format
).
float
()
return
x
if
len
(
x
.
shape
)
==
3
:
x
=
x
[...,
None
]
x
=
x
.
permute
(
0
,
3
,
1
,
2
).
to
(
memory_format
=
torch
.
contiguous_format
).
float
()
return
x
def
training_step
(
self
,
batch
,
batch_idx
,
optimizer_idx
):
# tic = time.time()
inputs
=
self
.
get_input
(
batch
,
self
.
image_key
)
# print(f"get_input time {time.time() - tic}")
# tic = time.time()
reconstructions
,
posterior
=
self
(
inputs
)
# print(f"model forward time {time.time() - tic}")
if
optimizer_idx
==
0
:
# train encoder+decoder+logvar
aeloss
,
log_dict_ae
=
self
.
loss
(
inputs
,
reconstructions
,
posterior
,
optimizer_idx
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"train"
)
self
.
log
(
"aeloss"
,
aeloss
,
prog_bar
=
True
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
True
)
self
.
log_dict
(
log_dict_ae
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
False
)
# print(f"cal loss time {time.time() - tic}")
return
aeloss
if
optimizer_idx
==
1
:
# train the discriminator
discloss
,
log_dict_disc
=
self
.
loss
(
inputs
,
reconstructions
,
posterior
,
optimizer_idx
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"train"
)
self
.
log
(
"discloss"
,
discloss
,
prog_bar
=
True
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
True
)
self
.
log_dict
(
log_dict_disc
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
False
)
# print(f"cal loss time {time.time() - tic}")
return
discloss
def
validation_step
(
self
,
batch
,
batch_idx
):
with
torch
.
no_grad
():
inputs
=
self
.
get_input
(
batch
,
self
.
image_key
)
reconstructions
,
posterior
=
self
(
inputs
)
aeloss
,
log_dict_ae
=
self
.
loss
(
inputs
,
reconstructions
,
posterior
,
0
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"val"
)
discloss
,
log_dict_disc
=
self
.
loss
(
inputs
,
reconstructions
,
posterior
,
1
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"val"
)
self
.
log
(
"val/rec_loss"
,
log_dict_ae
[
"val/rec_loss"
])
self
.
log_dict
(
log_dict_ae
)
self
.
log_dict
(
log_dict_disc
)
return
self
.
log_dict
def
configure_optimizers
(
self
):
lr
=
self
.
learning_rate
if
self
.
train_decoder_only
:
opt_ae
=
torch
.
optim
.
Adam
(
list
(
self
.
decoder
.
parameters
())
+
list
(
self
.
post_quant_conv
.
parameters
()),
lr
=
lr
,
betas
=
(
0.5
,
0.9
))
elif
self
.
train_encoder_only
:
opt_ae
=
torch
.
optim
.
Adam
(
list
(
self
.
encoder
.
parameters
())
+
list
(
self
.
quant_conv
.
parameters
()),
lr
=
lr
,
betas
=
(
0.5
,
0.9
))
else
:
opt_ae
=
torch
.
optim
.
Adam
(
list
(
self
.
encoder
.
parameters
())
+
list
(
self
.
decoder
.
parameters
())
+
list
(
self
.
quant_conv
.
parameters
())
+
list
(
self
.
post_quant_conv
.
parameters
()),
lr
=
lr
,
betas
=
(
0.5
,
0.9
))
opt_disc
=
torch
.
optim
.
Adam
(
list
(
self
.
loss
.
discriminator3d
.
parameters
())
+
list
(
self
.
loss
.
discriminator
.
parameters
()),
lr
=
lr
,
betas
=
(
0.5
,
0.9
))
return
[
opt_ae
,
opt_disc
],
[]
def
get_last_layer
(
self
):
return
self
.
decoder
.
conv_out
.
weight
@
torch
.
no_grad
()
def
log_images
(
self
,
batch
,
only_inputs
=
False
,
**
kwargs
):
log
=
dict
()
x
=
self
.
get_input
(
batch
,
self
.
image_key
)
x
=
x
.
to
(
self
.
device
)
if
not
only_inputs
:
xrec
,
posterior
=
self
(
x
)
if
x
.
shape
[
1
]
>
3
:
# colorize with random projection
assert
xrec
.
shape
[
1
]
>
3
x
=
self
.
to_rgb
(
x
)
xrec
=
self
.
to_rgb
(
xrec
)
log
[
"samples"
]
=
self
.
decode
(
torch
.
randn_like
(
posterior
.
sample
()))
log
[
"reconstructions"
]
=
xrec
log
[
"inputs"
]
=
x
return
log
def
to_rgb
(
self
,
x
):
assert
self
.
image_key
==
"segmentation"
if
not
hasattr
(
self
,
"colorize"
):
self
.
register_buffer
(
"colorize"
,
torch
.
randn
(
3
,
x
.
shape
[
1
],
1
,
1
).
to
(
x
))
x
=
F
.
conv2d
(
x
,
weight
=
self
.
colorize
)
x
=
2.
*
(
x
-
x
.
min
())
/
(
x
.
max
()
-
x
.
min
())
-
1.
return
x
Ruyi-Models/ruyi/vae/ldm/models/omnigen_enc_dec.py
0 → 100644
View file @
08a21d59
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
from
einops
import
rearrange
,
repeat
from
diffusers.utils
import
(
USE_PEFT_BACKEND
,
BaseOutput
,
is_torch_version
,
logging
)
from
..modules.vaemodules.activations
import
get_activation
from
..modules.vaemodules.common
import
CausalConv3d
from
..modules.vaemodules.down_blocks
import
get_down_block
from
..modules.vaemodules.mid_blocks
import
get_mid_block
from
..modules.vaemodules.up_blocks
import
get_up_block
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
if
return_dict
is
not
None
:
return
module
(
*
inputs
,
return_dict
=
return_dict
)
else
:
return
module
(
*
inputs
)
return
custom_forward
class
Encoder
(
nn
.
Module
):
r
"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 8):
The number of output channels.
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("SpatialDownBlock3D",)`):
The types of down blocks to use.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
use_gc_blocks (`Tuple[bool, ...]`, *optional*, defaults to `None`):
Whether to use global context blocks for each down block.
mid_block_type (`str`, *optional*, defaults to `"MidBlock3D"`):
The type of mid block to use.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
num_attention_heads (`int`, *optional*, defaults to 1):
The number of attention heads to use.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
"""
def
__init__
(
self
,
in_channels
:
int
=
3
,
out_channels
:
int
=
8
,
down_block_types
=
(
"SpatialDownBlock3D"
,),
ch
=
128
,
ch_mult
=
[
1
,
2
,
4
,
4
,],
use_gc_blocks
=
None
,
mid_block_type
:
str
=
"MidBlock3D"
,
mid_block_use_attention
:
bool
=
True
,
mid_block_attention_type
:
str
=
"3d"
,
mid_block_num_attention_heads
:
int
=
1
,
layers_per_block
:
int
=
2
,
norm_num_groups
:
int
=
32
,
act_fn
:
str
=
"silu"
,
num_attention_heads
:
int
=
1
,
double_z
:
bool
=
True
,
slice_mag_vae
:
bool
=
False
,
slice_compression_vae
:
bool
=
False
,
cache_compression_vae
:
bool
=
False
,
spatial_group_norm
:
bool
=
False
,
mini_batch_encoder
:
int
=
9
,
verbose
=
False
,
):
super
().
__init__
()
block_out_channels
=
[
ch
*
i
for
i
in
ch_mult
]
assert
len
(
down_block_types
)
==
len
(
block_out_channels
),
(
"Number of down block types must match number of block output channels."
)
if
use_gc_blocks
is
not
None
:
assert
len
(
use_gc_blocks
)
==
len
(
down_block_types
),
(
"Number of GC blocks must match number of down block types."
)
else
:
use_gc_blocks
=
[
False
]
*
len
(
down_block_types
)
self
.
conv_in
=
CausalConv3d
(
in_channels
,
block_out_channels
[
0
],
kernel_size
=
3
,
)
self
.
down_blocks
=
nn
.
ModuleList
([])
output_channels
=
block_out_channels
[
0
]
for
i
,
down_block_type
in
enumerate
(
down_block_types
):
input_channels
=
output_channels
output_channels
=
block_out_channels
[
i
]
is_final_block
=
(
i
==
len
(
block_out_channels
)
-
1
)
down_block
=
get_down_block
(
down_block_type
,
in_channels
=
input_channels
,
out_channels
=
output_channels
,
num_layers
=
layers_per_block
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
1e-6
,
num_attention_heads
=
num_attention_heads
,
add_gc_block
=
use_gc_blocks
[
i
],
add_downsample
=
not
is_final_block
,
)
self
.
down_blocks
.
append
(
down_block
)
self
.
mid_block
=
get_mid_block
(
mid_block_type
,
in_channels
=
block_out_channels
[
-
1
],
num_layers
=
layers_per_block
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
1e-6
,
add_attention
=
mid_block_use_attention
,
attention_type
=
mid_block_attention_type
,
num_attention_heads
=
mid_block_num_attention_heads
,
)
self
.
conv_norm_out
=
nn
.
GroupNorm
(
num_channels
=
block_out_channels
[
-
1
],
num_groups
=
norm_num_groups
,
eps
=
1e-6
,
)
self
.
conv_act
=
get_activation
(
act_fn
)
conv_out_channels
=
2
*
out_channels
if
double_z
else
out_channels
self
.
conv_out
=
CausalConv3d
(
block_out_channels
[
-
1
],
conv_out_channels
,
kernel_size
=
3
)
self
.
slice_mag_vae
=
slice_mag_vae
self
.
slice_compression_vae
=
slice_compression_vae
self
.
cache_compression_vae
=
cache_compression_vae
self
.
mini_batch_encoder
=
mini_batch_encoder
self
.
spatial_group_norm
=
spatial_group_norm
self
.
verbose
=
verbose
def
set_padding_one_frame
(
self
):
def
_set_padding_one_frame
(
name
,
module
):
if
hasattr
(
module
,
'padding_flag'
):
if
self
.
verbose
:
print
(
'Set pad mode for module[%s] type=%s'
%
(
name
,
str
(
type
(
module
))))
module
.
padding_flag
=
1
for
sub_name
,
sub_mod
in
module
.
named_children
():
_set_padding_one_frame
(
sub_name
,
sub_mod
)
for
name
,
module
in
self
.
named_children
():
_set_padding_one_frame
(
name
,
module
)
def
set_padding_more_frame
(
self
):
def
_set_padding_more_frame
(
name
,
module
):
if
hasattr
(
module
,
'padding_flag'
):
if
self
.
verbose
:
print
(
'Set pad mode for module[%s] type=%s'
%
(
name
,
str
(
type
(
module
))))
module
.
padding_flag
=
2
for
sub_name
,
sub_mod
in
module
.
named_children
():
_set_padding_more_frame
(
sub_name
,
sub_mod
)
for
name
,
module
in
self
.
named_children
():
_set_padding_more_frame
(
name
,
module
)
def
set_cache_slice_vae_padding_one_frame
(
self
):
def
_set_cache_slice_vae_padding_one_frame
(
name
,
module
):
if
hasattr
(
module
,
'padding_flag'
):
if
self
.
verbose
:
print
(
'Set pad mode for module[%s] type=%s'
%
(
name
,
str
(
type
(
module
))))
module
.
padding_flag
=
5
for
sub_name
,
sub_mod
in
module
.
named_children
():
_set_cache_slice_vae_padding_one_frame
(
sub_name
,
sub_mod
)
for
name
,
module
in
self
.
named_children
():
_set_cache_slice_vae_padding_one_frame
(
name
,
module
)
def
set_cache_slice_vae_padding_more_frame
(
self
):
def
_set_cache_slice_vae_padding_more_frame
(
name
,
module
):
if
hasattr
(
module
,
'padding_flag'
):
if
self
.
verbose
:
print
(
'Set pad mode for module[%s] type=%s'
%
(
name
,
str
(
type
(
module
))))
module
.
padding_flag
=
6
for
sub_name
,
sub_mod
in
module
.
named_children
():
_set_cache_slice_vae_padding_more_frame
(
sub_name
,
sub_mod
)
for
name
,
module
in
self
.
named_children
():
_set_cache_slice_vae_padding_more_frame
(
name
,
module
)
def
set_3dgroupnorm_for_submodule
(
self
):
def
_set_3dgroupnorm_for_submodule
(
name
,
module
):
if
hasattr
(
module
,
'set_3dgroupnorm'
):
if
self
.
verbose
:
print
(
'Set pad mode for module[%s] type=%s'
%
(
name
,
str
(
type
(
module
))))
module
.
set_3dgroupnorm
=
True
for
sub_name
,
sub_mod
in
module
.
named_children
():
_set_3dgroupnorm_for_submodule
(
sub_name
,
sub_mod
)
for
name
,
module
in
self
.
named_children
():
_set_3dgroupnorm_for_submodule
(
name
,
module
)
def
single_forward
(
self
,
x
:
torch
.
Tensor
,
previous_features
:
torch
.
Tensor
,
after_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# x: (B, C, T, H, W)
if
previous_features
is
not
None
and
after_features
is
None
:
x
=
torch
.
concat
([
previous_features
,
x
],
2
)
elif
previous_features
is
None
and
after_features
is
not
None
:
x
=
torch
.
concat
([
x
,
after_features
],
2
)
elif
previous_features
is
not
None
and
after_features
is
not
None
:
x
=
torch
.
concat
([
previous_features
,
x
,
after_features
],
2
)
x
=
self
.
conv_in
(
x
)
ckpt_kwargs
:
Dict
[
str
,
Any
]
=
{
"use_reentrant"
:
False
}
if
is_torch_version
(
">="
,
"1.11.0"
)
else
{}
for
down_block
in
self
.
down_blocks
:
x
=
down_block
(
x
)
x
=
self
.
mid_block
(
x
)
if
self
.
spatial_group_norm
:
batch_size
=
x
.
shape
[
0
]
x
=
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
x
=
self
.
conv_norm_out
(
x
)
x
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
b
=
batch_size
)
else
:
x
=
self
.
conv_norm_out
(
x
)
x
=
self
.
conv_act
(
x
)
x
=
self
.
conv_out
(
x
)
if
previous_features
is
not
None
and
after_features
is
None
:
x
=
x
[:,
:,
1
:]
elif
previous_features
is
None
and
after_features
is
not
None
:
x
=
x
[:,
:,
:
2
]
elif
previous_features
is
not
None
and
after_features
is
not
None
:
x
=
x
[:,
:,
1
:
3
]
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
spatial_group_norm
:
self
.
set_3dgroupnorm_for_submodule
()
if
self
.
cache_compression_vae
:
_
,
_
,
f
,
_
,
_
=
x
.
size
()
if
f
%
2
!=
0
:
self
.
set_padding_one_frame
()
first_frames
=
self
.
single_forward
(
x
[:,
:,
0
:
1
,
:,
:],
None
,
None
)
self
.
set_padding_more_frame
()
new_pixel_values
=
[
first_frames
]
start_index
=
1
else
:
self
.
set_padding_more_frame
()
new_pixel_values
=
[]
start_index
=
0
for
i
in
range
(
start_index
,
x
.
shape
[
2
],
self
.
mini_batch_encoder
):
next_frames
=
self
.
single_forward
(
x
[:,
:,
i
:
i
+
self
.
mini_batch_encoder
,
:,
:],
None
,
None
)
new_pixel_values
.
append
(
next_frames
)
new_pixel_values
=
torch
.
cat
(
new_pixel_values
,
dim
=
2
)
elif
self
.
slice_compression_vae
:
_
,
_
,
f
,
_
,
_
=
x
.
size
()
if
f
%
2
!=
0
:
self
.
set_padding_one_frame
()
first_frames
=
self
.
single_forward
(
x
[:,
:,
0
:
1
,
:,
:],
None
,
None
)
self
.
set_padding_more_frame
()
new_pixel_values
=
[
first_frames
]
start_index
=
1
else
:
self
.
set_padding_more_frame
()
new_pixel_values
=
[]
start_index
=
0
for
i
in
range
(
start_index
,
x
.
shape
[
2
],
self
.
mini_batch_encoder
):
next_frames
=
self
.
single_forward
(
x
[:,
:,
i
:
i
+
self
.
mini_batch_encoder
,
:,
:],
None
,
None
)
new_pixel_values
.
append
(
next_frames
)
new_pixel_values
=
torch
.
cat
(
new_pixel_values
,
dim
=
2
)
elif
self
.
slice_mag_vae
:
_
,
_
,
f
,
_
,
_
=
x
.
size
()
new_pixel_values
=
[]
for
i
in
range
(
0
,
x
.
shape
[
2
],
self
.
mini_batch_encoder
):
next_frames
=
self
.
single_forward
(
x
[:,
:,
i
:
i
+
self
.
mini_batch_encoder
,
:,
:],
None
,
None
)
new_pixel_values
.
append
(
next_frames
)
new_pixel_values
=
torch
.
cat
(
new_pixel_values
,
dim
=
2
)
else
:
new_pixel_values
=
self
.
single_forward
(
x
,
None
,
None
)
return
new_pixel_values
class
Decoder
(
nn
.
Module
):
r
"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
in_channels (`int`, *optional*, defaults to 8):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("SpatialUpBlock3D",)`):
The types of up blocks to use.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
use_gc_blocks (`Tuple[bool, ...]`, *optional*, defaults to `None`):
Whether to use global context blocks for each down block.
mid_block_type (`str`, *optional*, defaults to `"MidBlock3D"`):
The type of mid block to use.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
num_attention_heads (`int`, *optional*, defaults to 1):
The number of attention heads to use.
"""
def
__init__
(
self
,
in_channels
:
int
=
8
,
out_channels
:
int
=
3
,
up_block_types
=
(
"SpatialUpBlock3D"
,),
ch
=
128
,
ch_mult
=
[
1
,
2
,
4
,
4
,],
use_gc_blocks
=
None
,
mid_block_type
:
str
=
"MidBlock3D"
,
mid_block_use_attention
:
bool
=
True
,
mid_block_attention_type
:
str
=
"3d"
,
mid_block_num_attention_heads
:
int
=
1
,
layers_per_block
:
int
=
2
,
norm_num_groups
:
int
=
32
,
act_fn
:
str
=
"silu"
,
num_attention_heads
:
int
=
1
,
slice_mag_vae
:
bool
=
False
,
slice_compression_vae
:
bool
=
False
,
cache_compression_vae
:
bool
=
False
,
spatial_group_norm
:
bool
=
False
,
mini_batch_decoder
:
int
=
3
,
verbose
=
False
,
):
super
().
__init__
()
block_out_channels
=
[
ch
*
i
for
i
in
ch_mult
]
assert
len
(
up_block_types
)
==
len
(
block_out_channels
),
(
"Number of up block types must match number of block output channels."
)
if
use_gc_blocks
is
not
None
:
assert
len
(
use_gc_blocks
)
==
len
(
up_block_types
),
(
"Number of GC blocks must match number of up block types."
)
else
:
use_gc_blocks
=
[
False
]
*
len
(
up_block_types
)
self
.
conv_in
=
CausalConv3d
(
in_channels
,
block_out_channels
[
-
1
],
kernel_size
=
3
,
)
self
.
mid_block
=
get_mid_block
(
mid_block_type
,
in_channels
=
block_out_channels
[
-
1
],
num_layers
=
layers_per_block
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
1e-6
,
add_attention
=
mid_block_use_attention
,
attention_type
=
mid_block_attention_type
,
num_attention_heads
=
mid_block_num_attention_heads
,
)
self
.
up_blocks
=
nn
.
ModuleList
([])
reversed_block_out_channels
=
list
(
reversed
(
block_out_channels
))
output_channels
=
reversed_block_out_channels
[
0
]
for
i
,
up_block_type
in
enumerate
(
up_block_types
):
input_channels
=
output_channels
output_channels
=
reversed_block_out_channels
[
i
]
# is_first_block = i == 0
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
up_block
=
get_up_block
(
up_block_type
,
in_channels
=
input_channels
,
out_channels
=
output_channels
,
num_layers
=
layers_per_block
+
1
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
1e-6
,
num_attention_heads
=
num_attention_heads
,
add_gc_block
=
use_gc_blocks
[
i
],
add_upsample
=
not
is_final_block
,
)
self
.
up_blocks
.
append
(
up_block
)
self
.
conv_norm_out
=
nn
.
GroupNorm
(
num_channels
=
block_out_channels
[
0
],
num_groups
=
norm_num_groups
,
eps
=
1e-6
,
)
self
.
conv_act
=
get_activation
(
act_fn
)
self
.
conv_out
=
CausalConv3d
(
block_out_channels
[
0
],
out_channels
,
kernel_size
=
3
)
self
.
slice_mag_vae
=
slice_mag_vae
self
.
slice_compression_vae
=
slice_compression_vae
self
.
cache_compression_vae
=
cache_compression_vae
self
.
mini_batch_decoder
=
mini_batch_decoder
self
.
spatial_group_norm
=
spatial_group_norm
self
.
verbose
=
verbose
def
set_padding_one_frame
(
self
):
def
_set_padding_one_frame
(
name
,
module
):
if
hasattr
(
module
,
'padding_flag'
):
if
self
.
verbose
:
print
(
'Set pad mode for module[%s] type=%s'
%
(
name
,
str
(
type
(
module
))))
module
.
padding_flag
=
1
for
sub_name
,
sub_mod
in
module
.
named_children
():
_set_padding_one_frame
(
sub_name
,
sub_mod
)
for
name
,
module
in
self
.
named_children
():
_set_padding_one_frame
(
name
,
module
)
def
set_padding_more_frame
(
self
):
def
_set_padding_more_frame
(
name
,
module
):
if
hasattr
(
module
,
'padding_flag'
):
if
self
.
verbose
:
print
(
'Set pad mode for module[%s] type=%s'
%
(
name
,
str
(
type
(
module
))))
module
.
padding_flag
=
2
for
sub_name
,
sub_mod
in
module
.
named_children
():
_set_padding_more_frame
(
sub_name
,
sub_mod
)
for
name
,
module
in
self
.
named_children
():
_set_padding_more_frame
(
name
,
module
)
def
set_cache_slice_vae_padding_one_frame
(
self
):
def
_set_cache_slice_vae_padding_one_frame
(
name
,
module
):
if
hasattr
(
module
,
'padding_flag'
):
if
self
.
verbose
:
print
(
'Set pad mode for module[%s] type=%s'
%
(
name
,
str
(
type
(
module
))))
module
.
padding_flag
=
5
for
sub_name
,
sub_mod
in
module
.
named_children
():
_set_cache_slice_vae_padding_one_frame
(
sub_name
,
sub_mod
)
for
name
,
module
in
self
.
named_children
():
_set_cache_slice_vae_padding_one_frame
(
name
,
module
)
def
set_cache_slice_vae_padding_more_frame
(
self
):
def
_set_cache_slice_vae_padding_more_frame
(
name
,
module
):
if
hasattr
(
module
,
'padding_flag'
):
if
self
.
verbose
:
print
(
'Set pad mode for module[%s] type=%s'
%
(
name
,
str
(
type
(
module
))))
module
.
padding_flag
=
6
for
sub_name
,
sub_mod
in
module
.
named_children
():
_set_cache_slice_vae_padding_more_frame
(
sub_name
,
sub_mod
)
for
name
,
module
in
self
.
named_children
():
_set_cache_slice_vae_padding_more_frame
(
name
,
module
)
def
set_3dgroupnorm_for_submodule
(
self
):
def
_set_3dgroupnorm_for_submodule
(
name
,
module
):
if
hasattr
(
module
,
'set_3dgroupnorm'
):
if
self
.
verbose
:
print
(
'Set pad mode for module[%s] type=%s'
%
(
name
,
str
(
type
(
module
))))
module
.
set_3dgroupnorm
=
True
for
sub_name
,
sub_mod
in
module
.
named_children
():
_set_3dgroupnorm_for_submodule
(
sub_name
,
sub_mod
)
for
name
,
module
in
self
.
named_children
():
_set_3dgroupnorm_for_submodule
(
name
,
module
)
def
clear_cache
(
self
):
def
_clear_cache
(
name
,
module
):
if
hasattr
(
module
,
'prev_features'
):
if
self
.
verbose
:
print
(
'Set pad mode for module[%s] type=%s'
%
(
name
,
str
(
type
(
module
))))
module
.
prev_features
=
None
for
sub_name
,
sub_mod
in
module
.
named_children
():
_clear_cache
(
sub_name
,
sub_mod
)
for
name
,
module
in
self
.
named_children
():
_clear_cache
(
name
,
module
)
def
single_forward
(
self
,
x
:
torch
.
Tensor
,
previous_features
:
torch
.
Tensor
,
after_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# x: (B, C, T, H, W)
ckpt_kwargs
:
Dict
[
str
,
Any
]
=
{
"use_reentrant"
:
False
}
if
is_torch_version
(
">="
,
"1.11.0"
)
else
{}
if
previous_features
is
not
None
and
after_features
is
None
:
b
,
c
,
t
,
h
,
w
=
x
.
size
()
x
=
torch
.
concat
([
previous_features
,
x
],
2
)
x
=
self
.
conv_in
(
x
)
x
=
self
.
mid_block
(
x
)
x
=
x
[:,
:,
-
t
:]
elif
previous_features
is
None
and
after_features
is
not
None
:
b
,
c
,
t
,
h
,
w
=
x
.
size
()
x
=
torch
.
concat
([
x
,
after_features
],
2
)
x
=
self
.
conv_in
(
x
)
x
=
self
.
mid_block
(
x
)
x
=
x
[:,
:,
:
t
]
elif
previous_features
is
not
None
and
after_features
is
not
None
:
_
,
_
,
t_1
,
_
,
_
=
previous_features
.
size
()
_
,
_
,
t_2
,
_
,
_
=
x
.
size
()
x
=
torch
.
concat
([
previous_features
,
x
,
after_features
],
2
)
x
=
self
.
conv_in
(
x
)
x
=
self
.
mid_block
(
x
)
x
=
x
[:,
:,
t_1
:(
t_1
+
t_2
)]
else
:
x
=
self
.
conv_in
(
x
)
x
=
self
.
mid_block
(
x
)
for
up_block
in
self
.
up_blocks
:
x
=
up_block
(
x
)
if
self
.
spatial_group_norm
:
batch_size
=
x
.
shape
[
0
]
x
=
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
x
=
self
.
conv_norm_out
(
x
)
x
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
b
=
batch_size
)
else
:
x
=
self
.
conv_norm_out
(
x
)
x
=
self
.
conv_act
(
x
)
x
=
self
.
conv_out
(
x
)
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
spatial_group_norm
:
self
.
set_3dgroupnorm_for_submodule
()
if
self
.
cache_compression_vae
:
_
,
_
,
f
,
_
,
_
=
x
.
size
()
if
f
==
1
:
self
.
set_padding_one_frame
()
first_frames
=
self
.
single_forward
(
x
[:,
:,
:
1
,
:,
:],
None
,
None
)
new_pixel_values
=
[
first_frames
]
start_index
=
1
else
:
self
.
set_cache_slice_vae_padding_one_frame
()
first_frames
=
self
.
single_forward
(
x
[:,
:,
:
self
.
mini_batch_decoder
,
:,
:],
None
,
None
)
new_pixel_values
=
[
first_frames
]
start_index
=
self
.
mini_batch_decoder
for
i
in
range
(
start_index
,
x
.
shape
[
2
],
self
.
mini_batch_decoder
):
self
.
set_cache_slice_vae_padding_more_frame
()
next_frames
=
self
.
single_forward
(
x
[:,
:,
i
:
i
+
self
.
mini_batch_decoder
,
:,
:],
None
,
None
)
new_pixel_values
.
append
(
next_frames
)
new_pixel_values
=
torch
.
cat
(
new_pixel_values
,
dim
=
2
)
elif
self
.
slice_compression_vae
:
_
,
_
,
f
,
_
,
_
=
x
.
size
()
if
f
%
2
!=
0
:
self
.
set_padding_one_frame
()
first_frames
=
self
.
single_forward
(
x
[:,
:,
0
:
1
,
:,
:],
None
,
None
)
self
.
set_padding_more_frame
()
new_pixel_values
=
[
first_frames
]
start_index
=
1
else
:
self
.
set_padding_more_frame
()
new_pixel_values
=
[]
start_index
=
0
previous_features
=
None
for
i
in
range
(
start_index
,
x
.
shape
[
2
],
self
.
mini_batch_decoder
):
after_features
=
x
[:,
:,
i
+
self
.
mini_batch_decoder
:
i
+
2
*
self
.
mini_batch_decoder
,
:,
:]
if
i
+
self
.
mini_batch_decoder
<
x
.
shape
[
2
]
else
None
next_frames
=
self
.
single_forward
(
x
[:,
:,
i
:
i
+
self
.
mini_batch_decoder
,
:,
:],
previous_features
,
after_features
)
previous_features
=
x
[:,
:,
i
:
i
+
self
.
mini_batch_decoder
,
:,
:]
new_pixel_values
.
append
(
next_frames
)
new_pixel_values
=
torch
.
cat
(
new_pixel_values
,
dim
=
2
)
elif
self
.
slice_mag_vae
:
_
,
_
,
f
,
_
,
_
=
x
.
size
()
new_pixel_values
=
[]
for
i
in
range
(
0
,
x
.
shape
[
2
],
self
.
mini_batch_decoder
):
next_frames
=
self
.
single_forward
(
x
[:,
:,
i
:
i
+
self
.
mini_batch_decoder
,
:,
:],
None
,
None
)
new_pixel_values
.
append
(
next_frames
)
new_pixel_values
=
torch
.
cat
(
new_pixel_values
,
dim
=
2
)
else
:
new_pixel_values
=
self
.
single_forward
(
x
,
None
,
None
)
return
new_pixel_values
Ruyi-Models/ruyi/vae/ldm/modules/diffusionmodules/__init__.py
0 → 100644
View file @
08a21d59
Ruyi-Models/ruyi/vae/ldm/modules/diffusionmodules/model.py
0 → 100644
View file @
08a21d59
# pytorch_diffusion + derived encoder decoder
import
math
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
from
...util
import
instantiate_from_config
from
diffusers.models.autoencoders.vae
import
DiagonalGaussianDistribution
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
,
num_groups
=
32
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
class
LinearAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
().
__init__
()
self
.
heads
=
heads
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
q
,
k
,
v
=
rearrange
(
qkv
,
'b (qkv heads c) h w -> qkv b heads c (h w)'
,
heads
=
self
.
heads
,
qkv
=
3
)
k
=
k
.
softmax
(
dim
=-
1
)
context
=
torch
.
einsum
(
'bhdn,bhen->bhde'
,
k
,
v
)
out
=
torch
.
einsum
(
'bhde,bhdn->bhen'
,
context
,
q
)
out
=
rearrange
(
out
,
'b heads c (h w) -> b (heads c) h w'
,
heads
=
self
.
heads
,
h
=
h
,
w
=
w
)
return
self
.
to_out
(
out
)
class
LinAttnBlock
(
LinearAttention
):
"""to match AttnBlock usage"""
def
__init__
(
self
,
in_channels
):
super
().
__init__
(
dim
=
in_channels
,
heads
=
1
,
dim_head
=
in_channels
)
def
make_attn
(
in_channels
,
attn_type
=
"vanilla"
):
assert
attn_type
in
[
"vanilla"
,
"linear"
,
"none"
],
f
'attn_type
{
attn_type
}
unknown'
print
(
f
"making attention of type '
{
attn_type
}
' with
{
in_channels
}
in_channels"
)
if
attn_type
==
"vanilla"
:
return
AttnBlock
(
in_channels
)
elif
attn_type
==
"none"
:
return
nn
.
Identity
(
in_channels
)
else
:
return
LinAttnBlock
(
in_channels
)
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
double_z
=
True
,
use_linear_attn
=
False
,
attn_type
=
"vanilla"
,
**
ignore_kwargs
):
super
().
__init__
()
if
use_linear_attn
:
attn_type
=
"linear"
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
in_ch_mult
=
in_ch_mult
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
))
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
make_attn
(
block_in
,
attn_type
=
attn_type
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
make_attn
(
block_in
,
attn_type
=
attn_type
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
2
*
z_channels
if
double_z
else
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
# timestep embedding
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
give_pre_end
=
False
,
tanh_out
=
False
,
use_linear_attn
=
False
,
attn_type
=
"vanilla"
,
**
ignorekwargs
):
super
().
__init__
()
if
use_linear_attn
:
attn_type
=
"linear"
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
give_pre_end
=
give_pre_end
self
.
tanh_out
=
tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
print
(
"Working with z of shape {} = {} dimensions."
.
format
(
self
.
z_shape
,
np
.
prod
(
self
.
z_shape
)))
# z to block_in
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
make_attn
(
block_in
,
attn_type
=
attn_type
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
))
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
make_attn
(
block_in
,
attn_type
=
attn_type
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
z
):
#assert z.shape[1:] == self.z_shape[1:]
self
.
last_z_shape
=
z
.
shape
# timestep embedding
temb
=
None
# z to block_in
h
=
self
.
conv_in
(
z
)
# middle
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
if
self
.
give_pre_end
:
return
h
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
if
self
.
tanh_out
:
h
=
torch
.
tanh
(
h
)
return
h
class
SimpleDecoder
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
*
args
,
**
kwargs
):
super
().
__init__
()
self
.
model
=
nn
.
ModuleList
([
nn
.
Conv2d
(
in_channels
,
in_channels
,
1
),
ResnetBlock
(
in_channels
=
in_channels
,
out_channels
=
2
*
in_channels
,
temb_channels
=
0
,
dropout
=
0.0
),
ResnetBlock
(
in_channels
=
2
*
in_channels
,
out_channels
=
4
*
in_channels
,
temb_channels
=
0
,
dropout
=
0.0
),
ResnetBlock
(
in_channels
=
4
*
in_channels
,
out_channels
=
2
*
in_channels
,
temb_channels
=
0
,
dropout
=
0.0
),
nn
.
Conv2d
(
2
*
in_channels
,
in_channels
,
1
),
Upsample
(
in_channels
,
with_conv
=
True
)])
# end
self
.
norm_out
=
Normalize
(
in_channels
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
for
i
,
layer
in
enumerate
(
self
.
model
):
if
i
in
[
1
,
2
,
3
]:
x
=
layer
(
x
,
None
)
else
:
x
=
layer
(
x
)
h
=
self
.
norm_out
(
x
)
h
=
nonlinearity
(
h
)
x
=
self
.
conv_out
(
h
)
return
x
class
UpsampleDecoder
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
ch
,
num_res_blocks
,
resolution
,
ch_mult
=
(
2
,
2
),
dropout
=
0.0
):
super
().
__init__
()
# upsampling
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
block_in
=
in_channels
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
res_blocks
=
nn
.
ModuleList
()
self
.
upsample_blocks
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
res_block
=
[]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
res_block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
))
block_in
=
block_out
self
.
res_blocks
.
append
(
nn
.
ModuleList
(
res_block
))
if
i_level
!=
self
.
num_resolutions
-
1
:
self
.
upsample_blocks
.
append
(
Upsample
(
block_in
,
True
))
curr_res
=
curr_res
*
2
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
# upsampling
h
=
x
for
k
,
i_level
in
enumerate
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
res_blocks
[
i_level
][
i_block
](
h
,
None
)
if
i_level
!=
self
.
num_resolutions
-
1
:
h
=
self
.
upsample_blocks
[
k
](
h
)
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
LatentRescaler
(
nn
.
Module
):
def
__init__
(
self
,
factor
,
in_channels
,
mid_channels
,
out_channels
,
depth
=
2
):
super
().
__init__
()
# residual block, interpolate, residual block
self
.
factor
=
factor
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
mid_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
res_block1
=
nn
.
ModuleList
([
ResnetBlock
(
in_channels
=
mid_channels
,
out_channels
=
mid_channels
,
temb_channels
=
0
,
dropout
=
0.0
)
for
_
in
range
(
depth
)])
self
.
attn
=
AttnBlock
(
mid_channels
)
self
.
res_block2
=
nn
.
ModuleList
([
ResnetBlock
(
in_channels
=
mid_channels
,
out_channels
=
mid_channels
,
temb_channels
=
0
,
dropout
=
0.0
)
for
_
in
range
(
depth
)])
self
.
conv_out
=
nn
.
Conv2d
(
mid_channels
,
out_channels
,
kernel_size
=
1
,
)
def
forward
(
self
,
x
):
x
=
self
.
conv_in
(
x
)
for
block
in
self
.
res_block1
:
x
=
block
(
x
,
None
)
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
size
=
(
int
(
round
(
x
.
shape
[
2
]
*
self
.
factor
)),
int
(
round
(
x
.
shape
[
3
]
*
self
.
factor
))))
x
=
self
.
attn
(
x
)
for
block
in
self
.
res_block2
:
x
=
block
(
x
,
None
)
x
=
self
.
conv_out
(
x
)
return
x
class
MergedRescaleEncoder
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
ch
,
resolution
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
ch_mult
=
(
1
,
2
,
4
,
8
),
rescale_factor
=
1.0
,
rescale_module_depth
=
1
):
super
().
__init__
()
intermediate_chn
=
ch
*
ch_mult
[
-
1
]
self
.
encoder
=
Encoder
(
in_channels
=
in_channels
,
num_res_blocks
=
num_res_blocks
,
ch
=
ch
,
ch_mult
=
ch_mult
,
z_channels
=
intermediate_chn
,
double_z
=
False
,
resolution
=
resolution
,
attn_resolutions
=
attn_resolutions
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
out_ch
=
None
)
self
.
rescaler
=
LatentRescaler
(
factor
=
rescale_factor
,
in_channels
=
intermediate_chn
,
mid_channels
=
intermediate_chn
,
out_channels
=
out_ch
,
depth
=
rescale_module_depth
)
def
forward
(
self
,
x
):
x
=
self
.
encoder
(
x
)
x
=
self
.
rescaler
(
x
)
return
x
class
MergedRescaleDecoder
(
nn
.
Module
):
def
__init__
(
self
,
z_channels
,
out_ch
,
resolution
,
num_res_blocks
,
attn_resolutions
,
ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
resamp_with_conv
=
True
,
rescale_factor
=
1.0
,
rescale_module_depth
=
1
):
super
().
__init__
()
tmp_chn
=
z_channels
*
ch_mult
[
-
1
]
self
.
decoder
=
Decoder
(
out_ch
=
out_ch
,
z_channels
=
tmp_chn
,
attn_resolutions
=
attn_resolutions
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
in_channels
=
None
,
num_res_blocks
=
num_res_blocks
,
ch_mult
=
ch_mult
,
resolution
=
resolution
,
ch
=
ch
)
self
.
rescaler
=
LatentRescaler
(
factor
=
rescale_factor
,
in_channels
=
z_channels
,
mid_channels
=
tmp_chn
,
out_channels
=
tmp_chn
,
depth
=
rescale_module_depth
)
def
forward
(
self
,
x
):
x
=
self
.
rescaler
(
x
)
x
=
self
.
decoder
(
x
)
return
x
class
Upsampler
(
nn
.
Module
):
def
__init__
(
self
,
in_size
,
out_size
,
in_channels
,
out_channels
,
ch_mult
=
2
):
super
().
__init__
()
assert
out_size
>=
in_size
num_blocks
=
int
(
np
.
log2
(
out_size
//
in_size
))
+
1
factor_up
=
1.
+
(
out_size
%
in_size
)
print
(
f
"Building
{
self
.
__class__
.
__name__
}
with in_size:
{
in_size
}
--> out_size
{
out_size
}
and factor
{
factor_up
}
"
)
self
.
rescaler
=
LatentRescaler
(
factor
=
factor_up
,
in_channels
=
in_channels
,
mid_channels
=
2
*
in_channels
,
out_channels
=
in_channels
)
self
.
decoder
=
Decoder
(
out_ch
=
out_channels
,
resolution
=
out_size
,
z_channels
=
in_channels
,
num_res_blocks
=
2
,
attn_resolutions
=
[],
in_channels
=
None
,
ch
=
in_channels
,
ch_mult
=
[
ch_mult
for
_
in
range
(
num_blocks
)])
def
forward
(
self
,
x
):
x
=
self
.
rescaler
(
x
)
x
=
self
.
decoder
(
x
)
return
x
class
Resize
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
=
None
,
learned
=
False
,
mode
=
"bilinear"
):
super
().
__init__
()
self
.
with_conv
=
learned
self
.
mode
=
mode
if
self
.
with_conv
:
print
(
f
"Note:
{
self
.
__class__
.
__name
}
uses learned downsampling and will ignore the fixed
{
mode
}
mode"
)
raise
NotImplementedError
()
assert
in_channels
is
not
None
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
4
,
stride
=
2
,
padding
=
1
)
def
forward
(
self
,
x
,
scale_factor
=
1.0
):
if
scale_factor
==
1.0
:
return
x
else
:
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
mode
=
self
.
mode
,
align_corners
=
False
,
scale_factor
=
scale_factor
)
return
x
class
FirstStagePostProcessor
(
nn
.
Module
):
def
__init__
(
self
,
ch_mult
:
list
,
in_channels
,
pretrained_model
:
nn
.
Module
=
None
,
reshape
=
False
,
n_channels
=
None
,
dropout
=
0.
,
pretrained_config
=
None
):
super
().
__init__
()
if
pretrained_config
is
None
:
assert
pretrained_model
is
not
None
,
'Either "pretrained_model" or "pretrained_config" must not be None'
self
.
pretrained_model
=
pretrained_model
else
:
assert
pretrained_config
is
not
None
,
'Either "pretrained_model" or "pretrained_config" must not be None'
self
.
instantiate_pretrained
(
pretrained_config
)
self
.
do_reshape
=
reshape
if
n_channels
is
None
:
n_channels
=
self
.
pretrained_model
.
encoder
.
ch
self
.
proj_norm
=
Normalize
(
in_channels
,
num_groups
=
in_channels
//
2
)
self
.
proj
=
nn
.
Conv2d
(
in_channels
,
n_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
blocks
=
[]
downs
=
[]
ch_in
=
n_channels
for
m
in
ch_mult
:
blocks
.
append
(
ResnetBlock
(
in_channels
=
ch_in
,
out_channels
=
m
*
n_channels
,
dropout
=
dropout
))
ch_in
=
m
*
n_channels
downs
.
append
(
Downsample
(
ch_in
,
with_conv
=
False
))
self
.
model
=
nn
.
ModuleList
(
blocks
)
self
.
downsampler
=
nn
.
ModuleList
(
downs
)
def
instantiate_pretrained
(
self
,
config
):
model
=
instantiate_from_config
(
config
)
self
.
pretrained_model
=
model
.
eval
()
# self.pretrained_model.train = False
for
param
in
self
.
pretrained_model
.
parameters
():
param
.
requires_grad
=
False
@
torch
.
no_grad
()
def
encode_with_pretrained
(
self
,
x
):
c
=
self
.
pretrained_model
.
encode
(
x
)
if
isinstance
(
c
,
DiagonalGaussianDistribution
):
c
=
c
.
mode
()
return
c
def
forward
(
self
,
x
):
z_fs
=
self
.
encode_with_pretrained
(
x
)
z
=
self
.
proj_norm
(
z_fs
)
z
=
self
.
proj
(
z
)
z
=
nonlinearity
(
z
)
for
submodel
,
downmodel
in
zip
(
self
.
model
,
self
.
downsampler
):
z
=
submodel
(
z
,
temb
=
None
)
z
=
downmodel
(
z
)
if
self
.
do_reshape
:
z
=
rearrange
(
z
,
'b c h w -> b (h w) c'
)
return
z
Ruyi-Models/ruyi/vae/ldm/modules/diffusionmodules/util.py
0 → 100644
View file @
08a21d59
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!
import
math
import
os
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
einops
import
repeat
from
...util
import
instantiate_from_config
def
make_beta_schedule
(
schedule
,
n_timestep
,
linear_start
=
1e-4
,
linear_end
=
2e-2
,
cosine_s
=
8e-3
):
if
schedule
==
"linear"
:
betas
=
(
torch
.
linspace
(
linear_start
**
0.5
,
linear_end
**
0.5
,
n_timestep
,
dtype
=
torch
.
float64
)
**
2
)
elif
schedule
==
"cosine"
:
timesteps
=
(
torch
.
arange
(
n_timestep
+
1
,
dtype
=
torch
.
float64
)
/
n_timestep
+
cosine_s
)
alphas
=
timesteps
/
(
1
+
cosine_s
)
*
np
.
pi
/
2
alphas
=
torch
.
cos
(
alphas
).
pow
(
2
)
alphas
=
alphas
/
alphas
[
0
]
betas
=
1
-
alphas
[
1
:]
/
alphas
[:
-
1
]
betas
=
np
.
clip
(
betas
,
a_min
=
0
,
a_max
=
0.999
)
elif
schedule
==
"sqrt_linear"
:
betas
=
torch
.
linspace
(
linear_start
,
linear_end
,
n_timestep
,
dtype
=
torch
.
float64
)
elif
schedule
==
"sqrt"
:
betas
=
torch
.
linspace
(
linear_start
,
linear_end
,
n_timestep
,
dtype
=
torch
.
float64
)
**
0.5
else
:
raise
ValueError
(
f
"schedule '
{
schedule
}
' unknown."
)
return
betas
.
numpy
()
def
make_ddim_timesteps
(
ddim_discr_method
,
num_ddim_timesteps
,
num_ddpm_timesteps
,
verbose
=
True
):
if
ddim_discr_method
==
'uniform'
:
c
=
num_ddpm_timesteps
//
num_ddim_timesteps
ddim_timesteps
=
np
.
asarray
(
list
(
range
(
0
,
num_ddpm_timesteps
,
c
)))
elif
ddim_discr_method
==
'quad'
:
ddim_timesteps
=
((
np
.
linspace
(
0
,
np
.
sqrt
(
num_ddpm_timesteps
*
.
8
),
num_ddim_timesteps
))
**
2
).
astype
(
int
)
else
:
raise
NotImplementedError
(
f
'There is no ddim discretization method called "
{
ddim_discr_method
}
"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out
=
ddim_timesteps
+
1
if
verbose
:
print
(
f
'Selected timesteps for ddim sampler:
{
steps_out
}
'
)
return
steps_out
def
make_ddim_sampling_parameters
(
alphacums
,
ddim_timesteps
,
eta
,
verbose
=
True
):
# select alphas for computing the variance schedule
alphas
=
alphacums
[
ddim_timesteps
]
alphas_prev
=
np
.
asarray
([
alphacums
[
0
]]
+
alphacums
[
ddim_timesteps
[:
-
1
]].
tolist
())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas
=
eta
*
np
.
sqrt
((
1
-
alphas_prev
)
/
(
1
-
alphas
)
*
(
1
-
alphas
/
alphas_prev
))
if
verbose
:
print
(
f
'Selected alphas for ddim sampler: a_t:
{
alphas
}
; a_(t-1):
{
alphas_prev
}
'
)
print
(
f
'For the chosen value of eta, which is
{
eta
}
, '
f
'this results in the following sigma_t schedule for ddim sampler
{
sigmas
}
'
)
return
sigmas
,
alphas
,
alphas_prev
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
alpha_bar
,
max_beta
=
0.999
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
np
.
array
(
betas
)
def
extract_into_tensor
(
a
,
t
,
x_shape
):
b
,
*
_
=
t
.
shape
out
=
a
.
gather
(
-
1
,
t
)
return
out
.
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_shape
)
-
1
)))
def
checkpoint
(
func
,
inputs
,
params
,
flag
):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if
flag
:
args
=
tuple
(
inputs
)
+
tuple
(
params
)
return
CheckpointFunction
.
apply
(
func
,
len
(
inputs
),
*
args
)
else
:
return
func
(
*
inputs
)
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
run_function
,
length
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
input_tensors
=
list
(
args
[:
length
])
ctx
.
input_params
=
list
(
args
[
length
:])
with
torch
.
no_grad
():
output_tensors
=
ctx
.
run_function
(
*
ctx
.
input_tensors
)
return
output_tensors
@
staticmethod
def
backward
(
ctx
,
*
output_grads
):
ctx
.
input_tensors
=
[
x
.
detach
().
requires_grad_
(
True
)
for
x
in
ctx
.
input_tensors
]
with
torch
.
enable_grad
():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies
=
[
x
.
view_as
(
x
)
for
x
in
ctx
.
input_tensors
]
output_tensors
=
ctx
.
run_function
(
*
shallow_copies
)
input_grads
=
torch
.
autograd
.
grad
(
output_tensors
,
ctx
.
input_tensors
+
ctx
.
input_params
,
output_grads
,
allow_unused
=
True
,
)
del
ctx
.
input_tensors
del
ctx
.
input_params
del
output_tensors
return
(
None
,
None
)
+
input_grads
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
,
repeat_only
=
False
):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if
not
repeat_only
:
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
else
:
embedding
=
repeat
(
timesteps
,
'b -> b d'
,
d
=
dim
)
return
embedding
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
def
scale_module
(
module
,
scale
):
"""
Scale the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
mul_
(
scale
)
return
module
def
mean_flat
(
tensor
):
"""
Take the mean over all non-batch dimensions.
"""
return
tensor
.
mean
(
dim
=
list
(
range
(
1
,
len
(
tensor
.
shape
))))
def
normalization
(
channels
):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return
GroupNorm32
(
32
,
channels
)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class
SiLU
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
*
torch
.
sigmoid
(
x
)
class
GroupNorm32
(
nn
.
GroupNorm
):
def
forward
(
self
,
x
):
return
super
().
forward
(
x
.
float
()).
type
(
x
.
dtype
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if
dims
==
1
:
return
nn
.
Conv1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
linear
(
*
args
,
**
kwargs
):
"""
Create a linear module.
"""
return
nn
.
Linear
(
*
args
,
**
kwargs
)
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if
dims
==
1
:
return
nn
.
AvgPool1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
AvgPool2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
class
HybridConditioner
(
nn
.
Module
):
def
__init__
(
self
,
c_concat_config
,
c_crossattn_config
):
super
().
__init__
()
self
.
concat_conditioner
=
instantiate_from_config
(
c_concat_config
)
self
.
crossattn_conditioner
=
instantiate_from_config
(
c_crossattn_config
)
def
forward
(
self
,
c_concat
,
c_crossattn
):
c_concat
=
self
.
concat_conditioner
(
c_concat
)
c_crossattn
=
self
.
crossattn_conditioner
(
c_crossattn
)
return
{
'c_concat'
:
[
c_concat
],
'c_crossattn'
:
[
c_crossattn
]}
def
noise_like
(
shape
,
device
,
repeat
=
False
):
repeat_noise
=
lambda
:
torch
.
randn
((
1
,
*
shape
[
1
:]),
device
=
device
).
repeat
(
shape
[
0
],
*
((
1
,)
*
(
len
(
shape
)
-
1
)))
noise
=
lambda
:
torch
.
randn
(
shape
,
device
=
device
)
return
repeat_noise
()
if
repeat
else
noise
()
\ No newline at end of file
Ruyi-Models/ruyi/vae/ldm/modules/distributions/__init__.py
0 → 100644
View file @
08a21d59
Ruyi-Models/ruyi/vae/ldm/modules/distributions/distributions.py
0 → 100644
View file @
08a21d59
import
numpy
as
np
import
torch
class
AbstractDistribution
:
def
sample
(
self
):
raise
NotImplementedError
()
def
mode
(
self
):
raise
NotImplementedError
()
class
DiracDistribution
(
AbstractDistribution
):
def
__init__
(
self
,
value
):
self
.
value
=
value
def
sample
(
self
):
return
self
.
value
def
mode
(
self
):
return
self
.
value
class
DiagonalGaussianDistribution
(
object
):
def
__init__
(
self
,
parameters
,
deterministic
=
False
):
self
.
parameters
=
parameters
self
.
mean
,
self
.
logvar
=
torch
.
chunk
(
parameters
,
2
,
dim
=
1
)
self
.
logvar
=
torch
.
clamp
(
self
.
logvar
,
-
30.0
,
20.0
)
self
.
deterministic
=
deterministic
self
.
std
=
torch
.
exp
(
0.5
*
self
.
logvar
)
self
.
var
=
torch
.
exp
(
self
.
logvar
)
if
self
.
deterministic
:
self
.
var
=
self
.
std
=
torch
.
zeros_like
(
self
.
mean
).
to
(
device
=
self
.
parameters
.
device
)
def
sample
(
self
):
x
=
self
.
mean
+
self
.
std
*
torch
.
randn
(
self
.
mean
.
shape
).
to
(
device
=
self
.
parameters
.
device
)
return
x
def
kl
(
self
,
other
=
None
):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
])
else
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
else
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
])
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
def
mode
(
self
):
return
self
.
mean
def
normal_kl
(
mean1
,
logvar1
,
mean2
,
logvar2
):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor
=
None
for
obj
in
(
mean1
,
logvar1
,
mean2
,
logvar2
):
if
isinstance
(
obj
,
torch
.
Tensor
):
tensor
=
obj
break
assert
tensor
is
not
None
,
"at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1
,
logvar2
=
[
x
if
isinstance
(
x
,
torch
.
Tensor
)
else
torch
.
tensor
(
x
).
to
(
tensor
)
for
x
in
(
logvar1
,
logvar2
)
]
return
0.5
*
(
-
1.0
+
logvar2
-
logvar1
+
torch
.
exp
(
logvar1
-
logvar2
)
+
((
mean1
-
mean2
)
**
2
)
*
torch
.
exp
(
-
logvar2
)
)
Ruyi-Models/ruyi/vae/ldm/modules/ema.py
0 → 100644
View file @
08a21d59
#-*- encoding:utf-8 -*-
import
torch
from
torch
import
nn
from
pytorch_lightning.callbacks
import
Callback
class
LitEma
(
nn
.
Module
):
def
__init__
(
self
,
model
,
decay
=
0.9999
,
use_num_upates
=
True
):
super
().
__init__
()
if
decay
<
0.0
or
decay
>
1.0
:
raise
ValueError
(
'Decay must be between 0 and 1'
)
self
.
m_name2s_name
=
{}
self
.
register_buffer
(
'decay'
,
torch
.
tensor
(
decay
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
'num_updates'
,
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
use_num_upates
else
torch
.
tensor
(
-
1
,
dtype
=
torch
.
int
))
for
name
,
p
in
model
.
named_parameters
():
if
p
.
requires_grad
:
#remove as '.'-character is not allowed in buffers
s_name
=
name
.
replace
(
'.'
,
''
)
self
.
m_name2s_name
.
update
({
name
:
s_name
})
self
.
register_buffer
(
s_name
,
p
.
clone
().
detach
().
data
)
self
.
collected_params
=
[]
def
forward
(
self
,
model
):
decay
=
self
.
decay
if
self
.
num_updates
>=
0
:
self
.
num_updates
+=
1
decay
=
min
(
self
.
decay
,(
1
+
self
.
num_updates
)
/
(
10
+
self
.
num_updates
))
one_minus_decay
=
1.0
-
decay
with
torch
.
no_grad
():
m_param
=
dict
(
model
.
named_parameters
())
shadow_params
=
dict
(
self
.
named_buffers
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
sname
=
self
.
m_name2s_name
[
key
]
shadow_params
[
sname
]
=
shadow_params
[
sname
].
type_as
(
m_param
[
key
])
shadow_params
[
sname
].
sub_
(
one_minus_decay
*
(
shadow_params
[
sname
]
-
m_param
[
key
]))
else
:
assert
not
key
in
self
.
m_name2s_name
def
copy_to
(
self
,
model
):
m_param
=
dict
(
model
.
named_parameters
())
shadow_params
=
dict
(
self
.
named_buffers
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
m_param
[
key
].
data
.
copy_
(
shadow_params
[
self
.
m_name2s_name
[
key
]].
data
)
else
:
assert
not
key
in
self
.
m_name2s_name
def
store
(
self
,
parameters
):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self
.
collected_params
=
[
param
.
clone
()
for
param
in
parameters
]
def
restore
(
self
,
parameters
):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for
c_param
,
param
in
zip
(
self
.
collected_params
,
parameters
):
param
.
data
.
copy_
(
c_param
.
data
)
class
EMACallback
(
Callback
):
def
__init__
(
self
,
decay
=
0.9999
):
self
.
decay
=
decay
self
.
shadow_params
=
{}
def
on_train_start
(
self
,
trainer
,
pl_module
):
# initialize shadow parameters for original models
total_ema_cnt
=
0
for
name
,
param
in
pl_module
.
named_parameters
():
if
name
not
in
self
.
shadow_params
:
self
.
shadow_params
[
name
]
=
param
.
data
.
clone
()
else
:
# already in dict, maybe load from checkpoint
pass
print
(
'will calc ema for param: %s'
%
name
)
total_ema_cnt
+=
1
print
(
'total_ema_cnt=%d'
%
total_ema_cnt
)
def
on_train_batch_end
(
self
,
trainer
,
pl_module
,
outputs
,
batch
,
batch_idx
):
# Update the shadow params at the end of each epoch
for
name
,
param
in
pl_module
.
named_parameters
():
assert
name
in
self
.
shadow_params
new_average
=
(
1.0
-
self
.
decay
)
*
param
.
data
+
self
.
decay
*
self
.
shadow_params
[
name
]
self
.
shadow_params
[
name
]
=
new_average
.
clone
()
def
on_save_checkpoint
(
self
,
trainer
,
pl_module
,
checkpoint
):
# Save EMA parameters in the checkpoint
checkpoint
[
'ema_params'
]
=
self
.
shadow_params
def
on_load_checkpoint
(
self
,
trainer
,
pl_module
,
checkpoint
):
# Restore EMA parameters from the checkpoint
if
'ema_params'
in
checkpoint
:
self
.
shadow_params
=
checkpoint
.
get
(
'ema_params'
,
{})
for
k
in
self
.
shadow_params
:
self
.
shadow_params
[
k
]
=
self
.
shadow_params
[
k
].
cuda
()
print
(
'load shadow params from checkpoint, cnt=%d'
%
len
(
self
.
shadow_params
))
else
:
print
(
'ema_params is not in checkpoint'
)
\ No newline at end of file
Ruyi-Models/ruyi/vae/ldm/modules/image_degradation/__init__.py
0 → 100644
View file @
08a21d59
from
.bsrgan
import
degradation_bsrgan_variant
as
degradation_fn_bsr
from
.bsrgan_light
import
\
degradation_bsrgan_variant
as
degradation_fn_bsr_light
Ruyi-Models/ruyi/vae/ldm/modules/image_degradation/bsrgan.py
0 → 100644
View file @
08a21d59
# -*- coding: utf-8 -*-
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
import
random
from
functools
import
partial
import
albumentations
import
cv2
import
numpy
as
np
import
scipy
import
scipy.stats
as
ss
import
torch
from
scipy
import
ndimage
from
scipy.interpolate
import
interp2d
from
scipy.linalg
import
orth
from
.
import
utils_image
as
util
def
modcrop_np
(
img
,
sf
):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w
,
h
=
img
.
shape
[:
2
]
im
=
np
.
copy
(
img
)
return
im
[:
w
-
w
%
sf
,
:
h
-
h
%
sf
,
...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def
analytic_kernel
(
k
):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size
=
k
.
shape
[
0
]
# Calculate the big kernels size
big_k
=
np
.
zeros
((
3
*
k_size
-
2
,
3
*
k_size
-
2
))
# Loop over the small kernel to fill the big one
for
r
in
range
(
k_size
):
for
c
in
range
(
k_size
):
big_k
[
2
*
r
:
2
*
r
+
k_size
,
2
*
c
:
2
*
c
+
k_size
]
+=
k
[
r
,
c
]
*
k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop
=
k_size
//
2
cropped_big_k
=
big_k
[
crop
:
-
crop
,
crop
:
-
crop
]
# Normalize to 1
return
cropped_big_k
/
cropped_big_k
.
sum
()
def
anisotropic_Gaussian
(
ksize
=
15
,
theta
=
np
.
pi
,
l1
=
6
,
l2
=
6
):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v
=
np
.
dot
(
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]]),
np
.
array
([
1.
,
0.
]))
V
=
np
.
array
([[
v
[
0
],
v
[
1
]],
[
v
[
1
],
-
v
[
0
]]])
D
=
np
.
array
([[
l1
,
0
],
[
0
,
l2
]])
Sigma
=
np
.
dot
(
np
.
dot
(
V
,
D
),
np
.
linalg
.
inv
(
V
))
k
=
gm_blur_kernel
(
mean
=
[
0
,
0
],
cov
=
Sigma
,
size
=
ksize
)
return
k
def
gm_blur_kernel
(
mean
,
cov
,
size
=
15
):
center
=
size
/
2.0
+
0.5
k
=
np
.
zeros
([
size
,
size
])
for
y
in
range
(
size
):
for
x
in
range
(
size
):
cy
=
y
-
center
+
1
cx
=
x
-
center
+
1
k
[
y
,
x
]
=
ss
.
multivariate_normal
.
pdf
([
cx
,
cy
],
mean
=
mean
,
cov
=
cov
)
k
=
k
/
np
.
sum
(
k
)
return
k
def
shift_pixel
(
x
,
sf
,
upper_left
=
True
):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h
,
w
=
x
.
shape
[:
2
]
shift
=
(
sf
-
1
)
*
0.5
xv
,
yv
=
np
.
arange
(
0
,
w
,
1.0
),
np
.
arange
(
0
,
h
,
1.0
)
if
upper_left
:
x1
=
xv
+
shift
y1
=
yv
+
shift
else
:
x1
=
xv
-
shift
y1
=
yv
-
shift
x1
=
np
.
clip
(
x1
,
0
,
w
-
1
)
y1
=
np
.
clip
(
y1
,
0
,
h
-
1
)
if
x
.
ndim
==
2
:
x
=
interp2d
(
xv
,
yv
,
x
)(
x1
,
y1
)
if
x
.
ndim
==
3
:
for
i
in
range
(
x
.
shape
[
-
1
]):
x
[:,
:,
i
]
=
interp2d
(
xv
,
yv
,
x
[:,
:,
i
])(
x1
,
y1
)
return
x
def
blur
(
x
,
k
):
'''
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
n
,
c
=
x
.
shape
[:
2
]
p1
,
p2
=
(
k
.
shape
[
-
2
]
-
1
)
//
2
,
(
k
.
shape
[
-
1
]
-
1
)
//
2
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
=
(
p1
,
p2
,
p1
,
p2
),
mode
=
'replicate'
)
k
=
k
.
repeat
(
1
,
c
,
1
,
1
)
k
=
k
.
view
(
-
1
,
1
,
k
.
shape
[
2
],
k
.
shape
[
3
])
x
=
x
.
view
(
1
,
-
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
x
=
torch
.
nn
.
functional
.
conv2d
(
x
,
k
,
bias
=
None
,
stride
=
1
,
padding
=
0
,
groups
=
n
*
c
)
x
=
x
.
view
(
n
,
c
,
x
.
shape
[
2
],
x
.
shape
[
3
])
return
x
def
gen_kernel
(
k_size
=
np
.
array
([
15
,
15
]),
scale_factor
=
np
.
array
([
4
,
4
]),
min_var
=
0.6
,
max_var
=
10.
,
noise_level
=
0
):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
lambda_2
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
theta
=
np
.
random
.
rand
()
*
np
.
pi
# random theta
noise
=
-
noise_level
+
np
.
random
.
rand
(
*
k_size
)
*
noise_level
*
2
# Set COV matrix using Lambdas and Theta
LAMBDA
=
np
.
diag
([
lambda_1
,
lambda_2
])
Q
=
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]])
SIGMA
=
Q
@
LAMBDA
@
Q
.
T
INV_SIGMA
=
np
.
linalg
.
inv
(
SIGMA
)[
None
,
None
,
:,
:]
# Set expectation position (shifting kernel for aligned image)
MU
=
k_size
//
2
-
0.5
*
(
scale_factor
-
1
)
# - 0.5 * (scale_factor - k_size % 2)
MU
=
MU
[
None
,
None
,
:,
None
]
# Create meshgrid for Gaussian
[
X
,
Y
]
=
np
.
meshgrid
(
range
(
k_size
[
0
]),
range
(
k_size
[
1
]))
Z
=
np
.
stack
([
X
,
Y
],
2
)[:,
:,
:,
None
]
# Calcualte Gaussian for every pixel of the kernel
ZZ
=
Z
-
MU
ZZ_t
=
ZZ
.
transpose
(
0
,
1
,
3
,
2
)
raw_kernel
=
np
.
exp
(
-
0.5
*
np
.
squeeze
(
ZZ_t
@
INV_SIGMA
@
ZZ
))
*
(
1
+
noise
)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel
=
raw_kernel
/
np
.
sum
(
raw_kernel
)
return
kernel
def
fspecial_gaussian
(
hsize
,
sigma
):
hsize
=
[
hsize
,
hsize
]
siz
=
[(
hsize
[
0
]
-
1.0
)
/
2.0
,
(
hsize
[
1
]
-
1.0
)
/
2.0
]
std
=
sigma
[
x
,
y
]
=
np
.
meshgrid
(
np
.
arange
(
-
siz
[
1
],
siz
[
1
]
+
1
),
np
.
arange
(
-
siz
[
0
],
siz
[
0
]
+
1
))
arg
=
-
(
x
*
x
+
y
*
y
)
/
(
2
*
std
*
std
)
h
=
np
.
exp
(
arg
)
h
[
h
<
scipy
.
finfo
(
float
).
eps
*
h
.
max
()]
=
0
sumh
=
h
.
sum
()
if
sumh
!=
0
:
h
=
h
/
sumh
return
h
def
fspecial_laplacian
(
alpha
):
alpha
=
max
([
0
,
min
([
alpha
,
1
])])
h1
=
alpha
/
(
alpha
+
1
)
h2
=
(
1
-
alpha
)
/
(
alpha
+
1
)
h
=
[[
h1
,
h2
,
h1
],
[
h2
,
-
4
/
(
alpha
+
1
),
h2
],
[
h1
,
h2
,
h1
]]
h
=
np
.
array
(
h
)
return
h
def
fspecial
(
filter_type
,
*
args
,
**
kwargs
):
'''
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
if
filter_type
==
'gaussian'
:
return
fspecial_gaussian
(
*
args
,
**
kwargs
)
if
filter_type
==
'laplacian'
:
return
fspecial_laplacian
(
*
args
,
**
kwargs
)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def
bicubic_degradation
(
x
,
sf
=
3
):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x
=
util
.
imresize_np
(
x
,
scale
=
1
/
sf
)
return
x
def
srmd_degradation
(
x
,
k
,
sf
=
3
):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# 'nearest' | 'mirror'
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
return
x
def
dpsr_degradation
(
x
,
k
,
sf
=
3
):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
return
x
def
classical_degradation
(
x
,
k
,
sf
=
3
):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st
=
0
return
x
[
st
::
sf
,
st
::
sf
,
...]
def
add_sharpening
(
img
,
weight
=
0.5
,
radius
=
50
,
threshold
=
10
):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if
radius
%
2
==
0
:
radius
+=
1
blur
=
cv2
.
GaussianBlur
(
img
,
(
radius
,
radius
),
0
)
residual
=
img
-
blur
mask
=
np
.
abs
(
residual
)
*
255
>
threshold
mask
=
mask
.
astype
(
'float32'
)
soft_mask
=
cv2
.
GaussianBlur
(
mask
,
(
radius
,
radius
),
0
)
K
=
img
+
weight
*
residual
K
=
np
.
clip
(
K
,
0
,
1
)
return
soft_mask
*
K
+
(
1
-
soft_mask
)
*
img
def
add_blur
(
img
,
sf
=
4
):
wd2
=
4.0
+
sf
wd
=
2.0
+
0.2
*
sf
if
random
.
random
()
<
0.5
:
l1
=
wd2
*
random
.
random
()
l2
=
wd2
*
random
.
random
()
k
=
anisotropic_Gaussian
(
ksize
=
2
*
random
.
randint
(
2
,
11
)
+
3
,
theta
=
random
.
random
()
*
np
.
pi
,
l1
=
l1
,
l2
=
l2
)
else
:
k
=
fspecial
(
'gaussian'
,
2
*
random
.
randint
(
2
,
11
)
+
3
,
wd
*
random
.
random
())
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'mirror'
)
return
img
def
add_resize
(
img
,
sf
=
4
):
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.8
:
# up
sf1
=
random
.
uniform
(
1
,
2
)
elif
rnum
<
0.7
:
# down
sf1
=
random
.
uniform
(
0.5
/
sf
,
1
)
else
:
sf1
=
1.0
img
=
cv2
.
resize
(
img
,
(
int
(
sf1
*
img
.
shape
[
1
]),
int
(
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.6
:
# add color Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
# add grayscale Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
# add noise
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
=
img
+
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_speckle_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
rnum
=
random
.
random
()
if
rnum
>
0.6
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
+=
img
*
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_Poisson_noise
(
img
):
img
=
np
.
clip
((
img
*
255.0
).
round
(),
0
,
255
)
/
255.
vals
=
10
**
(
2
*
random
.
random
()
+
2.0
)
# [2, 4]
if
random
.
random
()
<
0.5
:
img
=
np
.
random
.
poisson
(
img
*
vals
).
astype
(
np
.
float32
)
/
vals
else
:
img_gray
=
np
.
dot
(
img
[...,
:
3
],
[
0.299
,
0.587
,
0.114
])
img_gray
=
np
.
clip
((
img_gray
*
255.0
).
round
(),
0
,
255
)
/
255.
noise_gray
=
np
.
random
.
poisson
(
img_gray
*
vals
).
astype
(
np
.
float32
)
/
vals
-
img_gray
img
+=
noise_gray
[:,
:,
np
.
newaxis
]
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_JPEG_noise
(
img
):
quality_factor
=
random
.
randint
(
30
,
95
)
img
=
cv2
.
cvtColor
(
util
.
single2uint
(
img
),
cv2
.
COLOR_RGB2BGR
)
result
,
encimg
=
cv2
.
imencode
(
'.jpg'
,
img
,
[
int
(
cv2
.
IMWRITE_JPEG_QUALITY
),
quality_factor
])
img
=
cv2
.
imdecode
(
encimg
,
1
)
img
=
cv2
.
cvtColor
(
util
.
uint2single
(
img
),
cv2
.
COLOR_BGR2RGB
)
return
img
def
random_crop
(
lq
,
hq
,
sf
=
4
,
lq_patchsize
=
64
):
h
,
w
=
lq
.
shape
[:
2
]
rnd_h
=
random
.
randint
(
0
,
h
-
lq_patchsize
)
rnd_w
=
random
.
randint
(
0
,
w
-
lq_patchsize
)
lq
=
lq
[
rnd_h
:
rnd_h
+
lq_patchsize
,
rnd_w
:
rnd_w
+
lq_patchsize
,
:]
rnd_h_H
,
rnd_w_H
=
int
(
rnd_h
*
sf
),
int
(
rnd_w
*
sf
)
hq
=
hq
[
rnd_h_H
:
rnd_h_H
+
lq_patchsize
*
sf
,
rnd_w_H
:
rnd_w_H
+
lq_patchsize
*
sf
,
:]
return
lq
,
hq
def
degradation_bsrgan
(
img
,
sf
=
4
,
lq_patchsize
=
72
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
img
.
shape
[:
2
]
img
=
img
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
img
.
shape
[:
2
]
if
h
<
lq_patchsize
*
sf
or
w
<
lq_patchsize
*
sf
:
raise
ValueError
(
f
'img size (
{
h1
}
X
{
w1
}
) is too small!'
)
hq
=
img
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
2
*
img
.
shape
[
1
]),
int
(
1
/
2
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
img
=
util
.
imresize_np
(
img
,
1
/
2
,
True
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
1
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
2
:
a
,
b
=
img
.
shape
[
1
],
img
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.75
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf1
*
img
.
shape
[
1
]),
int
(
1
/
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
img
=
img
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
img
=
add_JPEG_noise
(
img
)
elif
i
==
6
:
# add processed camera sensor noise
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
# add final JPEG compression noise
img
=
add_JPEG_noise
(
img
)
# random crop
img
,
hq
=
random_crop
(
img
,
hq
,
sf_ori
,
lq_patchsize
)
return
img
,
hq
# todo no isp_model?
def
degradation_bsrgan_variant
(
image
,
sf
=
4
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image
=
util
.
uint2single
(
image
)
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
image
.
shape
[:
2
]
image
=
image
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
image
.
shape
[:
2
]
hq
=
image
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
2
*
image
.
shape
[
1
]),
int
(
1
/
2
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
image
=
util
.
imresize_np
(
image
,
1
/
2
,
True
)
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
image
=
add_blur
(
image
,
sf
=
sf
)
elif
i
==
1
:
image
=
add_blur
(
image
,
sf
=
sf
)
elif
i
==
2
:
a
,
b
=
image
.
shape
[
1
],
image
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.75
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf1
*
image
.
shape
[
1
]),
int
(
1
/
sf1
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
image
=
ndimage
.
filters
.
convolve
(
image
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
image
=
image
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
image
=
add_Gaussian_noise
(
image
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
image
=
add_JPEG_noise
(
image
)
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image
=
add_JPEG_noise
(
image
)
image
=
util
.
single2uint
(
image
)
example
=
{
"image"
:
image
}
return
example
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
def
degradation_bsrgan_plus
(
img
,
sf
=
4
,
shuffle_prob
=
0.5
,
use_sharp
=
True
,
lq_patchsize
=
64
,
isp_model
=
None
):
"""
This is an extended degradation model by combining
the degradation models of BSRGAN and Real-ESRGAN
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
use_shuffle: the degradation shuffle
use_sharp: sharpening the img
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
h1
,
w1
=
img
.
shape
[:
2
]
img
=
img
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
img
.
shape
[:
2
]
if
h
<
lq_patchsize
*
sf
or
w
<
lq_patchsize
*
sf
:
raise
ValueError
(
f
'img size (
{
h1
}
X
{
w1
}
) is too small!'
)
if
use_sharp
:
img
=
add_sharpening
(
img
)
hq
=
img
.
copy
()
if
random
.
random
()
<
shuffle_prob
:
shuffle_order
=
random
.
sample
(
range
(
13
),
13
)
else
:
shuffle_order
=
list
(
range
(
13
))
# local shuffle for noise, JPEG is always the last one
shuffle_order
[
2
:
6
]
=
random
.
sample
(
shuffle_order
[
2
:
6
],
len
(
range
(
2
,
6
)))
shuffle_order
[
9
:
13
]
=
random
.
sample
(
shuffle_order
[
9
:
13
],
len
(
range
(
9
,
13
)))
poisson_prob
,
speckle_prob
,
isp_prob
=
0.1
,
0.1
,
0.1
for
i
in
shuffle_order
:
if
i
==
0
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
1
:
img
=
add_resize
(
img
,
sf
=
sf
)
elif
i
==
2
:
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
3
:
if
random
.
random
()
<
poisson_prob
:
img
=
add_Poisson_noise
(
img
)
elif
i
==
4
:
if
random
.
random
()
<
speckle_prob
:
img
=
add_speckle_noise
(
img
)
elif
i
==
5
:
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
elif
i
==
6
:
img
=
add_JPEG_noise
(
img
)
elif
i
==
7
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
8
:
img
=
add_resize
(
img
,
sf
=
sf
)
elif
i
==
9
:
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
)
elif
i
==
10
:
if
random
.
random
()
<
poisson_prob
:
img
=
add_Poisson_noise
(
img
)
elif
i
==
11
:
if
random
.
random
()
<
speckle_prob
:
img
=
add_speckle_noise
(
img
)
elif
i
==
12
:
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
else
:
print
(
'check the shuffle!'
)
# resize to desired size
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf
*
hq
.
shape
[
1
]),
int
(
1
/
sf
*
hq
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
# add final JPEG compression noise
img
=
add_JPEG_noise
(
img
)
# random crop
img
,
hq
=
random_crop
(
img
,
hq
,
sf
,
lq_patchsize
)
return
img
,
hq
if
__name__
==
'__main__'
:
'''
print("hey")
img = util.imread_uint('utils/test.png', 3)
print(img)
img = util.uint2single(img)
print(img)
img = img[:448, :448]
h = img.shape[0] // 4
print("resizing to", h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_lq = deg_fn(img)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
util.imsave(img_concat, str(i) + '.png')
'''
Ruyi-Models/ruyi/vae/ldm/modules/image_degradation/bsrgan_light.py
0 → 100644
View file @
08a21d59
# -*- coding: utf-8 -*-
import
random
from
functools
import
partial
import
albumentations
import
cv2
import
numpy
as
np
import
scipy
import
scipy.stats
as
ss
import
torch
from
scipy
import
ndimage
from
scipy.interpolate
import
interp2d
from
scipy.linalg
import
orth
from
.
import
utils_image
as
util
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
def
modcrop_np
(
img
,
sf
):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w
,
h
=
img
.
shape
[:
2
]
im
=
np
.
copy
(
img
)
return
im
[:
w
-
w
%
sf
,
:
h
-
h
%
sf
,
...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def
analytic_kernel
(
k
):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size
=
k
.
shape
[
0
]
# Calculate the big kernels size
big_k
=
np
.
zeros
((
3
*
k_size
-
2
,
3
*
k_size
-
2
))
# Loop over the small kernel to fill the big one
for
r
in
range
(
k_size
):
for
c
in
range
(
k_size
):
big_k
[
2
*
r
:
2
*
r
+
k_size
,
2
*
c
:
2
*
c
+
k_size
]
+=
k
[
r
,
c
]
*
k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop
=
k_size
//
2
cropped_big_k
=
big_k
[
crop
:
-
crop
,
crop
:
-
crop
]
# Normalize to 1
return
cropped_big_k
/
cropped_big_k
.
sum
()
def
anisotropic_Gaussian
(
ksize
=
15
,
theta
=
np
.
pi
,
l1
=
6
,
l2
=
6
):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v
=
np
.
dot
(
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]]),
np
.
array
([
1.
,
0.
]))
V
=
np
.
array
([[
v
[
0
],
v
[
1
]],
[
v
[
1
],
-
v
[
0
]]])
D
=
np
.
array
([[
l1
,
0
],
[
0
,
l2
]])
Sigma
=
np
.
dot
(
np
.
dot
(
V
,
D
),
np
.
linalg
.
inv
(
V
))
k
=
gm_blur_kernel
(
mean
=
[
0
,
0
],
cov
=
Sigma
,
size
=
ksize
)
return
k
def
gm_blur_kernel
(
mean
,
cov
,
size
=
15
):
center
=
size
/
2.0
+
0.5
k
=
np
.
zeros
([
size
,
size
])
for
y
in
range
(
size
):
for
x
in
range
(
size
):
cy
=
y
-
center
+
1
cx
=
x
-
center
+
1
k
[
y
,
x
]
=
ss
.
multivariate_normal
.
pdf
([
cx
,
cy
],
mean
=
mean
,
cov
=
cov
)
k
=
k
/
np
.
sum
(
k
)
return
k
def
shift_pixel
(
x
,
sf
,
upper_left
=
True
):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h
,
w
=
x
.
shape
[:
2
]
shift
=
(
sf
-
1
)
*
0.5
xv
,
yv
=
np
.
arange
(
0
,
w
,
1.0
),
np
.
arange
(
0
,
h
,
1.0
)
if
upper_left
:
x1
=
xv
+
shift
y1
=
yv
+
shift
else
:
x1
=
xv
-
shift
y1
=
yv
-
shift
x1
=
np
.
clip
(
x1
,
0
,
w
-
1
)
y1
=
np
.
clip
(
y1
,
0
,
h
-
1
)
if
x
.
ndim
==
2
:
x
=
interp2d
(
xv
,
yv
,
x
)(
x1
,
y1
)
if
x
.
ndim
==
3
:
for
i
in
range
(
x
.
shape
[
-
1
]):
x
[:,
:,
i
]
=
interp2d
(
xv
,
yv
,
x
[:,
:,
i
])(
x1
,
y1
)
return
x
def
blur
(
x
,
k
):
'''
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
n
,
c
=
x
.
shape
[:
2
]
p1
,
p2
=
(
k
.
shape
[
-
2
]
-
1
)
//
2
,
(
k
.
shape
[
-
1
]
-
1
)
//
2
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
=
(
p1
,
p2
,
p1
,
p2
),
mode
=
'replicate'
)
k
=
k
.
repeat
(
1
,
c
,
1
,
1
)
k
=
k
.
view
(
-
1
,
1
,
k
.
shape
[
2
],
k
.
shape
[
3
])
x
=
x
.
view
(
1
,
-
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
x
=
torch
.
nn
.
functional
.
conv2d
(
x
,
k
,
bias
=
None
,
stride
=
1
,
padding
=
0
,
groups
=
n
*
c
)
x
=
x
.
view
(
n
,
c
,
x
.
shape
[
2
],
x
.
shape
[
3
])
return
x
def
gen_kernel
(
k_size
=
np
.
array
([
15
,
15
]),
scale_factor
=
np
.
array
([
4
,
4
]),
min_var
=
0.6
,
max_var
=
10.
,
noise_level
=
0
):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
lambda_2
=
min_var
+
np
.
random
.
rand
()
*
(
max_var
-
min_var
)
theta
=
np
.
random
.
rand
()
*
np
.
pi
# random theta
noise
=
-
noise_level
+
np
.
random
.
rand
(
*
k_size
)
*
noise_level
*
2
# Set COV matrix using Lambdas and Theta
LAMBDA
=
np
.
diag
([
lambda_1
,
lambda_2
])
Q
=
np
.
array
([[
np
.
cos
(
theta
),
-
np
.
sin
(
theta
)],
[
np
.
sin
(
theta
),
np
.
cos
(
theta
)]])
SIGMA
=
Q
@
LAMBDA
@
Q
.
T
INV_SIGMA
=
np
.
linalg
.
inv
(
SIGMA
)[
None
,
None
,
:,
:]
# Set expectation position (shifting kernel for aligned image)
MU
=
k_size
//
2
-
0.5
*
(
scale_factor
-
1
)
# - 0.5 * (scale_factor - k_size % 2)
MU
=
MU
[
None
,
None
,
:,
None
]
# Create meshgrid for Gaussian
[
X
,
Y
]
=
np
.
meshgrid
(
range
(
k_size
[
0
]),
range
(
k_size
[
1
]))
Z
=
np
.
stack
([
X
,
Y
],
2
)[:,
:,
:,
None
]
# Calcualte Gaussian for every pixel of the kernel
ZZ
=
Z
-
MU
ZZ_t
=
ZZ
.
transpose
(
0
,
1
,
3
,
2
)
raw_kernel
=
np
.
exp
(
-
0.5
*
np
.
squeeze
(
ZZ_t
@
INV_SIGMA
@
ZZ
))
*
(
1
+
noise
)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel
=
raw_kernel
/
np
.
sum
(
raw_kernel
)
return
kernel
def
fspecial_gaussian
(
hsize
,
sigma
):
hsize
=
[
hsize
,
hsize
]
siz
=
[(
hsize
[
0
]
-
1.0
)
/
2.0
,
(
hsize
[
1
]
-
1.0
)
/
2.0
]
std
=
sigma
[
x
,
y
]
=
np
.
meshgrid
(
np
.
arange
(
-
siz
[
1
],
siz
[
1
]
+
1
),
np
.
arange
(
-
siz
[
0
],
siz
[
0
]
+
1
))
arg
=
-
(
x
*
x
+
y
*
y
)
/
(
2
*
std
*
std
)
h
=
np
.
exp
(
arg
)
h
[
h
<
scipy
.
finfo
(
float
).
eps
*
h
.
max
()]
=
0
sumh
=
h
.
sum
()
if
sumh
!=
0
:
h
=
h
/
sumh
return
h
def
fspecial_laplacian
(
alpha
):
alpha
=
max
([
0
,
min
([
alpha
,
1
])])
h1
=
alpha
/
(
alpha
+
1
)
h2
=
(
1
-
alpha
)
/
(
alpha
+
1
)
h
=
[[
h1
,
h2
,
h1
],
[
h2
,
-
4
/
(
alpha
+
1
),
h2
],
[
h1
,
h2
,
h1
]]
h
=
np
.
array
(
h
)
return
h
def
fspecial
(
filter_type
,
*
args
,
**
kwargs
):
'''
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
if
filter_type
==
'gaussian'
:
return
fspecial_gaussian
(
*
args
,
**
kwargs
)
if
filter_type
==
'laplacian'
:
return
fspecial_laplacian
(
*
args
,
**
kwargs
)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def
bicubic_degradation
(
x
,
sf
=
3
):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x
=
util
.
imresize_np
(
x
,
scale
=
1
/
sf
)
return
x
def
srmd_degradation
(
x
,
k
,
sf
=
3
):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# 'nearest' | 'mirror'
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
return
x
def
dpsr_degradation
(
x
,
k
,
sf
=
3
):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x
=
bicubic_degradation
(
x
,
sf
=
sf
)
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
return
x
def
classical_degradation
(
x
,
k
,
sf
=
3
):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x
=
ndimage
.
filters
.
convolve
(
x
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'wrap'
)
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st
=
0
return
x
[
st
::
sf
,
st
::
sf
,
...]
def
add_sharpening
(
img
,
weight
=
0.5
,
radius
=
50
,
threshold
=
10
):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if
radius
%
2
==
0
:
radius
+=
1
blur
=
cv2
.
GaussianBlur
(
img
,
(
radius
,
radius
),
0
)
residual
=
img
-
blur
mask
=
np
.
abs
(
residual
)
*
255
>
threshold
mask
=
mask
.
astype
(
'float32'
)
soft_mask
=
cv2
.
GaussianBlur
(
mask
,
(
radius
,
radius
),
0
)
K
=
img
+
weight
*
residual
K
=
np
.
clip
(
K
,
0
,
1
)
return
soft_mask
*
K
+
(
1
-
soft_mask
)
*
img
def
add_blur
(
img
,
sf
=
4
):
wd2
=
4.0
+
sf
wd
=
2.0
+
0.2
*
sf
wd2
=
wd2
/
4
wd
=
wd
/
4
if
random
.
random
()
<
0.5
:
l1
=
wd2
*
random
.
random
()
l2
=
wd2
*
random
.
random
()
k
=
anisotropic_Gaussian
(
ksize
=
random
.
randint
(
2
,
11
)
+
3
,
theta
=
random
.
random
()
*
np
.
pi
,
l1
=
l1
,
l2
=
l2
)
else
:
k
=
fspecial
(
'gaussian'
,
random
.
randint
(
2
,
4
)
+
3
,
wd
*
random
.
random
())
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k
,
axis
=
2
),
mode
=
'mirror'
)
return
img
def
add_resize
(
img
,
sf
=
4
):
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.8
:
# up
sf1
=
random
.
uniform
(
1
,
2
)
elif
rnum
<
0.7
:
# down
sf1
=
random
.
uniform
(
0.5
/
sf
,
1
)
else
:
sf1
=
1.0
img
=
cv2
.
resize
(
img
,
(
int
(
sf1
*
img
.
shape
[
1
]),
int
(
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
rnum
=
np
.
random
.
rand
()
if
rnum
>
0.6
:
# add color Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
# add grayscale Gaussian noise
img
=
img
+
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
# add noise
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
=
img
+
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_speckle_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
25
):
noise_level
=
random
.
randint
(
noise_level1
,
noise_level2
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
rnum
=
random
.
random
()
if
rnum
>
0.6
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
img
.
shape
).
astype
(
np
.
float32
)
elif
rnum
<
0.4
:
img
+=
img
*
np
.
random
.
normal
(
0
,
noise_level
/
255.0
,
(
*
img
.
shape
[:
2
],
1
)).
astype
(
np
.
float32
)
else
:
L
=
noise_level2
/
255.
D
=
np
.
diag
(
np
.
random
.
rand
(
3
))
U
=
orth
(
np
.
random
.
rand
(
3
,
3
))
conv
=
np
.
dot
(
np
.
dot
(
np
.
transpose
(
U
),
D
),
U
)
img
+=
img
*
np
.
random
.
multivariate_normal
([
0
,
0
,
0
],
np
.
abs
(
L
**
2
*
conv
),
img
.
shape
[:
2
]).
astype
(
np
.
float32
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_Poisson_noise
(
img
):
img
=
np
.
clip
((
img
*
255.0
).
round
(),
0
,
255
)
/
255.
vals
=
10
**
(
2
*
random
.
random
()
+
2.0
)
# [2, 4]
if
random
.
random
()
<
0.5
:
img
=
np
.
random
.
poisson
(
img
*
vals
).
astype
(
np
.
float32
)
/
vals
else
:
img_gray
=
np
.
dot
(
img
[...,
:
3
],
[
0.299
,
0.587
,
0.114
])
img_gray
=
np
.
clip
((
img_gray
*
255.0
).
round
(),
0
,
255
)
/
255.
noise_gray
=
np
.
random
.
poisson
(
img_gray
*
vals
).
astype
(
np
.
float32
)
/
vals
-
img_gray
img
+=
noise_gray
[:,
:,
np
.
newaxis
]
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
return
img
def
add_JPEG_noise
(
img
):
quality_factor
=
random
.
randint
(
80
,
95
)
img
=
cv2
.
cvtColor
(
util
.
single2uint
(
img
),
cv2
.
COLOR_RGB2BGR
)
result
,
encimg
=
cv2
.
imencode
(
'.jpg'
,
img
,
[
int
(
cv2
.
IMWRITE_JPEG_QUALITY
),
quality_factor
])
img
=
cv2
.
imdecode
(
encimg
,
1
)
img
=
cv2
.
cvtColor
(
util
.
uint2single
(
img
),
cv2
.
COLOR_BGR2RGB
)
return
img
def
random_crop
(
lq
,
hq
,
sf
=
4
,
lq_patchsize
=
64
):
h
,
w
=
lq
.
shape
[:
2
]
rnd_h
=
random
.
randint
(
0
,
h
-
lq_patchsize
)
rnd_w
=
random
.
randint
(
0
,
w
-
lq_patchsize
)
lq
=
lq
[
rnd_h
:
rnd_h
+
lq_patchsize
,
rnd_w
:
rnd_w
+
lq_patchsize
,
:]
rnd_h_H
,
rnd_w_H
=
int
(
rnd_h
*
sf
),
int
(
rnd_w
*
sf
)
hq
=
hq
[
rnd_h_H
:
rnd_h_H
+
lq_patchsize
*
sf
,
rnd_w_H
:
rnd_w_H
+
lq_patchsize
*
sf
,
:]
return
lq
,
hq
def
degradation_bsrgan
(
img
,
sf
=
4
,
lq_patchsize
=
72
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
img
.
shape
[:
2
]
img
=
img
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
img
.
shape
[:
2
]
if
h
<
lq_patchsize
*
sf
or
w
<
lq_patchsize
*
sf
:
raise
ValueError
(
f
'img size (
{
h1
}
X
{
w1
}
) is too small!'
)
hq
=
img
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
2
*
img
.
shape
[
1
]),
int
(
1
/
2
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
img
=
util
.
imresize_np
(
img
,
1
/
2
,
True
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
1
:
img
=
add_blur
(
img
,
sf
=
sf
)
elif
i
==
2
:
a
,
b
=
img
.
shape
[
1
],
img
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.75
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf1
*
img
.
shape
[
1
]),
int
(
1
/
sf1
*
img
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
img
=
ndimage
.
filters
.
convolve
(
img
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
img
=
img
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
img
=
cv2
.
resize
(
img
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
img
=
add_Gaussian_noise
(
img
,
noise_level1
=
2
,
noise_level2
=
8
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
img
=
add_JPEG_noise
(
img
)
elif
i
==
6
:
# add processed camera sensor noise
if
random
.
random
()
<
isp_prob
and
isp_model
is
not
None
:
with
torch
.
no_grad
():
img
,
hq
=
isp_model
.
forward
(
img
.
copy
(),
hq
)
# add final JPEG compression noise
img
=
add_JPEG_noise
(
img
)
# random crop
img
,
hq
=
random_crop
(
img
,
hq
,
sf_ori
,
lq_patchsize
)
return
img
,
hq
# todo no isp_model?
def
degradation_bsrgan_variant
(
image
,
sf
=
4
,
isp_model
=
None
):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image
=
util
.
uint2single
(
image
)
isp_prob
,
jpeg_prob
,
scale2_prob
=
0.25
,
0.9
,
0.25
sf_ori
=
sf
h1
,
w1
=
image
.
shape
[:
2
]
image
=
image
.
copy
()[:
w1
-
w1
%
sf
,
:
h1
-
h1
%
sf
,
...]
# mod crop
h
,
w
=
image
.
shape
[:
2
]
hq
=
image
.
copy
()
if
sf
==
4
and
random
.
random
()
<
scale2_prob
:
# downsample1
if
np
.
random
.
rand
()
<
0.5
:
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
2
*
image
.
shape
[
1
]),
int
(
1
/
2
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
image
=
util
.
imresize_np
(
image
,
1
/
2
,
True
)
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
sf
=
2
shuffle_order
=
random
.
sample
(
range
(
7
),
7
)
idx1
,
idx2
=
shuffle_order
.
index
(
2
),
shuffle_order
.
index
(
3
)
if
idx1
>
idx2
:
# keep downsample3 last
shuffle_order
[
idx1
],
shuffle_order
[
idx2
]
=
shuffle_order
[
idx2
],
shuffle_order
[
idx1
]
for
i
in
shuffle_order
:
if
i
==
0
:
image
=
add_blur
(
image
,
sf
=
sf
)
# elif i == 1:
# image = add_blur(image, sf=sf)
if
i
==
0
:
pass
elif
i
==
2
:
a
,
b
=
image
.
shape
[
1
],
image
.
shape
[
0
]
# downsample2
if
random
.
random
()
<
0.8
:
sf1
=
random
.
uniform
(
1
,
2
*
sf
)
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf1
*
image
.
shape
[
1
]),
int
(
1
/
sf1
*
image
.
shape
[
0
])),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
else
:
k
=
fspecial
(
'gaussian'
,
25
,
random
.
uniform
(
0.1
,
0.6
*
sf
))
k_shifted
=
shift_pixel
(
k
,
sf
)
k_shifted
=
k_shifted
/
k_shifted
.
sum
()
# blur with shifted kernel
image
=
ndimage
.
filters
.
convolve
(
image
,
np
.
expand_dims
(
k_shifted
,
axis
=
2
),
mode
=
'mirror'
)
image
=
image
[
0
::
sf
,
0
::
sf
,
...]
# nearest downsampling
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
3
:
# downsample3
image
=
cv2
.
resize
(
image
,
(
int
(
1
/
sf
*
a
),
int
(
1
/
sf
*
b
)),
interpolation
=
random
.
choice
([
1
,
2
,
3
]))
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
elif
i
==
4
:
# add Gaussian noise
image
=
add_Gaussian_noise
(
image
,
noise_level1
=
1
,
noise_level2
=
2
)
elif
i
==
5
:
# add JPEG noise
if
random
.
random
()
<
jpeg_prob
:
image
=
add_JPEG_noise
(
image
)
#
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image
=
add_JPEG_noise
(
image
)
image
=
util
.
single2uint
(
image
)
example
=
{
"image"
:
image
}
return
example
if
__name__
==
'__main__'
:
print
(
"hey"
)
img
=
util
.
imread_uint
(
'utils/test.png'
,
3
)
img
=
img
[:
448
,
:
448
]
h
=
img
.
shape
[
0
]
//
4
print
(
"resizing to"
,
h
)
sf
=
4
deg_fn
=
partial
(
degradation_bsrgan_variant
,
sf
=
sf
)
for
i
in
range
(
20
):
print
(
i
)
img_hq
=
img
img_lq
=
deg_fn
(
img
)[
"image"
]
img_hq
,
img_lq
=
util
.
uint2single
(
img_hq
),
util
.
uint2single
(
img_lq
)
print
(
img_lq
)
img_lq_bicubic
=
albumentations
.
SmallestMaxSize
(
max_size
=
h
,
interpolation
=
cv2
.
INTER_CUBIC
)(
image
=
img_hq
)[
"image"
]
print
(
img_lq
.
shape
)
print
(
"bicubic"
,
img_lq_bicubic
.
shape
)
print
(
img_hq
.
shape
)
lq_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
lq_bicubic_nearest
=
cv2
.
resize
(
util
.
single2uint
(
img_lq_bicubic
),
(
int
(
sf
*
img_lq
.
shape
[
1
]),
int
(
sf
*
img_lq
.
shape
[
0
])),
interpolation
=
0
)
img_concat
=
np
.
concatenate
([
lq_bicubic_nearest
,
lq_nearest
,
util
.
single2uint
(
img_hq
)],
axis
=
1
)
util
.
imsave
(
img_concat
,
str
(
i
)
+
'.png'
)
Ruyi-Models/ruyi/vae/ldm/modules/image_degradation/utils/test.png
0 → 100644
View file @
08a21d59
431 KB
Ruyi-Models/ruyi/vae/ldm/modules/image_degradation/utils_image.py
0 → 100644
View file @
08a21d59
import
math
import
os
import
random
from
datetime
import
datetime
import
cv2
import
numpy
as
np
import
torch
from
torchvision.utils
import
make_grid
import
matplotlib.pyplot
as
plt
#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
os
.
environ
[
"KMP_DUPLICATE_LIB_OK"
]
=
"TRUE"
'''
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# --------------------------------------------
# https://github.com/twhui/SRGAN-pyTorch
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
IMG_EXTENSIONS
=
[
'.jpg'
,
'.JPG'
,
'.jpeg'
,
'.JPEG'
,
'.png'
,
'.PNG'
,
'.ppm'
,
'.PPM'
,
'.bmp'
,
'.BMP'
,
'.tif'
]
def
is_image_file
(
filename
):
return
any
(
filename
.
endswith
(
extension
)
for
extension
in
IMG_EXTENSIONS
)
def
get_timestamp
():
return
datetime
.
now
().
strftime
(
'%y%m%d-%H%M%S'
)
def
imshow
(
x
,
title
=
None
,
cbar
=
False
,
figsize
=
None
):
plt
.
figure
(
figsize
=
figsize
)
plt
.
imshow
(
np
.
squeeze
(
x
),
interpolation
=
'nearest'
,
cmap
=
'gray'
)
if
title
:
plt
.
title
(
title
)
if
cbar
:
plt
.
colorbar
()
plt
.
show
()
def
surf
(
Z
,
cmap
=
'rainbow'
,
figsize
=
None
):
plt
.
figure
(
figsize
=
figsize
)
ax3
=
plt
.
axes
(
projection
=
'3d'
)
w
,
h
=
Z
.
shape
[:
2
]
xx
=
np
.
arange
(
0
,
w
,
1
)
yy
=
np
.
arange
(
0
,
h
,
1
)
X
,
Y
=
np
.
meshgrid
(
xx
,
yy
)
ax3
.
plot_surface
(
X
,
Y
,
Z
,
cmap
=
cmap
)
#ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
plt
.
show
()
'''
# --------------------------------------------
# get image pathes
# --------------------------------------------
'''
def
get_image_paths
(
dataroot
):
paths
=
None
# return None if dataroot is None
if
dataroot
is
not
None
:
paths
=
sorted
(
_get_paths_from_images
(
dataroot
))
return
paths
def
_get_paths_from_images
(
path
):
assert
os
.
path
.
isdir
(
path
),
'{:s} is not a valid directory'
.
format
(
path
)
images
=
[]
for
dirpath
,
_
,
fnames
in
sorted
(
os
.
walk
(
path
)):
for
fname
in
sorted
(
fnames
):
if
is_image_file
(
fname
):
img_path
=
os
.
path
.
join
(
dirpath
,
fname
)
images
.
append
(
img_path
)
assert
images
,
'{:s} has no valid image file'
.
format
(
path
)
return
images
'''
# --------------------------------------------
# split large images into small images
# --------------------------------------------
'''
def
patches_from_image
(
img
,
p_size
=
512
,
p_overlap
=
64
,
p_max
=
800
):
w
,
h
=
img
.
shape
[:
2
]
patches
=
[]
if
w
>
p_max
and
h
>
p_max
:
w1
=
list
(
np
.
arange
(
0
,
w
-
p_size
,
p_size
-
p_overlap
,
dtype
=
np
.
int
))
h1
=
list
(
np
.
arange
(
0
,
h
-
p_size
,
p_size
-
p_overlap
,
dtype
=
np
.
int
))
w1
.
append
(
w
-
p_size
)
h1
.
append
(
h
-
p_size
)
# print(w1)
# print(h1)
for
i
in
w1
:
for
j
in
h1
:
patches
.
append
(
img
[
i
:
i
+
p_size
,
j
:
j
+
p_size
,:])
else
:
patches
.
append
(
img
)
return
patches
def
imssave
(
imgs
,
img_path
):
"""
imgs: list, N images of size WxHxC
"""
img_name
,
ext
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))
for
i
,
img
in
enumerate
(
imgs
):
if
img
.
ndim
==
3
:
img
=
img
[:,
:,
[
2
,
1
,
0
]]
new_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
img_path
),
img_name
+
str
(
'_s{:04d}'
.
format
(
i
))
+
'.png'
)
cv2
.
imwrite
(
new_path
,
img
)
def
split_imageset
(
original_dataroot
,
taget_dataroot
,
n_channels
=
3
,
p_size
=
800
,
p_overlap
=
96
,
p_max
=
1000
):
"""
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
will be splitted.
Args:
original_dataroot:
taget_dataroot:
p_size: size of small images
p_overlap: patch size in training is a good choice
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
"""
paths
=
get_image_paths
(
original_dataroot
)
for
img_path
in
paths
:
# img_name, ext = os.path.splitext(os.path.basename(img_path))
img
=
imread_uint
(
img_path
,
n_channels
=
n_channels
)
patches
=
patches_from_image
(
img
,
p_size
,
p_overlap
,
p_max
)
imssave
(
patches
,
os
.
path
.
join
(
taget_dataroot
,
os
.
path
.
basename
(
img_path
)))
#if original_dataroot == taget_dataroot:
#del img_path
'''
# --------------------------------------------
# makedir
# --------------------------------------------
'''
def
mkdir
(
path
):
if
not
os
.
path
.
exists
(
path
):
os
.
makedirs
(
path
)
def
mkdirs
(
paths
):
if
isinstance
(
paths
,
str
):
mkdir
(
paths
)
else
:
for
path
in
paths
:
mkdir
(
path
)
def
mkdir_and_rename
(
path
):
if
os
.
path
.
exists
(
path
):
new_name
=
path
+
'_archived_'
+
get_timestamp
()
print
(
'Path already exists. Rename it to [{:s}]'
.
format
(
new_name
))
os
.
rename
(
path
,
new_name
)
os
.
makedirs
(
path
)
'''
# --------------------------------------------
# read image from path
# opencv is fast, but read BGR numpy image
# --------------------------------------------
'''
# --------------------------------------------
# get uint8 image of size HxWxn_channles (RGB)
# --------------------------------------------
def
imread_uint
(
path
,
n_channels
=
3
):
# input: path
# output: HxWx3(RGB or GGG), or HxWx1 (G)
if
n_channels
==
1
:
img
=
cv2
.
imread
(
path
,
0
)
# cv2.IMREAD_GRAYSCALE
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
# HxWx1
elif
n_channels
==
3
:
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
# BGR or G
if
img
.
ndim
==
2
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2RGB
)
# GGG
else
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
# RGB
return
img
# --------------------------------------------
# matlab's imwrite
# --------------------------------------------
def
imsave
(
img
,
img_path
):
img
=
np
.
squeeze
(
img
)
if
img
.
ndim
==
3
:
img
=
img
[:,
:,
[
2
,
1
,
0
]]
cv2
.
imwrite
(
img_path
,
img
)
def
imwrite
(
img
,
img_path
):
img
=
np
.
squeeze
(
img
)
if
img
.
ndim
==
3
:
img
=
img
[:,
:,
[
2
,
1
,
0
]]
cv2
.
imwrite
(
img_path
,
img
)
# --------------------------------------------
# get single image of size HxWxn_channles (BGR)
# --------------------------------------------
def
read_img
(
path
):
# read image by cv2
# return: Numpy float32, HWC, BGR, [0,1]
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
# cv2.IMREAD_GRAYSCALE
img
=
img
.
astype
(
np
.
float32
)
/
255.
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
# some images have 4 channels
if
img
.
shape
[
2
]
>
3
:
img
=
img
[:,
:,
:
3
]
return
img
'''
# --------------------------------------------
# image format conversion
# --------------------------------------------
# numpy(single) <---> numpy(unit)
# numpy(single) <---> tensor
# numpy(unit) <---> tensor
# --------------------------------------------
'''
# --------------------------------------------
# numpy(single) [0, 1] <---> numpy(unit)
# --------------------------------------------
def
uint2single
(
img
):
return
np
.
float32
(
img
/
255.
)
def
single2uint
(
img
):
return
np
.
uint8
((
img
.
clip
(
0
,
1
)
*
255.
).
round
())
def
uint162single
(
img
):
return
np
.
float32
(
img
/
65535.
)
def
single2uint16
(
img
):
return
np
.
uint16
((
img
.
clip
(
0
,
1
)
*
65535.
).
round
())
# --------------------------------------------
# numpy(unit) (HxWxC or HxW) <---> tensor
# --------------------------------------------
# convert uint to 4-dimensional torch tensor
def
uint2tensor4
(
img
):
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
().
div
(
255.
).
unsqueeze
(
0
)
# convert uint to 3-dimensional torch tensor
def
uint2tensor3
(
img
):
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
().
div
(
255.
)
# convert 2/3/4-dimensional torch tensor to uint
def
tensor2uint
(
img
):
img
=
img
.
data
.
squeeze
().
float
().
clamp_
(
0
,
1
).
cpu
().
numpy
()
if
img
.
ndim
==
3
:
img
=
np
.
transpose
(
img
,
(
1
,
2
,
0
))
return
np
.
uint8
((
img
*
255.0
).
round
())
# --------------------------------------------
# numpy(single) (HxWxC) <---> tensor
# --------------------------------------------
# convert single (HxWxC) to 3-dimensional torch tensor
def
single2tensor3
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
()
# convert single (HxWxC) to 4-dimensional torch tensor
def
single2tensor4
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
).
float
().
unsqueeze
(
0
)
# convert torch tensor to single
def
tensor2single
(
img
):
img
=
img
.
data
.
squeeze
().
float
().
cpu
().
numpy
()
if
img
.
ndim
==
3
:
img
=
np
.
transpose
(
img
,
(
1
,
2
,
0
))
return
img
# convert torch tensor to single
def
tensor2single3
(
img
):
img
=
img
.
data
.
squeeze
().
float
().
cpu
().
numpy
()
if
img
.
ndim
==
3
:
img
=
np
.
transpose
(
img
,
(
1
,
2
,
0
))
elif
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
return
img
def
single2tensor5
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
,
3
).
float
().
unsqueeze
(
0
)
def
single32tensor5
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
float
().
unsqueeze
(
0
).
unsqueeze
(
0
)
def
single42tensor4
(
img
):
return
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img
)).
permute
(
2
,
0
,
1
,
3
).
float
()
# from skimage.io import imread, imsave
def
tensor2img
(
tensor
,
out_type
=
np
.
uint8
,
min_max
=
(
0
,
1
)):
'''
Converts a torch Tensor into an image Numpy array of BGR channel order
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
'''
tensor
=
tensor
.
squeeze
().
float
().
cpu
().
clamp_
(
*
min_max
)
# squeeze first, then clamp
tensor
=
(
tensor
-
min_max
[
0
])
/
(
min_max
[
1
]
-
min_max
[
0
])
# to range [0,1]
n_dim
=
tensor
.
dim
()
if
n_dim
==
4
:
n_img
=
len
(
tensor
)
img_np
=
make_grid
(
tensor
,
nrow
=
int
(
math
.
sqrt
(
n_img
)),
normalize
=
False
).
numpy
()
img_np
=
np
.
transpose
(
img_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
# HWC, BGR
elif
n_dim
==
3
:
img_np
=
tensor
.
numpy
()
img_np
=
np
.
transpose
(
img_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
# HWC, BGR
elif
n_dim
==
2
:
img_np
=
tensor
.
numpy
()
else
:
raise
TypeError
(
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'
.
format
(
n_dim
))
if
out_type
==
np
.
uint8
:
img_np
=
(
img_np
*
255.0
).
round
()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return
img_np
.
astype
(
out_type
)
'''
# --------------------------------------------
# Augmentation, flipe and/or rotate
# --------------------------------------------
# The following two are enough.
# (1) augmet_img: numpy image of WxHxC or WxH
# (2) augment_img_tensor4: tensor image 1xCxWxH
# --------------------------------------------
'''
def
augment_img
(
img
,
mode
=
0
):
'''Kai Zhang (github: https://github.com/cszn)
'''
if
mode
==
0
:
return
img
elif
mode
==
1
:
return
np
.
flipud
(
np
.
rot90
(
img
))
elif
mode
==
2
:
return
np
.
flipud
(
img
)
elif
mode
==
3
:
return
np
.
rot90
(
img
,
k
=
3
)
elif
mode
==
4
:
return
np
.
flipud
(
np
.
rot90
(
img
,
k
=
2
))
elif
mode
==
5
:
return
np
.
rot90
(
img
)
elif
mode
==
6
:
return
np
.
rot90
(
img
,
k
=
2
)
elif
mode
==
7
:
return
np
.
flipud
(
np
.
rot90
(
img
,
k
=
3
))
def
augment_img_tensor4
(
img
,
mode
=
0
):
'''Kai Zhang (github: https://github.com/cszn)
'''
if
mode
==
0
:
return
img
elif
mode
==
1
:
return
img
.
rot90
(
1
,
[
2
,
3
]).
flip
([
2
])
elif
mode
==
2
:
return
img
.
flip
([
2
])
elif
mode
==
3
:
return
img
.
rot90
(
3
,
[
2
,
3
])
elif
mode
==
4
:
return
img
.
rot90
(
2
,
[
2
,
3
]).
flip
([
2
])
elif
mode
==
5
:
return
img
.
rot90
(
1
,
[
2
,
3
])
elif
mode
==
6
:
return
img
.
rot90
(
2
,
[
2
,
3
])
elif
mode
==
7
:
return
img
.
rot90
(
3
,
[
2
,
3
]).
flip
([
2
])
def
augment_img_tensor
(
img
,
mode
=
0
):
'''Kai Zhang (github: https://github.com/cszn)
'''
img_size
=
img
.
size
()
img_np
=
img
.
data
.
cpu
().
numpy
()
if
len
(
img_size
)
==
3
:
img_np
=
np
.
transpose
(
img_np
,
(
1
,
2
,
0
))
elif
len
(
img_size
)
==
4
:
img_np
=
np
.
transpose
(
img_np
,
(
2
,
3
,
1
,
0
))
img_np
=
augment_img
(
img_np
,
mode
=
mode
)
img_tensor
=
torch
.
from_numpy
(
np
.
ascontiguousarray
(
img_np
))
if
len
(
img_size
)
==
3
:
img_tensor
=
img_tensor
.
permute
(
2
,
0
,
1
)
elif
len
(
img_size
)
==
4
:
img_tensor
=
img_tensor
.
permute
(
3
,
2
,
0
,
1
)
return
img_tensor
.
type_as
(
img
)
def
augment_img_np3
(
img
,
mode
=
0
):
if
mode
==
0
:
return
img
elif
mode
==
1
:
return
img
.
transpose
(
1
,
0
,
2
)
elif
mode
==
2
:
return
img
[::
-
1
,
:,
:]
elif
mode
==
3
:
img
=
img
[::
-
1
,
:,
:]
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
elif
mode
==
4
:
return
img
[:,
::
-
1
,
:]
elif
mode
==
5
:
img
=
img
[:,
::
-
1
,
:]
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
elif
mode
==
6
:
img
=
img
[:,
::
-
1
,
:]
img
=
img
[::
-
1
,
:,
:]
return
img
elif
mode
==
7
:
img
=
img
[:,
::
-
1
,
:]
img
=
img
[::
-
1
,
:,
:]
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
def
augment_imgs
(
img_list
,
hflip
=
True
,
rot
=
True
):
# horizontal flip OR rotate
hflip
=
hflip
and
random
.
random
()
<
0.5
vflip
=
rot
and
random
.
random
()
<
0.5
rot90
=
rot
and
random
.
random
()
<
0.5
def
_augment
(
img
):
if
hflip
:
img
=
img
[:,
::
-
1
,
:]
if
vflip
:
img
=
img
[::
-
1
,
:,
:]
if
rot90
:
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
return
[
_augment
(
img
)
for
img
in
img_list
]
'''
# --------------------------------------------
# modcrop and shave
# --------------------------------------------
'''
def
modcrop
(
img_in
,
scale
):
# img_in: Numpy, HWC or HW
img
=
np
.
copy
(
img_in
)
if
img
.
ndim
==
2
:
H
,
W
=
img
.
shape
H_r
,
W_r
=
H
%
scale
,
W
%
scale
img
=
img
[:
H
-
H_r
,
:
W
-
W_r
]
elif
img
.
ndim
==
3
:
H
,
W
,
C
=
img
.
shape
H_r
,
W_r
=
H
%
scale
,
W
%
scale
img
=
img
[:
H
-
H_r
,
:
W
-
W_r
,
:]
else
:
raise
ValueError
(
'Wrong img ndim: [{:d}].'
.
format
(
img
.
ndim
))
return
img
def
shave
(
img_in
,
border
=
0
):
# img_in: Numpy, HWC or HW
img
=
np
.
copy
(
img_in
)
h
,
w
=
img
.
shape
[:
2
]
img
=
img
[
border
:
h
-
border
,
border
:
w
-
border
]
return
img
'''
# --------------------------------------------
# image processing process on numpy image
# channel_convert(in_c, tar_type, img_list):
# rgb2ycbcr(img, only_y=True):
# bgr2ycbcr(img, only_y=True):
# ycbcr2rgb(img):
# --------------------------------------------
'''
def
rgb2ycbcr
(
img
,
only_y
=
True
):
'''same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type
=
img
.
dtype
img
.
astype
(
np
.
float32
)
if
in_img_type
!=
np
.
uint8
:
img
*=
255.
# convert
if
only_y
:
rlt
=
np
.
dot
(
img
,
[
65.481
,
128.553
,
24.966
])
/
255.0
+
16.0
else
:
rlt
=
np
.
matmul
(
img
,
[[
65.481
,
-
37.797
,
112.0
],
[
128.553
,
-
74.203
,
-
93.786
],
[
24.966
,
112.0
,
-
18.214
]])
/
255.0
+
[
16
,
128
,
128
]
if
in_img_type
==
np
.
uint8
:
rlt
=
rlt
.
round
()
else
:
rlt
/=
255.
return
rlt
.
astype
(
in_img_type
)
def
ycbcr2rgb
(
img
):
'''same as matlab ycbcr2rgb
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type
=
img
.
dtype
img
.
astype
(
np
.
float32
)
if
in_img_type
!=
np
.
uint8
:
img
*=
255.
# convert
rlt
=
np
.
matmul
(
img
,
[[
0.00456621
,
0.00456621
,
0.00456621
],
[
0
,
-
0.00153632
,
0.00791071
],
[
0.00625893
,
-
0.00318811
,
0
]])
*
255.0
+
[
-
222.921
,
135.576
,
-
276.836
]
if
in_img_type
==
np
.
uint8
:
rlt
=
rlt
.
round
()
else
:
rlt
/=
255.
return
rlt
.
astype
(
in_img_type
)
def
bgr2ycbcr
(
img
,
only_y
=
True
):
'''bgr version of rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type
=
img
.
dtype
img
.
astype
(
np
.
float32
)
if
in_img_type
!=
np
.
uint8
:
img
*=
255.
# convert
if
only_y
:
rlt
=
np
.
dot
(
img
,
[
24.966
,
128.553
,
65.481
])
/
255.0
+
16.0
else
:
rlt
=
np
.
matmul
(
img
,
[[
24.966
,
112.0
,
-
18.214
],
[
128.553
,
-
74.203
,
-
93.786
],
[
65.481
,
-
37.797
,
112.0
]])
/
255.0
+
[
16
,
128
,
128
]
if
in_img_type
==
np
.
uint8
:
rlt
=
rlt
.
round
()
else
:
rlt
/=
255.
return
rlt
.
astype
(
in_img_type
)
def
channel_convert
(
in_c
,
tar_type
,
img_list
):
# conversion among BGR, gray and y
if
in_c
==
3
and
tar_type
==
'gray'
:
# BGR to gray
gray_list
=
[
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
for
img
in
img_list
]
return
[
np
.
expand_dims
(
img
,
axis
=
2
)
for
img
in
gray_list
]
elif
in_c
==
3
and
tar_type
==
'y'
:
# BGR to y
y_list
=
[
bgr2ycbcr
(
img
,
only_y
=
True
)
for
img
in
img_list
]
return
[
np
.
expand_dims
(
img
,
axis
=
2
)
for
img
in
y_list
]
elif
in_c
==
1
and
tar_type
==
'RGB'
:
# gray/y to BGR
return
[
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
for
img
in
img_list
]
else
:
return
img_list
'''
# --------------------------------------------
# metric, PSNR and SSIM
# --------------------------------------------
'''
# --------------------------------------------
# PSNR
# --------------------------------------------
def
calculate_psnr
(
img1
,
img2
,
border
=
0
):
# img1 and img2 have range [0, 255]
#img1 = img1.squeeze()
#img2 = img2.squeeze()
if
not
img1
.
shape
==
img2
.
shape
:
raise
ValueError
(
'Input images must have the same dimensions.'
)
h
,
w
=
img1
.
shape
[:
2
]
img1
=
img1
[
border
:
h
-
border
,
border
:
w
-
border
]
img2
=
img2
[
border
:
h
-
border
,
border
:
w
-
border
]
img1
=
img1
.
astype
(
np
.
float64
)
img2
=
img2
.
astype
(
np
.
float64
)
mse
=
np
.
mean
((
img1
-
img2
)
**
2
)
if
mse
==
0
:
return
float
(
'inf'
)
return
20
*
math
.
log10
(
255.0
/
math
.
sqrt
(
mse
))
# --------------------------------------------
# SSIM
# --------------------------------------------
def
calculate_ssim
(
img1
,
img2
,
border
=
0
):
'''calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
#img1 = img1.squeeze()
#img2 = img2.squeeze()
if
not
img1
.
shape
==
img2
.
shape
:
raise
ValueError
(
'Input images must have the same dimensions.'
)
h
,
w
=
img1
.
shape
[:
2
]
img1
=
img1
[
border
:
h
-
border
,
border
:
w
-
border
]
img2
=
img2
[
border
:
h
-
border
,
border
:
w
-
border
]
if
img1
.
ndim
==
2
:
return
ssim
(
img1
,
img2
)
elif
img1
.
ndim
==
3
:
if
img1
.
shape
[
2
]
==
3
:
ssims
=
[]
for
i
in
range
(
3
):
ssims
.
append
(
ssim
(
img1
[:,:,
i
],
img2
[:,:,
i
]))
return
np
.
array
(
ssims
).
mean
()
elif
img1
.
shape
[
2
]
==
1
:
return
ssim
(
np
.
squeeze
(
img1
),
np
.
squeeze
(
img2
))
else
:
raise
ValueError
(
'Wrong input image dimensions.'
)
def
ssim
(
img1
,
img2
):
C1
=
(
0.01
*
255
)
**
2
C2
=
(
0.03
*
255
)
**
2
img1
=
img1
.
astype
(
np
.
float64
)
img2
=
img2
.
astype
(
np
.
float64
)
kernel
=
cv2
.
getGaussianKernel
(
11
,
1.5
)
window
=
np
.
outer
(
kernel
,
kernel
.
transpose
())
mu1
=
cv2
.
filter2D
(
img1
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
# valid
mu2
=
cv2
.
filter2D
(
img2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
mu1_sq
=
mu1
**
2
mu2_sq
=
mu2
**
2
mu1_mu2
=
mu1
*
mu2
sigma1_sq
=
cv2
.
filter2D
(
img1
**
2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu1_sq
sigma2_sq
=
cv2
.
filter2D
(
img2
**
2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu2_sq
sigma12
=
cv2
.
filter2D
(
img1
*
img2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu1_mu2
ssim_map
=
((
2
*
mu1_mu2
+
C1
)
*
(
2
*
sigma12
+
C2
))
/
((
mu1_sq
+
mu2_sq
+
C1
)
*
(
sigma1_sq
+
sigma2_sq
+
C2
))
return
ssim_map
.
mean
()
'''
# --------------------------------------------
# matlab's bicubic imresize (numpy and torch) [0, 1]
# --------------------------------------------
'''
# matlab 'imresize' function, now only support 'bicubic'
def
cubic
(
x
):
absx
=
torch
.
abs
(
x
)
absx2
=
absx
**
2
absx3
=
absx
**
3
return
(
1.5
*
absx3
-
2.5
*
absx2
+
1
)
*
((
absx
<=
1
).
type_as
(
absx
))
+
\
(
-
0.5
*
absx3
+
2.5
*
absx2
-
4
*
absx
+
2
)
*
(((
absx
>
1
)
*
(
absx
<=
2
)).
type_as
(
absx
))
def
calculate_weights_indices
(
in_length
,
out_length
,
scale
,
kernel
,
kernel_width
,
antialiasing
):
if
(
scale
<
1
)
and
(
antialiasing
):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width
=
kernel_width
/
scale
# Output-space coordinates
x
=
torch
.
linspace
(
1
,
out_length
,
out_length
)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5+scale in output
# space maps to 1.5 in input space.
u
=
x
/
scale
+
0.5
*
(
1
-
1
/
scale
)
# What is the left-most pixel that can be involved in the computation?
left
=
torch
.
floor
(
u
-
kernel_width
/
2
)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
P
=
math
.
ceil
(
kernel_width
)
+
2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices
=
left
.
view
(
out_length
,
1
).
expand
(
out_length
,
P
)
+
torch
.
linspace
(
0
,
P
-
1
,
P
).
view
(
1
,
P
).
expand
(
out_length
,
P
)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center
=
u
.
view
(
out_length
,
1
).
expand
(
out_length
,
P
)
-
indices
# apply cubic kernel
if
(
scale
<
1
)
and
(
antialiasing
):
weights
=
scale
*
cubic
(
distance_to_center
*
scale
)
else
:
weights
=
cubic
(
distance_to_center
)
# Normalize the weights matrix so that each row sums to 1.
weights_sum
=
torch
.
sum
(
weights
,
1
).
view
(
out_length
,
1
)
weights
=
weights
/
weights_sum
.
expand
(
out_length
,
P
)
# If a column in weights is all zero, get rid of it. only consider the first and last column.
weights_zero_tmp
=
torch
.
sum
((
weights
==
0
),
0
)
if
not
math
.
isclose
(
weights_zero_tmp
[
0
],
0
,
rel_tol
=
1e-6
):
indices
=
indices
.
narrow
(
1
,
1
,
P
-
2
)
weights
=
weights
.
narrow
(
1
,
1
,
P
-
2
)
if
not
math
.
isclose
(
weights_zero_tmp
[
-
1
],
0
,
rel_tol
=
1e-6
):
indices
=
indices
.
narrow
(
1
,
0
,
P
-
2
)
weights
=
weights
.
narrow
(
1
,
0
,
P
-
2
)
weights
=
weights
.
contiguous
()
indices
=
indices
.
contiguous
()
sym_len_s
=
-
indices
.
min
()
+
1
sym_len_e
=
indices
.
max
()
-
in_length
indices
=
indices
+
sym_len_s
-
1
return
weights
,
indices
,
int
(
sym_len_s
),
int
(
sym_len_e
)
# --------------------------------------------
# imresize for tensor image [0, 1]
# --------------------------------------------
def
imresize
(
img
,
scale
,
antialiasing
=
True
):
# Now the scale should be the same for H and W
# input: img: pytorch tensor, CHW or HW [0,1]
# output: CHW or HW [0,1] w/o round
need_squeeze
=
True
if
img
.
dim
()
==
2
else
False
if
need_squeeze
:
img
.
unsqueeze_
(
0
)
in_C
,
in_H
,
in_W
=
img
.
size
()
out_C
,
out_H
,
out_W
=
in_C
,
math
.
ceil
(
in_H
*
scale
),
math
.
ceil
(
in_W
*
scale
)
kernel_width
=
4
kernel
=
'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H
,
indices_H
,
sym_len_Hs
,
sym_len_He
=
calculate_weights_indices
(
in_H
,
out_H
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
weights_W
,
indices_W
,
sym_len_Ws
,
sym_len_We
=
calculate_weights_indices
(
in_W
,
out_W
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
# process H dimension
# symmetric copying
img_aug
=
torch
.
FloatTensor
(
in_C
,
in_H
+
sym_len_Hs
+
sym_len_He
,
in_W
)
img_aug
.
narrow
(
1
,
sym_len_Hs
,
in_H
).
copy_
(
img
)
sym_patch
=
img
[:,
:
sym_len_Hs
,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
img_aug
.
narrow
(
1
,
0
,
sym_len_Hs
).
copy_
(
sym_patch_inv
)
sym_patch
=
img
[:,
-
sym_len_He
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
img_aug
.
narrow
(
1
,
sym_len_Hs
+
in_H
,
sym_len_He
).
copy_
(
sym_patch_inv
)
out_1
=
torch
.
FloatTensor
(
in_C
,
out_H
,
in_W
)
kernel_width
=
weights_H
.
size
(
1
)
for
i
in
range
(
out_H
):
idx
=
int
(
indices_H
[
i
][
0
])
for
j
in
range
(
out_C
):
out_1
[
j
,
i
,
:]
=
img_aug
[
j
,
idx
:
idx
+
kernel_width
,
:].
transpose
(
0
,
1
).
mv
(
weights_H
[
i
])
# process W dimension
# symmetric copying
out_1_aug
=
torch
.
FloatTensor
(
in_C
,
out_H
,
in_W
+
sym_len_Ws
+
sym_len_We
)
out_1_aug
.
narrow
(
2
,
sym_len_Ws
,
in_W
).
copy_
(
out_1
)
sym_patch
=
out_1
[:,
:,
:
sym_len_Ws
]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
2
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
2
,
inv_idx
)
out_1_aug
.
narrow
(
2
,
0
,
sym_len_Ws
).
copy_
(
sym_patch_inv
)
sym_patch
=
out_1
[:,
:,
-
sym_len_We
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
2
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
2
,
inv_idx
)
out_1_aug
.
narrow
(
2
,
sym_len_Ws
+
in_W
,
sym_len_We
).
copy_
(
sym_patch_inv
)
out_2
=
torch
.
FloatTensor
(
in_C
,
out_H
,
out_W
)
kernel_width
=
weights_W
.
size
(
1
)
for
i
in
range
(
out_W
):
idx
=
int
(
indices_W
[
i
][
0
])
for
j
in
range
(
out_C
):
out_2
[
j
,
:,
i
]
=
out_1_aug
[
j
,
:,
idx
:
idx
+
kernel_width
].
mv
(
weights_W
[
i
])
if
need_squeeze
:
out_2
.
squeeze_
()
return
out_2
# --------------------------------------------
# imresize for numpy image [0, 1]
# --------------------------------------------
def
imresize_np
(
img
,
scale
,
antialiasing
=
True
):
# Now the scale should be the same for H and W
# input: img: Numpy, HWC or HW [0,1]
# output: HWC or HW [0,1] w/o round
img
=
torch
.
from_numpy
(
img
)
need_squeeze
=
True
if
img
.
dim
()
==
2
else
False
if
need_squeeze
:
img
.
unsqueeze_
(
2
)
in_H
,
in_W
,
in_C
=
img
.
size
()
out_C
,
out_H
,
out_W
=
in_C
,
math
.
ceil
(
in_H
*
scale
),
math
.
ceil
(
in_W
*
scale
)
kernel_width
=
4
kernel
=
'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H
,
indices_H
,
sym_len_Hs
,
sym_len_He
=
calculate_weights_indices
(
in_H
,
out_H
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
weights_W
,
indices_W
,
sym_len_Ws
,
sym_len_We
=
calculate_weights_indices
(
in_W
,
out_W
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
# process H dimension
# symmetric copying
img_aug
=
torch
.
FloatTensor
(
in_H
+
sym_len_Hs
+
sym_len_He
,
in_W
,
in_C
)
img_aug
.
narrow
(
0
,
sym_len_Hs
,
in_H
).
copy_
(
img
)
sym_patch
=
img
[:
sym_len_Hs
,
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
0
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
0
,
inv_idx
)
img_aug
.
narrow
(
0
,
0
,
sym_len_Hs
).
copy_
(
sym_patch_inv
)
sym_patch
=
img
[
-
sym_len_He
:,
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
0
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
0
,
inv_idx
)
img_aug
.
narrow
(
0
,
sym_len_Hs
+
in_H
,
sym_len_He
).
copy_
(
sym_patch_inv
)
out_1
=
torch
.
FloatTensor
(
out_H
,
in_W
,
in_C
)
kernel_width
=
weights_H
.
size
(
1
)
for
i
in
range
(
out_H
):
idx
=
int
(
indices_H
[
i
][
0
])
for
j
in
range
(
out_C
):
out_1
[
i
,
:,
j
]
=
img_aug
[
idx
:
idx
+
kernel_width
,
:,
j
].
transpose
(
0
,
1
).
mv
(
weights_H
[
i
])
# process W dimension
# symmetric copying
out_1_aug
=
torch
.
FloatTensor
(
out_H
,
in_W
+
sym_len_Ws
+
sym_len_We
,
in_C
)
out_1_aug
.
narrow
(
1
,
sym_len_Ws
,
in_W
).
copy_
(
out_1
)
sym_patch
=
out_1
[:,
:
sym_len_Ws
,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
out_1_aug
.
narrow
(
1
,
0
,
sym_len_Ws
).
copy_
(
sym_patch_inv
)
sym_patch
=
out_1
[:,
-
sym_len_We
:,
:]
inv_idx
=
torch
.
arange
(
sym_patch
.
size
(
1
)
-
1
,
-
1
,
-
1
).
long
()
sym_patch_inv
=
sym_patch
.
index_select
(
1
,
inv_idx
)
out_1_aug
.
narrow
(
1
,
sym_len_Ws
+
in_W
,
sym_len_We
).
copy_
(
sym_patch_inv
)
out_2
=
torch
.
FloatTensor
(
out_H
,
out_W
,
in_C
)
kernel_width
=
weights_W
.
size
(
1
)
for
i
in
range
(
out_W
):
idx
=
int
(
indices_W
[
i
][
0
])
for
j
in
range
(
out_C
):
out_2
[:,
i
,
j
]
=
out_1_aug
[:,
idx
:
idx
+
kernel_width
,
j
].
mv
(
weights_W
[
i
])
if
need_squeeze
:
out_2
.
squeeze_
()
return
out_2
.
numpy
()
if
__name__
==
'__main__'
:
print
(
'---'
)
# img = imread_uint('test.bmp', 3)
# img = uint2single(img)
# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
Ruyi-Models/ruyi/vae/ldm/modules/losses/__init__.py
0 → 100644
View file @
08a21d59
from
.contperceptual
import
LPIPSWithDiscriminator
\ No newline at end of file
Ruyi-Models/ruyi/vae/ldm/modules/losses/contperceptual.py
0 → 100644
View file @
08a21d59
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
taming.modules.losses.vqperceptual
import
*
# TODO: taming dependency yes/no?
from
..vaemodules.discriminator
import
Discriminator3D
class
LPIPSWithDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
disc_start
,
logvar_init
=
0.0
,
kl_weight
=
1.0
,
pixelloss_weight
=
1.0
,
disc_num_layers
=
3
,
disc_in_channels
=
3
,
disc_factor
=
1.0
,
disc_weight
=
1.0
,
perceptual_weight
=
1.0
,
use_actnorm
=
False
,
disc_conditional
=
False
,
disc_loss
=
"hinge"
,
l2_loss_weight
=
0.0
,
l1_loss_weight
=
1.0
):
super
().
__init__
()
assert
disc_loss
in
[
"hinge"
,
"vanilla"
]
self
.
kl_weight
=
kl_weight
self
.
pixel_weight
=
pixelloss_weight
self
.
perceptual_loss
=
LPIPS
().
eval
()
self
.
perceptual_weight
=
perceptual_weight
# output log variance
self
.
logvar
=
nn
.
Parameter
(
torch
.
ones
(
size
=
())
*
logvar_init
)
self
.
discriminator
=
NLayerDiscriminator
(
input_nc
=
disc_in_channels
,
n_layers
=
disc_num_layers
,
use_actnorm
=
use_actnorm
).
apply
(
weights_init
)
self
.
discriminator3d
=
Discriminator3D
(
in_channels
=
disc_in_channels
,
block_out_channels
=
(
64
,
128
,
256
)
).
apply
(
weights_init
)
self
.
discriminator_iter_start
=
disc_start
self
.
disc_loss
=
hinge_d_loss
if
disc_loss
==
"hinge"
else
vanilla_d_loss
self
.
disc_factor
=
disc_factor
self
.
discriminator_weight
=
disc_weight
self
.
disc_conditional
=
disc_conditional
self
.
l1_loss_weight
=
l1_loss_weight
self
.
l2_loss_weight
=
l2_loss_weight
def
calculate_adaptive_weight
(
self
,
nll_loss
,
g_loss
,
last_layer
=
None
):
if
last_layer
is
not
None
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
else
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
d_weight
=
torch
.
norm
(
nll_grads
)
/
(
torch
.
norm
(
g_grads
)
+
1e-4
)
d_weight
=
torch
.
clamp
(
d_weight
,
0.0
,
1e4
).
detach
()
d_weight
=
d_weight
*
self
.
discriminator_weight
return
d_weight
def
forward
(
self
,
inputs
,
reconstructions
,
posteriors
,
optimizer_idx
,
global_step
,
last_layer
=
None
,
cond
=
None
,
split
=
"train"
,
weights
=
None
):
if
inputs
.
ndim
==
4
:
inputs
=
inputs
.
unsqueeze
(
2
)
if
reconstructions
.
ndim
==
4
:
reconstructions
=
reconstructions
.
unsqueeze
(
2
)
inputs_ori
=
inputs
reconstructions_ori
=
reconstructions
# get new loss_weight
loss_weights
=
1
# b, _ ,f, _, _ = reconstructions.size()
# loss_weights = torch.ones([b, f]).view(b, 1, f, 1, 1)
# loss_weights[:, :, 0] = 3
# for i in range(1, f, 8):
# loss_weights[:, :, i - 1] = 3
# loss_weights[:, :, i] = 3
# loss_weights[:, :, -1] = 3
# loss_weights = loss_weights.permute(0, 2, 1, 3, 4).flatten(0, 1).to(reconstructions.device)
inputs
=
inputs
.
permute
(
0
,
2
,
1
,
3
,
4
).
flatten
(
0
,
1
)
reconstructions
=
reconstructions
.
permute
(
0
,
2
,
1
,
3
,
4
).
flatten
(
0
,
1
)
rec_loss
=
0
if
self
.
l1_loss_weight
>
0
:
rec_loss
+=
torch
.
abs
(
inputs
.
contiguous
()
-
reconstructions
.
contiguous
())
*
self
.
l1_loss_weight
if
self
.
l2_loss_weight
>
0
:
rec_loss
+=
F
.
mse_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
(),
reduction
=
"none"
)
*
self
.
l2_loss_weight
if
self
.
perceptual_weight
>
0
:
p_loss
=
self
.
perceptual_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
())
rec_loss
=
rec_loss
+
self
.
perceptual_weight
*
p_loss
rec_loss
=
rec_loss
*
loss_weights
nll_loss
=
rec_loss
/
torch
.
exp
(
self
.
logvar
)
+
self
.
logvar
weighted_nll_loss
=
nll_loss
if
weights
is
not
None
:
weighted_nll_loss
=
weights
*
nll_loss
weighted_nll_loss
=
torch
.
sum
(
weighted_nll_loss
)
/
weighted_nll_loss
.
shape
[
0
]
nll_loss
=
torch
.
sum
(
nll_loss
)
/
nll_loss
.
shape
[
0
]
kl_loss
=
posteriors
.
kl
()
kl_loss
=
torch
.
sum
(
kl_loss
)
/
kl_loss
.
shape
[
0
]
# now the GAN part
if
optimizer_idx
==
0
:
# generator update
if
cond
is
None
:
assert
not
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
())
else
:
assert
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
(),
cond
),
dim
=
1
))
logits_fake_3d
=
self
.
discriminator3d
(
reconstructions_ori
.
contiguous
())
g_loss
=
-
torch
.
mean
(
logits_fake
)
-
torch
.
mean
(
logits_fake_3d
)
if
self
.
disc_factor
>
0.0
:
try
:
d_weight
=
self
.
calculate_adaptive_weight
(
nll_loss
,
g_loss
,
last_layer
=
last_layer
)
except
RuntimeError
:
assert
not
self
.
training
d_weight
=
torch
.
tensor
(
0.0
)
else
:
d_weight
=
torch
.
tensor
(
0.0
)
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
loss
=
weighted_nll_loss
+
self
.
kl_weight
*
kl_loss
+
d_weight
*
disc_factor
*
g_loss
log
=
{
"{}/total_loss"
.
format
(
split
):
loss
.
clone
().
detach
().
mean
(),
"{}/logvar"
.
format
(
split
):
self
.
logvar
.
detach
(),
"{}/kl_loss"
.
format
(
split
):
kl_loss
.
detach
().
mean
(),
"{}/nll_loss"
.
format
(
split
):
nll_loss
.
detach
().
mean
(),
"{}/rec_loss"
.
format
(
split
):
rec_loss
.
detach
().
mean
(),
"{}/d_weight"
.
format
(
split
):
d_weight
.
detach
(),
"{}/disc_factor"
.
format
(
split
):
torch
.
tensor
(
disc_factor
),
"{}/g_loss"
.
format
(
split
):
g_loss
.
detach
().
mean
(),
}
return
loss
,
log
if
optimizer_idx
==
1
:
# second pass for discriminator update
if
cond
is
None
:
logits_real
=
self
.
discriminator
(
inputs
.
contiguous
().
detach
())
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
().
detach
())
else
:
logits_real
=
self
.
discriminator
(
torch
.
cat
((
inputs
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
logits_real_3d
=
self
.
discriminator3d
(
inputs_ori
.
contiguous
().
detach
())
logits_fake_3d
=
self
.
discriminator3d
(
reconstructions_ori
.
contiguous
().
detach
())
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
d_loss
=
disc_factor
*
self
.
disc_loss
(
logits_real
,
logits_fake
)
+
disc_factor
*
self
.
disc_loss
(
logits_real_3d
,
logits_fake_3d
)
log
=
{
"{}/disc_loss"
.
format
(
split
):
d_loss
.
clone
().
detach
().
mean
(),
"{}/logits_real"
.
format
(
split
):
logits_real
.
detach
().
mean
(),
"{}/logits_fake"
.
format
(
split
):
logits_fake
.
detach
().
mean
()
}
return
d_loss
,
log
Ruyi-Models/ruyi/vae/ldm/modules/losses/vqperceptual.py
0 → 100644
View file @
08a21d59
import
torch
import
torch.nn.functional
as
F
from
einops
import
repeat
from
taming.modules.discriminator.model
import
(
NLayerDiscriminator
,
weights_init
)
from
taming.modules.losses.lpips
import
LPIPS
from
taming.modules.losses.vqperceptual
import
hinge_d_loss
,
vanilla_d_loss
from
torch
import
nn
def
hinge_d_loss_with_exemplar_weights
(
logits_real
,
logits_fake
,
weights
):
assert
weights
.
shape
[
0
]
==
logits_real
.
shape
[
0
]
==
logits_fake
.
shape
[
0
]
loss_real
=
torch
.
mean
(
F
.
relu
(
1.
-
logits_real
),
dim
=
[
1
,
2
,
3
])
loss_fake
=
torch
.
mean
(
F
.
relu
(
1.
+
logits_fake
),
dim
=
[
1
,
2
,
3
])
loss_real
=
(
weights
*
loss_real
).
sum
()
/
weights
.
sum
()
loss_fake
=
(
weights
*
loss_fake
).
sum
()
/
weights
.
sum
()
d_loss
=
0.5
*
(
loss_real
+
loss_fake
)
return
d_loss
def
adopt_weight
(
weight
,
global_step
,
threshold
=
0
,
value
=
0.
):
if
global_step
<
threshold
:
weight
=
value
return
weight
def
measure_perplexity
(
predicted_indices
,
n_embed
):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings
=
F
.
one_hot
(
predicted_indices
,
n_embed
).
float
().
reshape
(
-
1
,
n_embed
)
avg_probs
=
encodings
.
mean
(
0
)
perplexity
=
(
-
(
avg_probs
*
torch
.
log
(
avg_probs
+
1e-10
)).
sum
()).
exp
()
cluster_use
=
torch
.
sum
(
avg_probs
>
0
)
return
perplexity
,
cluster_use
def
l1
(
x
,
y
):
return
torch
.
abs
(
x
-
y
)
def
l2
(
x
,
y
):
return
torch
.
pow
((
x
-
y
),
2
)
'''
class VQLPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
pixel_loss="l1"):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
assert perceptual_loss in ["lpips", "clips", "dists"]
assert pixel_loss in ["l1", "l2"]
self.codebook_weight = codebook_weight
self.pixel_weight = pixelloss_weight
if perceptual_loss == "lpips":
print(f"{self.__class__.__name__}: Running with LPIPS.")
self.perceptual_loss = LPIPS().eval()
else:
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
self.perceptual_weight = perceptual_weight
if pixel_loss == "l1":
self.pixel_loss = l1
else:
self.pixel_loss = l2
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
ndf=disc_ndf
).apply(weights_init)
self.discriminator_iter_start = disc_start
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
self.n_classes = n_classes
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
if not exists(codebook_loss):
codebook_loss = torch.tensor([0.]).to(inputs.device)
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
rec_loss = rec_loss + self.perceptual_weight * p_loss
else:
p_loss = torch.tensor([0.0])
nll_loss = rec_loss
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss = torch.mean(nll_loss)
# now the GAN part
if optimizer_idx == 0:
# generator update
if cond is None:
assert not self.disc_conditional
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
g_loss = -torch.mean(logits_fake)
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/p_loss".format(split): p_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
if predicted_indices is not None:
assert self.n_classes is not None
with torch.no_grad():
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
log[f"{split}/perplexity"] = perplexity
log[f"{split}/cluster_usage"] = cluster_usage
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
}
return d_loss, log
'''
\ No newline at end of file
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/__init__.py
0 → 100755
View file @
08a21d59
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/activations.py
0 → 100755
View file @
08a21d59
import
torch.nn
as
nn
ACTIVATION_FUNCTIONS
=
{
"elu"
:
nn
.
ELU
(),
"swish"
:
nn
.
SiLU
(),
"silu"
:
nn
.
SiLU
(),
"mish"
:
nn
.
Mish
(),
"gelu"
:
nn
.
GELU
(),
"relu"
:
nn
.
ReLU
(),
}
def
get_activation
(
act_fn
:
str
)
->
nn
.
Module
:
"""Helper function to get activation function from string.
Args:
act_fn (str): Name of activation function.
Returns:
nn.Module: Activation function.
"""
act_fn
=
act_fn
.
lower
()
if
act_fn
in
ACTIVATION_FUNCTIONS
:
return
ACTIVATION_FUNCTIONS
[
act_fn
]
else
:
raise
ValueError
(
f
"Unsupported activation function:
{
act_fn
}
"
)
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/attention.py
0 → 100644
View file @
08a21d59
import
inspect
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
.attention_processors
import
AttnProcessor
,
AttnProcessor2_0
from
.common
import
SpatialNorm3D
class
Attention
(
nn
.
Module
):
r
"""
A cross attention layer.
Parameters:
query_dim (`int`):
The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
nheads (`int`, *optional*, defaults to 8):
The number of heads to use for multi-head attention.
head_dim (`int`, *optional*, defaults to 64):
The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
upcast_attention (`bool`, *optional*, defaults to False):
Set to `True` to upcast the attention computation to `float32`.
upcast_softmax (`bool`, *optional*, defaults to False):
Set to `True` to upcast the softmax computation to `float32`.
cross_attention_norm (`str`, *optional*, defaults to `None`):
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the group norm in the cross attention.
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the added key and value projections. If `None`, no projection is used.
norm_num_groups (`int`, *optional*, defaults to `None`):
The number of groups to use for the group norm in the attention.
spatial_norm_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the spatial normalization.
out_bias (`bool`, *optional*, defaults to `True`):
Set to `True` to use a bias in the output linear layer.
scale_qk (`bool`, *optional*, defaults to `True`):
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
only_cross_attention (`bool`, *optional*, defaults to `False`):
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
`added_kv_proj_dim` is not `None`.
eps (`float`, *optional*, defaults to 1e-5):
An additional value added to the denominator in group normalization that is used for numerical stability.
rescale_output_factor (`float`, *optional*, defaults to 1.0):
A factor to rescale the output by dividing it with this value.
residual_connection (`bool`, *optional*, defaults to `False`):
Set to `True` to add the residual connection to the output.
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
Set to `True` if the attention block is loaded from a deprecated state dict.
processor (`AttnProcessor`, *optional*, defaults to `None`):
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
`AttnProcessor` otherwise.
"""
def
__init__
(
self
,
query_dim
:
int
,
cross_attention_dim
:
int
=
None
,
nheads
:
int
=
8
,
head_dim
:
int
=
64
,
dropout
:
float
=
0.0
,
bias
:
bool
=
False
,
upcast_attention
:
bool
=
False
,
upcast_softmax
:
bool
=
False
,
cross_attention_norm
=
None
,
cross_attention_norm_num_groups
:
int
=
32
,
added_kv_proj_dim
=
None
,
norm_num_groups
=
None
,
spatial_norm_dim
=
None
,
out_bias
:
bool
=
True
,
scale_qk
:
bool
=
True
,
only_cross_attention
:
bool
=
False
,
eps
:
float
=
1e-5
,
rescale_output_factor
:
float
=
1.0
,
residual_connection
:
bool
=
False
,
processor
=
None
,
out_dim
:
int
=
None
,
):
super
().
__init__
()
self
.
query_dim
=
query_dim
self
.
cross_attention_dim
=
cross_attention_dim
if
cross_attention_dim
is
not
None
else
query_dim
self
.
inner_dim
=
out_dim
if
out_dim
is
not
None
else
head_dim
*
nheads
self
.
nheads
=
out_dim
//
head_dim
if
out_dim
is
not
None
else
nheads
self
.
out_dim
=
out_dim
if
out_dim
is
not
None
else
query_dim
self
.
upcast_attention
=
upcast_attention
self
.
upcast_softmax
=
upcast_softmax
self
.
added_kv_proj_dim
=
added_kv_proj_dim
self
.
only_cross_attention
=
only_cross_attention
if
self
.
added_kv_proj_dim
is
None
and
self
.
only_cross_attention
:
raise
ValueError
(
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
)
self
.
scale_qk
=
scale_qk
self
.
scale
=
head_dim
**
-
0.5
if
scale_qk
else
1.0
self
.
rescale_output_factor
=
rescale_output_factor
self
.
residual_connection
=
residual_connection
if
norm_num_groups
is
not
None
:
self
.
group_norm
=
nn
.
GroupNorm
(
num_channels
=
query_dim
,
num_groups
=
norm_num_groups
,
eps
=
eps
,
affine
=
True
)
else
:
self
.
group_norm
=
None
if
spatial_norm_dim
is
not
None
:
self
.
spatial_norm
=
SpatialNorm3D
(
f_channels
=
query_dim
,
zq_channels
=
spatial_norm_dim
)
else
:
self
.
spatial_norm
=
None
if
cross_attention_norm
is
None
:
self
.
norm_cross
=
None
elif
cross_attention_norm
==
"layer_norm"
:
self
.
norm_cross
=
nn
.
LayerNorm
(
self
.
cross_attention_dim
)
elif
cross_attention_norm
==
"group_norm"
:
if
self
.
added_kv_proj_dim
is
not
None
:
# The given `encoder_hidden_states` are initially of shape
# (batch_size, seq_len, added_kv_proj_dim) before being projected
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
# before the projection, so we need to use `added_kv_proj_dim` as
# the number of channels for the group norm.
norm_cross_num_channels
=
added_kv_proj_dim
else
:
norm_cross_num_channels
=
self
.
cross_attention_dim
self
.
norm_cross
=
nn
.
GroupNorm
(
num_channels
=
norm_cross_num_channels
,
num_groups
=
cross_attention_norm_num_groups
,
eps
=
1e-5
,
affine
=
True
)
else
:
raise
ValueError
(
f
"unknown cross_attention_norm:
{
cross_attention_norm
}
. Should be None, 'layer_norm' or 'group_norm'"
)
self
.
to_q
=
nn
.
Linear
(
query_dim
,
self
.
inner_dim
,
bias
=
bias
)
if
not
self
.
only_cross_attention
:
self
.
to_k
=
nn
.
Linear
(
self
.
cross_attention_dim
,
self
.
inner_dim
,
bias
=
bias
)
self
.
to_v
=
nn
.
Linear
(
self
.
cross_attention_dim
,
self
.
inner_dim
,
bias
=
bias
)
else
:
self
.
to_k
=
None
self
.
to_v
=
None
if
self
.
added_kv_proj_dim
is
not
None
:
self
.
add_k_proj
=
nn
.
Linear
(
added_kv_proj_dim
,
self
.
inner_dim
)
self
.
add_v_proj
=
nn
.
Linear
(
added_kv_proj_dim
,
self
.
inner_dim
)
self
.
to_out
=
nn
.
Linear
(
self
.
inner_dim
,
self
.
out_dim
,
bias
=
out_bias
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
if
processor
is
None
:
processor
=
(
AttnProcessor2_0
()
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
AttnProcessor
()
)
self
.
set_processor
(
processor
)
def
set_processor
(
self
,
processor
:
AttnProcessor
)
->
None
:
r
"""
Set the attention processor to use.
Args:
processor (`AttnProcessor`):
The attention processor to use.
"""
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if
(
hasattr
(
self
,
"processor"
)
and
isinstance
(
self
.
processor
,
torch
.
nn
.
Module
)
and
not
isinstance
(
processor
,
torch
.
nn
.
Module
)
):
self
.
_modules
.
pop
(
"processor"
)
self
.
processor
=
processor
self
.
_attn_parameters
=
set
(
inspect
.
signature
(
self
.
processor
.
__call__
).
parameters
.
keys
())
def
prepare_attention_mask
(
self
,
attention_mask
:
torch
.
Tensor
,
target_length
:
int
,
batch_size
:
int
,
out_dim
:
int
=
3
)
->
torch
.
Tensor
:
r
"""
Prepare the attention mask for the attention computation.
Args:
attention_mask (`torch.Tensor`):
The attention mask to prepare.
target_length (`int`):
The target length of the attention mask. This is the length of the attention mask after padding.
batch_size (`int`):
The batch size, which is used to repeat the attention mask.
out_dim (`int`, *optional*, defaults to `3`):
The output dimension of the attention mask. Can be either `3` or `4`.
Returns:
`torch.Tensor`: The prepared attention mask.
"""
head_size
=
self
.
nheads
if
attention_mask
is
None
:
return
attention_mask
current_length
:
int
=
attention_mask
.
shape
[
-
1
]
if
current_length
!=
target_length
:
if
attention_mask
.
device
.
type
==
"mps"
:
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.
padding_shape
=
(
attention_mask
.
shape
[
0
],
attention_mask
.
shape
[
1
],
target_length
)
padding
=
torch
.
zeros
(
padding_shape
,
dtype
=
attention_mask
.
dtype
,
device
=
attention_mask
.
device
)
attention_mask
=
torch
.
cat
([
attention_mask
,
padding
],
dim
=
2
)
else
:
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
# we want to instead pad by (0, remaining_length), where remaining_length is:
# remaining_length: int = target_length - current_length
attention_mask
=
F
.
pad
(
attention_mask
,
(
0
,
target_length
),
value
=
0.0
)
if
out_dim
==
3
:
if
attention_mask
.
shape
[
0
]
<
batch_size
*
head_size
:
attention_mask
=
attention_mask
.
repeat_interleave
(
head_size
,
dim
=
0
)
elif
out_dim
==
4
:
attention_mask
=
attention_mask
.
unsqueeze
(
1
)
attention_mask
=
attention_mask
.
repeat_interleave
(
head_size
,
dim
=
1
)
return
attention_mask
def
get_attention_scores
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
r
"""
Compute the attention scores.
Args:
query (`torch.Tensor`): The query tensor.
key (`torch.Tensor`): The key tensor.
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
Returns:
`torch.Tensor`: The attention probabilities/scores.
"""
dtype
=
query
.
dtype
if
self
.
upcast_attention
:
query
=
query
.
float
()
key
=
key
.
float
()
if
attention_mask
is
None
:
baddbmm_input
=
torch
.
empty
(
query
.
shape
[
0
],
query
.
shape
[
1
],
key
.
shape
[
1
],
dtype
=
query
.
dtype
,
device
=
query
.
device
)
beta
=
0
else
:
baddbmm_input
=
attention_mask
beta
=
1
attention_scores
=
torch
.
baddbmm
(
baddbmm_input
,
query
,
key
.
transpose
(
-
1
,
-
2
),
beta
=
beta
,
alpha
=
self
.
scale
,
)
del
baddbmm_input
if
self
.
upcast_softmax
:
attention_scores
=
attention_scores
.
float
()
attention_probs
=
attention_scores
.
softmax
(
dim
=-
1
)
del
attention_scores
attention_probs
=
attention_probs
.
to
(
dtype
)
return
attention_probs
def
norm_encoder_hidden_states
(
self
,
encoder_hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
`Attention` class.
Args:
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
Returns:
`torch.Tensor`: The normalized encoder hidden states.
"""
assert
self
.
norm_cross
is
not
None
,
"self.norm_cross must be defined to call self.norm_encoder_hidden_states"
if
isinstance
(
self
.
norm_cross
,
nn
.
LayerNorm
):
encoder_hidden_states
=
self
.
norm_cross
(
encoder_hidden_states
)
elif
isinstance
(
self
.
norm_cross
,
nn
.
GroupNorm
):
# Group norm norms along the channels dimension and expects
# input to be in the shape of (N, C, *). In this case, we want
# to norm along the hidden dimension, so we need to move
# (batch_size, sequence_length, hidden_size) ->
# (batch_size, hidden_size, sequence_length)
encoder_hidden_states
=
encoder_hidden_states
.
transpose
(
1
,
2
)
encoder_hidden_states
=
self
.
norm_cross
(
encoder_hidden_states
)
encoder_hidden_states
=
encoder_hidden_states
.
transpose
(
1
,
2
)
else
:
assert
False
return
encoder_hidden_states
def
batch_to_head_dim
(
self
,
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // nheads, seq_len, dim * nheads]`. `nheads`
is the number of heads initialized while constructing the `Attention` class.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size
=
self
.
nheads
batch_size
,
seq_len
,
dim
=
tensor
.
shape
tensor
=
tensor
.
reshape
(
batch_size
//
head_size
,
head_size
,
seq_len
,
dim
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
return
tensor
def
head_to_batch_dim
(
self
,
tensor
:
torch
.
Tensor
,
out_dim
:
int
=
3
)
->
torch
.
Tensor
:
r
"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, nheads, dim // nheads]` `nheads` is
the number of heads initialized while constructing the `Attention` class.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
reshaped to `[batch_size * nheads, seq_len, dim // nheads]`.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size
=
self
.
nheads
batch_size
,
seq_len
,
dim
=
tensor
.
shape
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
)
if
out_dim
==
3
:
tensor
=
tensor
.
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
return
tensor
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
encoder_hidden_states
:
torch
.
FloatTensor
=
None
,
attention_mask
:
torch
.
FloatTensor
=
None
,
**
cross_attention_kwargs
,
)
->
torch
.
Tensor
:
r
"""
The forward method of the `Attention` class.
Args:
hidden_states (`torch.Tensor`):
The hidden states of the query.
encoder_hidden_states (`torch.Tensor`, *optional*):
The hidden states of the encoder.
attention_mask (`torch.Tensor`, *optional*):
The attention mask to use. If `None`, no mask is applied.
**cross_attention_kwargs:
Additional keyword arguments to pass along to the cross attention.
Returns:
`torch.Tensor`: The output of the attention layer.
"""
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
unused_kwargs
=
[
k
for
k
,
_
in
cross_attention_kwargs
.
items
()
if
k
not
in
self
.
_attn_parameters
]
# if len(unused_kwargs) > 0:
# logger.warning(
# f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
# )
cross_attention_kwargs
=
{
k
:
w
for
k
,
w
in
cross_attention_kwargs
.
items
()
if
k
in
self
.
_attn_parameters
}
return
self
.
processor
(
self
,
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
attention_mask
=
attention_mask
,
**
cross_attention_kwargs
,
)
class
SpatialAttention
(
Attention
):
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
encoder_hidden_states
:
torch
.
FloatTensor
=
None
,
attention_mask
:
torch
.
FloatTensor
=
None
,
**
cross_attention_kwargs
,
)
->
torch
.
Tensor
:
is_image
=
hidden_states
.
ndim
==
4
if
is_image
:
hidden_states
=
rearrange
(
hidden_states
,
"b c h w -> b c 1 h w"
)
bsz
,
h
=
hidden_states
.
shape
[
0
],
hidden_states
.
shape
[
3
]
hidden_states
=
rearrange
(
hidden_states
,
"b c t h w -> (b t) (h w) c"
)
if
encoder_hidden_states
is
not
None
:
encoder_hidden_states
=
rearrange
(
encoder_hidden_states
,
"b c t h w -> (b t) (h w) c"
)
if
attention_mask
is
not
None
:
attention_mask
=
rearrange
(
attention_mask
,
"b t h w -> (b t) (h w)"
)
hidden_states
=
super
().
forward
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
attention_mask
=
attention_mask
,
**
cross_attention_kwargs
,
)
hidden_states
=
rearrange
(
hidden_states
,
"(b t) (h w) c -> b c t h w"
,
b
=
bsz
,
h
=
h
)
if
is_image
:
hidden_states
=
rearrange
(
hidden_states
,
"b c 1 h w -> b c h w"
)
return
hidden_states
class
TemporalAttention
(
Attention
):
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
encoder_hidden_states
:
torch
.
FloatTensor
=
None
,
attention_mask
:
torch
.
FloatTensor
=
None
,
**
cross_attention_kwargs
,
)
->
torch
.
Tensor
:
bsz
,
h
=
hidden_states
.
shape
[
0
],
hidden_states
.
shape
[
3
]
hidden_states
=
rearrange
(
hidden_states
,
"b c t h w -> (b h w) t c"
)
if
encoder_hidden_states
is
not
None
:
encoder_hidden_states
=
rearrange
(
encoder_hidden_states
,
"b c t h w -> (b h w) t c"
)
if
attention_mask
is
not
None
:
attention_mask
=
rearrange
(
attention_mask
,
"b t h w -> (b h w) t"
)
hidden_states
=
super
().
forward
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
attention_mask
=
attention_mask
,
**
cross_attention_kwargs
,
)
hidden_states
=
rearrange
(
hidden_states
,
"(b h w) t c -> b c t h w"
,
b
=
bsz
,
h
=
h
)
return
hidden_states
class
Attention3D
(
Attention
):
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
encoder_hidden_states
:
torch
.
FloatTensor
=
None
,
attention_mask
:
torch
.
FloatTensor
=
None
,
**
cross_attention_kwargs
,
)
->
torch
.
Tensor
:
t
,
h
=
hidden_states
.
shape
[
2
],
hidden_states
.
shape
[
3
]
hidden_states
=
rearrange
(
hidden_states
,
"b c t h w -> b (t h w) c"
)
if
encoder_hidden_states
is
not
None
:
encoder_hidden_states
=
rearrange
(
encoder_hidden_states
,
"b c t h w -> b (t h w) c"
)
if
attention_mask
is
not
None
:
attention_mask
=
rearrange
(
attention_mask
,
"b t h w -> b (t h w)"
)
hidden_states
=
super
().
forward
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
attention_mask
=
attention_mask
,
**
cross_attention_kwargs
,
)
hidden_states
=
rearrange
(
hidden_states
,
"b (t h w) c -> b c t h w"
,
t
=
t
,
h
=
h
)
return
hidden_states
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment