Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
6d5ef87e
"examples/llama.android/vscode:/vscode.git/clone" did not exist on "4cc1a6143387f41e2466536abcd6a2620b63a35b"
Unverified
Commit
6d5ef87e
authored
Jul 14, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 14, 2022
Browse files
[DDPM] Make DDPM work (#88)
* up * finish * uP
parent
e7fe901e
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
188 additions
and
293 deletions
+188
-293
conversion.py
conversion.py
+33
-2
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+40
-38
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+27
-197
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+10
-1
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+0
-2
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+7
-1
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+32
-23
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+39
-29
No files found.
conversion.py
View file @
6d5ef87e
...
@@ -50,6 +50,8 @@ from diffusers.testing_utils import floats_tensor, slow, torch_device
...
@@ -50,6 +50,8 @@ from diffusers.testing_utils import floats_tensor, slow, torch_device
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
# 1. LDM
def
test_output_pretrained_ldm_dummy
():
def
test_output_pretrained_ldm_dummy
():
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
ldm
=
True
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
ldm
=
True
)
model
.
eval
()
model
.
eval
()
...
@@ -86,9 +88,38 @@ def test_output_pretrained_ldm():
...
@@ -86,9 +88,38 @@ def test_output_pretrained_ldm():
import
ipdb
;
ipdb
.
set_trace
()
import
ipdb
;
ipdb
.
set_trace
()
# To see the how the final model should look like
# To see the how the final model should look like
test_output_pretrained_ldm_dummy
()
test_output_pretrained_ldm
()
# => this is the architecture in which the model should be saved in the new format
# => this is the architecture in which the model should be saved in the new format
# -> verify new repo with the following tests (in `test_modeling_utils.py`)
# -> verify new repo with the following tests (in `test_modeling_utils.py`)
# - test_ldm_uncond (in PipelineTesterMixin)
# - test_ldm_uncond (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetLDMModelTests)
# - test_output_pretrained ( in UNetLDMModelTests)
#test_output_pretrained_ldm_dummy()
#test_output_pretrained_ldm()
# 2. DDPM
def
get_model
(
model_id
):
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
ldm
=
True
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
print
(
model
)
# Repos to convert and port to google (part of https://github.com/hojonathanho/diffusion)
# - fusing/ddpm_dummy
# - fusing/ddpm-cifar10
# - https://huggingface.co/fusing/ddpm-lsun-church-ema
# - https://huggingface.co/fusing/ddpm-lsun-bedroom-ema
# - https://huggingface.co/fusing/ddpm-celeba-hq
# tests to make sure to pass
# - test_ddim_cifar10, test_ddim_lsun, test_ddpm_cifar10, test_ddim_cifar10 (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetModelTests)
# e.g.
get_model
(
"fusing/ddpm-cifar10"
)
src/diffusers/modeling_utils.py
View file @
6d5ef87e
...
@@ -492,44 +492,46 @@ class ModelMixin(torch.nn.Module):
...
@@ -492,44 +492,46 @@ class ModelMixin(torch.nn.Module):
)
)
raise
RuntimeError
(
f
"Error(s) in loading state_dict for
{
model
.
__class__
.
__name__
}
:
\n\t
{
error_msg
}
"
)
raise
RuntimeError
(
f
"Error(s) in loading state_dict for
{
model
.
__class__
.
__name__
}
:
\n\t
{
error_msg
}
"
)
if
len
(
unexpected_keys
)
>
0
:
if
False
:
logger
.
warning
(
if
len
(
unexpected_keys
)
>
0
:
f
"Some weights of the model checkpoint at
{
pretrained_model_name_or_path
}
were not used when"
logger
.
warning
(
f
" initializing
{
model
.
__class__
.
__name__
}
:
{
unexpected_keys
}
\n
- This IS expected if you are"
f
"Some weights of the model checkpoint at
{
pretrained_model_name_or_path
}
were not used when"
f
" initializing
{
model
.
__class__
.
__name__
}
from the checkpoint of a model trained on another task or"
f
" initializing
{
model
.
__class__
.
__name__
}
:
{
unexpected_keys
}
\n
- This IS expected if you are"
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
f
" initializing
{
model
.
__class__
.
__name__
}
from the checkpoint of a model trained on another task"
" BertForPreTraining model).
\n
- This IS NOT expected if you are initializing"
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
f
"
{
model
.
__class__
.
__name__
}
from the checkpoint of a model that you expect to be exactly identical"
" BertForPreTraining model).
\n
- This IS NOT expected if you are initializing"
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
f
"
{
model
.
__class__
.
__name__
}
from the checkpoint of a model that you expect to be exactly"
)
" identical (initializing a BertForSequenceClassification model from a"
else
:
" BertForSequenceClassification model)."
logger
.
info
(
f
"All model checkpoint weights were used when initializing
{
model
.
__class__
.
__name__
}
.
\n
"
)
)
if
len
(
missing_keys
)
>
0
:
else
:
logger
.
warning
(
logger
.
info
(
f
"All model checkpoint weights were used when initializing
{
model
.
__class__
.
__name__
}
.
\n
"
)
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
if
len
(
missing_keys
)
>
0
:
f
"
{
pretrained_model_name_or_path
}
and are newly initialized:
{
missing_keys
}
\n
You should probably"
logger
.
warning
(
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
)
f
"
{
pretrained_model_name_or_path
}
and are newly initialized:
{
missing_keys
}
\n
You should probably"
elif
len
(
mismatched_keys
)
==
0
:
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
logger
.
info
(
)
f
"All the weights of
{
model
.
__class__
.
__name__
}
were initialized from the model checkpoint at"
elif
len
(
mismatched_keys
)
==
0
:
f
"
{
pretrained_model_name_or_path
}
.
\n
If your task is similar to the task the model of the checkpoint"
logger
.
info
(
f
" was trained on, you can already use
{
model
.
__class__
.
__name__
}
for predictions without further"
f
"All the weights of
{
model
.
__class__
.
__name__
}
were initialized from the model checkpoint at"
" training."
f
"
{
pretrained_model_name_or_path
}
.
\n
If your task is similar to the task the model of the"
)
f
" checkpoint was trained on, you can already use
{
model
.
__class__
.
__name__
}
for predictions"
if
len
(
mismatched_keys
)
>
0
:
" without further training."
mismatched_warning
=
"
\n
"
.
join
(
)
[
if
len
(
mismatched_keys
)
>
0
:
f
"-
{
key
}
: found shape
{
shape1
}
in the checkpoint and
{
shape2
}
in the model instantiated"
mismatched_warning
=
"
\n
"
.
join
(
for
key
,
shape1
,
shape2
in
mismatched_keys
[
]
f
"-
{
key
}
: found shape
{
shape1
}
in the checkpoint and
{
shape2
}
in the model instantiated"
)
for
key
,
shape1
,
shape2
in
mismatched_keys
logger
.
warning
(
]
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
)
f
"
{
pretrained_model_name_or_path
}
and are newly initialized because the shapes did not"
logger
.
warning
(
f
" match:
\n
{
mismatched_warning
}
\n
You should probably TRAIN this model on a down-stream task to be able"
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
" to use it for predictions and inference."
f
"
{
pretrained_model_name_or_path
}
and are newly initialized because the shapes did not"
)
f
" match:
\n
{
mismatched_warning
}
\n
You should probably TRAIN this model on a down-stream task to be"
" able to use it for predictions and inference."
)
return
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
return
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
...
...
src/diffusers/models/attention.py
View file @
6d5ef87e
...
@@ -166,188 +166,6 @@ class AttentionBlock(nn.Module):
...
@@ -166,188 +166,6 @@ class AttentionBlock(nn.Module):
return
result
return
result
class
AttentionBlockNew_2
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_head_channels
=
1
,
num_groups
=
32
,
encoder_channels
=
None
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
channels
//
num_head_channels
self
.
num_head_size
=
num_head_channels
self
.
rescale_output_factor
=
rescale_output_factor
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
nn
.
Conv1d
(
encoder_channels
,
channels
*
2
,
1
)
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
# ------------------------- new -----------------------
num_heads
=
self
.
n_heads
self
.
channels
=
channels
if
num_head_channels
is
None
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
group_norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
# define q,k,v as linear layers
self
.
query
=
nn
.
Linear
(
channels
,
channels
)
self
.
key
=
nn
.
Linear
(
channels
,
channels
)
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
))
# ------------------------- new -----------------------
def
set_weight
(
self
,
attn_layer
):
self
.
norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
self
.
qkv
.
weight
.
data
=
attn_layer
.
qkv
.
weight
.
data
self
.
qkv
.
bias
.
data
=
attn_layer
.
qkv
.
bias
.
data
self
.
proj
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
self
.
proj
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
if
hasattr
(
attn_layer
,
"q"
):
module
=
attn_layer
qkv_weight
=
torch
.
cat
([
module
.
q
.
weight
.
data
,
module
.
k
.
weight
.
data
,
module
.
v
.
weight
.
data
],
dim
=
0
)[
:,
:,
:,
0
]
qkv_bias
=
torch
.
cat
([
module
.
q
.
bias
.
data
,
module
.
k
.
bias
.
data
,
module
.
v
.
bias
.
data
],
dim
=
0
)
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
bias
.
data
=
qkv_bias
proj_out
=
zero_module
(
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
))
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
self
.
proj
=
proj_out
self
.
set_weights_2
(
attn_layer
)
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
n_heads
,
self
.
num_head_size
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
def
set_weights_2
(
self
,
attn_layer
):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
n_heads
,
3
*
self
.
channels
//
self
.
n_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
n_heads
,
3
*
self
.
channels
//
self
.
n_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
n_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
n_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
def
forward_2
(
self
,
hidden_states
):
residual
=
hidden_states
batch
,
channel
,
height
,
width
=
hidden_states
.
shape
# norm
hidden_states
=
self
.
group_norm
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
batch
,
channel
,
height
*
width
).
transpose
(
1
,
2
)
# proj to q, k, v
query_proj
=
self
.
query
(
hidden_states
)
key_proj
=
self
.
key
(
hidden_states
)
value_proj
=
self
.
value
(
hidden_states
)
# transpose
query_states
=
self
.
transpose_for_scores
(
query_proj
)
key_states
=
self
.
transpose_for_scores
(
key_proj
)
value_states
=
self
.
transpose_for_scores
(
value_proj
)
# get scores
attention_scores
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
math
.
sqrt
(
self
.
channels
//
self
.
n_heads
)
attention_probs
=
nn
.
functional
.
softmax
(
attention_scores
,
dim
=-
1
)
# compute attention output
context_states
=
torch
.
matmul
(
attention_probs
,
value_states
)
context_states
=
context_states
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_states_shape
=
context_states
.
size
()[:
-
2
]
+
(
self
.
channels
,)
context_states
=
context_states
.
view
(
new_context_states_shape
)
# compute next hidden_states
hidden_states
=
self
.
proj_attn
(
context_states
)
hidden_states
=
hidden_states
.
transpose
(
-
1
,
-
2
).
reshape
(
batch
,
channel
,
height
,
width
)
# res connect and rescale
hidden_states
=
(
hidden_states
+
residual
)
/
self
.
rescale_output_factor
return
hidden_states
def
forward
(
self
,
x
,
encoder_out
=
None
):
b
,
c
,
*
spatial
=
x
.
shape
hid_states
=
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
)
qkv
=
self
.
qkv
(
hid_states
)
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
if
encoder_out
is
not
None
:
encoder_kv
=
self
.
encoder_kv
(
encoder_out
)
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
k
=
torch
.
cat
([
ek
,
k
],
dim
=-
1
)
v
=
torch
.
cat
([
ev
,
v
],
dim
=-
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
h
=
a
.
reshape
(
bs
,
-
1
,
length
)
h
=
self
.
proj
(
h
)
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
result
=
x
+
h
result
=
result
/
self
.
rescale_output_factor
result_2
=
self
.
forward_2
(
x
)
print
((
result
-
result_2
).
abs
().
sum
())
return
result_2
class
AttentionBlockNew
(
nn
.
Module
):
class
AttentionBlockNew
(
nn
.
Module
):
"""
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
...
@@ -387,7 +205,7 @@ class AttentionBlockNew(nn.Module):
...
@@ -387,7 +205,7 @@ class AttentionBlockNew(nn.Module):
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
))
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
))
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
self
.
num_head_size
)
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
-
1
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
return
new_projection
...
@@ -434,24 +252,36 @@ class AttentionBlockNew(nn.Module):
...
@@ -434,24 +252,36 @@ class AttentionBlockNew(nn.Module):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
if
hasattr
(
attn_layer
,
"q"
):
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
,
self
.
channels
self
.
query
.
weight
.
data
=
attn_layer
.
q
.
weight
.
data
[:,
:,
0
,
0
]
)
self
.
key
.
weight
.
data
=
attn_layer
.
k
.
weight
.
data
[:,
:,
0
,
0
]
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
)
self
.
value
.
weight
.
data
=
attn_layer
.
v
.
weight
.
data
[:,
:,
0
,
0
]
self
.
query
.
bias
.
data
=
attn_layer
.
q
.
bias
.
data
self
.
key
.
bias
.
data
=
attn_layer
.
k
.
bias
.
data
self
.
value
.
bias
.
data
=
attn_layer
.
v
.
bias
.
data
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj_out
.
weight
.
data
[:,
:,
0
,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj_out
.
bias
.
data
else
:
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
class
SpatialTransformer
(
nn
.
Module
):
class
SpatialTransformer
(
nn
.
Module
):
...
...
src/diffusers/models/resnet.py
View file @
6d5ef87e
...
@@ -87,12 +87,21 @@ class Downsample2D(nn.Module):
...
@@ -87,12 +87,21 @@ class Downsample2D(nn.Module):
self
.
conv
=
conv
self
.
conv
=
conv
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
# print("use_conv", self.use_conv)
# print("padding", self.padding)
assert
x
.
shape
[
1
]
==
self
.
channels
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv
and
self
.
padding
==
0
:
if
self
.
use_conv
and
self
.
padding
==
0
:
pad
=
(
0
,
1
,
0
,
1
)
pad
=
(
0
,
1
,
0
,
1
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
return
self
.
conv
(
x
)
# print("x", x.abs().sum())
self
.
hey
=
x
assert
x
.
shape
[
1
]
==
self
.
channels
x
=
self
.
conv
(
x
)
self
.
yas
=
x
# print("x", x.abs().sum())
return
x
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
...
...
src/diffusers/models/unet.py
View file @
6d5ef87e
...
@@ -177,9 +177,7 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -177,9 +177,7 @@ class UNetModel(ModelMixin, ConfigMixin):
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
# middle
print
(
"hs"
,
hs
[
-
1
].
abs
().
sum
())
h
=
self
.
mid_new
(
hs
[
-
1
],
temb
)
h
=
self
.
mid_new
(
hs
[
-
1
],
temb
)
print
(
"h"
,
h
.
abs
().
sum
())
# upsampling
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
...
...
src/diffusers/models/unet_new.py
View file @
6d5ef87e
...
@@ -51,6 +51,7 @@ def get_down_block(
...
@@ -51,6 +51,7 @@ def get_down_block(
add_downsample
=
add_downsample
,
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_act_fn
=
resnet_act_fn
,
downsample_padding
=
downsample_padding
,
attn_num_head_channels
=
attn_num_head_channels
,
attn_num_head_channels
=
attn_num_head_channels
,
)
)
...
@@ -186,6 +187,7 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -186,6 +187,7 @@ class UNetResAttnDownBlock2D(nn.Module):
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
downsample_padding
=
1
,
add_downsample
=
True
,
add_downsample
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -224,7 +226,11 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -224,7 +226,11 @@ class UNetResAttnDownBlock2D(nn.Module):
if
add_downsample
:
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
(
self
.
downsamplers
=
nn
.
ModuleList
(
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
1
,
name
=
"op"
)]
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
name
=
"op"
)
]
)
)
else
:
else
:
self
.
downsamplers
=
None
self
.
downsamplers
=
None
...
...
src/diffusers/models/unet_unconditional.py
View file @
6d5ef87e
...
@@ -94,25 +94,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -94,25 +94,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
):
):
super
().
__init__
()
super
().
__init__
()
# DELETE if statements if not necessary anymore
# DDPM
if
ddpm
:
out_channels
=
out_ch
image_size
=
resolution
block_channels
=
[
x
*
ch
for
x
in
ch_mult
]
conv_resample
=
resamp_with_conv
flip_sin_to_cos
=
False
downscale_freq_shift
=
1
resnet_eps
=
1e-6
block_channels
=
(
32
,
64
)
down_blocks
=
(
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
)
up_blocks
=
(
"UNetResUpBlock2D"
,
"UNetResAttnUpBlock2D"
)
downsample_padding
=
0
num_head_channels
=
64
# register all __init__ params with self.register
# register all __init__ params with self.register
self
.
register_to_config
(
self
.
register_to_config
(
image_size
=
image_size
,
image_size
=
image_size
,
...
@@ -250,6 +231,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -250,6 +231,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_channels
,
out_channels
,
)
)
if
ddpm
:
if
ddpm
:
out_channels
=
out_ch
image_size
=
resolution
block_channels
=
[
x
*
ch
for
x
in
ch_mult
]
conv_resample
=
resamp_with_conv
self
.
init_for_ddpm
(
self
.
init_for_ddpm
(
ch_mult
,
ch_mult
,
ch
,
ch
,
...
@@ -290,13 +275,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -290,13 +275,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# append to tuple
# append to tuple
down_block_res_samples
+=
res_samples
down_block_res_samples
+=
res_samples
print
(
"sample"
,
sample
.
abs
().
sum
())
# 4. mid block
# 4. mid block
if
self
.
config
.
ddpm
:
if
self
.
config
.
ddpm
:
sample
=
self
.
mid_new_2
(
sample
,
emb
)
sample
=
self
.
mid_new_2
(
sample
,
emb
)
else
:
else
:
sample
=
self
.
mid
(
sample
,
emb
)
sample
=
self
.
mid
(
sample
,
emb
)
print
(
"sample"
,
sample
.
abs
().
sum
())
# 5. up blocks
# 5. up blocks
for
upsample_block
in
self
.
upsample_blocks
:
for
upsample_block
in
self
.
upsample_blocks
:
...
@@ -373,8 +356,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -373,8 +356,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
elif
self
.
config
.
ddpm
:
elif
self
.
config
.
ddpm
:
# =============== SET WEIGHTS ===============
# =============== SET WEIGHTS ===============
# =============== TIME ======================
# =============== TIME ======================
self
.
time_embed
[
0
]
=
self
.
temb
.
dense
[
0
]
self
.
time_embedding
.
linear_1
.
weight
.
data
=
self
.
temb
.
dense
[
0
].
weight
.
data
self
.
time_embed
[
2
]
=
self
.
temb
.
dense
[
1
]
self
.
time_embedding
.
linear_1
.
bias
.
data
=
self
.
temb
.
dense
[
0
].
bias
.
data
self
.
time_embedding
.
linear_2
.
weight
.
data
=
self
.
temb
.
dense
[
1
].
weight
.
data
self
.
time_embedding
.
linear_2
.
bias
.
data
=
self
.
temb
.
dense
[
1
].
bias
.
data
for
i
,
block
in
enumerate
(
self
.
down
):
for
i
,
block
in
enumerate
(
self
.
down
):
if
hasattr
(
block
,
"downsample"
):
if
hasattr
(
block
,
"downsample"
):
...
@@ -391,6 +376,23 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -391,6 +376,23 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
mid_new_2
.
resnets
[
1
].
set_weight
(
self
.
mid
.
block_2
)
self
.
mid_new_2
.
resnets
[
1
].
set_weight
(
self
.
mid
.
block_2
)
self
.
mid_new_2
.
attentions
[
0
].
set_weight
(
self
.
mid
.
attn_1
)
self
.
mid_new_2
.
attentions
[
0
].
set_weight
(
self
.
mid
.
attn_1
)
for
i
,
block
in
enumerate
(
self
.
up
):
k
=
len
(
self
.
up
)
-
1
-
i
if
hasattr
(
block
,
"upsample"
):
self
.
upsample_blocks
[
k
].
upsamplers
[
0
].
conv
.
weight
.
data
=
block
.
upsample
.
conv
.
weight
.
data
self
.
upsample_blocks
[
k
].
upsamplers
[
0
].
conv
.
bias
.
data
=
block
.
upsample
.
conv
.
bias
.
data
if
hasattr
(
block
,
"block"
)
and
len
(
block
.
block
)
>
0
:
for
j
in
range
(
self
.
num_res_blocks
+
1
):
self
.
upsample_blocks
[
k
].
resnets
[
j
].
set_weight
(
block
.
block
[
j
])
if
hasattr
(
block
,
"attn"
)
and
len
(
block
.
attn
)
>
0
:
for
j
in
range
(
self
.
num_res_blocks
+
1
):
self
.
upsample_blocks
[
k
].
attentions
[
j
].
set_weight
(
block
.
attn
[
j
])
self
.
conv_norm_out
.
weight
.
data
=
self
.
norm_out
.
weight
.
data
self
.
conv_norm_out
.
bias
.
data
=
self
.
norm_out
.
bias
.
data
self
.
remove_ddpm
()
def
init_for_ddpm
(
def
init_for_ddpm
(
self
,
self
,
ch_mult
,
ch_mult
,
...
@@ -685,3 +687,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -685,3 +687,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
del
self
.
middle_block
del
self
.
middle_block
del
self
.
output_blocks
del
self
.
output_blocks
del
self
.
out
del
self
.
out
def
remove_ddpm
(
self
):
del
self
.
temb
del
self
.
down
del
self
.
mid_new
del
self
.
up
del
self
.
norm_out
tests/test_modeling_utils.py
View file @
6d5ef87e
...
@@ -40,7 +40,6 @@ from diffusers import (
...
@@ -40,7 +40,6 @@ from diffusers import (
ScoreSdeVpPipeline
,
ScoreSdeVpPipeline
,
ScoreSdeVpScheduler
,
ScoreSdeVpScheduler
,
UNetLDMModel
,
UNetLDMModel
,
UNetModel
,
UNetUnconditionalModel
,
UNetUnconditionalModel
,
VQModel
,
VQModel
,
)
)
...
@@ -209,7 +208,7 @@ class ModelTesterMixin:
...
@@ -209,7 +208,7 @@ class ModelTesterMixin:
class
UnetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
UnetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetModel
model_class
=
UNet
Unconditional
Model
@
property
@
property
def
dummy_input
(
self
):
def
dummy_input
(
self
):
...
@@ -234,15 +233,24 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -234,15 +233,24 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
init_dict
=
{
init_dict
=
{
"ch"
:
32
,
"ch"
:
32
,
"ch_mult"
:
(
1
,
2
),
"ch_mult"
:
(
1
,
2
),
"block_channels"
:
(
32
,
64
),
"down_blocks"
:
(
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
),
"up_blocks"
:
(
"UNetResAttnUpBlock2D"
,
"UNetResUpBlock2D"
),
"num_head_channels"
:
None
,
"out_channels"
:
3
,
"in_channels"
:
3
,
"num_res_blocks"
:
2
,
"num_res_blocks"
:
2
,
"attn_resolutions"
:
(
16
,),
"attn_resolutions"
:
(
16
,),
"resolution"
:
32
,
"resolution"
:
32
,
"image_size"
:
32
,
}
}
inputs_dict
=
self
.
dummy_input
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
output_loading_info
=
True
)
model
,
loading_info
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
output_loading_info
=
True
,
ddpm
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
...
@@ -252,27 +260,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -252,27 +260,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
def
test_output_pretrained
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
print
(
"Original success!!!"
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
ddpm
=
True
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
ddpm
=
True
)
model
.
eval
()
model
.
eval
()
...
@@ -849,7 +836,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
...
@@ -849,7 +836,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
# 1. Load models
# 1. Load models
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
model
=
UNetUnconditionalModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
,
ddpm
=
True
)
schedular
=
DDPMScheduler
(
timesteps
=
10
)
schedular
=
DDPMScheduler
(
timesteps
=
10
)
ddpm
=
DDPMPipeline
(
model
,
schedular
)
ddpm
=
DDPMPipeline
(
model
,
schedular
)
...
@@ -888,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -888,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_ddpm_cifar10
(
self
):
def
test_ddpm_cifar10
(
self
):
model_id
=
"fusing/ddpm-cifar10"
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
unet
=
UNet
Unconditional
Model
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
...
@@ -901,7 +890,28 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -901,7 +890,28 @@ class PipelineTesterMixin(unittest.TestCase):
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
(
expected_slice
=
torch
.
tensor
(
[
-
0.5712
,
-
0.6215
,
-
0.5953
,
-
0.5438
,
-
0.4775
,
-
0.4539
,
-
0.5172
,
-
0.4872
,
-
0.5105
]
[
-
0.1601
,
-
0.2823
,
-
0.6123
,
-
0.2305
,
-
0.3236
,
-
0.4706
,
-
0.1691
,
-
0.2836
,
-
0.3231
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_ddim_lsun
(
self
):
model_id
=
"fusing/ddpm-lsun-bedroom-ema"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
expected_slice
=
torch
.
tensor
(
[
-
0.9879
,
-
0.9598
,
-
0.9312
,
-
0.9953
,
-
0.9963
,
-
0.9995
,
-
0.9957
,
-
1.0000
,
-
0.9863
]
)
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
...
@@ -909,7 +919,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -909,7 +919,7 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_ddim_cifar10
(
self
):
def
test_ddim_cifar10
(
self
):
model_id
=
"fusing/ddpm-cifar10"
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
unet
=
UNet
Unconditional
Model
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
noise_scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
...
@@ -929,7 +939,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -929,7 +939,7 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_pndm_cifar10
(
self
):
def
test_pndm_cifar10
(
self
):
model_id
=
"fusing/ddpm-cifar10"
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
unet
=
UNet
Unconditional
Model
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
noise_scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment