Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e7fe901e
You need to sign in or sign up before continuing.
Unverified
Commit
e7fe901e
authored
Jul 14, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 14, 2022
Browse files
save intermediate (#87)
* save intermediate * up * up
parent
c3d78cd3
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
801 additions
and
173 deletions
+801
-173
conversion.py
conversion.py
+94
-0
docs/source/examples/diffusers_for_vision.mdx
docs/source/examples/diffusers_for_vision.mdx
+2
-1
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+2
-0
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+3
-0
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+223
-5
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+22
-17
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+2
-0
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+50
-27
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+376
-119
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+27
-4
No files found.
conversion.py
0 → 100755
View file @
e7fe901e
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
inspect
import
tempfile
import
unittest
import
numpy
as
np
import
torch
from
diffusers
import
(
AutoencoderKL
,
DDIMPipeline
,
DDIMScheduler
,
DDPMPipeline
,
DDPMScheduler
,
GlidePipeline
,
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
LatentDiffusionPipeline
,
LatentDiffusionUncondPipeline
,
NCSNpp
,
PNDMPipeline
,
PNDMScheduler
,
ScoreSdeVePipeline
,
ScoreSdeVeScheduler
,
ScoreSdeVpPipeline
,
ScoreSdeVpScheduler
,
UNetLDMModel
,
UNetModel
,
UNetUnconditionalModel
,
VQModel
,
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.training_utils
import
EMAModel
def
test_output_pretrained_ldm_dummy
():
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
ldm
=
True
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
print
(
model
)
import
ipdb
;
ipdb
.
set_trace
()
def
test_output_pretrained_ldm
():
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/latent-diffusion-celeba-256"
,
subfolder
=
"unet"
,
ldm
=
True
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
print
(
model
)
import
ipdb
;
ipdb
.
set_trace
()
# To see the how the final model should look like
test_output_pretrained_ldm_dummy
()
test_output_pretrained_ldm
()
# => this is the architecture in which the model should be saved in the new format
# -> verify new repo with the following tests (in `test_modeling_utils.py`)
# - test_ldm_uncond (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetLDMModelTests)
docs/source/examples/diffusers_for_vision.mdx
View file @
e7fe901e
...
@@ -111,7 +111,7 @@ prompt = "A painting of a squirrel eating a burger"
...
@@ -111,7 +111,7 @@ prompt = "A painting of a squirrel eating a burger"
image
=
ldm
([
prompt
],
generator
=
generator
,
eta
=
0.3
,
guidance_scale
=
6.0
,
num_inference_steps
=
50
)
image
=
ldm
([
prompt
],
generator
=
generator
,
eta
=
0.3
,
guidance_scale
=
6.0
,
num_inference_steps
=
50
)
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
image_processed
*
255.
image_processed
=
image_processed
*
255.
0
image_processed
=
image_processed
.
numpy
().
astype
(
np
.
uint8
)
image_processed
=
image_processed
.
numpy
().
astype
(
np
.
uint8
)
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
...
@@ -143,6 +143,7 @@ audio = bddm(mel_spec, generator, torch_device=torch_device)
...
@@ -143,6 +143,7 @@ audio = bddm(mel_spec, generator, torch_device=torch_device)
#
save
generated
audio
#
save
generated
audio
from
scipy
.
io
.
wavfile
import
write
as
wavwrite
from
scipy
.
io
.
wavfile
import
write
as
wavwrite
sampling_rate
=
22050
sampling_rate
=
22050
wavwrite
(
"generated_audio.wav"
,
sampling_rate
,
audio
.
squeeze
().
cpu
().
numpy
())
wavwrite
(
"generated_audio.wav"
,
sampling_rate
,
audio
.
squeeze
().
cpu
().
numpy
())
```
```
...
...
src/diffusers/configuration_utils.py
View file @
e7fe901e
...
@@ -116,6 +116,7 @@ class ConfigMixin:
...
@@ -116,6 +116,7 @@ class ConfigMixin:
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
subfolder
=
kwargs
.
pop
(
"subfolder"
,
None
)
user_agent
=
{
"file_type"
:
"config"
}
user_agent
=
{
"file_type"
:
"config"
}
...
@@ -150,6 +151,7 @@ class ConfigMixin:
...
@@ -150,6 +151,7 @@ class ConfigMixin:
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
user_agent
=
user_agent
,
subfolder
=
subfolder
,
)
)
except
RepositoryNotFoundError
:
except
RepositoryNotFoundError
:
...
...
src/diffusers/modeling_utils.py
View file @
e7fe901e
...
@@ -321,6 +321,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -321,6 +321,7 @@ class ModelMixin(torch.nn.Module):
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
from_auto_class
=
kwargs
.
pop
(
"_from_auto"
,
False
)
from_auto_class
=
kwargs
.
pop
(
"_from_auto"
,
False
)
subfolder
=
kwargs
.
pop
(
"subfolder"
,
None
)
user_agent
=
{
"file_type"
:
"model"
,
"framework"
:
"pytorch"
,
"from_auto_class"
:
from_auto_class
}
user_agent
=
{
"file_type"
:
"model"
,
"framework"
:
"pytorch"
,
"from_auto_class"
:
from_auto_class
}
...
@@ -336,6 +337,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -336,6 +337,7 @@ class ModelMixin(torch.nn.Module):
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
revision
=
revision
,
subfolder
=
subfolder
,
**
kwargs
,
**
kwargs
,
)
)
model
.
register_to_config
(
name_or_path
=
pretrained_model_name_or_path
)
model
.
register_to_config
(
name_or_path
=
pretrained_model_name_or_path
)
...
@@ -363,6 +365,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -363,6 +365,7 @@ class ModelMixin(torch.nn.Module):
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
user_agent
=
user_agent
,
subfolder
=
subfolder
,
)
)
except
RepositoryNotFoundError
:
except
RepositoryNotFoundError
:
...
...
src/diffusers/models/attention.py
View file @
e7fe901e
...
@@ -51,6 +51,7 @@ class AttentionBlock(nn.Module):
...
@@ -51,6 +51,7 @@ class AttentionBlock(nn.Module):
overwrite_qkv
=
False
,
overwrite_qkv
=
False
,
overwrite_linear
=
False
,
overwrite_linear
=
False
,
rescale_output_factor
=
1.0
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -62,7 +63,7 @@ class AttentionBlock(nn.Module):
...
@@ -62,7 +63,7 @@ class AttentionBlock(nn.Module):
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
num_heads
=
channels
//
num_head_channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-5
,
affine
=
True
)
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
self
.
num_heads
self
.
n_heads
=
self
.
num_heads
self
.
rescale_output_factor
=
rescale_output_factor
self
.
rescale_output_factor
=
rescale_output_factor
...
@@ -165,7 +166,7 @@ class AttentionBlock(nn.Module):
...
@@ -165,7 +166,7 @@ class AttentionBlock(nn.Module):
return
result
return
result
class
AttentionBlockNew
(
nn
.
Module
):
class
AttentionBlockNew
_2
(
nn
.
Module
):
"""
"""
An attention block that allows spatial positions to attend to each other.
An attention block that allows spatial positions to attend to each other.
...
@@ -180,11 +181,14 @@ class AttentionBlockNew(nn.Module):
...
@@ -180,11 +181,14 @@ class AttentionBlockNew(nn.Module):
num_groups
=
32
,
num_groups
=
32
,
encoder_channels
=
None
,
encoder_channels
=
None
,
rescale_output_factor
=
1.0
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-5
,
affine
=
True
)
self
.
channels
=
channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
channels
//
num_head_channels
self
.
n_heads
=
channels
//
num_head_channels
self
.
num_head_size
=
num_head_channels
self
.
rescale_output_factor
=
rescale_output_factor
self
.
rescale_output_factor
=
rescale_output_factor
if
encoder_channels
is
not
None
:
if
encoder_channels
is
not
None
:
...
@@ -192,6 +196,28 @@ class AttentionBlockNew(nn.Module):
...
@@ -192,6 +196,28 @@ class AttentionBlockNew(nn.Module):
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
# ------------------------- new -----------------------
num_heads
=
self
.
n_heads
self
.
channels
=
channels
if
num_head_channels
is
None
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
group_norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
# define q,k,v as linear layers
self
.
query
=
nn
.
Linear
(
channels
,
channels
)
self
.
key
=
nn
.
Linear
(
channels
,
channels
)
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
))
# ------------------------- new -----------------------
def
set_weight
(
self
,
attn_layer
):
def
set_weight
(
self
,
attn_layer
):
self
.
norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
self
.
norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
...
@@ -202,6 +228,89 @@ class AttentionBlockNew(nn.Module):
...
@@ -202,6 +228,89 @@ class AttentionBlockNew(nn.Module):
self
.
proj
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
self
.
proj
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
self
.
proj
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
self
.
proj
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
if
hasattr
(
attn_layer
,
"q"
):
module
=
attn_layer
qkv_weight
=
torch
.
cat
([
module
.
q
.
weight
.
data
,
module
.
k
.
weight
.
data
,
module
.
v
.
weight
.
data
],
dim
=
0
)[
:,
:,
:,
0
]
qkv_bias
=
torch
.
cat
([
module
.
q
.
bias
.
data
,
module
.
k
.
bias
.
data
,
module
.
v
.
bias
.
data
],
dim
=
0
)
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
bias
.
data
=
qkv_bias
proj_out
=
zero_module
(
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
))
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
self
.
proj
=
proj_out
self
.
set_weights_2
(
attn_layer
)
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
n_heads
,
self
.
num_head_size
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
def
set_weights_2
(
self
,
attn_layer
):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
n_heads
,
3
*
self
.
channels
//
self
.
n_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
n_heads
,
3
*
self
.
channels
//
self
.
n_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
n_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
n_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
def
forward_2
(
self
,
hidden_states
):
residual
=
hidden_states
batch
,
channel
,
height
,
width
=
hidden_states
.
shape
# norm
hidden_states
=
self
.
group_norm
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
batch
,
channel
,
height
*
width
).
transpose
(
1
,
2
)
# proj to q, k, v
query_proj
=
self
.
query
(
hidden_states
)
key_proj
=
self
.
key
(
hidden_states
)
value_proj
=
self
.
value
(
hidden_states
)
# transpose
query_states
=
self
.
transpose_for_scores
(
query_proj
)
key_states
=
self
.
transpose_for_scores
(
key_proj
)
value_states
=
self
.
transpose_for_scores
(
value_proj
)
# get scores
attention_scores
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
math
.
sqrt
(
self
.
channels
//
self
.
n_heads
)
attention_probs
=
nn
.
functional
.
softmax
(
attention_scores
,
dim
=-
1
)
# compute attention output
context_states
=
torch
.
matmul
(
attention_probs
,
value_states
)
context_states
=
context_states
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_states_shape
=
context_states
.
size
()[:
-
2
]
+
(
self
.
channels
,)
context_states
=
context_states
.
view
(
new_context_states_shape
)
# compute next hidden_states
hidden_states
=
self
.
proj_attn
(
context_states
)
hidden_states
=
hidden_states
.
transpose
(
-
1
,
-
2
).
reshape
(
batch
,
channel
,
height
,
width
)
# res connect and rescale
hidden_states
=
(
hidden_states
+
residual
)
/
self
.
rescale_output_factor
return
hidden_states
def
forward
(
self
,
x
,
encoder_out
=
None
):
def
forward
(
self
,
x
,
encoder_out
=
None
):
b
,
c
,
*
spatial
=
x
.
shape
b
,
c
,
*
spatial
=
x
.
shape
hid_states
=
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
)
hid_states
=
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
)
...
@@ -230,10 +339,119 @@ class AttentionBlockNew(nn.Module):
...
@@ -230,10 +339,119 @@ class AttentionBlockNew(nn.Module):
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
result
=
x
+
h
result
=
x
+
h
result
=
result
/
self
.
rescale_output_factor
result
=
result
/
self
.
rescale_output_factor
return
result
result_2
=
self
.
forward_2
(
x
)
print
((
result
-
result_2
).
abs
().
sum
())
return
result_2
class
AttentionBlockNew
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=
None
,
num_groups
=
32
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
is
None
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
num_head_size
=
num_head_channels
self
.
group_norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
# define q,k,v as linear layers
self
.
query
=
nn
.
Linear
(
channels
,
channels
)
self
.
key
=
nn
.
Linear
(
channels
,
channels
)
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
))
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
self
.
num_head_size
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
def
forward
(
self
,
hidden_states
):
residual
=
hidden_states
batch
,
channel
,
height
,
width
=
hidden_states
.
shape
# norm
hidden_states
=
self
.
group_norm
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
batch
,
channel
,
height
*
width
).
transpose
(
1
,
2
)
# proj to q, k, v
query_proj
=
self
.
query
(
hidden_states
)
key_proj
=
self
.
key
(
hidden_states
)
value_proj
=
self
.
value
(
hidden_states
)
# transpose
query_states
=
self
.
transpose_for_scores
(
query_proj
)
key_states
=
self
.
transpose_for_scores
(
key_proj
)
value_states
=
self
.
transpose_for_scores
(
value_proj
)
# get scores
attention_scores
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
math
.
sqrt
(
self
.
channels
//
self
.
num_heads
)
attention_probs
=
nn
.
functional
.
softmax
(
attention_scores
,
dim
=-
1
)
# compute attention output
context_states
=
torch
.
matmul
(
attention_probs
,
value_states
)
context_states
=
context_states
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_states_shape
=
context_states
.
size
()[:
-
2
]
+
(
self
.
channels
,)
context_states
=
context_states
.
view
(
new_context_states_shape
)
# compute next hidden_states
hidden_states
=
self
.
proj_attn
(
context_states
)
hidden_states
=
hidden_states
.
transpose
(
-
1
,
-
2
).
reshape
(
batch
,
channel
,
height
,
width
)
# res connect and rescale
hidden_states
=
(
hidden_states
+
residual
)
/
self
.
rescale_output_factor
return
hidden_states
def
set_weight
(
self
,
attn_layer
):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
class
SpatialTransformer
(
nn
.
Module
):
class
SpatialTransformer
(
nn
.
Module
):
...
...
src/diffusers/models/resnet.py
View file @
e7fe901e
...
@@ -81,8 +81,10 @@ class Downsample2D(nn.Module):
...
@@ -81,8 +81,10 @@ class Downsample2D(nn.Module):
self
.
conv
=
conv
self
.
conv
=
conv
elif
name
==
"Conv2d_0"
:
elif
name
==
"Conv2d_0"
:
self
.
Conv2d_0
=
conv
self
.
Conv2d_0
=
conv
self
.
conv
=
conv
else
:
else
:
self
.
op
=
conv
self
.
op
=
conv
self
.
conv
=
conv
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
x
.
shape
[
1
]
==
self
.
channels
...
@@ -90,13 +92,16 @@ class Downsample2D(nn.Module):
...
@@ -90,13 +92,16 @@ class Downsample2D(nn.Module):
pad
=
(
0
,
1
,
0
,
1
)
pad
=
(
0
,
1
,
0
,
1
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
return
self
.
conv
(
x
)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if
self
.
name
==
"conv"
:
return
self
.
conv
(
x
)
elif
self
.
name
==
"Conv2d_0"
:
# if self.name == "conv":
return
self
.
Conv2d_0
(
x
)
# return self.conv(x)
else
:
# elif self.name == "Conv2d_0":
return
self
.
op
(
x
)
# return self.Conv2d_0(x)
# else:
# return self.op(x)
class
Upsample1D
(
nn
.
Module
):
class
Upsample1D
(
nn
.
Module
):
...
@@ -656,9 +661,9 @@ class ResnetBlock(nn.Module):
...
@@ -656,9 +661,9 @@ class ResnetBlock(nn.Module):
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
time_embedding_norm
==
"default"
and
temb_channels
>
0
:
if
time_embedding_norm
==
"default"
and
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
t
ime_
emb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
elif
time_embedding_norm
==
"scale_shift"
and
temb_channels
>
0
:
elif
time_embedding_norm
==
"scale_shift"
and
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
t
ime_
emb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
norm2
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
...
@@ -691,9 +696,9 @@ class ResnetBlock(nn.Module):
...
@@ -691,9 +696,9 @@ class ResnetBlock(nn.Module):
self
.
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
use_nin_shortcut
is
None
else
use_nin_shortcut
self
.
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
use_nin_shortcut
is
None
else
use_nin_shortcut
self
.
nin
_shortcut
=
None
self
.
conv
_shortcut
=
None
if
self
.
use_nin_shortcut
:
if
self
.
use_nin_shortcut
:
self
.
nin
_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
conv
_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
x
...
@@ -715,7 +720,7 @@ class ResnetBlock(nn.Module):
...
@@ -715,7 +720,7 @@ class ResnetBlock(nn.Module):
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
temb
is
not
None
:
if
temb
is
not
None
:
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
temb
=
self
.
t
ime_
emb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
else
:
else
:
temb
=
0
temb
=
0
...
@@ -738,8 +743,8 @@ class ResnetBlock(nn.Module):
...
@@ -738,8 +743,8 @@ class ResnetBlock(nn.Module):
h
=
self
.
norm2
(
h
)
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
nin
_shortcut
is
not
None
:
if
self
.
conv
_shortcut
is
not
None
:
x
=
self
.
nin
_shortcut
(
x
)
x
=
self
.
conv
_shortcut
(
x
)
return
(
x
+
h
)
/
self
.
output_scale_factor
return
(
x
+
h
)
/
self
.
output_scale_factor
...
@@ -750,8 +755,8 @@ class ResnetBlock(nn.Module):
...
@@ -750,8 +755,8 @@ class ResnetBlock(nn.Module):
self
.
conv1
.
weight
.
data
=
resnet
.
conv1
.
weight
.
data
self
.
conv1
.
weight
.
data
=
resnet
.
conv1
.
weight
.
data
self
.
conv1
.
bias
.
data
=
resnet
.
conv1
.
bias
.
data
self
.
conv1
.
bias
.
data
=
resnet
.
conv1
.
bias
.
data
self
.
temb_proj
.
weight
.
data
=
resnet
.
temb_proj
.
weight
.
data
self
.
t
ime_
emb_proj
.
weight
.
data
=
resnet
.
temb_proj
.
weight
.
data
self
.
temb_proj
.
bias
.
data
=
resnet
.
temb_proj
.
bias
.
data
self
.
t
ime_
emb_proj
.
bias
.
data
=
resnet
.
temb_proj
.
bias
.
data
self
.
norm2
.
weight
.
data
=
resnet
.
norm2
.
weight
.
data
self
.
norm2
.
weight
.
data
=
resnet
.
norm2
.
weight
.
data
self
.
norm2
.
bias
.
data
=
resnet
.
norm2
.
bias
.
data
self
.
norm2
.
bias
.
data
=
resnet
.
norm2
.
bias
.
data
...
@@ -760,8 +765,8 @@ class ResnetBlock(nn.Module):
...
@@ -760,8 +765,8 @@ class ResnetBlock(nn.Module):
self
.
conv2
.
bias
.
data
=
resnet
.
conv2
.
bias
.
data
self
.
conv2
.
bias
.
data
=
resnet
.
conv2
.
bias
.
data
if
self
.
use_nin_shortcut
:
if
self
.
use_nin_shortcut
:
self
.
nin
_shortcut
.
weight
.
data
=
resnet
.
nin_shortcut
.
weight
.
data
self
.
conv
_shortcut
.
weight
.
data
=
resnet
.
nin_shortcut
.
weight
.
data
self
.
nin
_shortcut
.
bias
.
data
=
resnet
.
nin_shortcut
.
bias
.
data
self
.
conv
_shortcut
.
bias
.
data
=
resnet
.
nin_shortcut
.
bias
.
data
# TODO(Patrick) - just there to convert the weights; can delete afterward
# TODO(Patrick) - just there to convert the weights; can delete afterward
...
...
src/diffusers/models/unet.py
View file @
e7fe901e
...
@@ -177,7 +177,9 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -177,7 +177,9 @@ class UNetModel(ModelMixin, ConfigMixin):
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
# middle
print
(
"hs"
,
hs
[
-
1
].
abs
().
sum
())
h
=
self
.
mid_new
(
hs
[
-
1
],
temb
)
h
=
self
.
mid_new
(
hs
[
-
1
],
temb
)
print
(
"h"
,
h
.
abs
().
sum
())
# upsampling
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
...
...
src/diffusers/models/unet_new.py
View file @
e7fe901e
...
@@ -29,9 +29,10 @@ def get_down_block(
...
@@ -29,9 +29,10 @@ def get_down_block(
resnet_eps
,
resnet_eps
,
resnet_act_fn
,
resnet_act_fn
,
attn_num_head_channels
,
attn_num_head_channels
,
downsample_padding
=
None
,
):
):
if
down_block_type
==
"UNetResDownBlock2D"
:
if
down_block_type
==
"UNetResDownBlock2D"
:
return
UNetRes
Attn
DownBlock2D
(
return
UNetResDownBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
...
@@ -39,6 +40,7 @@ def get_down_block(
...
@@ -39,6 +40,7 @@ def get_down_block(
add_downsample
=
add_downsample
,
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_act_fn
=
resnet_act_fn
,
downsample_padding
=
downsample_padding
,
)
)
elif
down_block_type
==
"UNetResAttnDownBlock2D"
:
elif
down_block_type
==
"UNetResAttnDownBlock2D"
:
return
UNetResAttnDownBlock2D
(
return
UNetResAttnDownBlock2D
(
...
@@ -57,7 +59,8 @@ def get_up_block(
...
@@ -57,7 +59,8 @@ def get_up_block(
up_block_type
,
up_block_type
,
num_layers
,
num_layers
,
in_channels
,
in_channels
,
next_channels
,
out_channels
,
prev_output_channel
,
temb_channels
,
temb_channels
,
add_upsample
,
add_upsample
,
resnet_eps
,
resnet_eps
,
...
@@ -68,7 +71,8 @@ def get_up_block(
...
@@ -68,7 +71,8 @@ def get_up_block(
return
UNetResUpBlock2D
(
return
UNetResUpBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
next_channels
=
next_channels
,
out_channels
=
out_channels
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
...
@@ -78,7 +82,8 @@ def get_up_block(
...
@@ -78,7 +82,8 @@ def get_up_block(
return
UNetResAttnUpBlock2D
(
return
UNetResAttnUpBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
next_channels
=
next_channels
,
out_channels
=
out_channels
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
...
@@ -100,11 +105,14 @@ class UNetMidBlock2D(nn.Module):
...
@@ -100,11 +105,14 @@ class UNetMidBlock2D(nn.Module):
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
**
kwargs
,
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
attention_type
=
attention_type
# there is always at least one resnet
# there is always at least one resnet
resnets
=
[
resnets
=
[
ResnetBlock
(
ResnetBlock
(
...
@@ -128,6 +136,7 @@ class UNetMidBlock2D(nn.Module):
...
@@ -128,6 +136,7 @@ class UNetMidBlock2D(nn.Module):
in_channels
,
in_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
eps
=
resnet_eps
,
)
)
)
)
resnets
.
append
(
resnets
.
append
(
...
@@ -148,18 +157,15 @@ class UNetMidBlock2D(nn.Module):
...
@@ -148,18 +157,15 @@ class UNetMidBlock2D(nn.Module):
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
,
mask
=
None
):
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
):
if
mask
is
not
None
:
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
,
mask
=
mask
)
else
:
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
hidden_states
=
attn
(
hidden_states
,
encoder_states
)
if
self
.
attention_type
==
"default"
:
if
mask
is
not
None
:
hidden_states
=
attn
(
hidden_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
,
mask
=
mask
)
else
:
else
:
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
attn
(
hidden_states
,
encoder_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
return
hidden_states
return
hidden_states
...
@@ -178,6 +184,7 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -178,6 +184,7 @@ class UNetResAttnDownBlock2D(nn.Module):
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
add_downsample
=
True
,
):
):
...
@@ -185,6 +192,8 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -185,6 +192,8 @@ class UNetResAttnDownBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
...
@@ -206,6 +215,7 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -206,6 +215,7 @@ class UNetResAttnDownBlock2D(nn.Module):
out_channels
,
out_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
eps
=
resnet_eps
,
)
)
)
)
...
@@ -251,6 +261,7 @@ class UNetResDownBlock2D(nn.Module):
...
@@ -251,6 +261,7 @@ class UNetResDownBlock2D(nn.Module):
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
add_downsample
=
True
,
downsample_padding
=
1
,
):
):
super
().
__init__
()
super
().
__init__
()
resnets
=
[]
resnets
=
[]
...
@@ -276,7 +287,11 @@ class UNetResDownBlock2D(nn.Module):
...
@@ -276,7 +287,11 @@ class UNetResDownBlock2D(nn.Module):
if
add_downsample
:
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
(
self
.
downsamplers
=
nn
.
ModuleList
(
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
1
,
name
=
"op"
)]
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
name
=
"op"
)
]
)
)
else
:
else
:
self
.
downsamplers
=
None
self
.
downsamplers
=
None
...
@@ -301,7 +316,8 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -301,7 +316,8 @@ class UNetResAttnUpBlock2D(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
:
int
,
in_channels
:
int
,
next_channels
:
int
,
prev_output_channel
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
num_layers
:
int
=
1
,
...
@@ -310,7 +326,7 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -310,7 +326,7 @@ class UNetResAttnUpBlock2D(nn.Module):
resnet_act_fn
:
str
=
"swish"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attention_
layer_type
:
str
=
"self
"
,
attention_
type
=
"default
"
,
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
add_upsample
=
True
,
...
@@ -319,12 +335,16 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -319,12 +335,16 @@ class UNetResAttnUpBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
resnet_channels
=
in_channels
if
i
<
num_layers
-
1
else
next_channels
res_skip_channels
=
in_channels
if
(
i
==
num_layers
-
1
)
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
(
in_channels
=
in_channels
+
res
net
_channels
,
in_channels
=
resnet_
in_channels
+
res
_skip
_channels
,
out_channels
=
in
_channels
,
out_channels
=
out
_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
groups
=
resnet_groups
,
...
@@ -337,9 +357,10 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -337,9 +357,10 @@ class UNetResAttnUpBlock2D(nn.Module):
)
)
attentions
.
append
(
attentions
.
append
(
AttentionBlockNew
(
AttentionBlockNew
(
in
_channels
,
out
_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
eps
=
resnet_eps
,
)
)
)
)
...
@@ -347,7 +368,7 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -347,7 +368,7 @@ class UNetResAttnUpBlock2D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_upsample
:
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
in
_channels
,
use_conv
=
True
,
out_channels
=
in
_channels
)])
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
out
_channels
,
use_conv
=
True
,
out_channels
=
out
_channels
)])
else
:
else
:
self
.
upsamplers
=
None
self
.
upsamplers
=
None
...
@@ -373,7 +394,8 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -373,7 +394,8 @@ class UNetResUpBlock2D(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
:
int
,
in_channels
:
int
,
next_channels
:
int
,
prev_output_channel
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
num_layers
:
int
=
1
,
...
@@ -382,7 +404,6 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -382,7 +404,6 @@ class UNetResUpBlock2D(nn.Module):
resnet_act_fn
:
str
=
"swish"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attention_layer_type
:
str
=
"self"
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
add_upsample
=
True
,
):
):
...
@@ -390,11 +411,13 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -390,11 +411,13 @@ class UNetResUpBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
resnet_channels
=
in_channels
if
i
<
num_layers
-
1
else
next_channels
res_skip_channels
=
in_channels
if
(
i
==
num_layers
-
1
)
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
(
in_channels
=
in_channels
+
res
net
_channels
,
in_channels
=
resnet_
in_channels
+
res
_skip
_channels
,
out_channels
=
in
_channels
,
out_channels
=
out
_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
groups
=
resnet_groups
,
...
@@ -409,7 +432,7 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -409,7 +432,7 @@ class UNetResUpBlock2D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_upsample
:
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
in
_channels
,
use_conv
=
True
,
out_channels
=
in
_channels
)])
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
out
_channels
,
use_conv
=
True
,
out_channels
=
out
_channels
)])
else
:
else
:
self
.
upsamplers
=
None
self
.
upsamplers
=
None
...
...
src/diffusers/models/unet_unconditional.py
View file @
e7fe901e
...
@@ -9,6 +9,30 @@ from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
...
@@ -9,6 +9,30 @@ from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from
.unet_new
import
UNetMidBlock2D
,
get_down_block
,
get_up_block
from
.unet_new
import
UNetMidBlock2D
,
get_down_block
,
get_up_block
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
TimestepEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
channel
,
time_embed_dim
):
super
().
__init__
()
self
.
linear_1
=
nn
.
Linear
(
channel
,
time_embed_dim
)
self
.
act
=
nn
.
SiLU
()
self
.
linear_2
=
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
)
def
forward
(
self
,
sample
):
sample
=
self
.
linear_1
(
sample
)
sample
=
self
.
act
(
sample
)
sample
=
self
.
linear_2
(
sample
)
return
sample
class
UNetUnconditionalModel
(
ModelMixin
,
ConfigMixin
):
class
UNetUnconditionalModel
(
ModelMixin
,
ConfigMixin
):
"""
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
...
@@ -35,35 +59,66 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -35,35 +59,66 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
def
__init__
(
def
__init__
(
self
,
self
,
image_size
,
image_size
=
None
,
in_channels
,
in_channels
=
None
,
out_channels
,
out_channels
=
None
,
num_res_blocks
,
num_res_blocks
=
None
,
dropout
=
0
,
dropout
=
0
,
block_input_channels
=
(
224
,
224
,
448
,
672
),
block_channels
=
(
224
,
448
,
672
,
896
),
block_output_channels
=
(
224
,
448
,
672
,
896
),
down_blocks
=
(
down_blocks
=
(
"UNetResDownBlock2D"
,
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
),
),
downsample_padding
=
1
,
up_blocks
=
(
"UNetResAttnUpBlock2D"
,
"UNetResAttnUpBlock2D"
,
"UNetResAttnUpBlock2D"
,
"UNetResUpBlock2D"
),
up_blocks
=
(
"UNetResAttnUpBlock2D"
,
"UNetResAttnUpBlock2D"
,
"UNetResAttnUpBlock2D"
,
"UNetResUpBlock2D"
),
resnet_act_fn
=
"silu"
,
resnet_act_fn
=
"silu"
,
resnet_eps
=
1e-5
,
resnet_eps
=
1e-5
,
conv_resample
=
True
,
conv_resample
=
True
,
num_head_channels
=
32
,
num_head_channels
=
32
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
,
# To delete once weights are converted
# To delete once weights are converted
# LDM
attention_resolutions
=
(
8
,
4
,
2
),
attention_resolutions
=
(
8
,
4
,
2
),
ldm
=
False
,
# DDPM
out_ch
=
None
,
resolution
=
None
,
attn_resolutions
=
None
,
resamp_with_conv
=
None
,
ch_mult
=
None
,
ch
=
None
,
ddpm
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
# DELETE if statements if not necessary anymore
# DDPM
if
ddpm
:
out_channels
=
out_ch
image_size
=
resolution
block_channels
=
[
x
*
ch
for
x
in
ch_mult
]
conv_resample
=
resamp_with_conv
flip_sin_to_cos
=
False
downscale_freq_shift
=
1
resnet_eps
=
1e-6
block_channels
=
(
32
,
64
)
down_blocks
=
(
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
)
up_blocks
=
(
"UNetResUpBlock2D"
,
"UNetResAttnUpBlock2D"
)
downsample_padding
=
0
num_head_channels
=
64
# register all __init__ params with self.register
# register all __init__ params with self.register
self
.
register_to_config
(
self
.
register_to_config
(
image_size
=
image_size
,
image_size
=
image_size
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
block_
input_
channels
=
block_
input_
channels
,
block_channels
=
block_channels
,
block_output_channels
=
block_output_channels
,
downsample_padding
=
downsample_padding
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
num_res_blocks
=
num_res_blocks
,
down_blocks
=
down_blocks
,
down_blocks
=
down_blocks
,
...
@@ -71,37 +126,34 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -71,37 +126,34 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
dropout
=
dropout
,
dropout
=
dropout
,
conv_resample
=
conv_resample
,
conv_resample
=
conv_resample
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
flip_sin_to_cos
=
flip_sin_to_cos
,
downscale_freq_shift
=
downscale_freq_shift
,
# (TODO(PVP) - To delete once weights are converted
# (TODO(PVP) - To delete once weights are converted
attention_resolutions
=
attention_resolutions
,
attention_resolutions
=
attention_resolutions
,
ldm
=
ldm
,
ddpm
=
ddpm
,
)
)
# To delete - replace with config values
# To delete - replace with config values
self
.
image_size
=
image_size
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
time_embed_dim
=
block_channels
[
0
]
*
4
self
.
out_channels
=
out_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
dropout
=
dropout
time_embed_dim
=
block_input_channels
[
0
]
*
4
# # input
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
block_channels
[
0
],
kernel_size
=
3
,
padding
=
(
1
,
1
))
# ======================== Input ===================
# # time
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
block_input_channels
[
0
],
kernel_size
=
3
,
padding
=
(
1
,
1
))
self
.
time_embedding
=
TimestepEmbedding
(
block_channels
[
0
],
time_embed_dim
)
# ======================== Time ====================
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
block_input_channels
[
0
],
time_embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
),
)
# ======================== Down ====================
input_channels
=
list
(
block_input_channels
)
output_channels
=
list
(
block_output_channels
)
self
.
downsample_blocks
=
nn
.
ModuleList
([])
self
.
downsample_blocks
=
nn
.
ModuleList
([])
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
input_channels
,
output_channels
)):
self
.
mid
=
None
down_block_type
=
down_blocks
[
i
]
self
.
upsample_blocks
=
nn
.
ModuleList
([])
is_final_block
=
i
==
len
(
input_channels
)
-
1
# down
output_channel
=
block_channels
[
0
]
for
i
,
down_block_type
in
enumerate
(
down_blocks
):
input_channel
=
output_channel
output_channel
=
block_channels
[
i
]
is_final_block
=
i
==
len
(
block_channels
)
-
1
down_block
=
get_down_block
(
down_block
=
get_down_block
(
down_block_type
,
down_block_type
,
...
@@ -113,30 +165,48 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -113,30 +165,48 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
num_head_channels
,
attn_num_head_channels
=
num_head_channels
,
downsample_padding
=
downsample_padding
,
)
)
self
.
downsample_blocks
.
append
(
down_block
)
self
.
downsample_blocks
.
append
(
down_block
)
# ======================== Mid ====================
# mid
self
.
mid
=
UNetMidBlock2D
(
if
self
.
config
.
ddpm
:
in_channels
=
output_channels
[
-
1
],
self
.
mid_new_2
=
UNetMidBlock2D
(
dropout
=
dropout
,
in_channels
=
block_channels
[
-
1
],
temb_channels
=
time_embed_dim
,
dropout
=
dropout
,
resnet_eps
=
resnet_eps
,
temb_channels
=
time_embed_dim
,
resnet_act_fn
=
resnet_act_fn
,
resnet_eps
=
resnet_eps
,
resnet_time_scale_shift
=
"default"
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
num_head_channels
,
resnet_time_scale_shift
=
"default"
,
)
attn_num_head_channels
=
num_head_channels
,
)
else
:
self
.
mid
=
UNetMidBlock2D
(
in_channels
=
block_channels
[
-
1
],
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_time_scale_shift
=
"default"
,
attn_num_head_channels
=
num_head_channels
,
)
self
.
upsample_blocks
=
nn
.
ModuleList
([])
# up
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
reversed
(
input_channels
),
reversed
(
output_channels
))):
reversed_block_channels
=
list
(
reversed
(
block_channels
))
up_block_type
=
up_blocks
[
i
]
output_channel
=
reversed_block_channels
[
0
]
is_final_block
=
i
==
len
(
input_channels
)
-
1
for
i
,
up_block_type
in
enumerate
(
up_blocks
):
prev_output_channel
=
output_channel
output_channel
=
reversed_block_channels
[
i
]
input_channel
=
reversed_block_channels
[
min
(
i
+
1
,
len
(
block_channels
)
-
1
)]
is_final_block
=
i
==
len
(
block_channels
)
-
1
up_block
=
get_up_block
(
up_block
=
get_up_block
(
up_block_type
,
up_block_type
,
num_layers
=
num_res_blocks
+
1
,
num_layers
=
num_res_blocks
+
1
,
in_channels
=
output_channel
,
in_channels
=
input_channel
,
next_channels
=
input_channel
,
out_channels
=
output_channel
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
add_upsample
=
not
is_final_block
,
add_upsample
=
not
is_final_block
,
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
...
@@ -144,50 +214,72 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -144,50 +214,72 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
attn_num_head_channels
=
num_head_channels
,
attn_num_head_channels
=
num_head_channels
,
)
)
self
.
upsample_blocks
.
append
(
up_block
)
self
.
upsample_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
# out
self
.
conv_norm_out
=
nn
.
GroupNorm
(
num_channels
=
block_channels
[
0
],
num_groups
=
32
,
eps
=
1e-5
)
self
.
conv_act
=
nn
.
SiLU
()
self
.
conv_out
=
nn
.
Conv2d
(
block_channels
[
0
],
out_channels
,
3
,
padding
=
1
)
# ======================== Out ====================
# ======================== Out ====================
self
.
out
=
nn
.
Sequential
(
nn
.
GroupNorm
(
num_channels
=
output_channels
[
0
],
num_groups
=
32
,
eps
=
1e-5
),
nn
.
SiLU
(),
nn
.
Conv2d
(
block_input_channels
[
0
],
out_channels
,
3
,
padding
=
1
),
)
# =========== TO DELETE AFTER CONVERSION ==========
self
.
is_overwritten
=
False
transformer_depth
=
1
if
ldm
:
context_dim
=
None
# =========== TO DELETE AFTER CONVERSION ==========
legacy
=
True
transformer_depth
=
1
num_heads
=
-
1
context_dim
=
None
model_channels
=
block_input_channels
[
0
]
legacy
=
True
channel_mult
=
tuple
([
x
//
model_channels
for
x
in
block_output_channels
])
num_heads
=
-
1
self
.
init_for_ldm
(
model_channels
=
block_channels
[
0
]
in_channels
,
channel_mult
=
tuple
([
x
//
model_channels
for
x
in
block_channels
])
model_channels
,
self
.
init_for_ldm
(
channel_mult
,
in_channels
,
num_res_blocks
,
model_channels
,
dropout
,
channel_mult
,
time_embed_dim
,
num_res_blocks
,
attention_resolutions
,
dropout
,
num_head_channels
,
time_embed_dim
,
num_heads
,
attention_resolutions
,
legacy
,
num_head_channels
,
False
,
num_heads
,
transformer_depth
,
legacy
,
context_dim
,
False
,
conv_resample
,
transformer_depth
,
out_channels
,
context_dim
,
)
conv_resample
,
out_channels
,
)
if
ddpm
:
self
.
init_for_ddpm
(
ch_mult
,
ch
,
num_res_blocks
,
resolution
,
in_channels
,
resamp_with_conv
,
attn_resolutions
,
out_ch
,
dropout
=
0.1
,
)
def
forward
(
self
,
sample
,
timesteps
=
None
):
def
forward
(
self
,
sample
,
timesteps
=
None
):
# TODO(PVP) - to delete later
if
not
self
.
is_overwritten
:
self
.
set_weights
()
# 1. time step embeddings
# 1. time step embeddings
if
not
torch
.
is_tensor
(
timesteps
):
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
t_emb
=
get_timestep_embedding
(
t_emb
=
get_timestep_embedding
(
timesteps
,
self
.
config
.
block_input_channels
[
0
],
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
timesteps
,
self
.
config
.
block_channels
[
0
],
flip_sin_to_cos
=
self
.
config
.
flip_sin_to_cos
,
downscale_freq_shift
=
self
.
config
.
downscale_freq_shift
,
)
)
emb
=
self
.
time_embed
(
t_emb
)
emb
=
self
.
time_embed
ding
(
t_emb
)
# 2. pre-process sample
# 2. pre-process sample
# sample = sample.type(self.dtype_)
sample
=
self
.
conv_in
(
sample
)
sample
=
self
.
conv_in
(
sample
)
# 3. down blocks
# 3. down blocks
...
@@ -198,8 +290,13 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -198,8 +290,13 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# append to tuple
# append to tuple
down_block_res_samples
+=
res_samples
down_block_res_samples
+=
res_samples
print
(
"sample"
,
sample
.
abs
().
sum
())
# 4. mid block
# 4. mid block
sample
=
self
.
mid
(
sample
,
emb
)
if
self
.
config
.
ddpm
:
sample
=
self
.
mid_new_2
(
sample
,
emb
)
else
:
sample
=
self
.
mid
(
sample
,
emb
)
print
(
"sample"
,
sample
.
abs
().
sum
())
# 5. up blocks
# 5. up blocks
for
upsample_block
in
self
.
upsample_blocks
:
for
upsample_block
in
self
.
upsample_blocks
:
...
@@ -211,10 +308,192 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -211,10 +308,192 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
sample
=
upsample_block
(
sample
,
res_samples
,
emb
)
sample
=
upsample_block
(
sample
,
res_samples
,
emb
)
# 6. post-process sample
# 6. post-process sample
sample
=
self
.
out
(
sample
)
sample
=
self
.
conv_norm_out
(
sample
)
sample
=
self
.
conv_act
(
sample
)
sample
=
self
.
conv_out
(
sample
)
return
sample
return
sample
def
set_weights
(
self
):
self
.
is_overwritten
=
True
if
self
.
config
.
ldm
:
self
.
time_embedding
.
linear_1
.
weight
.
data
=
self
.
time_embed
[
0
].
weight
.
data
self
.
time_embedding
.
linear_1
.
bias
.
data
=
self
.
time_embed
[
0
].
bias
.
data
self
.
time_embedding
.
linear_2
.
weight
.
data
=
self
.
time_embed
[
2
].
weight
.
data
self
.
time_embedding
.
linear_2
.
bias
.
data
=
self
.
time_embed
[
2
].
bias
.
data
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
for
i
,
input_layer
in
enumerate
(
self
.
input_blocks
[
1
:]):
block_id
=
i
//
(
self
.
config
.
num_res_blocks
+
1
)
layer_in_block_id
=
i
%
(
self
.
config
.
num_res_blocks
+
1
)
if
layer_in_block_id
==
2
:
self
.
downsample_blocks
[
block_id
].
downsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
0
].
op
.
weight
.
data
self
.
downsample_blocks
[
block_id
].
downsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
0
].
op
.
bias
.
data
elif
len
(
input_layer
)
>
1
:
self
.
downsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
downsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
else
:
self
.
downsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
mid
.
resnets
[
0
].
set_weight
(
self
.
middle_block
[
0
])
self
.
mid
.
resnets
[
1
].
set_weight
(
self
.
middle_block
[
2
])
self
.
mid
.
attentions
[
0
].
set_weight
(
self
.
middle_block
[
1
])
for
i
,
input_layer
in
enumerate
(
self
.
output_blocks
):
block_id
=
i
//
(
self
.
config
.
num_res_blocks
+
1
)
layer_in_block_id
=
i
%
(
self
.
config
.
num_res_blocks
+
1
)
if
len
(
input_layer
)
>
2
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
2
].
conv
.
weight
.
data
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
2
].
conv
.
bias
.
data
elif
len
(
input_layer
)
>
1
and
"Upsample2D"
in
input_layer
[
1
].
__class__
.
__name__
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
1
].
conv
.
weight
.
data
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
1
].
conv
.
bias
.
data
elif
len
(
input_layer
)
>
1
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
else
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
conv_in
.
weight
.
data
=
self
.
input_blocks
[
0
][
0
].
weight
.
data
self
.
conv_in
.
bias
.
data
=
self
.
input_blocks
[
0
][
0
].
bias
.
data
self
.
conv_norm_out
.
weight
.
data
=
self
.
out
[
0
].
weight
.
data
self
.
conv_norm_out
.
bias
.
data
=
self
.
out
[
0
].
bias
.
data
self
.
conv_out
.
weight
.
data
=
self
.
out
[
2
].
weight
.
data
self
.
conv_out
.
bias
.
data
=
self
.
out
[
2
].
bias
.
data
self
.
remove_ldm
()
elif
self
.
config
.
ddpm
:
# =============== SET WEIGHTS ===============
# =============== TIME ======================
self
.
time_embed
[
0
]
=
self
.
temb
.
dense
[
0
]
self
.
time_embed
[
2
]
=
self
.
temb
.
dense
[
1
]
for
i
,
block
in
enumerate
(
self
.
down
):
if
hasattr
(
block
,
"downsample"
):
self
.
downsample_blocks
[
i
].
downsamplers
[
0
].
conv
.
weight
.
data
=
block
.
downsample
.
conv
.
weight
.
data
self
.
downsample_blocks
[
i
].
downsamplers
[
0
].
conv
.
bias
.
data
=
block
.
downsample
.
conv
.
bias
.
data
if
hasattr
(
block
,
"block"
)
and
len
(
block
.
block
)
>
0
:
for
j
in
range
(
self
.
num_res_blocks
):
self
.
downsample_blocks
[
i
].
resnets
[
j
].
set_weight
(
block
.
block
[
j
])
if
hasattr
(
block
,
"attn"
)
and
len
(
block
.
attn
)
>
0
:
for
j
in
range
(
self
.
num_res_blocks
):
self
.
downsample_blocks
[
i
].
attentions
[
j
].
set_weight
(
block
.
attn
[
j
])
self
.
mid_new_2
.
resnets
[
0
].
set_weight
(
self
.
mid
.
block_1
)
self
.
mid_new_2
.
resnets
[
1
].
set_weight
(
self
.
mid
.
block_2
)
self
.
mid_new_2
.
attentions
[
0
].
set_weight
(
self
.
mid
.
attn_1
)
def
init_for_ddpm
(
self
,
ch_mult
,
ch
,
num_res_blocks
,
resolution
,
in_channels
,
resamp_with_conv
,
attn_resolutions
,
out_ch
,
dropout
=
0.1
,
):
ch_mult
=
tuple
(
ch_mult
)
self
.
ch
=
ch
self
.
temb_ch
=
self
.
ch
*
4
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
# timestep embedding
self
.
temb
=
nn
.
Module
()
self
.
temb
.
dense
=
nn
.
ModuleList
(
[
torch
.
nn
.
Linear
(
self
.
ch
,
self
.
temb_ch
),
torch
.
nn
.
Linear
(
self
.
temb_ch
,
self
.
temb_ch
),
]
)
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
ch_mult
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample2D
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid_new
=
UNetMidBlock2D
(
in_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid_new
.
resnets
[
0
]
=
self
.
mid
.
block_1
self
.
mid_new
.
attentions
[
0
]
=
self
.
mid
.
attn_1
self
.
mid_new
.
resnets
[
1
]
=
self
.
mid
.
block_2
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
skip_in
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
if
i_block
==
self
.
num_res_blocks
:
skip_in
=
ch
*
in_ch_mult
[
i_level
]
block
.
append
(
ResnetBlock2D
(
in_channels
=
block_in
+
skip_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample2D
(
block_in
,
use_conv
=
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
init_for_ldm
(
def
init_for_ldm
(
self
,
self
,
in_channels
,
in_channels
,
...
@@ -234,7 +513,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -234,7 +513,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_channels
,
out_channels
,
):
):
# TODO(PVP) - delete after weight conversion
# TODO(PVP) - delete after weight conversion
class
TimestepEmbedSequential
(
nn
.
Sequential
):
class
TimestepEmbedSequential
(
nn
.
Sequential
):
"""
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
A sequential module that passes timestep embeddings to the children that support it as an extra input.
...
@@ -255,6 +533,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -255,6 +533,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
),
)
dims
=
2
dims
=
2
self
.
input_blocks
=
nn
.
ModuleList
(
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
...
@@ -389,42 +673,15 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -389,42 +673,15 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
self
.
out
=
nn
.
Sequential
(
for
i
,
input_layer
in
enumerate
(
self
.
input_blocks
[
1
:]):
nn
.
GroupNorm
(
num_channels
=
model_channels
,
num_groups
=
32
,
eps
=
1e-5
),
block_id
=
i
//
(
num_res_blocks
+
1
)
nn
.
SiLU
(),
layer_in_block_id
=
i
%
(
num_res_blocks
+
1
)
nn
.
Conv2d
(
model_channels
,
out_channels
,
3
,
padding
=
1
),
)
if
layer_in_block_id
==
2
:
self
.
downsample_blocks
[
block_id
].
downsamplers
[
0
].
op
.
weight
.
data
=
input_layer
[
0
].
op
.
weight
.
data
def
remove_ldm
(
self
):
self
.
downsample_blocks
[
block_id
].
downsamplers
[
0
].
op
.
bias
.
data
=
input_layer
[
0
].
op
.
bias
.
data
del
self
.
time_embed
elif
len
(
input_layer
)
>
1
:
del
self
.
input_blocks
self
.
downsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
del
self
.
middle_block
self
.
downsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
del
self
.
output_blocks
else
:
del
self
.
out
self
.
downsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
mid
.
resnets
[
0
].
set_weight
(
self
.
middle_block
[
0
])
self
.
mid
.
resnets
[
1
].
set_weight
(
self
.
middle_block
[
2
])
self
.
mid
.
attentions
[
0
].
set_weight
(
self
.
middle_block
[
1
])
for
i
,
input_layer
in
enumerate
(
self
.
output_blocks
):
block_id
=
i
//
(
num_res_blocks
+
1
)
layer_in_block_id
=
i
%
(
num_res_blocks
+
1
)
if
len
(
input_layer
)
>
2
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
2
].
conv
.
weight
.
data
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
2
].
conv
.
bias
.
data
elif
len
(
input_layer
)
>
1
and
"Upsample2D"
in
input_layer
[
1
].
__class__
.
__name__
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
1
].
conv
.
weight
.
data
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
1
].
conv
.
bias
.
data
elif
len
(
input_layer
)
>
1
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
else
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
conv_in
.
weight
.
data
=
self
.
input_blocks
[
0
][
0
].
weight
.
data
self
.
conv_in
.
bias
.
data
=
self
.
input_blocks
[
0
][
0
].
bias
.
data
tests/test_modeling_utils.py
View file @
e7fe901e
...
@@ -271,6 +271,27 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -271,6 +271,27 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
print
(
"Original success!!!"
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
ddpm
=
True
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
GlideSuperResUNetModel
model_class
=
GlideSuperResUNetModel
...
@@ -486,18 +507,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -486,18 +507,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"out_channels"
:
4
,
"out_channels"
:
4
,
"num_res_blocks"
:
2
,
"num_res_blocks"
:
2
,
"attention_resolutions"
:
(
16
,),
"attention_resolutions"
:
(
16
,),
"block_input_channels"
:
[
32
,
32
],
"block_channels"
:
(
32
,
64
),
"block_output_channels"
:
[
32
,
64
],
"num_head_channels"
:
32
,
"num_head_channels"
:
32
,
"conv_resample"
:
True
,
"conv_resample"
:
True
,
"down_blocks"
:
(
"UNetResDownBlock2D"
,
"UNetResDownBlock2D"
),
"down_blocks"
:
(
"UNetResDownBlock2D"
,
"UNetResDownBlock2D"
),
"up_blocks"
:
(
"UNetResUpBlock2D"
,
"UNetResUpBlock2D"
),
"up_blocks"
:
(
"UNetResUpBlock2D"
,
"UNetResUpBlock2D"
),
"ldm"
:
True
,
}
}
inputs_dict
=
self
.
dummy_input
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
)
model
,
loading_info
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
,
ldm
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
...
@@ -507,7 +530,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -507,7 +530,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
def
test_output_pretrained
(
self
):
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
ldm
=
True
)
model
.
eval
()
model
.
eval
()
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment