Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
571e4062
Commit
571e4062
authored
Jul 01, 2022
by
Patrick von Platen
Browse files
merge from master
parents
14bd3567
c2bc59d2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
174 additions
and
40 deletions
+174
-40
scripts/conversion_bddm.py
scripts/conversion_bddm.py
+40
-0
scripts/conversion_ldm_uncond.py
scripts/conversion_ldm_uncond.py
+56
-0
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+36
-8
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+38
-27
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
...tent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
+1
-4
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+3
-1
No files found.
scripts/conversion_bddm.py
0 → 100644
View file @
571e4062
import
argparse
import
torch
from
diffusers.pipelines.bddm
import
DiffWave
,
BDDMPipeline
from
diffusers
import
DDPMScheduler
def
convert_bddm_orginal
(
checkpoint_path
,
noise_scheduler_checkpoint_path
,
output_path
):
sd
=
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
)[
"model_state_dict"
]
noise_scheduler_sd
=
torch
.
load
(
noise_scheduler_checkpoint_path
,
map_location
=
"cpu"
)
model
=
DiffWave
()
model
.
load_state_dict
(
sd
,
strict
=
False
)
ts
,
_
,
betas
,
_
=
noise_scheduler_sd
ts
,
betas
=
list
(
ts
.
numpy
().
tolist
()),
list
(
betas
.
numpy
().
tolist
())
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
,
trained_betas
=
betas
,
timestep_values
=
ts
,
clip_sample
=
False
,
tensor_format
=
"np"
,
)
pipeline
=
BDDMPipeline
(
model
,
noise_scheduler
)
pipeline
.
save_pretrained
(
output_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--checkpoint_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--noise_scheduler_checkpoint_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
required
=
True
)
args
=
parser
.
parse_args
()
convert_bddm_orginal
(
args
.
checkpoint_path
,
args
.
noise_scheduler_checkpoint_path
,
args
.
output_path
)
scripts/conversion_ldm_uncond.py
0 → 100644
View file @
571e4062
import
argparse
import
OmegaConf
import
torch
from
diffusers
import
UNetLDMModel
,
VQModel
,
LatentDiffusionUncondPipeline
,
DDIMScheduler
def
convert_ldm_original
(
checkpoint_path
,
config_path
,
output_path
):
config
=
OmegaConf
.
load
(
config_path
)
state_dict
=
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
)[
"model"
]
keys
=
list
(
state_dict
.
keys
())
# extract state_dict for VQVAE
first_stage_dict
=
{}
first_stage_key
=
"first_stage_model."
for
key
in
keys
:
if
key
.
startswith
(
first_stage_key
):
first_stage_dict
[
key
.
replace
(
first_stage_key
,
""
)]
=
state_dict
[
key
]
# extract state_dict for UNetLDM
unet_state_dict
=
{}
unet_key
=
"model.diffusion_model."
for
key
in
keys
:
if
key
.
startswith
(
unet_key
):
unet_state_dict
[
key
.
replace
(
unet_key
,
""
)]
=
state_dict
[
key
]
vqvae_init_args
=
config
.
model
.
params
.
first_stage_config
.
params
unet_init_args
=
config
.
model
.
params
.
unet_config
.
params
vqvae
=
VQModel
(
**
vqvae_init_args
).
eval
()
vqvae
.
load_state_dict
(
first_stage_dict
)
unet
=
UNetLDMModel
(
**
unet_init_args
).
eval
()
unet
.
load_state_dict
(
unet_state_dict
)
noise_scheduler
=
DDIMScheduler
(
timesteps
=
config
.
model
.
params
.
timesteps
,
beta_schedule
=
"scaled_linear"
,
beta_start
=
config
.
model
.
params
.
linear_start
,
beta_end
=
config
.
model
.
params
.
linear_end
,
clip_sample
=
False
,
)
pipeline
=
LatentDiffusionUncondPipeline
(
vqvae
,
unet
,
noise_scheduler
)
pipeline
.
save_pretrained
(
output_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--checkpoint_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
required
=
True
)
args
=
parser
.
parse_args
()
convert_ldm_original
(
args
.
checkpoint_path
,
args
.
config_path
,
args
.
output_path
)
src/diffusers/models/resnet.py
View file @
571e4062
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
functools
import
partial
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -78,18 +79,25 @@ class Upsample(nn.Module):
...
@@ -78,18 +79,25 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
upsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
,
name
=
"conv"
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_conv
=
use_conv
self
.
dims
=
dims
self
.
dims
=
dims
self
.
use_conv_transpose
=
use_conv_transpose
self
.
use_conv_transpose
=
use_conv_transpose
self
.
name
=
name
conv
=
None
if
use_conv_transpose
:
if
use_conv_transpose
:
self
.
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
elif
use_conv
:
elif
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
if
name
==
"conv"
:
self
.
conv
=
conv
else
:
self
.
Conv2d_0
=
conv
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
x
.
shape
[
1
]
==
self
.
channels
...
@@ -102,7 +110,10 @@ class Upsample(nn.Module):
...
@@ -102,7 +110,10 @@ class Upsample(nn.Module):
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
use_conv
:
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
if
self
.
name
==
"conv"
:
x
=
self
.
conv
(
x
)
else
:
x
=
self
.
Conv2d_0
(
x
)
return
x
return
x
...
@@ -134,6 +145,8 @@ class Downsample(nn.Module):
...
@@ -134,6 +145,8 @@ class Downsample(nn.Module):
if
name
==
"conv"
:
if
name
==
"conv"
:
self
.
conv
=
conv
self
.
conv
=
conv
elif
name
==
"Conv2d_0"
:
self
.
Conv2d_0
=
conv
else
:
else
:
self
.
op
=
conv
self
.
op
=
conv
...
@@ -145,6 +158,8 @@ class Downsample(nn.Module):
...
@@ -145,6 +158,8 @@ class Downsample(nn.Module):
if
self
.
name
==
"conv"
:
if
self
.
name
==
"conv"
:
return
self
.
conv
(
x
)
return
self
.
conv
(
x
)
elif
self
.
name
==
"Conv2d_0"
:
return
self
.
Conv2d_0
(
x
)
else
:
else
:
return
self
.
op
(
x
)
return
self
.
op
(
x
)
...
@@ -469,6 +484,7 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -469,6 +484,7 @@ class ResnetBlockBigGANpp(nn.Module):
up
=
False
,
up
=
False
,
down
=
False
,
down
=
False
,
dropout
=
0.1
,
dropout
=
0.1
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
fir_kernel
=
(
1
,
3
,
3
,
1
),
skip_rescale
=
True
,
skip_rescale
=
True
,
init_scale
=
0.0
,
init_scale
=
0.0
,
...
@@ -479,8 +495,20 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -479,8 +495,20 @@ class ResnetBlockBigGANpp(nn.Module):
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
up
=
up
self
.
up
=
up
self
.
down
=
down
self
.
down
=
down
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
fir_kernel
=
fir_kernel
if
self
.
up
:
if
self
.
fir
:
self
.
upsample
=
partial
(
upsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
upsample
=
partial
(
F
.
interpolate
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
elif
self
.
down
:
if
self
.
fir
:
self
.
downsample
=
partial
(
downsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
downsample
=
partial
(
F
.
avg_pool2d
,
kernel_size
=
2
,
stride
=
2
)
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
if
temb_dim
is
not
None
:
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
...
@@ -503,11 +531,11 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -503,11 +531,11 @@ class ResnetBlockBigGANpp(nn.Module):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
if
self
.
up
:
h
=
upsample
_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
h
=
self
.
upsample
(
h
)
x
=
upsample
_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
x
=
self
.
upsample
(
x
)
elif
self
.
down
:
elif
self
.
down
:
h
=
downsample
_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
h
=
self
.
downsample
(
h
)
x
=
downsample
_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
x
=
self
.
downsample
(
x
)
h
=
self
.
Conv_0
(
h
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
# Add bias to each feature map conditioned on the time embedding
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
571e4062
...
@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
downsample_2d
,
upfirdn2d
,
upsample_2d
from
.resnet
import
downsample_2d
,
upfirdn2d
,
upsample_2d
,
Downsample
,
Upsample
from
.resnet
import
ResnetBlock
from
.resnet
import
ResnetBlock
...
@@ -185,18 +185,19 @@ class Combine(nn.Module):
...
@@ -185,18 +185,19 @@ class Combine(nn.Module):
class
FirUpsample
(
nn
.
Module
):
class
FirUpsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
channels
=
None
,
out_ch
annels
=
None
,
use
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
out_ch
annels
=
out_ch
annels
if
out_ch
annels
else
channels
if
with
_conv
:
if
use
_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
Conv2d_0
=
Conv2d
(
channels
,
out_ch
annels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
with
_conv
=
with
_conv
self
.
use
_conv
=
use
_conv
self
.
fir_kernel
=
fir_kernel
self
.
fir_kernel
=
fir_kernel
self
.
out_ch
=
out_ch
self
.
out_ch
annels
=
out_ch
annels
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
self
.
with
_conv
:
if
self
.
use
_conv
:
h
=
_upsample_conv_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
h
=
_upsample_conv_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
h
=
h
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
else
:
else
:
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
...
@@ -204,18 +205,19 @@ class FirUpsample(nn.Module):
...
@@ -204,18 +205,19 @@ class FirUpsample(nn.Module):
class
FirDownsample
(
nn
.
Module
):
class
FirDownsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
channels
=
None
,
out_ch
annels
=
None
,
use
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
out_ch
annels
=
out_ch
annels
if
out_ch
annels
else
channels
if
with
_conv
:
if
use
_conv
:
self
.
Conv2d_0
=
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
Conv2d_0
=
self
.
Conv2d_0
=
Conv2d
(
channels
,
out_ch
annels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
fir_kernel
=
fir_kernel
self
.
fir_kernel
=
fir_kernel
self
.
with
_conv
=
with
_conv
self
.
use
_conv
=
use
_conv
self
.
out_ch
=
out_ch
self
.
out_ch
annels
=
out_ch
annels
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
self
.
with
_conv
:
if
self
.
use
_conv
:
x
=
_conv_downsample_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
x
=
_conv_downsample_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
x
=
x
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
else
:
else
:
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
...
@@ -229,13 +231,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -229,13 +231,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
,
self
,
image_size
=
1024
,
image_size
=
1024
,
num_channels
=
3
,
num_channels
=
3
,
centered
=
False
,
attn_resolutions
=
(
16
,),
attn_resolutions
=
(
16
,),
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
),
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
),
conditional
=
True
,
conditional
=
True
,
conv_size
=
3
,
conv_size
=
3
,
dropout
=
0.0
,
dropout
=
0.0
,
embedding_type
=
"fourier"
,
embedding_type
=
"fourier"
,
fir
=
True
,
# TODO (patil-suraj) remove this option from here and pre-trained model configs
fir
=
True
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
fir_kernel
=
(
1
,
3
,
3
,
1
),
fourier_scale
=
16
,
fourier_scale
=
16
,
init_scale
=
0.0
,
init_scale
=
0.0
,
...
@@ -253,12 +256,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -253,12 +256,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
.
register_to_config
(
self
.
register_to_config
(
image_size
=
image_size
,
image_size
=
image_size
,
num_channels
=
num_channels
,
num_channels
=
num_channels
,
centered
=
centered
,
attn_resolutions
=
attn_resolutions
,
attn_resolutions
=
attn_resolutions
,
ch_mult
=
ch_mult
,
ch_mult
=
ch_mult
,
conditional
=
conditional
,
conditional
=
conditional
,
conv_size
=
conv_size
,
conv_size
=
conv_size
,
dropout
=
dropout
,
dropout
=
dropout
,
embedding_type
=
embedding_type
,
embedding_type
=
embedding_type
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
fir_kernel
=
fir_kernel
,
fourier_scale
=
fourier_scale
,
fourier_scale
=
fourier_scale
,
init_scale
=
init_scale
,
init_scale
=
init_scale
,
...
@@ -308,21 +313,26 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -308,21 +313,26 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules
.
append
(
Linear
(
nf
*
4
,
nf
*
4
))
modules
.
append
(
Linear
(
nf
*
4
,
nf
*
4
))
AttnBlock
=
functools
.
partial
(
AttentionBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
))
AttnBlock
=
functools
.
partial
(
AttentionBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
))
Up_sample
=
functools
.
partial
(
FirUpsample
,
with_conv
=
resamp_with_conv
,
fir_kernel
=
fir_kernel
)
if
self
.
fir
:
Up_sample
=
functools
.
partial
(
FirUpsample
,
fir_kernel
=
fir_kernel
,
use_conv
=
resamp_with_conv
)
else
:
Up_sample
=
functools
.
partial
(
Upsample
,
name
=
"Conv2d_0"
)
if
progressive
==
"output_skip"
:
if
progressive
==
"output_skip"
:
self
.
pyramid_upsample
=
Up_sample
(
fir_kernel
=
fir_ker
ne
l
,
with
_conv
=
False
)
self
.
pyramid_upsample
=
Up_sample
(
channels
=
No
ne
,
use
_conv
=
False
)
elif
progressive
==
"residual"
:
elif
progressive
==
"residual"
:
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
fir_kernel
=
fir_kernel
,
with
_conv
=
True
)
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
use
_conv
=
True
)
Down_sample
=
functools
.
partial
(
FirDownsample
,
with_conv
=
resamp_with_conv
,
fir_kernel
=
fir_kernel
)
if
self
.
fir
:
Down_sample
=
functools
.
partial
(
FirDownsample
,
fir_kernel
=
fir_kernel
,
use_conv
=
resamp_with_conv
)
else
:
Down_sample
=
functools
.
partial
(
Downsample
,
padding
=
0
,
name
=
"Conv2d_0"
)
if
progressive_input
==
"input_skip"
:
if
progressive_input
==
"input_skip"
:
self
.
pyramid_downsample
=
Down_sample
(
fir_kernel
=
fir_ker
ne
l
,
with
_conv
=
False
)
self
.
pyramid_downsample
=
Down_sample
(
channels
=
No
ne
,
use
_conv
=
False
)
elif
progressive_input
==
"residual"
:
elif
progressive_input
==
"residual"
:
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
use_conv
=
True
)
# Downsampling block
channels
=
num_channels
channels
=
num_channels
if
progressive_input
!=
"none"
:
if
progressive_input
!=
"none"
:
...
@@ -376,7 +386,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -376,7 +386,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch
*=
2
in_ch
*=
2
elif
progressive_input
==
"residual"
:
elif
progressive_input
==
"residual"
:
modules
.
append
(
pyramid_downsample
(
in_ch
=
input_pyramid_ch
,
out_ch
=
in_ch
))
modules
.
append
(
pyramid_downsample
(
channels
=
input_pyramid_ch
,
out_ch
annels
=
in_ch
))
input_pyramid_ch
=
in_ch
input_pyramid_ch
=
in_ch
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
...
@@ -448,7 +458,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -448,7 +458,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
)
)
pyramid_ch
=
channels
pyramid_ch
=
channels
elif
progressive
==
"residual"
:
elif
progressive
==
"residual"
:
modules
.
append
(
pyramid_upsample
(
in_ch
=
pyramid_ch
,
out_ch
=
in_ch
))
modules
.
append
(
pyramid_upsample
(
channels
=
pyramid_ch
,
out_ch
annels
=
in_ch
))
pyramid_ch
=
in_ch
pyramid_ch
=
in_ch
else
:
else
:
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
...
@@ -505,7 +515,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -505,7 +515,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
temb
=
None
temb
=
None
# If input data is in [0, 1]
# If input data is in [0, 1]
x
=
2
*
x
-
1.0
if
not
self
.
config
.
centered
:
x
=
2
*
x
-
1.0
# Downsampling block
# Downsampling block
input_pyramid
=
None
input_pyramid
=
None
...
...
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
View file @
571e4062
...
@@ -63,9 +63,6 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
...
@@ -63,9 +63,6 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
# 4. set current image to prev_image: x_t -> x_t-1
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
image
=
pred_prev_image
+
variance
# scale and decode image with vae
# decode image with vae
image
=
1
/
0.18215
*
image
image
=
self
.
vqvae
.
decode
(
image
)
image
=
self
.
vqvae
.
decode
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
image
return
image
tests/test_modeling_utils.py
View file @
571e4062
...
@@ -1159,7 +1159,9 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -1159,7 +1159,9 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
expected_slice
=
torch
.
tensor
([
0.5025
,
0.4121
,
0.3851
,
0.4806
,
0.3996
,
0.3745
,
0.4839
,
0.4559
,
0.4293
])
expected_slice
=
torch
.
tensor
(
[
-
0.1202
,
-
0.1005
,
-
0.0635
,
-
0.0520
,
-
0.1282
,
-
0.0838
,
-
0.0981
,
-
0.1318
,
-
0.1106
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
def
test_module_from_pipeline
(
self
):
def
test_module_from_pipeline
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment