Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
6d5ef87e
Unverified
Commit
6d5ef87e
authored
Jul 14, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 14, 2022
Browse files
[DDPM] Make DDPM work (#88)
* up * finish * uP
parent
e7fe901e
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
188 additions
and
293 deletions
+188
-293
conversion.py
conversion.py
+33
-2
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+40
-38
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+27
-197
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+10
-1
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+0
-2
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+7
-1
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+32
-23
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+39
-29
No files found.
conversion.py
View file @
6d5ef87e
...
...
@@ -50,6 +50,8 @@ from diffusers.testing_utils import floats_tensor, slow, torch_device
from
diffusers.training_utils
import
EMAModel
# 1. LDM
def
test_output_pretrained_ldm_dummy
():
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
ldm
=
True
)
model
.
eval
()
...
...
@@ -86,9 +88,38 @@ def test_output_pretrained_ldm():
import
ipdb
;
ipdb
.
set_trace
()
# To see the how the final model should look like
test_output_pretrained_ldm_dummy
()
test_output_pretrained_ldm
()
# => this is the architecture in which the model should be saved in the new format
# -> verify new repo with the following tests (in `test_modeling_utils.py`)
# - test_ldm_uncond (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetLDMModelTests)
#test_output_pretrained_ldm_dummy()
#test_output_pretrained_ldm()
# 2. DDPM
def
get_model
(
model_id
):
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
ldm
=
True
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
print
(
model
)
# Repos to convert and port to google (part of https://github.com/hojonathanho/diffusion)
# - fusing/ddpm_dummy
# - fusing/ddpm-cifar10
# - https://huggingface.co/fusing/ddpm-lsun-church-ema
# - https://huggingface.co/fusing/ddpm-lsun-bedroom-ema
# - https://huggingface.co/fusing/ddpm-celeba-hq
# tests to make sure to pass
# - test_ddim_cifar10, test_ddim_lsun, test_ddpm_cifar10, test_ddim_cifar10 (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetModelTests)
# e.g.
get_model
(
"fusing/ddpm-cifar10"
)
src/diffusers/modeling_utils.py
View file @
6d5ef87e
...
...
@@ -492,44 +492,46 @@ class ModelMixin(torch.nn.Module):
)
raise
RuntimeError
(
f
"Error(s) in loading state_dict for
{
model
.
__class__
.
__name__
}
:
\n\t
{
error_msg
}
"
)
if
len
(
unexpected_keys
)
>
0
:
logger
.
warning
(
f
"Some weights of the model checkpoint at
{
pretrained_model_name_or_path
}
were not used when"
f
" initializing
{
model
.
__class__
.
__name__
}
:
{
unexpected_keys
}
\n
- This IS expected if you are"
f
" initializing
{
model
.
__class__
.
__name__
}
from the checkpoint of a model trained on another task or"
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).
\n
- This IS NOT expected if you are initializing"
f
"
{
model
.
__class__
.
__name__
}
from the checkpoint of a model that you expect to be exactly identical"
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else
:
logger
.
info
(
f
"All model checkpoint weights were used when initializing
{
model
.
__class__
.
__name__
}
.
\n
"
)
if
len
(
missing_keys
)
>
0
:
logger
.
warning
(
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized:
{
missing_keys
}
\n
You should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
elif
len
(
mismatched_keys
)
==
0
:
logger
.
info
(
f
"All the weights of
{
model
.
__class__
.
__name__
}
were initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
.
\n
If your task is similar to the task the model of the checkpoint"
f
" was trained on, you can already use
{
model
.
__class__
.
__name__
}
for predictions without further"
" training."
)
if
len
(
mismatched_keys
)
>
0
:
mismatched_warning
=
"
\n
"
.
join
(
[
f
"-
{
key
}
: found shape
{
shape1
}
in the checkpoint and
{
shape2
}
in the model instantiated"
for
key
,
shape1
,
shape2
in
mismatched_keys
]
)
logger
.
warning
(
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized because the shapes did not"
f
" match:
\n
{
mismatched_warning
}
\n
You should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
if
False
:
if
len
(
unexpected_keys
)
>
0
:
logger
.
warning
(
f
"Some weights of the model checkpoint at
{
pretrained_model_name_or_path
}
were not used when"
f
" initializing
{
model
.
__class__
.
__name__
}
:
{
unexpected_keys
}
\n
- This IS expected if you are"
f
" initializing
{
model
.
__class__
.
__name__
}
from the checkpoint of a model trained on another task"
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).
\n
- This IS NOT expected if you are initializing"
f
"
{
model
.
__class__
.
__name__
}
from the checkpoint of a model that you expect to be exactly"
" identical (initializing a BertForSequenceClassification model from a"
" BertForSequenceClassification model)."
)
else
:
logger
.
info
(
f
"All model checkpoint weights were used when initializing
{
model
.
__class__
.
__name__
}
.
\n
"
)
if
len
(
missing_keys
)
>
0
:
logger
.
warning
(
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized:
{
missing_keys
}
\n
You should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
elif
len
(
mismatched_keys
)
==
0
:
logger
.
info
(
f
"All the weights of
{
model
.
__class__
.
__name__
}
were initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
.
\n
If your task is similar to the task the model of the"
f
" checkpoint was trained on, you can already use
{
model
.
__class__
.
__name__
}
for predictions"
" without further training."
)
if
len
(
mismatched_keys
)
>
0
:
mismatched_warning
=
"
\n
"
.
join
(
[
f
"-
{
key
}
: found shape
{
shape1
}
in the checkpoint and
{
shape2
}
in the model instantiated"
for
key
,
shape1
,
shape2
in
mismatched_keys
]
)
logger
.
warning
(
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized because the shapes did not"
f
" match:
\n
{
mismatched_warning
}
\n
You should probably TRAIN this model on a down-stream task to be"
" able to use it for predictions and inference."
)
return
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
...
...
src/diffusers/models/attention.py
View file @
6d5ef87e
...
...
@@ -166,188 +166,6 @@ class AttentionBlock(nn.Module):
return
result
class
AttentionBlockNew_2
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_head_channels
=
1
,
num_groups
=
32
,
encoder_channels
=
None
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
channels
//
num_head_channels
self
.
num_head_size
=
num_head_channels
self
.
rescale_output_factor
=
rescale_output_factor
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
nn
.
Conv1d
(
encoder_channels
,
channels
*
2
,
1
)
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
# ------------------------- new -----------------------
num_heads
=
self
.
n_heads
self
.
channels
=
channels
if
num_head_channels
is
None
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
group_norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
# define q,k,v as linear layers
self
.
query
=
nn
.
Linear
(
channels
,
channels
)
self
.
key
=
nn
.
Linear
(
channels
,
channels
)
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
))
# ------------------------- new -----------------------
def
set_weight
(
self
,
attn_layer
):
self
.
norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
self
.
qkv
.
weight
.
data
=
attn_layer
.
qkv
.
weight
.
data
self
.
qkv
.
bias
.
data
=
attn_layer
.
qkv
.
bias
.
data
self
.
proj
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
self
.
proj
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
if
hasattr
(
attn_layer
,
"q"
):
module
=
attn_layer
qkv_weight
=
torch
.
cat
([
module
.
q
.
weight
.
data
,
module
.
k
.
weight
.
data
,
module
.
v
.
weight
.
data
],
dim
=
0
)[
:,
:,
:,
0
]
qkv_bias
=
torch
.
cat
([
module
.
q
.
bias
.
data
,
module
.
k
.
bias
.
data
,
module
.
v
.
bias
.
data
],
dim
=
0
)
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
bias
.
data
=
qkv_bias
proj_out
=
zero_module
(
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
))
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
self
.
proj
=
proj_out
self
.
set_weights_2
(
attn_layer
)
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
n_heads
,
self
.
num_head_size
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
def
set_weights_2
(
self
,
attn_layer
):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
n_heads
,
3
*
self
.
channels
//
self
.
n_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
n_heads
,
3
*
self
.
channels
//
self
.
n_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
n_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
n_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
def
forward_2
(
self
,
hidden_states
):
residual
=
hidden_states
batch
,
channel
,
height
,
width
=
hidden_states
.
shape
# norm
hidden_states
=
self
.
group_norm
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
batch
,
channel
,
height
*
width
).
transpose
(
1
,
2
)
# proj to q, k, v
query_proj
=
self
.
query
(
hidden_states
)
key_proj
=
self
.
key
(
hidden_states
)
value_proj
=
self
.
value
(
hidden_states
)
# transpose
query_states
=
self
.
transpose_for_scores
(
query_proj
)
key_states
=
self
.
transpose_for_scores
(
key_proj
)
value_states
=
self
.
transpose_for_scores
(
value_proj
)
# get scores
attention_scores
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
math
.
sqrt
(
self
.
channels
//
self
.
n_heads
)
attention_probs
=
nn
.
functional
.
softmax
(
attention_scores
,
dim
=-
1
)
# compute attention output
context_states
=
torch
.
matmul
(
attention_probs
,
value_states
)
context_states
=
context_states
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_states_shape
=
context_states
.
size
()[:
-
2
]
+
(
self
.
channels
,)
context_states
=
context_states
.
view
(
new_context_states_shape
)
# compute next hidden_states
hidden_states
=
self
.
proj_attn
(
context_states
)
hidden_states
=
hidden_states
.
transpose
(
-
1
,
-
2
).
reshape
(
batch
,
channel
,
height
,
width
)
# res connect and rescale
hidden_states
=
(
hidden_states
+
residual
)
/
self
.
rescale_output_factor
return
hidden_states
def
forward
(
self
,
x
,
encoder_out
=
None
):
b
,
c
,
*
spatial
=
x
.
shape
hid_states
=
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
)
qkv
=
self
.
qkv
(
hid_states
)
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
if
encoder_out
is
not
None
:
encoder_kv
=
self
.
encoder_kv
(
encoder_out
)
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
k
=
torch
.
cat
([
ek
,
k
],
dim
=-
1
)
v
=
torch
.
cat
([
ev
,
v
],
dim
=-
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
h
=
a
.
reshape
(
bs
,
-
1
,
length
)
h
=
self
.
proj
(
h
)
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
result
=
x
+
h
result
=
result
/
self
.
rescale_output_factor
result_2
=
self
.
forward_2
(
x
)
print
((
result
-
result_2
).
abs
().
sum
())
return
result_2
class
AttentionBlockNew
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
...
...
@@ -387,7 +205,7 @@ class AttentionBlockNew(nn.Module):
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
))
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
self
.
num_head_size
)
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
-
1
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
...
...
@@ -434,24 +252,36 @@ class AttentionBlockNew(nn.Module):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
)
if
hasattr
(
attn_layer
,
"q"
):
self
.
query
.
weight
.
data
=
attn_layer
.
q
.
weight
.
data
[:,
:,
0
,
0
]
self
.
key
.
weight
.
data
=
attn_layer
.
k
.
weight
.
data
[:,
:,
0
,
0
]
self
.
value
.
weight
.
data
=
attn_layer
.
v
.
weight
.
data
[:,
:,
0
,
0
]
self
.
query
.
bias
.
data
=
attn_layer
.
q
.
bias
.
data
self
.
key
.
bias
.
data
=
attn_layer
.
k
.
bias
.
data
self
.
value
.
bias
.
data
=
attn_layer
.
v
.
bias
.
data
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj_out
.
weight
.
data
[:,
:,
0
,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj_out
.
bias
.
data
else
:
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
class
SpatialTransformer
(
nn
.
Module
):
...
...
src/diffusers/models/resnet.py
View file @
6d5ef87e
...
...
@@ -87,12 +87,21 @@ class Downsample2D(nn.Module):
self
.
conv
=
conv
def
forward
(
self
,
x
):
# print("use_conv", self.use_conv)
# print("padding", self.padding)
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv
and
self
.
padding
==
0
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
return
self
.
conv
(
x
)
# print("x", x.abs().sum())
self
.
hey
=
x
assert
x
.
shape
[
1
]
==
self
.
channels
x
=
self
.
conv
(
x
)
self
.
yas
=
x
# print("x", x.abs().sum())
return
x
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
...
...
src/diffusers/models/unet.py
View file @
6d5ef87e
...
...
@@ -177,9 +177,7 @@ class UNetModel(ModelMixin, ConfigMixin):
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
print
(
"hs"
,
hs
[
-
1
].
abs
().
sum
())
h
=
self
.
mid_new
(
hs
[
-
1
],
temb
)
print
(
"h"
,
h
.
abs
().
sum
())
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
...
...
src/diffusers/models/unet_new.py
View file @
6d5ef87e
...
...
@@ -51,6 +51,7 @@ def get_down_block(
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
downsample_padding
=
downsample_padding
,
attn_num_head_channels
=
attn_num_head_channels
,
)
...
...
@@ -186,6 +187,7 @@ class UNetResAttnDownBlock2D(nn.Module):
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
downsample_padding
=
1
,
add_downsample
=
True
,
):
super
().
__init__
()
...
...
@@ -224,7 +226,11 @@ class UNetResAttnDownBlock2D(nn.Module):
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
(
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
1
,
name
=
"op"
)]
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
name
=
"op"
)
]
)
else
:
self
.
downsamplers
=
None
...
...
src/diffusers/models/unet_unconditional.py
View file @
6d5ef87e
...
...
@@ -94,25 +94,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
):
super
().
__init__
()
# DELETE if statements if not necessary anymore
# DDPM
if
ddpm
:
out_channels
=
out_ch
image_size
=
resolution
block_channels
=
[
x
*
ch
for
x
in
ch_mult
]
conv_resample
=
resamp_with_conv
flip_sin_to_cos
=
False
downscale_freq_shift
=
1
resnet_eps
=
1e-6
block_channels
=
(
32
,
64
)
down_blocks
=
(
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
)
up_blocks
=
(
"UNetResUpBlock2D"
,
"UNetResAttnUpBlock2D"
)
downsample_padding
=
0
num_head_channels
=
64
# register all __init__ params with self.register
self
.
register_to_config
(
image_size
=
image_size
,
...
...
@@ -250,6 +231,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_channels
,
)
if
ddpm
:
out_channels
=
out_ch
image_size
=
resolution
block_channels
=
[
x
*
ch
for
x
in
ch_mult
]
conv_resample
=
resamp_with_conv
self
.
init_for_ddpm
(
ch_mult
,
ch
,
...
...
@@ -290,13 +275,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# append to tuple
down_block_res_samples
+=
res_samples
print
(
"sample"
,
sample
.
abs
().
sum
())
# 4. mid block
if
self
.
config
.
ddpm
:
sample
=
self
.
mid_new_2
(
sample
,
emb
)
else
:
sample
=
self
.
mid
(
sample
,
emb
)
print
(
"sample"
,
sample
.
abs
().
sum
())
# 5. up blocks
for
upsample_block
in
self
.
upsample_blocks
:
...
...
@@ -373,8 +356,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
elif
self
.
config
.
ddpm
:
# =============== SET WEIGHTS ===============
# =============== TIME ======================
self
.
time_embed
[
0
]
=
self
.
temb
.
dense
[
0
]
self
.
time_embed
[
2
]
=
self
.
temb
.
dense
[
1
]
self
.
time_embedding
.
linear_1
.
weight
.
data
=
self
.
temb
.
dense
[
0
].
weight
.
data
self
.
time_embedding
.
linear_1
.
bias
.
data
=
self
.
temb
.
dense
[
0
].
bias
.
data
self
.
time_embedding
.
linear_2
.
weight
.
data
=
self
.
temb
.
dense
[
1
].
weight
.
data
self
.
time_embedding
.
linear_2
.
bias
.
data
=
self
.
temb
.
dense
[
1
].
bias
.
data
for
i
,
block
in
enumerate
(
self
.
down
):
if
hasattr
(
block
,
"downsample"
):
...
...
@@ -391,6 +376,23 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
mid_new_2
.
resnets
[
1
].
set_weight
(
self
.
mid
.
block_2
)
self
.
mid_new_2
.
attentions
[
0
].
set_weight
(
self
.
mid
.
attn_1
)
for
i
,
block
in
enumerate
(
self
.
up
):
k
=
len
(
self
.
up
)
-
1
-
i
if
hasattr
(
block
,
"upsample"
):
self
.
upsample_blocks
[
k
].
upsamplers
[
0
].
conv
.
weight
.
data
=
block
.
upsample
.
conv
.
weight
.
data
self
.
upsample_blocks
[
k
].
upsamplers
[
0
].
conv
.
bias
.
data
=
block
.
upsample
.
conv
.
bias
.
data
if
hasattr
(
block
,
"block"
)
and
len
(
block
.
block
)
>
0
:
for
j
in
range
(
self
.
num_res_blocks
+
1
):
self
.
upsample_blocks
[
k
].
resnets
[
j
].
set_weight
(
block
.
block
[
j
])
if
hasattr
(
block
,
"attn"
)
and
len
(
block
.
attn
)
>
0
:
for
j
in
range
(
self
.
num_res_blocks
+
1
):
self
.
upsample_blocks
[
k
].
attentions
[
j
].
set_weight
(
block
.
attn
[
j
])
self
.
conv_norm_out
.
weight
.
data
=
self
.
norm_out
.
weight
.
data
self
.
conv_norm_out
.
bias
.
data
=
self
.
norm_out
.
bias
.
data
self
.
remove_ddpm
()
def
init_for_ddpm
(
self
,
ch_mult
,
...
...
@@ -685,3 +687,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
del
self
.
middle_block
del
self
.
output_blocks
del
self
.
out
def
remove_ddpm
(
self
):
del
self
.
temb
del
self
.
down
del
self
.
mid_new
del
self
.
up
del
self
.
norm_out
tests/test_modeling_utils.py
View file @
6d5ef87e
...
...
@@ -40,7 +40,6 @@ from diffusers import (
ScoreSdeVpPipeline
,
ScoreSdeVpScheduler
,
UNetLDMModel
,
UNetModel
,
UNetUnconditionalModel
,
VQModel
,
)
...
...
@@ -209,7 +208,7 @@ class ModelTesterMixin:
class
UnetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetModel
model_class
=
UNet
Unconditional
Model
@
property
def
dummy_input
(
self
):
...
...
@@ -234,15 +233,24 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
init_dict
=
{
"ch"
:
32
,
"ch_mult"
:
(
1
,
2
),
"block_channels"
:
(
32
,
64
),
"down_blocks"
:
(
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
),
"up_blocks"
:
(
"UNetResAttnUpBlock2D"
,
"UNetResUpBlock2D"
),
"num_head_channels"
:
None
,
"out_channels"
:
3
,
"in_channels"
:
3
,
"num_res_blocks"
:
2
,
"attn_resolutions"
:
(
16
,),
"resolution"
:
32
,
"image_size"
:
32
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
output_loading_info
=
True
)
model
,
loading_info
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
output_loading_info
=
True
,
ddpm
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
...
...
@@ -252,27 +260,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
print
(
"Original success!!!"
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
ddpm
=
True
)
model
.
eval
()
...
...
@@ -849,7 +836,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
def
test_from_pretrained_save_pretrained
(
self
):
# 1. Load models
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
model
=
UNetUnconditionalModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
,
ddpm
=
True
)
schedular
=
DDPMScheduler
(
timesteps
=
10
)
ddpm
=
DDPMPipeline
(
model
,
schedular
)
...
...
@@ -888,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_ddpm_cifar10
(
self
):
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
unet
=
UNet
Unconditional
Model
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
...
...
@@ -901,7 +890,28 @@ class PipelineTesterMixin(unittest.TestCase):
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
(
[
-
0.5712
,
-
0.6215
,
-
0.5953
,
-
0.5438
,
-
0.4775
,
-
0.4539
,
-
0.5172
,
-
0.4872
,
-
0.5105
]
[
-
0.1601
,
-
0.2823
,
-
0.6123
,
-
0.2305
,
-
0.3236
,
-
0.4706
,
-
0.1691
,
-
0.2836
,
-
0.3231
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_ddim_lsun
(
self
):
model_id
=
"fusing/ddpm-lsun-bedroom-ema"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
expected_slice
=
torch
.
tensor
(
[
-
0.9879
,
-
0.9598
,
-
0.9312
,
-
0.9953
,
-
0.9963
,
-
0.9995
,
-
0.9957
,
-
1.0000
,
-
0.9863
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
...
...
@@ -909,7 +919,7 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_ddim_cifar10
(
self
):
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
unet
=
UNet
Unconditional
Model
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
...
...
@@ -929,7 +939,7 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_pndm_cifar10
(
self
):
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
unet
=
UNet
Unconditional
Model
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment