Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
97e1e3ba
Commit
97e1e3ba
authored
Jul 15, 2022
by
Patrick von Platen
Browse files
finalize model API
parent
dacabaa4
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
120 additions
and
39 deletions
+120
-39
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+4
-2
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+2
-1
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+50
-24
src/diffusers/pipelines/ddim/pipeline_ddim.py
src/diffusers/pipelines/ddim/pipeline_ddim.py
+3
-0
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+3
-0
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
...tent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
+3
-0
src/diffusers/pipelines/pndm/pipeline_pndm.py
src/diffusers/pipelines/pndm/pipeline_pndm.py
+6
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+49
-12
No files found.
src/diffusers/models/unet_glide.py
View file @
97e1e3ba
...
@@ -438,7 +438,8 @@ class GlideTextToImageUNetModel(GlideUNetModel):
...
@@ -438,7 +438,8 @@ class GlideTextToImageUNetModel(GlideUNetModel):
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
def
forward
(
self
,
sample
,
timesteps
,
transformer_out
=
None
):
def
forward
(
self
,
sample
,
step_value
,
transformer_out
=
None
):
timesteps
=
step_value
x
=
sample
x
=
sample
hs
=
[]
hs
=
[]
emb
=
self
.
time_embed
(
emb
=
self
.
time_embed
(
...
@@ -529,7 +530,8 @@ class GlideSuperResUNetModel(GlideUNetModel):
...
@@ -529,7 +530,8 @@ class GlideSuperResUNetModel(GlideUNetModel):
resblock_updown
=
resblock_updown
,
resblock_updown
=
resblock_updown
,
)
)
def
forward
(
self
,
sample
,
timesteps
,
low_res
=
None
):
def
forward
(
self
,
sample
,
step_value
,
low_res
=
None
):
timesteps
=
step_value
x
=
sample
x
=
sample
_
,
_
,
new_height
,
new_width
=
x
.
shape
_
,
_
,
new_height
,
new_width
=
x
.
shape
upsampled
=
F
.
interpolate
(
low_res
,
(
new_height
,
new_width
),
mode
=
"bilinear"
)
upsampled
=
F
.
interpolate
(
low_res
,
(
new_height
,
new_width
),
mode
=
"bilinear"
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
97e1e3ba
...
@@ -323,7 +323,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -323,7 +323,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
def
forward
(
self
,
sample
,
timesteps
,
sigmas
=
None
):
def
forward
(
self
,
sample
,
step_value
,
sigmas
=
None
):
timesteps
=
step_value
x
=
sample
x
=
sample
# timestep/noise_level embedding; only for continuous training
# timestep/noise_level embedding; only for continuous training
modules
=
self
.
all_modules
modules
=
self
.
all_modules
...
...
src/diffusers/models/unet_unconditional.py
View file @
97e1e3ba
from
typing
import
Dict
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -9,15 +11,6 @@ from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
...
@@ -9,15 +11,6 @@ from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from
.unet_new
import
UNetMidBlock2D
,
get_down_block
,
get_up_block
from
.unet_new
import
UNetMidBlock2D
,
get_down_block
,
get_up_block
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
TimestepEmbedding
(
nn
.
Module
):
class
TimestepEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
channel
,
time_embed_dim
):
def
__init__
(
self
,
channel
,
time_embed_dim
):
super
().
__init__
()
super
().
__init__
()
...
@@ -79,7 +72,9 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -79,7 +72,9 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
num_head_channels
=
32
,
num_head_channels
=
32
,
flip_sin_to_cos
=
True
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
,
downscale_freq_shift
=
0
,
# To delete once weights are converted
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
# LDM
# LDM
attention_resolutions
=
(
8
,
4
,
2
),
attention_resolutions
=
(
8
,
4
,
2
),
ldm
=
False
,
ldm
=
False
,
...
@@ -91,10 +86,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -91,10 +86,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
ch_mult
=
None
,
ch_mult
=
None
,
ch
=
None
,
ch
=
None
,
ddpm
=
False
,
ddpm
=
False
,
# ======================================
):
):
super
().
__init__
()
super
().
__init__
()
# register all __init__ params to be accessible via `self.config.<...>`
#
register all __init__ params with self.register
#
should probably be automated down the road as this is pure boiler plate code
self
.
register_to_config
(
self
.
register_to_config
(
image_size
=
image_size
,
image_size
=
image_size
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
...
@@ -109,15 +105,22 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -109,15 +105,22 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
flip_sin_to_cos
=
flip_sin_to_cos
,
flip_sin_to_cos
=
flip_sin_to_cos
,
downscale_freq_shift
=
downscale_freq_shift
,
downscale_freq_shift
=
downscale_freq_shift
,
# (TODO(PVP) - To delete once weights are converted
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
attention_resolutions
=
attention_resolutions
,
attention_resolutions
=
attention_resolutions
,
attn_resolutions
=
attn_resolutions
,
ldm
=
ldm
,
ldm
=
ldm
,
ddpm
=
ddpm
,
ddpm
=
ddpm
,
# ======================================
)
)
# To delete - replace with config values
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self
.
image_size
=
image_size
self
.
image_size
=
image_size
time_embed_dim
=
block_channels
[
0
]
*
4
time_embed_dim
=
block_channels
[
0
]
*
4
# ======================================
# # input
# # input
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
block_channels
[
0
],
kernel_size
=
3
,
padding
=
(
1
,
1
))
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
block_channels
[
0
],
kernel_size
=
3
,
padding
=
(
1
,
1
))
...
@@ -202,8 +205,9 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -202,8 +205,9 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
conv_act
=
nn
.
SiLU
()
self
.
conv_act
=
nn
.
SiLU
()
self
.
conv_out
=
nn
.
Conv2d
(
block_channels
[
0
],
out_channels
,
3
,
padding
=
1
)
self
.
conv_out
=
nn
.
Conv2d
(
block_channels
[
0
],
out_channels
,
3
,
padding
=
1
)
# ======================== Out ====================
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self
.
is_overwritten
=
False
self
.
is_overwritten
=
False
if
ldm
:
if
ldm
:
# =========== TO DELETE AFTER CONVERSION ==========
# =========== TO DELETE AFTER CONVERSION ==========
...
@@ -231,10 +235,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -231,10 +235,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_channels
,
out_channels
,
)
)
if
ddpm
:
if
ddpm
:
out_channels
=
out_ch
out_ch
=
out_channels
image_size
=
resolution
resolution
=
image_size
block_channels
=
[
x
*
ch
for
x
in
ch_mult
]
ch
=
block_channels
[
0
]
conv_resample
=
resamp_with_conv
ch_mult
=
[
b
//
ch
for
b
in
block_channels
]
resamp_with_conv
=
conv_resample
self
.
init_for_ddpm
(
self
.
init_for_ddpm
(
ch_mult
,
ch_mult
,
ch
,
ch
,
...
@@ -246,13 +251,20 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -246,13 +251,20 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_ch
,
out_ch
,
dropout
=
0.1
,
dropout
=
0.1
,
)
)
# ======================================
def
forward
(
self
,
sample
,
timesteps
=
None
):
# TODO(PVP) - to delete later
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
step_value
:
Union
[
torch
.
Tensor
,
float
,
int
]
)
->
Dict
[
str
,
torch
.
FloatTensor
]:
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
if
not
self
.
is_overwritten
:
if
not
self
.
is_overwritten
:
self
.
set_weights
()
self
.
set_weights
()
# ======================================
# 1. time step embeddings
# 1. time step embeddings
timesteps
=
step_value
if
not
torch
.
is_tensor
(
timesteps
):
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
...
@@ -295,7 +307,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -295,7 +307,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
sample
=
self
.
conv_act
(
sample
)
sample
=
self
.
conv_act
(
sample
)
sample
=
self
.
conv_out
(
sample
)
sample
=
self
.
conv_out
(
sample
)
return
sample
output
=
{
"sample"
:
sample
}
return
output
# !!!IMPORTANT - ALL OF THE FOLLOWING CODE WILL BE DELETED AT RELEASE TIME AND SHOULD NOT BE TAKEN INTO CONSIDERATION WHEN EVALUATING THE API ###
# =================================================================================================================================================
def
set_weights
(
self
):
def
set_weights
(
self
):
self
.
is_overwritten
=
True
self
.
is_overwritten
=
True
...
@@ -694,3 +711,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -694,3 +711,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
del
self
.
mid_new
del
self
.
mid_new
del
self
.
up
del
self
.
up
del
self
.
norm_out
del
self
.
norm_out
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
src/diffusers/pipelines/ddim/pipeline_ddim.py
View file @
97e1e3ba
...
@@ -59,6 +59,9 @@ class DDIMPipeline(DiffusionPipeline):
...
@@ -59,6 +59,9 @@ class DDIMPipeline(DiffusionPipeline):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
residual
=
self
.
unet
(
image
,
inference_step_times
[
t
])
residual
=
self
.
unet
(
image
,
inference_step_times
[
t
])
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
# 2. predict previous mean of image x_t-1
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
step
(
residual
,
image
,
t
,
num_inference_steps
,
eta
)
pred_prev_image
=
self
.
noise_scheduler
.
step
(
residual
,
image
,
t
,
num_inference_steps
,
eta
)
...
...
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
97e1e3ba
...
@@ -46,6 +46,9 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -46,6 +46,9 @@ class DDPMPipeline(DiffusionPipeline):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
residual
=
self
.
unet
(
image
,
t
)
residual
=
self
.
unet
(
image
,
t
)
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
# 2. predict previous mean of image x_t-1
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
step
(
residual
,
image
,
t
)
pred_prev_image
=
self
.
noise_scheduler
.
step
(
residual
,
image
,
t
)
...
...
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
View file @
97e1e3ba
...
@@ -51,6 +51,9 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
...
@@ -51,6 +51,9 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
pred_noise_t
=
self
.
unet
(
image
,
timesteps
)
pred_noise_t
=
self
.
unet
(
image
,
timesteps
)
if
isinstance
(
pred_noise_t
,
dict
):
pred_noise_t
=
pred_noise_t
[
"sample"
]
# 2. predict previous mean of image x_t-1
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
pred_prev_image
=
self
.
noise_scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
...
...
src/diffusers/pipelines/pndm/pipeline_pndm.py
View file @
97e1e3ba
...
@@ -47,6 +47,9 @@ class PNDMPipeline(DiffusionPipeline):
...
@@ -47,6 +47,9 @@ class PNDMPipeline(DiffusionPipeline):
t_orig
=
prk_time_steps
[
t
]
t_orig
=
prk_time_steps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
residual
=
self
.
unet
(
image
,
t_orig
)
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
image
=
self
.
noise_scheduler
.
step_prk
(
residual
,
image
,
t
,
num_inference_steps
)
image
=
self
.
noise_scheduler
.
step_prk
(
residual
,
image
,
t
,
num_inference_steps
)
timesteps
=
self
.
noise_scheduler
.
get_time_steps
(
num_inference_steps
)
timesteps
=
self
.
noise_scheduler
.
get_time_steps
(
num_inference_steps
)
...
@@ -54,6 +57,9 @@ class PNDMPipeline(DiffusionPipeline):
...
@@ -54,6 +57,9 @@ class PNDMPipeline(DiffusionPipeline):
t_orig
=
timesteps
[
t
]
t_orig
=
timesteps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
residual
=
self
.
unet
(
image
,
t_orig
)
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
image
=
self
.
noise_scheduler
.
step_plms
(
residual
,
image
,
t
,
num_inference_steps
)
image
=
self
.
noise_scheduler
.
step_plms
(
residual
,
image
,
t
,
num_inference_steps
)
return
image
return
image
tests/test_modeling_utils.py
View file @
97e1e3ba
...
@@ -109,8 +109,14 @@ class ModelTesterMixin:
...
@@ -109,8 +109,14 @@ class ModelTesterMixin:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
image
=
model
(
**
inputs_dict
)
image
=
model
(
**
inputs_dict
)
if
isinstance
(
image
,
dict
):
image
=
image
[
"sample"
]
new_image
=
new_model
(
**
inputs_dict
)
new_image
=
new_model
(
**
inputs_dict
)
if
isinstance
(
new_image
,
dict
):
new_image
=
new_image
[
"sample"
]
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
self
.
assertLessEqual
(
max_diff
,
5e-5
,
"Models give different forward passes"
)
self
.
assertLessEqual
(
max_diff
,
5e-5
,
"Models give different forward passes"
)
...
@@ -121,7 +127,12 @@ class ModelTesterMixin:
...
@@ -121,7 +127,12 @@ class ModelTesterMixin:
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
first
=
model
(
**
inputs_dict
)
first
=
model
(
**
inputs_dict
)
if
isinstance
(
first
,
dict
):
first
=
first
[
"sample"
]
second
=
model
(
**
inputs_dict
)
second
=
model
(
**
inputs_dict
)
if
isinstance
(
second
,
dict
):
second
=
second
[
"sample"
]
out_1
=
first
.
cpu
().
numpy
()
out_1
=
first
.
cpu
().
numpy
()
out_2
=
second
.
cpu
().
numpy
()
out_2
=
second
.
cpu
().
numpy
()
...
@@ -139,6 +150,9 @@ class ModelTesterMixin:
...
@@ -139,6 +150,9 @@ class ModelTesterMixin:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
output
=
model
(
**
inputs_dict
)
if
isinstance
(
output
,
dict
):
output
=
output
[
"sample"
]
self
.
assertIsNotNone
(
output
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"sample"
].
shape
expected_shape
=
inputs_dict
[
"sample"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
...
@@ -151,7 +165,7 @@ class ModelTesterMixin:
...
@@ -151,7 +165,7 @@ class ModelTesterMixin:
# signature.parameters is an OrderedDict => so arg_names order is deterministic
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names
=
[
*
signature
.
parameters
.
keys
()]
arg_names
=
[
*
signature
.
parameters
.
keys
()]
expected_arg_names
=
[
"sample"
,
"
timesteps
"
]
expected_arg_names
=
[
"sample"
,
"
step_value
"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
def
test_model_from_config
(
self
):
def
test_model_from_config
(
self
):
...
@@ -177,8 +191,15 @@ class ModelTesterMixin:
...
@@ -177,8 +191,15 @@ class ModelTesterMixin:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output_1
=
model
(
**
inputs_dict
)
output_1
=
model
(
**
inputs_dict
)
if
isinstance
(
output_1
,
dict
):
output_1
=
output_1
[
"sample"
]
output_2
=
new_model
(
**
inputs_dict
)
output_2
=
new_model
(
**
inputs_dict
)
if
isinstance
(
output_2
,
dict
):
output_2
=
output_2
[
"sample"
]
self
.
assertEqual
(
output_1
.
shape
,
output_2
.
shape
)
self
.
assertEqual
(
output_1
.
shape
,
output_2
.
shape
)
def
test_training
(
self
):
def
test_training
(
self
):
...
@@ -188,6 +209,10 @@ class ModelTesterMixin:
...
@@ -188,6 +209,10 @@ class ModelTesterMixin:
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
train
()
model
.
train
()
output
=
model
(
**
inputs_dict
)
output
=
model
(
**
inputs_dict
)
if
isinstance
(
output
,
dict
):
output
=
output
[
"sample"
]
noise
=
torch
.
randn
((
inputs_dict
[
"sample"
].
shape
[
0
],)
+
self
.
output_shape
).
to
(
torch_device
)
noise
=
torch
.
randn
((
inputs_dict
[
"sample"
].
shape
[
0
],)
+
self
.
output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
loss
.
backward
()
...
@@ -201,6 +226,10 @@ class ModelTesterMixin:
...
@@ -201,6 +226,10 @@ class ModelTesterMixin:
ema_model
=
EMAModel
(
model
,
device
=
torch_device
)
ema_model
=
EMAModel
(
model
,
device
=
torch_device
)
output
=
model
(
**
inputs_dict
)
output
=
model
(
**
inputs_dict
)
if
isinstance
(
output
,
dict
):
output
=
output
[
"sample"
]
noise
=
torch
.
randn
((
inputs_dict
[
"sample"
].
shape
[
0
],)
+
self
.
output_shape
).
to
(
torch_device
)
noise
=
torch
.
randn
((
inputs_dict
[
"sample"
].
shape
[
0
],)
+
self
.
output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
loss
.
backward
()
...
@@ -219,7 +248,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -219,7 +248,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"
timesteps
"
:
time_step
}
return
{
"sample"
:
noise
,
"
step_value
"
:
time_step
}
@
property
@
property
def
input_shape
(
self
):
def
input_shape
(
self
):
...
@@ -255,7 +284,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -255,7 +284,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
image
=
model
(
**
self
.
dummy_input
)
[
"sample"
]
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
...
@@ -271,7 +300,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -271,7 +300,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
time_step
=
torch
.
tensor
([
10
])
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
[
"sample"
]
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
# fmt: off
...
@@ -294,7 +323,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
...
@@ -294,7 +323,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
low_res
=
torch
.
randn
((
batch_size
,
3
)
+
low_res_size
).
to
(
torch_device
)
low_res
=
torch
.
randn
((
batch_size
,
3
)
+
low_res_size
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"sample"
:
noise
,
"
timesteps
"
:
time_step
,
"low_res"
:
low_res
}
return
{
"sample"
:
noise
,
"
step_value
"
:
time_step
,
"low_res"
:
low_res
}
@
property
@
property
def
input_shape
(
self
):
def
input_shape
(
self
):
...
@@ -385,7 +414,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -385,7 +414,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
emb
=
torch
.
randn
((
batch_size
,
seq_len
,
transformer_dim
)).
to
(
torch_device
)
emb
=
torch
.
randn
((
batch_size
,
seq_len
,
transformer_dim
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"sample"
:
noise
,
"
timesteps
"
:
time_step
,
"transformer_out"
:
emb
}
return
{
"sample"
:
noise
,
"
step_value
"
:
time_step
,
"transformer_out"
:
emb
}
@
property
@
property
def
input_shape
(
self
):
def
input_shape
(
self
):
...
@@ -477,7 +506,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -477,7 +506,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"
timesteps
"
:
time_step
}
return
{
"sample"
:
noise
,
"
step_value
"
:
time_step
}
@
property
@
property
def
input_shape
(
self
):
def
input_shape
(
self
):
...
@@ -512,7 +541,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -512,7 +541,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
image
=
model
(
**
self
.
dummy_input
)
[
"sample"
]
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
...
@@ -528,7 +557,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -528,7 +557,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
[
"sample"
]
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
# fmt: off
...
@@ -572,7 +601,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -572,7 +601,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
10
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"
timesteps
"
:
time_step
}
return
{
"sample"
:
noise
,
"
step_value
"
:
time_step
}
@
property
@
property
def
input_shape
(
self
):
def
input_shape
(
self
):
...
@@ -837,7 +866,15 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -837,7 +866,15 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
# 1. Load models
# 1. Load models
model
=
UNetUnconditionalModel
(
model
=
UNetUnconditionalModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
,
ddpm
=
True
block_channels
=
(
32
,
64
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
image_size
=
32
,
in_channels
=
3
,
out_channels
=
3
,
down_blocks
=
(
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
),
up_blocks
=
(
"UNetResAttnUpBlock2D"
,
"UNetResUpBlock2D"
),
ddpm
=
True
,
)
)
schedular
=
DDPMScheduler
(
timesteps
=
10
)
schedular
=
DDPMScheduler
(
timesteps
=
10
)
...
@@ -1034,7 +1071,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -1034,7 +1071,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ldm_uncond
(
self
):
def
test_ldm_uncond
(
self
):
ldm
=
LatentDiffusionUncondPipeline
.
from_pretrained
(
"fusing/latent-diffusion-celeba-256"
)
ldm
=
LatentDiffusionUncondPipeline
.
from_pretrained
(
"fusing/latent-diffusion-celeba-256"
,
ldm
=
True
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
)
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment