Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
6d5ef87e
Unverified
Commit
6d5ef87e
authored
Jul 14, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 14, 2022
Browse files
[DDPM] Make DDPM work (#88)
* up * finish * uP
parent
e7fe901e
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
188 additions
and
293 deletions
+188
-293
conversion.py
conversion.py
+33
-2
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+40
-38
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+27
-197
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+10
-1
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+0
-2
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+7
-1
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+32
-23
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+39
-29
No files found.
conversion.py
View file @
6d5ef87e
...
...
@@ -50,6 +50,8 @@ from diffusers.testing_utils import floats_tensor, slow, torch_device
from
diffusers.training_utils
import
EMAModel
# 1. LDM
def
test_output_pretrained_ldm_dummy
():
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
ldm
=
True
)
model
.
eval
()
...
...
@@ -86,9 +88,38 @@ def test_output_pretrained_ldm():
import
ipdb
;
ipdb
.
set_trace
()
# To see the how the final model should look like
test_output_pretrained_ldm_dummy
()
test_output_pretrained_ldm
()
# => this is the architecture in which the model should be saved in the new format
# -> verify new repo with the following tests (in `test_modeling_utils.py`)
# - test_ldm_uncond (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetLDMModelTests)
#test_output_pretrained_ldm_dummy()
#test_output_pretrained_ldm()
# 2. DDPM
def
get_model
(
model_id
):
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
ldm
=
True
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
print
(
model
)
# Repos to convert and port to google (part of https://github.com/hojonathanho/diffusion)
# - fusing/ddpm_dummy
# - fusing/ddpm-cifar10
# - https://huggingface.co/fusing/ddpm-lsun-church-ema
# - https://huggingface.co/fusing/ddpm-lsun-bedroom-ema
# - https://huggingface.co/fusing/ddpm-celeba-hq
# tests to make sure to pass
# - test_ddim_cifar10, test_ddim_lsun, test_ddpm_cifar10, test_ddim_cifar10 (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetModelTests)
# e.g.
get_model
(
"fusing/ddpm-cifar10"
)
src/diffusers/modeling_utils.py
View file @
6d5ef87e
...
...
@@ -492,44 +492,46 @@ class ModelMixin(torch.nn.Module):
)
raise
RuntimeError
(
f
"Error(s) in loading state_dict for
{
model
.
__class__
.
__name__
}
:
\n\t
{
error_msg
}
"
)
if
len
(
unexpected_keys
)
>
0
:
logger
.
warning
(
f
"Some weights of the model checkpoint at
{
pretrained_model_name_or_path
}
were not used when"
f
" initializing
{
model
.
__class__
.
__name__
}
:
{
unexpected_keys
}
\n
- This IS expected if you are"
f
" initializing
{
model
.
__class__
.
__name__
}
from the checkpoint of a model trained on another task or"
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).
\n
- This IS NOT expected if you are initializing"
f
"
{
model
.
__class__
.
__name__
}
from the checkpoint of a model that you expect to be exactly identical"
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else
:
logger
.
info
(
f
"All model checkpoint weights were used when initializing
{
model
.
__class__
.
__name__
}
.
\n
"
)
if
len
(
missing_keys
)
>
0
:
logger
.
warning
(
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized:
{
missing_keys
}
\n
You should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
elif
len
(
mismatched_keys
)
==
0
:
logger
.
info
(
f
"All the weights of
{
model
.
__class__
.
__name__
}
were initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
.
\n
If your task is similar to the task the model of the checkpoint"
f
" was trained on, you can already use
{
model
.
__class__
.
__name__
}
for predictions without further"
" training."
)
if
len
(
mismatched_keys
)
>
0
:
mismatched_warning
=
"
\n
"
.
join
(
[
f
"-
{
key
}
: found shape
{
shape1
}
in the checkpoint and
{
shape2
}
in the model instantiated"
for
key
,
shape1
,
shape2
in
mismatched_keys
]
)
logger
.
warning
(
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized because the shapes did not"
f
" match:
\n
{
mismatched_warning
}
\n
You should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
if
False
:
if
len
(
unexpected_keys
)
>
0
:
logger
.
warning
(
f
"Some weights of the model checkpoint at
{
pretrained_model_name_or_path
}
were not used when"
f
" initializing
{
model
.
__class__
.
__name__
}
:
{
unexpected_keys
}
\n
- This IS expected if you are"
f
" initializing
{
model
.
__class__
.
__name__
}
from the checkpoint of a model trained on another task"
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).
\n
- This IS NOT expected if you are initializing"
f
"
{
model
.
__class__
.
__name__
}
from the checkpoint of a model that you expect to be exactly"
" identical (initializing a BertForSequenceClassification model from a"
" BertForSequenceClassification model)."
)
else
:
logger
.
info
(
f
"All model checkpoint weights were used when initializing
{
model
.
__class__
.
__name__
}
.
\n
"
)
if
len
(
missing_keys
)
>
0
:
logger
.
warning
(
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized:
{
missing_keys
}
\n
You should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
elif
len
(
mismatched_keys
)
==
0
:
logger
.
info
(
f
"All the weights of
{
model
.
__class__
.
__name__
}
were initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
.
\n
If your task is similar to the task the model of the"
f
" checkpoint was trained on, you can already use
{
model
.
__class__
.
__name__
}
for predictions"
" without further training."
)
if
len
(
mismatched_keys
)
>
0
:
mismatched_warning
=
"
\n
"
.
join
(
[
f
"-
{
key
}
: found shape
{
shape1
}
in the checkpoint and
{
shape2
}
in the model instantiated"
for
key
,
shape1
,
shape2
in
mismatched_keys
]
)
logger
.
warning
(
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized because the shapes did not"
f
" match:
\n
{
mismatched_warning
}
\n
You should probably TRAIN this model on a down-stream task to be"
" able to use it for predictions and inference."
)
return
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
...
...
src/diffusers/models/attention.py
View file @
6d5ef87e
...
...
@@ -166,188 +166,6 @@ class AttentionBlock(nn.Module):
return
result
class
AttentionBlockNew_2
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_head_channels
=
1
,
num_groups
=
32
,
encoder_channels
=
None
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
channels
//
num_head_channels
self
.
num_head_size
=
num_head_channels
self
.
rescale_output_factor
=
rescale_output_factor
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
nn
.
Conv1d
(
encoder_channels
,
channels
*
2
,
1
)
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
# ------------------------- new -----------------------
num_heads
=
self
.
n_heads
self
.
channels
=
channels
if
num_head_channels
is
None
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
group_norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
# define q,k,v as linear layers
self
.
query
=
nn
.
Linear
(
channels
,
channels
)
self
.
key
=
nn
.
Linear
(
channels
,
channels
)
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
))
# ------------------------- new -----------------------
def
set_weight
(
self
,
attn_layer
):
self
.
norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
self
.
qkv
.
weight
.
data
=
attn_layer
.
qkv
.
weight
.
data
self
.
qkv
.
bias
.
data
=
attn_layer
.
qkv
.
bias
.
data
self
.
proj
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
self
.
proj
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
if
hasattr
(
attn_layer
,
"q"
):
module
=
attn_layer
qkv_weight
=
torch
.
cat
([
module
.
q
.
weight
.
data
,
module
.
k
.
weight
.
data
,
module
.
v
.
weight
.
data
],
dim
=
0
)[
:,
:,
:,
0
]
qkv_bias
=
torch
.
cat
([
module
.
q
.
bias
.
data
,
module
.
k
.
bias
.
data
,
module
.
v
.
bias
.
data
],
dim
=
0
)
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
bias
.
data
=
qkv_bias
proj_out
=
zero_module
(
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
))
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
self
.
proj
=
proj_out
self
.
set_weights_2
(
attn_layer
)
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
n_heads
,
self
.
num_head_size
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
def
set_weights_2
(
self
,
attn_layer
):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
n_heads
,
3
*
self
.
channels
//
self
.
n_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
n_heads
,
3
*
self
.
channels
//
self
.
n_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
n_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
n_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
def
forward_2
(
self
,
hidden_states
):
residual
=
hidden_states
batch
,
channel
,
height
,
width
=
hidden_states
.
shape
# norm
hidden_states
=
self
.
group_norm
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
batch
,
channel
,
height
*
width
).
transpose
(
1
,
2
)
# proj to q, k, v
query_proj
=
self
.
query
(
hidden_states
)
key_proj
=
self
.
key
(
hidden_states
)
value_proj
=
self
.
value
(
hidden_states
)
# transpose
query_states
=
self
.
transpose_for_scores
(
query_proj
)
key_states
=
self
.
transpose_for_scores
(
key_proj
)
value_states
=
self
.
transpose_for_scores
(
value_proj
)
# get scores
attention_scores
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
math
.
sqrt
(
self
.
channels
//
self
.
n_heads
)
attention_probs
=
nn
.
functional
.
softmax
(
attention_scores
,
dim
=-
1
)
# compute attention output
context_states
=
torch
.
matmul
(
attention_probs
,
value_states
)
context_states
=
context_states
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_states_shape
=
context_states
.
size
()[:
-
2
]
+
(
self
.
channels
,)
context_states
=
context_states
.
view
(
new_context_states_shape
)
# compute next hidden_states
hidden_states
=
self
.
proj_attn
(
context_states
)
hidden_states
=
hidden_states
.
transpose
(
-
1
,
-
2
).
reshape
(
batch
,
channel
,
height
,
width
)
# res connect and rescale
hidden_states
=
(
hidden_states
+
residual
)
/
self
.
rescale_output_factor
return
hidden_states
def
forward
(
self
,
x
,
encoder_out
=
None
):
b
,
c
,
*
spatial
=
x
.
shape
hid_states
=
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
)
qkv
=
self
.
qkv
(
hid_states
)
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
if
encoder_out
is
not
None
:
encoder_kv
=
self
.
encoder_kv
(
encoder_out
)
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
k
=
torch
.
cat
([
ek
,
k
],
dim
=-
1
)
v
=
torch
.
cat
([
ev
,
v
],
dim
=-
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
h
=
a
.
reshape
(
bs
,
-
1
,
length
)
h
=
self
.
proj
(
h
)
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
result
=
x
+
h
result
=
result
/
self
.
rescale_output_factor
result_2
=
self
.
forward_2
(
x
)
print
((
result
-
result_2
).
abs
().
sum
())
return
result_2
class
AttentionBlockNew
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
...
...
@@ -387,7 +205,7 @@ class AttentionBlockNew(nn.Module):
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
))
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
self
.
num_head_size
)
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
-
1
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
...
...
@@ -434,24 +252,36 @@ class AttentionBlockNew(nn.Module):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
)
if
hasattr
(
attn_layer
,
"q"
):
self
.
query
.
weight
.
data
=
attn_layer
.
q
.
weight
.
data
[:,
:,
0
,
0
]
self
.
key
.
weight
.
data
=
attn_layer
.
k
.
weight
.
data
[:,
:,
0
,
0
]
self
.
value
.
weight
.
data
=
attn_layer
.
v
.
weight
.
data
[:,
:,
0
,
0
]
self
.
query
.
bias
.
data
=
attn_layer
.
q
.
bias
.
data
self
.
key
.
bias
.
data
=
attn_layer
.
k
.
bias
.
data
self
.
value
.
bias
.
data
=
attn_layer
.
v
.
bias
.
data
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj_out
.
weight
.
data
[:,
:,
0
,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj_out
.
bias
.
data
else
:
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
class
SpatialTransformer
(
nn
.
Module
):
...
...
src/diffusers/models/resnet.py
View file @
6d5ef87e
...
...
@@ -87,12 +87,21 @@ class Downsample2D(nn.Module):
self
.
conv
=
conv
def
forward
(
self
,
x
):
# print("use_conv", self.use_conv)
# print("padding", self.padding)
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv
and
self
.
padding
==
0
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
return
self
.
conv
(
x
)
# print("x", x.abs().sum())
self
.
hey
=
x
assert
x
.
shape
[
1
]
==
self
.
channels
x
=
self
.
conv
(
x
)
self
.
yas
=
x
# print("x", x.abs().sum())
return
x
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
...
...
src/diffusers/models/unet.py
View file @
6d5ef87e
...
...
@@ -177,9 +177,7 @@ class UNetModel(ModelMixin, ConfigMixin):
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
print
(
"hs"
,
hs
[
-
1
].
abs
().
sum
())
h
=
self
.
mid_new
(
hs
[
-
1
],
temb
)
print
(
"h"
,
h
.
abs
().
sum
())
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
...
...
src/diffusers/models/unet_new.py
View file @
6d5ef87e
...
...
@@ -51,6 +51,7 @@ def get_down_block(
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
downsample_padding
=
downsample_padding
,
attn_num_head_channels
=
attn_num_head_channels
,
)
...
...
@@ -186,6 +187,7 @@ class UNetResAttnDownBlock2D(nn.Module):
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
downsample_padding
=
1
,
add_downsample
=
True
,
):
super
().
__init__
()
...
...
@@ -224,7 +226,11 @@ class UNetResAttnDownBlock2D(nn.Module):
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
(
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
1
,
name
=
"op"
)]
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
name
=
"op"
)
]
)
else
:
self
.
downsamplers
=
None
...
...
src/diffusers/models/unet_unconditional.py
View file @
6d5ef87e
...
...
@@ -94,25 +94,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
):
super
().
__init__
()
# DELETE if statements if not necessary anymore
# DDPM
if
ddpm
:
out_channels
=
out_ch
image_size
=
resolution
block_channels
=
[
x
*
ch
for
x
in
ch_mult
]
conv_resample
=
resamp_with_conv
flip_sin_to_cos
=
False
downscale_freq_shift
=
1
resnet_eps
=
1e-6
block_channels
=
(
32
,
64
)
down_blocks
=
(
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
)
up_blocks
=
(
"UNetResUpBlock2D"
,
"UNetResAttnUpBlock2D"
)
downsample_padding
=
0
num_head_channels
=
64
# register all __init__ params with self.register
self
.
register_to_config
(
image_size
=
image_size
,
...
...
@@ -250,6 +231,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_channels
,
)
if
ddpm
:
out_channels
=
out_ch
image_size
=
resolution
block_channels
=
[
x
*
ch
for
x
in
ch_mult
]
conv_resample
=
resamp_with_conv
self
.
init_for_ddpm
(
ch_mult
,
ch
,
...
...
@@ -290,13 +275,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# append to tuple
down_block_res_samples
+=
res_samples
print
(
"sample"
,
sample
.
abs
().
sum
())
# 4. mid block
if
self
.
config
.
ddpm
:
sample
=
self
.
mid_new_2
(
sample
,
emb
)
else
:
sample
=
self
.
mid
(
sample
,
emb
)
print
(
"sample"
,
sample
.
abs
().
sum
())
# 5. up blocks
for
upsample_block
in
self
.
upsample_blocks
:
...
...
@@ -373,8 +356,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
elif
self
.
config
.
ddpm
:
# =============== SET WEIGHTS ===============
# =============== TIME ======================
self
.
time_embed
[
0
]
=
self
.
temb
.
dense
[
0
]
self
.
time_embed
[
2
]
=
self
.
temb
.
dense
[
1
]
self
.
time_embedding
.
linear_1
.
weight
.
data
=
self
.
temb
.
dense
[
0
].
weight
.
data
self
.
time_embedding
.
linear_1
.
bias
.
data
=
self
.
temb
.
dense
[
0
].
bias
.
data
self
.
time_embedding
.
linear_2
.
weight
.
data
=
self
.
temb
.
dense
[
1
].
weight
.
data
self
.
time_embedding
.
linear_2
.
bias
.
data
=
self
.
temb
.
dense
[
1
].
bias
.
data
for
i
,
block
in
enumerate
(
self
.
down
):
if
hasattr
(
block
,
"downsample"
):
...
...
@@ -391,6 +376,23 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
mid_new_2
.
resnets
[
1
].
set_weight
(
self
.
mid
.
block_2
)
self
.
mid_new_2
.
attentions
[
0
].
set_weight
(
self
.
mid
.
attn_1
)
for
i
,
block
in
enumerate
(
self
.
up
):
k
=
len
(
self
.
up
)
-
1
-
i
if
hasattr
(
block
,
"upsample"
):
self
.
upsample_blocks
[
k
].
upsamplers
[
0
].
conv
.
weight
.
data
=
block
.
upsample
.
conv
.
weight
.
data
self
.
upsample_blocks
[
k
].
upsamplers
[
0
].
conv
.
bias
.
data
=
block
.
upsample
.
conv
.
bias
.
data
if
hasattr
(
block
,
"block"
)
and
len
(
block
.
block
)
>
0
:
for
j
in
range
(
self
.
num_res_blocks
+
1
):
self
.
upsample_blocks
[
k
].
resnets
[
j
].
set_weight
(
block
.
block
[
j
])
if
hasattr
(
block
,
"attn"
)
and
len
(
block
.
attn
)
>
0
:
for
j
in
range
(
self
.
num_res_blocks
+
1
):
self
.
upsample_blocks
[
k
].
attentions
[
j
].
set_weight
(
block
.
attn
[
j
])
self
.
conv_norm_out
.
weight
.
data
=
self
.
norm_out
.
weight
.
data
self
.
conv_norm_out
.
bias
.
data
=
self
.
norm_out
.
bias
.
data
self
.
remove_ddpm
()
def
init_for_ddpm
(
self
,
ch_mult
,
...
...
@@ -685,3 +687,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
del
self
.
middle_block
del
self
.
output_blocks
del
self
.
out
def
remove_ddpm
(
self
):
del
self
.
temb
del
self
.
down
del
self
.
mid_new
del
self
.
up
del
self
.
norm_out
tests/test_modeling_utils.py
View file @
6d5ef87e
...
...
@@ -40,7 +40,6 @@ from diffusers import (
ScoreSdeVpPipeline
,
ScoreSdeVpScheduler
,
UNetLDMModel
,
UNetModel
,
UNetUnconditionalModel
,
VQModel
,
)
...
...
@@ -209,7 +208,7 @@ class ModelTesterMixin:
class
UnetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetModel
model_class
=
UNet
Unconditional
Model
@
property
def
dummy_input
(
self
):
...
...
@@ -234,15 +233,24 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
init_dict
=
{
"ch"
:
32
,
"ch_mult"
:
(
1
,
2
),
"block_channels"
:
(
32
,
64
),
"down_blocks"
:
(
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
),
"up_blocks"
:
(
"UNetResAttnUpBlock2D"
,
"UNetResUpBlock2D"
),
"num_head_channels"
:
None
,
"out_channels"
:
3
,
"in_channels"
:
3
,
"num_res_blocks"
:
2
,
"attn_resolutions"
:
(
16
,),
"resolution"
:
32
,
"image_size"
:
32
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
output_loading_info
=
True
)
model
,
loading_info
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
output_loading_info
=
True
,
ddpm
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
...
...
@@ -252,27 +260,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
print
(
"Original success!!!"
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
ddpm
=
True
)
model
.
eval
()
...
...
@@ -849,7 +836,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
def
test_from_pretrained_save_pretrained
(
self
):
# 1. Load models
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
model
=
UNetUnconditionalModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
,
ddpm
=
True
)
schedular
=
DDPMScheduler
(
timesteps
=
10
)
ddpm
=
DDPMPipeline
(
model
,
schedular
)
...
...
@@ -888,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_ddpm_cifar10
(
self
):
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
unet
=
UNet
Unconditional
Model
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
...
...
@@ -901,7 +890,28 @@ class PipelineTesterMixin(unittest.TestCase):
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
(
[
-
0.5712
,
-
0.6215
,
-
0.5953
,
-
0.5438
,
-
0.4775
,
-
0.4539
,
-
0.5172
,
-
0.4872
,
-
0.5105
]
[
-
0.1601
,
-
0.2823
,
-
0.6123
,
-
0.2305
,
-
0.3236
,
-
0.4706
,
-
0.1691
,
-
0.2836
,
-
0.3231
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_ddim_lsun
(
self
):
model_id
=
"fusing/ddpm-lsun-bedroom-ema"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
expected_slice
=
torch
.
tensor
(
[
-
0.9879
,
-
0.9598
,
-
0.9312
,
-
0.9953
,
-
0.9963
,
-
0.9995
,
-
0.9957
,
-
1.0000
,
-
0.9863
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
...
...
@@ -909,7 +919,7 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_ddim_cifar10
(
self
):
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
unet
=
UNet
Unconditional
Model
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
...
...
@@ -929,7 +939,7 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_pndm_cifar10
(
self
):
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
unet
=
UNet
Unconditional
Model
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment