Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
97e1e3ba
Commit
97e1e3ba
authored
Jul 15, 2022
by
Patrick von Platen
Browse files
finalize model API
parent
dacabaa4
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
120 additions
and
39 deletions
+120
-39
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+4
-2
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+2
-1
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+50
-24
src/diffusers/pipelines/ddim/pipeline_ddim.py
src/diffusers/pipelines/ddim/pipeline_ddim.py
+3
-0
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+3
-0
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
...tent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
+3
-0
src/diffusers/pipelines/pndm/pipeline_pndm.py
src/diffusers/pipelines/pndm/pipeline_pndm.py
+6
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+49
-12
No files found.
src/diffusers/models/unet_glide.py
View file @
97e1e3ba
...
...
@@ -438,7 +438,8 @@ class GlideTextToImageUNetModel(GlideUNetModel):
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
def
forward
(
self
,
sample
,
timesteps
,
transformer_out
=
None
):
def
forward
(
self
,
sample
,
step_value
,
transformer_out
=
None
):
timesteps
=
step_value
x
=
sample
hs
=
[]
emb
=
self
.
time_embed
(
...
...
@@ -529,7 +530,8 @@ class GlideSuperResUNetModel(GlideUNetModel):
resblock_updown
=
resblock_updown
,
)
def
forward
(
self
,
sample
,
timesteps
,
low_res
=
None
):
def
forward
(
self
,
sample
,
step_value
,
low_res
=
None
):
timesteps
=
step_value
x
=
sample
_
,
_
,
new_height
,
new_width
=
x
.
shape
upsampled
=
F
.
interpolate
(
low_res
,
(
new_height
,
new_width
),
mode
=
"bilinear"
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
97e1e3ba
...
...
@@ -323,7 +323,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
def
forward
(
self
,
sample
,
timesteps
,
sigmas
=
None
):
def
forward
(
self
,
sample
,
step_value
,
sigmas
=
None
):
timesteps
=
step_value
x
=
sample
# timestep/noise_level embedding; only for continuous training
modules
=
self
.
all_modules
...
...
src/diffusers/models/unet_unconditional.py
View file @
97e1e3ba
from
typing
import
Dict
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -9,15 +11,6 @@ from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from
.unet_new
import
UNetMidBlock2D
,
get_down_block
,
get_up_block
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
TimestepEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
channel
,
time_embed_dim
):
super
().
__init__
()
...
...
@@ -79,7 +72,9 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
num_head_channels
=
32
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
,
# To delete once weights are converted
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
# LDM
attention_resolutions
=
(
8
,
4
,
2
),
ldm
=
False
,
...
...
@@ -91,10 +86,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
ch_mult
=
None
,
ch
=
None
,
ddpm
=
False
,
# ======================================
):
super
().
__init__
()
#
register all __init__ params with self.register
# register all __init__ params to be accessible via `self.config.<...>`
#
should probably be automated down the road as this is pure boiler plate code
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
...
...
@@ -109,15 +105,22 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
num_head_channels
=
num_head_channels
,
flip_sin_to_cos
=
flip_sin_to_cos
,
downscale_freq_shift
=
downscale_freq_shift
,
# (TODO(PVP) - To delete once weights are converted
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
attention_resolutions
=
attention_resolutions
,
attn_resolutions
=
attn_resolutions
,
ldm
=
ldm
,
ddpm
=
ddpm
,
# ======================================
)
# To delete - replace with config values
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self
.
image_size
=
image_size
time_embed_dim
=
block_channels
[
0
]
*
4
# ======================================
# # input
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
block_channels
[
0
],
kernel_size
=
3
,
padding
=
(
1
,
1
))
...
...
@@ -202,8 +205,9 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
conv_act
=
nn
.
SiLU
()
self
.
conv_out
=
nn
.
Conv2d
(
block_channels
[
0
],
out_channels
,
3
,
padding
=
1
)
# ======================== Out ====================
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self
.
is_overwritten
=
False
if
ldm
:
# =========== TO DELETE AFTER CONVERSION ==========
...
...
@@ -231,10 +235,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_channels
,
)
if
ddpm
:
out_channels
=
out_ch
image_size
=
resolution
block_channels
=
[
x
*
ch
for
x
in
ch_mult
]
conv_resample
=
resamp_with_conv
out_ch
=
out_channels
resolution
=
image_size
ch
=
block_channels
[
0
]
ch_mult
=
[
b
//
ch
for
b
in
block_channels
]
resamp_with_conv
=
conv_resample
self
.
init_for_ddpm
(
ch_mult
,
ch
,
...
...
@@ -246,13 +251,20 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_ch
,
dropout
=
0.1
,
)
def
forward
(
self
,
sample
,
timesteps
=
None
):
# TODO(PVP) - to delete later
# ======================================
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
step_value
:
Union
[
torch
.
Tensor
,
float
,
int
]
)
->
Dict
[
str
,
torch
.
FloatTensor
]:
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
if
not
self
.
is_overwritten
:
self
.
set_weights
()
# ======================================
# 1. time step embeddings
timesteps
=
step_value
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
...
...
@@ -295,7 +307,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
sample
=
self
.
conv_act
(
sample
)
sample
=
self
.
conv_out
(
sample
)
return
sample
output
=
{
"sample"
:
sample
}
return
output
# !!!IMPORTANT - ALL OF THE FOLLOWING CODE WILL BE DELETED AT RELEASE TIME AND SHOULD NOT BE TAKEN INTO CONSIDERATION WHEN EVALUATING THE API ###
# =================================================================================================================================================
def
set_weights
(
self
):
self
.
is_overwritten
=
True
...
...
@@ -694,3 +711,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
del
self
.
mid_new
del
self
.
up
del
self
.
norm_out
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
src/diffusers/pipelines/ddim/pipeline_ddim.py
View file @
97e1e3ba
...
...
@@ -59,6 +59,9 @@ class DDIMPipeline(DiffusionPipeline):
with
torch
.
no_grad
():
residual
=
self
.
unet
(
image
,
inference_step_times
[
t
])
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
step
(
residual
,
image
,
t
,
num_inference_steps
,
eta
)
...
...
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
97e1e3ba
...
...
@@ -46,6 +46,9 @@ class DDPMPipeline(DiffusionPipeline):
with
torch
.
no_grad
():
residual
=
self
.
unet
(
image
,
t
)
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
step
(
residual
,
image
,
t
)
...
...
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
View file @
97e1e3ba
...
...
@@ -51,6 +51,9 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
pred_noise_t
=
self
.
unet
(
image
,
timesteps
)
if
isinstance
(
pred_noise_t
,
dict
):
pred_noise_t
=
pred_noise_t
[
"sample"
]
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
...
...
src/diffusers/pipelines/pndm/pipeline_pndm.py
View file @
97e1e3ba
...
...
@@ -47,6 +47,9 @@ class PNDMPipeline(DiffusionPipeline):
t_orig
=
prk_time_steps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
image
=
self
.
noise_scheduler
.
step_prk
(
residual
,
image
,
t
,
num_inference_steps
)
timesteps
=
self
.
noise_scheduler
.
get_time_steps
(
num_inference_steps
)
...
...
@@ -54,6 +57,9 @@ class PNDMPipeline(DiffusionPipeline):
t_orig
=
timesteps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
image
=
self
.
noise_scheduler
.
step_plms
(
residual
,
image
,
t
,
num_inference_steps
)
return
image
tests/test_modeling_utils.py
View file @
97e1e3ba
...
...
@@ -109,8 +109,14 @@ class ModelTesterMixin:
with
torch
.
no_grad
():
image
=
model
(
**
inputs_dict
)
if
isinstance
(
image
,
dict
):
image
=
image
[
"sample"
]
new_image
=
new_model
(
**
inputs_dict
)
if
isinstance
(
new_image
,
dict
):
new_image
=
new_image
[
"sample"
]
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
self
.
assertLessEqual
(
max_diff
,
5e-5
,
"Models give different forward passes"
)
...
...
@@ -121,7 +127,12 @@ class ModelTesterMixin:
model
.
eval
()
with
torch
.
no_grad
():
first
=
model
(
**
inputs_dict
)
if
isinstance
(
first
,
dict
):
first
=
first
[
"sample"
]
second
=
model
(
**
inputs_dict
)
if
isinstance
(
second
,
dict
):
second
=
second
[
"sample"
]
out_1
=
first
.
cpu
().
numpy
()
out_2
=
second
.
cpu
().
numpy
()
...
...
@@ -139,6 +150,9 @@ class ModelTesterMixin:
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
if
isinstance
(
output
,
dict
):
output
=
output
[
"sample"
]
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"sample"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
...
...
@@ -151,7 +165,7 @@ class ModelTesterMixin:
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names
=
[
*
signature
.
parameters
.
keys
()]
expected_arg_names
=
[
"sample"
,
"
timesteps
"
]
expected_arg_names
=
[
"sample"
,
"
step_value
"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
def
test_model_from_config
(
self
):
...
...
@@ -177,8 +191,15 @@ class ModelTesterMixin:
with
torch
.
no_grad
():
output_1
=
model
(
**
inputs_dict
)
if
isinstance
(
output_1
,
dict
):
output_1
=
output_1
[
"sample"
]
output_2
=
new_model
(
**
inputs_dict
)
if
isinstance
(
output_2
,
dict
):
output_2
=
output_2
[
"sample"
]
self
.
assertEqual
(
output_1
.
shape
,
output_2
.
shape
)
def
test_training
(
self
):
...
...
@@ -188,6 +209,10 @@ class ModelTesterMixin:
model
.
to
(
torch_device
)
model
.
train
()
output
=
model
(
**
inputs_dict
)
if
isinstance
(
output
,
dict
):
output
=
output
[
"sample"
]
noise
=
torch
.
randn
((
inputs_dict
[
"sample"
].
shape
[
0
],)
+
self
.
output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
...
...
@@ -201,6 +226,10 @@ class ModelTesterMixin:
ema_model
=
EMAModel
(
model
,
device
=
torch_device
)
output
=
model
(
**
inputs_dict
)
if
isinstance
(
output
,
dict
):
output
=
output
[
"sample"
]
noise
=
torch
.
randn
((
inputs_dict
[
"sample"
].
shape
[
0
],)
+
self
.
output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
...
...
@@ -219,7 +248,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"
timesteps
"
:
time_step
}
return
{
"sample"
:
noise
,
"
step_value
"
:
time_step
}
@
property
def
input_shape
(
self
):
...
...
@@ -255,7 +284,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
image
=
model
(
**
self
.
dummy_input
)
[
"sample"
]
assert
image
is
not
None
,
"Make sure output is not None"
...
...
@@ -271,7 +300,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
[
"sample"
]
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
...
...
@@ -294,7 +323,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
low_res
=
torch
.
randn
((
batch_size
,
3
)
+
low_res_size
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"sample"
:
noise
,
"
timesteps
"
:
time_step
,
"low_res"
:
low_res
}
return
{
"sample"
:
noise
,
"
step_value
"
:
time_step
,
"low_res"
:
low_res
}
@
property
def
input_shape
(
self
):
...
...
@@ -385,7 +414,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
emb
=
torch
.
randn
((
batch_size
,
seq_len
,
transformer_dim
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"sample"
:
noise
,
"
timesteps
"
:
time_step
,
"transformer_out"
:
emb
}
return
{
"sample"
:
noise
,
"
step_value
"
:
time_step
,
"transformer_out"
:
emb
}
@
property
def
input_shape
(
self
):
...
...
@@ -477,7 +506,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"
timesteps
"
:
time_step
}
return
{
"sample"
:
noise
,
"
step_value
"
:
time_step
}
@
property
def
input_shape
(
self
):
...
...
@@ -512,7 +541,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
image
=
model
(
**
self
.
dummy_input
)
[
"sample"
]
assert
image
is
not
None
,
"Make sure output is not None"
...
...
@@ -528,7 +557,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
[
"sample"
]
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
...
...
@@ -572,7 +601,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"
timesteps
"
:
time_step
}
return
{
"sample"
:
noise
,
"
step_value
"
:
time_step
}
@
property
def
input_shape
(
self
):
...
...
@@ -837,7 +866,15 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_from_pretrained_save_pretrained
(
self
):
# 1. Load models
model
=
UNetUnconditionalModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
,
ddpm
=
True
block_channels
=
(
32
,
64
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
image_size
=
32
,
in_channels
=
3
,
out_channels
=
3
,
down_blocks
=
(
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
),
up_blocks
=
(
"UNetResAttnUpBlock2D"
,
"UNetResUpBlock2D"
),
ddpm
=
True
,
)
schedular
=
DDPMScheduler
(
timesteps
=
10
)
...
...
@@ -1034,7 +1071,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
def
test_ldm_uncond
(
self
):
ldm
=
LatentDiffusionUncondPipeline
.
from_pretrained
(
"fusing/latent-diffusion-celeba-256"
)
ldm
=
LatentDiffusionUncondPipeline
.
from_pretrained
(
"fusing/latent-diffusion-celeba-256"
,
ldm
=
True
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment