Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
571e4062
Commit
571e4062
authored
Jul 01, 2022
by
Patrick von Platen
Browse files
merge from master
parents
14bd3567
c2bc59d2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
174 additions
and
40 deletions
+174
-40
scripts/conversion_bddm.py
scripts/conversion_bddm.py
+40
-0
scripts/conversion_ldm_uncond.py
scripts/conversion_ldm_uncond.py
+56
-0
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+36
-8
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+38
-27
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
...tent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
+1
-4
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+3
-1
No files found.
scripts/conversion_bddm.py
0 → 100644
View file @
571e4062
import
argparse
import
torch
from
diffusers.pipelines.bddm
import
DiffWave
,
BDDMPipeline
from
diffusers
import
DDPMScheduler
def
convert_bddm_orginal
(
checkpoint_path
,
noise_scheduler_checkpoint_path
,
output_path
):
sd
=
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
)[
"model_state_dict"
]
noise_scheduler_sd
=
torch
.
load
(
noise_scheduler_checkpoint_path
,
map_location
=
"cpu"
)
model
=
DiffWave
()
model
.
load_state_dict
(
sd
,
strict
=
False
)
ts
,
_
,
betas
,
_
=
noise_scheduler_sd
ts
,
betas
=
list
(
ts
.
numpy
().
tolist
()),
list
(
betas
.
numpy
().
tolist
())
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
,
trained_betas
=
betas
,
timestep_values
=
ts
,
clip_sample
=
False
,
tensor_format
=
"np"
,
)
pipeline
=
BDDMPipeline
(
model
,
noise_scheduler
)
pipeline
.
save_pretrained
(
output_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--checkpoint_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--noise_scheduler_checkpoint_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
required
=
True
)
args
=
parser
.
parse_args
()
convert_bddm_orginal
(
args
.
checkpoint_path
,
args
.
noise_scheduler_checkpoint_path
,
args
.
output_path
)
scripts/conversion_ldm_uncond.py
0 → 100644
View file @
571e4062
import
argparse
import
OmegaConf
import
torch
from
diffusers
import
UNetLDMModel
,
VQModel
,
LatentDiffusionUncondPipeline
,
DDIMScheduler
def
convert_ldm_original
(
checkpoint_path
,
config_path
,
output_path
):
config
=
OmegaConf
.
load
(
config_path
)
state_dict
=
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
)[
"model"
]
keys
=
list
(
state_dict
.
keys
())
# extract state_dict for VQVAE
first_stage_dict
=
{}
first_stage_key
=
"first_stage_model."
for
key
in
keys
:
if
key
.
startswith
(
first_stage_key
):
first_stage_dict
[
key
.
replace
(
first_stage_key
,
""
)]
=
state_dict
[
key
]
# extract state_dict for UNetLDM
unet_state_dict
=
{}
unet_key
=
"model.diffusion_model."
for
key
in
keys
:
if
key
.
startswith
(
unet_key
):
unet_state_dict
[
key
.
replace
(
unet_key
,
""
)]
=
state_dict
[
key
]
vqvae_init_args
=
config
.
model
.
params
.
first_stage_config
.
params
unet_init_args
=
config
.
model
.
params
.
unet_config
.
params
vqvae
=
VQModel
(
**
vqvae_init_args
).
eval
()
vqvae
.
load_state_dict
(
first_stage_dict
)
unet
=
UNetLDMModel
(
**
unet_init_args
).
eval
()
unet
.
load_state_dict
(
unet_state_dict
)
noise_scheduler
=
DDIMScheduler
(
timesteps
=
config
.
model
.
params
.
timesteps
,
beta_schedule
=
"scaled_linear"
,
beta_start
=
config
.
model
.
params
.
linear_start
,
beta_end
=
config
.
model
.
params
.
linear_end
,
clip_sample
=
False
,
)
pipeline
=
LatentDiffusionUncondPipeline
(
vqvae
,
unet
,
noise_scheduler
)
pipeline
.
save_pretrained
(
output_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--checkpoint_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
required
=
True
)
args
=
parser
.
parse_args
()
convert_ldm_original
(
args
.
checkpoint_path
,
args
.
config_path
,
args
.
output_path
)
src/diffusers/models/resnet.py
View file @
571e4062
from
abc
import
abstractmethod
from
functools
import
partial
import
numpy
as
np
import
torch
...
...
@@ -78,18 +79,25 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
,
name
=
"conv"
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
self
.
use_conv_transpose
=
use_conv_transpose
self
.
name
=
name
conv
=
None
if
use_conv_transpose
:
self
.
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
elif
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
if
name
==
"conv"
:
self
.
conv
=
conv
else
:
self
.
Conv2d_0
=
conv
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
...
...
@@ -102,7 +110,10 @@ class Upsample(nn.Module):
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
if
self
.
name
==
"conv"
:
x
=
self
.
conv
(
x
)
else
:
x
=
self
.
Conv2d_0
(
x
)
return
x
...
...
@@ -134,6 +145,8 @@ class Downsample(nn.Module):
if
name
==
"conv"
:
self
.
conv
=
conv
elif
name
==
"Conv2d_0"
:
self
.
Conv2d_0
=
conv
else
:
self
.
op
=
conv
...
...
@@ -145,6 +158,8 @@ class Downsample(nn.Module):
if
self
.
name
==
"conv"
:
return
self
.
conv
(
x
)
elif
self
.
name
==
"Conv2d_0"
:
return
self
.
Conv2d_0
(
x
)
else
:
return
self
.
op
(
x
)
...
...
@@ -469,6 +484,7 @@ class ResnetBlockBigGANpp(nn.Module):
up
=
False
,
down
=
False
,
dropout
=
0.1
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
skip_rescale
=
True
,
init_scale
=
0.0
,
...
...
@@ -479,8 +495,20 @@ class ResnetBlockBigGANpp(nn.Module):
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
up
=
up
self
.
down
=
down
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
if
self
.
up
:
if
self
.
fir
:
self
.
upsample
=
partial
(
upsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
upsample
=
partial
(
F
.
interpolate
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
elif
self
.
down
:
if
self
.
fir
:
self
.
downsample
=
partial
(
downsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
downsample
=
partial
(
F
.
avg_pool2d
,
kernel_size
=
2
,
stride
=
2
)
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
...
...
@@ -503,11 +531,11 @@ class ResnetBlockBigGANpp(nn.Module):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
h
=
upsample
_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample
_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
h
=
self
.
upsample
(
h
)
x
=
self
.
upsample
(
x
)
elif
self
.
down
:
h
=
downsample
_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample
_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
h
=
self
.
downsample
(
h
)
x
=
self
.
downsample
(
x
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
571e4062
...
...
@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
downsample_2d
,
upfirdn2d
,
upsample_2d
from
.resnet
import
downsample_2d
,
upfirdn2d
,
upsample_2d
,
Downsample
,
Upsample
from
.resnet
import
ResnetBlock
...
...
@@ -185,18 +185,19 @@ class Combine(nn.Module):
class
FirUpsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
channels
=
None
,
out_ch
annels
=
None
,
use
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
if
with
_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
with
_conv
=
with
_conv
out_ch
annels
=
out_ch
annels
if
out_ch
annels
else
channels
if
use
_conv
:
self
.
Conv2d_0
=
Conv2d
(
channels
,
out_ch
annels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
use
_conv
=
use
_conv
self
.
fir_kernel
=
fir_kernel
self
.
out_ch
=
out_ch
self
.
out_ch
annels
=
out_ch
annels
def
forward
(
self
,
x
):
if
self
.
with
_conv
:
if
self
.
use
_conv
:
h
=
_upsample_conv_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
h
=
h
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
else
:
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
...
...
@@ -204,18 +205,19 @@ class FirUpsample(nn.Module):
class
FirDownsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
channels
=
None
,
out_ch
annels
=
None
,
use
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
if
with
_conv
:
self
.
Conv2d_0
=
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
out_ch
annels
=
out_ch
annels
if
out_ch
annels
else
channels
if
use
_conv
:
self
.
Conv2d_0
=
self
.
Conv2d_0
=
Conv2d
(
channels
,
out_ch
annels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
fir_kernel
=
fir_kernel
self
.
with
_conv
=
with
_conv
self
.
out_ch
=
out_ch
self
.
use
_conv
=
use
_conv
self
.
out_ch
annels
=
out_ch
annels
def
forward
(
self
,
x
):
if
self
.
with
_conv
:
if
self
.
use
_conv
:
x
=
_conv_downsample_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
x
=
x
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
else
:
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
...
...
@@ -229,13 +231,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
,
image_size
=
1024
,
num_channels
=
3
,
centered
=
False
,
attn_resolutions
=
(
16
,),
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
),
conditional
=
True
,
conv_size
=
3
,
dropout
=
0.0
,
embedding_type
=
"fourier"
,
fir
=
True
,
# TODO (patil-suraj) remove this option from here and pre-trained model configs
fir
=
True
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
fourier_scale
=
16
,
init_scale
=
0.0
,
...
...
@@ -253,12 +256,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
.
register_to_config
(
image_size
=
image_size
,
num_channels
=
num_channels
,
centered
=
centered
,
attn_resolutions
=
attn_resolutions
,
ch_mult
=
ch_mult
,
conditional
=
conditional
,
conv_size
=
conv_size
,
dropout
=
dropout
,
embedding_type
=
embedding_type
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
fourier_scale
=
fourier_scale
,
init_scale
=
init_scale
,
...
...
@@ -308,21 +313,26 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules
.
append
(
Linear
(
nf
*
4
,
nf
*
4
))
AttnBlock
=
functools
.
partial
(
AttentionBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
))
Up_sample
=
functools
.
partial
(
FirUpsample
,
with_conv
=
resamp_with_conv
,
fir_kernel
=
fir_kernel
)
if
self
.
fir
:
Up_sample
=
functools
.
partial
(
FirUpsample
,
fir_kernel
=
fir_kernel
,
use_conv
=
resamp_with_conv
)
else
:
Up_sample
=
functools
.
partial
(
Upsample
,
name
=
"Conv2d_0"
)
if
progressive
==
"output_skip"
:
self
.
pyramid_upsample
=
Up_sample
(
fir_kernel
=
fir_ker
ne
l
,
with
_conv
=
False
)
self
.
pyramid_upsample
=
Up_sample
(
channels
=
No
ne
,
use
_conv
=
False
)
elif
progressive
==
"residual"
:
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
fir_kernel
=
fir_kernel
,
with
_conv
=
True
)
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
use
_conv
=
True
)
Down_sample
=
functools
.
partial
(
FirDownsample
,
with_conv
=
resamp_with_conv
,
fir_kernel
=
fir_kernel
)
if
self
.
fir
:
Down_sample
=
functools
.
partial
(
FirDownsample
,
fir_kernel
=
fir_kernel
,
use_conv
=
resamp_with_conv
)
else
:
Down_sample
=
functools
.
partial
(
Downsample
,
padding
=
0
,
name
=
"Conv2d_0"
)
if
progressive_input
==
"input_skip"
:
self
.
pyramid_downsample
=
Down_sample
(
fir_kernel
=
fir_ker
ne
l
,
with
_conv
=
False
)
self
.
pyramid_downsample
=
Down_sample
(
channels
=
No
ne
,
use
_conv
=
False
)
elif
progressive_input
==
"residual"
:
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
# Downsampling block
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
use_conv
=
True
)
channels
=
num_channels
if
progressive_input
!=
"none"
:
...
...
@@ -376,7 +386,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch
*=
2
elif
progressive_input
==
"residual"
:
modules
.
append
(
pyramid_downsample
(
in_ch
=
input_pyramid_ch
,
out_ch
=
in_ch
))
modules
.
append
(
pyramid_downsample
(
channels
=
input_pyramid_ch
,
out_ch
annels
=
in_ch
))
input_pyramid_ch
=
in_ch
hs_c
.
append
(
in_ch
)
...
...
@@ -448,7 +458,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
)
pyramid_ch
=
channels
elif
progressive
==
"residual"
:
modules
.
append
(
pyramid_upsample
(
in_ch
=
pyramid_ch
,
out_ch
=
in_ch
))
modules
.
append
(
pyramid_upsample
(
channels
=
pyramid_ch
,
out_ch
annels
=
in_ch
))
pyramid_ch
=
in_ch
else
:
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
...
...
@@ -505,7 +515,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
temb
=
None
# If input data is in [0, 1]
x
=
2
*
x
-
1.0
if
not
self
.
config
.
centered
:
x
=
2
*
x
-
1.0
# Downsampling block
input_pyramid
=
None
...
...
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
View file @
571e4062
...
...
@@ -63,9 +63,6 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
# scale and decode image with vae
image
=
1
/
0.18215
*
image
# decode image with vae
image
=
self
.
vqvae
.
decode
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
image
tests/test_modeling_utils.py
View file @
571e4062
...
...
@@ -1159,7 +1159,9 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
expected_slice
=
torch
.
tensor
([
0.5025
,
0.4121
,
0.3851
,
0.4806
,
0.3996
,
0.3745
,
0.4839
,
0.4559
,
0.4293
])
expected_slice
=
torch
.
tensor
(
[
-
0.1202
,
-
0.1005
,
-
0.0635
,
-
0.0520
,
-
0.1282
,
-
0.0838
,
-
0.0981
,
-
0.1318
,
-
0.1106
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
def
test_module_from_pipeline
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment