Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
97e1e3ba
Commit
97e1e3ba
authored
Jul 15, 2022
by
Patrick von Platen
Browse files
finalize model API
parent
dacabaa4
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
120 additions
and
39 deletions
+120
-39
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+4
-2
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+2
-1
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+50
-24
src/diffusers/pipelines/ddim/pipeline_ddim.py
src/diffusers/pipelines/ddim/pipeline_ddim.py
+3
-0
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+3
-0
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
...tent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
+3
-0
src/diffusers/pipelines/pndm/pipeline_pndm.py
src/diffusers/pipelines/pndm/pipeline_pndm.py
+6
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+49
-12
No files found.
src/diffusers/models/unet_glide.py
View file @
97e1e3ba
...
...
@@ -438,7 +438,8 @@ class GlideTextToImageUNetModel(GlideUNetModel):
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
def
forward
(
self
,
sample
,
timesteps
,
transformer_out
=
None
):
def
forward
(
self
,
sample
,
step_value
,
transformer_out
=
None
):
timesteps
=
step_value
x
=
sample
hs
=
[]
emb
=
self
.
time_embed
(
...
...
@@ -529,7 +530,8 @@ class GlideSuperResUNetModel(GlideUNetModel):
resblock_updown
=
resblock_updown
,
)
def
forward
(
self
,
sample
,
timesteps
,
low_res
=
None
):
def
forward
(
self
,
sample
,
step_value
,
low_res
=
None
):
timesteps
=
step_value
x
=
sample
_
,
_
,
new_height
,
new_width
=
x
.
shape
upsampled
=
F
.
interpolate
(
low_res
,
(
new_height
,
new_width
),
mode
=
"bilinear"
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
97e1e3ba
...
...
@@ -323,7 +323,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
def
forward
(
self
,
sample
,
timesteps
,
sigmas
=
None
):
def
forward
(
self
,
sample
,
step_value
,
sigmas
=
None
):
timesteps
=
step_value
x
=
sample
# timestep/noise_level embedding; only for continuous training
modules
=
self
.
all_modules
...
...
src/diffusers/models/unet_unconditional.py
View file @
97e1e3ba
from
typing
import
Dict
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -9,15 +11,6 @@ from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from
.unet_new
import
UNetMidBlock2D
,
get_down_block
,
get_up_block
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
TimestepEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
channel
,
time_embed_dim
):
super
().
__init__
()
...
...
@@ -79,7 +72,9 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
num_head_channels
=
32
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
,
# To delete once weights are converted
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
# LDM
attention_resolutions
=
(
8
,
4
,
2
),
ldm
=
False
,
...
...
@@ -91,10 +86,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
ch_mult
=
None
,
ch
=
None
,
ddpm
=
False
,
# ======================================
):
super
().
__init__
()
#
register all __init__ params with self.register
# register all __init__ params to be accessible via `self.config.<...>`
#
should probably be automated down the road as this is pure boiler plate code
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
...
...
@@ -109,15 +105,22 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
num_head_channels
=
num_head_channels
,
flip_sin_to_cos
=
flip_sin_to_cos
,
downscale_freq_shift
=
downscale_freq_shift
,
# (TODO(PVP) - To delete once weights are converted
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
attention_resolutions
=
attention_resolutions
,
attn_resolutions
=
attn_resolutions
,
ldm
=
ldm
,
ddpm
=
ddpm
,
# ======================================
)
# To delete - replace with config values
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self
.
image_size
=
image_size
time_embed_dim
=
block_channels
[
0
]
*
4
# ======================================
# # input
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
block_channels
[
0
],
kernel_size
=
3
,
padding
=
(
1
,
1
))
...
...
@@ -202,8 +205,9 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
conv_act
=
nn
.
SiLU
()
self
.
conv_out
=
nn
.
Conv2d
(
block_channels
[
0
],
out_channels
,
3
,
padding
=
1
)
# ======================== Out ====================
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self
.
is_overwritten
=
False
if
ldm
:
# =========== TO DELETE AFTER CONVERSION ==========
...
...
@@ -231,10 +235,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_channels
,
)
if
ddpm
:
out_channels
=
out_ch
image_size
=
resolution
block_channels
=
[
x
*
ch
for
x
in
ch_mult
]
conv_resample
=
resamp_with_conv
out_ch
=
out_channels
resolution
=
image_size
ch
=
block_channels
[
0
]
ch_mult
=
[
b
//
ch
for
b
in
block_channels
]
resamp_with_conv
=
conv_resample
self
.
init_for_ddpm
(
ch_mult
,
ch
,
...
...
@@ -246,13 +251,20 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_ch
,
dropout
=
0.1
,
)
def
forward
(
self
,
sample
,
timesteps
=
None
):
# TODO(PVP) - to delete later
# ======================================
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
step_value
:
Union
[
torch
.
Tensor
,
float
,
int
]
)
->
Dict
[
str
,
torch
.
FloatTensor
]:
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
if
not
self
.
is_overwritten
:
self
.
set_weights
()
# ======================================
# 1. time step embeddings
timesteps
=
step_value
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
...
...
@@ -295,7 +307,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
sample
=
self
.
conv_act
(
sample
)
sample
=
self
.
conv_out
(
sample
)
return
sample
output
=
{
"sample"
:
sample
}
return
output
# !!!IMPORTANT - ALL OF THE FOLLOWING CODE WILL BE DELETED AT RELEASE TIME AND SHOULD NOT BE TAKEN INTO CONSIDERATION WHEN EVALUATING THE API ###
# =================================================================================================================================================
def
set_weights
(
self
):
self
.
is_overwritten
=
True
...
...
@@ -694,3 +711,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
del
self
.
mid_new
del
self
.
up
del
self
.
norm_out
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
src/diffusers/pipelines/ddim/pipeline_ddim.py
View file @
97e1e3ba
...
...
@@ -59,6 +59,9 @@ class DDIMPipeline(DiffusionPipeline):
with
torch
.
no_grad
():
residual
=
self
.
unet
(
image
,
inference_step_times
[
t
])
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
step
(
residual
,
image
,
t
,
num_inference_steps
,
eta
)
...
...
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
97e1e3ba
...
...
@@ -46,6 +46,9 @@ class DDPMPipeline(DiffusionPipeline):
with
torch
.
no_grad
():
residual
=
self
.
unet
(
image
,
t
)
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
step
(
residual
,
image
,
t
)
...
...
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
View file @
97e1e3ba
...
...
@@ -51,6 +51,9 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
pred_noise_t
=
self
.
unet
(
image
,
timesteps
)
if
isinstance
(
pred_noise_t
,
dict
):
pred_noise_t
=
pred_noise_t
[
"sample"
]
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
...
...
src/diffusers/pipelines/pndm/pipeline_pndm.py
View file @
97e1e3ba
...
...
@@ -47,6 +47,9 @@ class PNDMPipeline(DiffusionPipeline):
t_orig
=
prk_time_steps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
image
=
self
.
noise_scheduler
.
step_prk
(
residual
,
image
,
t
,
num_inference_steps
)
timesteps
=
self
.
noise_scheduler
.
get_time_steps
(
num_inference_steps
)
...
...
@@ -54,6 +57,9 @@ class PNDMPipeline(DiffusionPipeline):
t_orig
=
timesteps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
image
=
self
.
noise_scheduler
.
step_plms
(
residual
,
image
,
t
,
num_inference_steps
)
return
image
tests/test_modeling_utils.py
View file @
97e1e3ba
...
...
@@ -109,8 +109,14 @@ class ModelTesterMixin:
with
torch
.
no_grad
():
image
=
model
(
**
inputs_dict
)
if
isinstance
(
image
,
dict
):
image
=
image
[
"sample"
]
new_image
=
new_model
(
**
inputs_dict
)
if
isinstance
(
new_image
,
dict
):
new_image
=
new_image
[
"sample"
]
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
self
.
assertLessEqual
(
max_diff
,
5e-5
,
"Models give different forward passes"
)
...
...
@@ -121,7 +127,12 @@ class ModelTesterMixin:
model
.
eval
()
with
torch
.
no_grad
():
first
=
model
(
**
inputs_dict
)
if
isinstance
(
first
,
dict
):
first
=
first
[
"sample"
]
second
=
model
(
**
inputs_dict
)
if
isinstance
(
second
,
dict
):
second
=
second
[
"sample"
]
out_1
=
first
.
cpu
().
numpy
()
out_2
=
second
.
cpu
().
numpy
()
...
...
@@ -139,6 +150,9 @@ class ModelTesterMixin:
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
if
isinstance
(
output
,
dict
):
output
=
output
[
"sample"
]
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"sample"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
...
...
@@ -151,7 +165,7 @@ class ModelTesterMixin:
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names
=
[
*
signature
.
parameters
.
keys
()]
expected_arg_names
=
[
"sample"
,
"
timesteps
"
]
expected_arg_names
=
[
"sample"
,
"
step_value
"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
def
test_model_from_config
(
self
):
...
...
@@ -177,8 +191,15 @@ class ModelTesterMixin:
with
torch
.
no_grad
():
output_1
=
model
(
**
inputs_dict
)
if
isinstance
(
output_1
,
dict
):
output_1
=
output_1
[
"sample"
]
output_2
=
new_model
(
**
inputs_dict
)
if
isinstance
(
output_2
,
dict
):
output_2
=
output_2
[
"sample"
]
self
.
assertEqual
(
output_1
.
shape
,
output_2
.
shape
)
def
test_training
(
self
):
...
...
@@ -188,6 +209,10 @@ class ModelTesterMixin:
model
.
to
(
torch_device
)
model
.
train
()
output
=
model
(
**
inputs_dict
)
if
isinstance
(
output
,
dict
):
output
=
output
[
"sample"
]
noise
=
torch
.
randn
((
inputs_dict
[
"sample"
].
shape
[
0
],)
+
self
.
output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
...
...
@@ -201,6 +226,10 @@ class ModelTesterMixin:
ema_model
=
EMAModel
(
model
,
device
=
torch_device
)
output
=
model
(
**
inputs_dict
)
if
isinstance
(
output
,
dict
):
output
=
output
[
"sample"
]
noise
=
torch
.
randn
((
inputs_dict
[
"sample"
].
shape
[
0
],)
+
self
.
output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
...
...
@@ -219,7 +248,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"
timesteps
"
:
time_step
}
return
{
"sample"
:
noise
,
"
step_value
"
:
time_step
}
@
property
def
input_shape
(
self
):
...
...
@@ -255,7 +284,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
image
=
model
(
**
self
.
dummy_input
)
[
"sample"
]
assert
image
is
not
None
,
"Make sure output is not None"
...
...
@@ -271,7 +300,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
[
"sample"
]
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
...
...
@@ -294,7 +323,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
low_res
=
torch
.
randn
((
batch_size
,
3
)
+
low_res_size
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"sample"
:
noise
,
"
timesteps
"
:
time_step
,
"low_res"
:
low_res
}
return
{
"sample"
:
noise
,
"
step_value
"
:
time_step
,
"low_res"
:
low_res
}
@
property
def
input_shape
(
self
):
...
...
@@ -385,7 +414,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
emb
=
torch
.
randn
((
batch_size
,
seq_len
,
transformer_dim
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"sample"
:
noise
,
"
timesteps
"
:
time_step
,
"transformer_out"
:
emb
}
return
{
"sample"
:
noise
,
"
step_value
"
:
time_step
,
"transformer_out"
:
emb
}
@
property
def
input_shape
(
self
):
...
...
@@ -477,7 +506,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"
timesteps
"
:
time_step
}
return
{
"sample"
:
noise
,
"
step_value
"
:
time_step
}
@
property
def
input_shape
(
self
):
...
...
@@ -512,7 +541,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
image
=
model
(
**
self
.
dummy_input
)
[
"sample"
]
assert
image
is
not
None
,
"Make sure output is not None"
...
...
@@ -528,7 +557,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
[
"sample"
]
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
...
...
@@ -572,7 +601,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"
timesteps
"
:
time_step
}
return
{
"sample"
:
noise
,
"
step_value
"
:
time_step
}
@
property
def
input_shape
(
self
):
...
...
@@ -837,7 +866,15 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_from_pretrained_save_pretrained
(
self
):
# 1. Load models
model
=
UNetUnconditionalModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
,
ddpm
=
True
block_channels
=
(
32
,
64
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
image_size
=
32
,
in_channels
=
3
,
out_channels
=
3
,
down_blocks
=
(
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
),
up_blocks
=
(
"UNetResAttnUpBlock2D"
,
"UNetResUpBlock2D"
),
ddpm
=
True
,
)
schedular
=
DDPMScheduler
(
timesteps
=
10
)
...
...
@@ -1034,7 +1071,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
def
test_ldm_uncond
(
self
):
ldm
=
LatentDiffusionUncondPipeline
.
from_pretrained
(
"fusing/latent-diffusion-celeba-256"
)
ldm
=
LatentDiffusionUncondPipeline
.
from_pretrained
(
"fusing/latent-diffusion-celeba-256"
,
ldm
=
True
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment