Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e7fe901e
Unverified
Commit
e7fe901e
authored
Jul 14, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 14, 2022
Browse files
save intermediate (#87)
* save intermediate * up * up
parent
c3d78cd3
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
801 additions
and
173 deletions
+801
-173
conversion.py
conversion.py
+94
-0
docs/source/examples/diffusers_for_vision.mdx
docs/source/examples/diffusers_for_vision.mdx
+2
-1
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+2
-0
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+3
-0
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+223
-5
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+22
-17
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+2
-0
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+50
-27
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+376
-119
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+27
-4
No files found.
conversion.py
0 → 100755
View file @
e7fe901e
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
inspect
import
tempfile
import
unittest
import
numpy
as
np
import
torch
from
diffusers
import
(
AutoencoderKL
,
DDIMPipeline
,
DDIMScheduler
,
DDPMPipeline
,
DDPMScheduler
,
GlidePipeline
,
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
LatentDiffusionPipeline
,
LatentDiffusionUncondPipeline
,
NCSNpp
,
PNDMPipeline
,
PNDMScheduler
,
ScoreSdeVePipeline
,
ScoreSdeVeScheduler
,
ScoreSdeVpPipeline
,
ScoreSdeVpScheduler
,
UNetLDMModel
,
UNetModel
,
UNetUnconditionalModel
,
VQModel
,
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.training_utils
import
EMAModel
def
test_output_pretrained_ldm_dummy
():
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
ldm
=
True
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
print
(
model
)
import
ipdb
;
ipdb
.
set_trace
()
def
test_output_pretrained_ldm
():
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/latent-diffusion-celeba-256"
,
subfolder
=
"unet"
,
ldm
=
True
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
print
(
model
)
import
ipdb
;
ipdb
.
set_trace
()
# To see the how the final model should look like
test_output_pretrained_ldm_dummy
()
test_output_pretrained_ldm
()
# => this is the architecture in which the model should be saved in the new format
# -> verify new repo with the following tests (in `test_modeling_utils.py`)
# - test_ldm_uncond (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetLDMModelTests)
docs/source/examples/diffusers_for_vision.mdx
View file @
e7fe901e
...
@@ -111,7 +111,7 @@ prompt = "A painting of a squirrel eating a burger"
...
@@ -111,7 +111,7 @@ prompt = "A painting of a squirrel eating a burger"
image
=
ldm
([
prompt
],
generator
=
generator
,
eta
=
0.3
,
guidance_scale
=
6.0
,
num_inference_steps
=
50
)
image
=
ldm
([
prompt
],
generator
=
generator
,
eta
=
0.3
,
guidance_scale
=
6.0
,
num_inference_steps
=
50
)
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
image_processed
*
255.
image_processed
=
image_processed
*
255.
0
image_processed
=
image_processed
.
numpy
().
astype
(
np
.
uint8
)
image_processed
=
image_processed
.
numpy
().
astype
(
np
.
uint8
)
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
...
@@ -143,6 +143,7 @@ audio = bddm(mel_spec, generator, torch_device=torch_device)
...
@@ -143,6 +143,7 @@ audio = bddm(mel_spec, generator, torch_device=torch_device)
#
save
generated
audio
#
save
generated
audio
from
scipy
.
io
.
wavfile
import
write
as
wavwrite
from
scipy
.
io
.
wavfile
import
write
as
wavwrite
sampling_rate
=
22050
sampling_rate
=
22050
wavwrite
(
"generated_audio.wav"
,
sampling_rate
,
audio
.
squeeze
().
cpu
().
numpy
())
wavwrite
(
"generated_audio.wav"
,
sampling_rate
,
audio
.
squeeze
().
cpu
().
numpy
())
```
```
...
...
src/diffusers/configuration_utils.py
View file @
e7fe901e
...
@@ -116,6 +116,7 @@ class ConfigMixin:
...
@@ -116,6 +116,7 @@ class ConfigMixin:
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
subfolder
=
kwargs
.
pop
(
"subfolder"
,
None
)
user_agent
=
{
"file_type"
:
"config"
}
user_agent
=
{
"file_type"
:
"config"
}
...
@@ -150,6 +151,7 @@ class ConfigMixin:
...
@@ -150,6 +151,7 @@ class ConfigMixin:
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
user_agent
=
user_agent
,
subfolder
=
subfolder
,
)
)
except
RepositoryNotFoundError
:
except
RepositoryNotFoundError
:
...
...
src/diffusers/modeling_utils.py
View file @
e7fe901e
...
@@ -321,6 +321,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -321,6 +321,7 @@ class ModelMixin(torch.nn.Module):
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
from_auto_class
=
kwargs
.
pop
(
"_from_auto"
,
False
)
from_auto_class
=
kwargs
.
pop
(
"_from_auto"
,
False
)
subfolder
=
kwargs
.
pop
(
"subfolder"
,
None
)
user_agent
=
{
"file_type"
:
"model"
,
"framework"
:
"pytorch"
,
"from_auto_class"
:
from_auto_class
}
user_agent
=
{
"file_type"
:
"model"
,
"framework"
:
"pytorch"
,
"from_auto_class"
:
from_auto_class
}
...
@@ -336,6 +337,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -336,6 +337,7 @@ class ModelMixin(torch.nn.Module):
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
revision
=
revision
,
subfolder
=
subfolder
,
**
kwargs
,
**
kwargs
,
)
)
model
.
register_to_config
(
name_or_path
=
pretrained_model_name_or_path
)
model
.
register_to_config
(
name_or_path
=
pretrained_model_name_or_path
)
...
@@ -363,6 +365,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -363,6 +365,7 @@ class ModelMixin(torch.nn.Module):
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
user_agent
=
user_agent
,
subfolder
=
subfolder
,
)
)
except
RepositoryNotFoundError
:
except
RepositoryNotFoundError
:
...
...
src/diffusers/models/attention.py
View file @
e7fe901e
...
@@ -51,6 +51,7 @@ class AttentionBlock(nn.Module):
...
@@ -51,6 +51,7 @@ class AttentionBlock(nn.Module):
overwrite_qkv
=
False
,
overwrite_qkv
=
False
,
overwrite_linear
=
False
,
overwrite_linear
=
False
,
rescale_output_factor
=
1.0
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -62,7 +63,7 @@ class AttentionBlock(nn.Module):
...
@@ -62,7 +63,7 @@ class AttentionBlock(nn.Module):
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
num_heads
=
channels
//
num_head_channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-5
,
affine
=
True
)
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
self
.
num_heads
self
.
n_heads
=
self
.
num_heads
self
.
rescale_output_factor
=
rescale_output_factor
self
.
rescale_output_factor
=
rescale_output_factor
...
@@ -165,7 +166,7 @@ class AttentionBlock(nn.Module):
...
@@ -165,7 +166,7 @@ class AttentionBlock(nn.Module):
return
result
return
result
class
AttentionBlockNew
(
nn
.
Module
):
class
AttentionBlockNew
_2
(
nn
.
Module
):
"""
"""
An attention block that allows spatial positions to attend to each other.
An attention block that allows spatial positions to attend to each other.
...
@@ -180,11 +181,14 @@ class AttentionBlockNew(nn.Module):
...
@@ -180,11 +181,14 @@ class AttentionBlockNew(nn.Module):
num_groups
=
32
,
num_groups
=
32
,
encoder_channels
=
None
,
encoder_channels
=
None
,
rescale_output_factor
=
1.0
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-5
,
affine
=
True
)
self
.
channels
=
channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
channels
//
num_head_channels
self
.
n_heads
=
channels
//
num_head_channels
self
.
num_head_size
=
num_head_channels
self
.
rescale_output_factor
=
rescale_output_factor
self
.
rescale_output_factor
=
rescale_output_factor
if
encoder_channels
is
not
None
:
if
encoder_channels
is
not
None
:
...
@@ -192,6 +196,28 @@ class AttentionBlockNew(nn.Module):
...
@@ -192,6 +196,28 @@ class AttentionBlockNew(nn.Module):
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
# ------------------------- new -----------------------
num_heads
=
self
.
n_heads
self
.
channels
=
channels
if
num_head_channels
is
None
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
group_norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
# define q,k,v as linear layers
self
.
query
=
nn
.
Linear
(
channels
,
channels
)
self
.
key
=
nn
.
Linear
(
channels
,
channels
)
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
))
# ------------------------- new -----------------------
def
set_weight
(
self
,
attn_layer
):
def
set_weight
(
self
,
attn_layer
):
self
.
norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
self
.
norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
...
@@ -202,6 +228,89 @@ class AttentionBlockNew(nn.Module):
...
@@ -202,6 +228,89 @@ class AttentionBlockNew(nn.Module):
self
.
proj
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
self
.
proj
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
self
.
proj
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
self
.
proj
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
if
hasattr
(
attn_layer
,
"q"
):
module
=
attn_layer
qkv_weight
=
torch
.
cat
([
module
.
q
.
weight
.
data
,
module
.
k
.
weight
.
data
,
module
.
v
.
weight
.
data
],
dim
=
0
)[
:,
:,
:,
0
]
qkv_bias
=
torch
.
cat
([
module
.
q
.
bias
.
data
,
module
.
k
.
bias
.
data
,
module
.
v
.
bias
.
data
],
dim
=
0
)
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
bias
.
data
=
qkv_bias
proj_out
=
zero_module
(
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
))
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
self
.
proj
=
proj_out
self
.
set_weights_2
(
attn_layer
)
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
n_heads
,
self
.
num_head_size
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
def
set_weights_2
(
self
,
attn_layer
):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
n_heads
,
3
*
self
.
channels
//
self
.
n_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
n_heads
,
3
*
self
.
channels
//
self
.
n_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
n_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
n_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
def
forward_2
(
self
,
hidden_states
):
residual
=
hidden_states
batch
,
channel
,
height
,
width
=
hidden_states
.
shape
# norm
hidden_states
=
self
.
group_norm
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
batch
,
channel
,
height
*
width
).
transpose
(
1
,
2
)
# proj to q, k, v
query_proj
=
self
.
query
(
hidden_states
)
key_proj
=
self
.
key
(
hidden_states
)
value_proj
=
self
.
value
(
hidden_states
)
# transpose
query_states
=
self
.
transpose_for_scores
(
query_proj
)
key_states
=
self
.
transpose_for_scores
(
key_proj
)
value_states
=
self
.
transpose_for_scores
(
value_proj
)
# get scores
attention_scores
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
math
.
sqrt
(
self
.
channels
//
self
.
n_heads
)
attention_probs
=
nn
.
functional
.
softmax
(
attention_scores
,
dim
=-
1
)
# compute attention output
context_states
=
torch
.
matmul
(
attention_probs
,
value_states
)
context_states
=
context_states
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_states_shape
=
context_states
.
size
()[:
-
2
]
+
(
self
.
channels
,)
context_states
=
context_states
.
view
(
new_context_states_shape
)
# compute next hidden_states
hidden_states
=
self
.
proj_attn
(
context_states
)
hidden_states
=
hidden_states
.
transpose
(
-
1
,
-
2
).
reshape
(
batch
,
channel
,
height
,
width
)
# res connect and rescale
hidden_states
=
(
hidden_states
+
residual
)
/
self
.
rescale_output_factor
return
hidden_states
def
forward
(
self
,
x
,
encoder_out
=
None
):
def
forward
(
self
,
x
,
encoder_out
=
None
):
b
,
c
,
*
spatial
=
x
.
shape
b
,
c
,
*
spatial
=
x
.
shape
hid_states
=
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
)
hid_states
=
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
)
...
@@ -230,10 +339,119 @@ class AttentionBlockNew(nn.Module):
...
@@ -230,10 +339,119 @@ class AttentionBlockNew(nn.Module):
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
result
=
x
+
h
result
=
x
+
h
result
=
result
/
self
.
rescale_output_factor
result
=
result
/
self
.
rescale_output_factor
return
result
result_2
=
self
.
forward_2
(
x
)
print
((
result
-
result_2
).
abs
().
sum
())
return
result_2
class
AttentionBlockNew
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=
None
,
num_groups
=
32
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
is
None
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
num_head_size
=
num_head_channels
self
.
group_norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
# define q,k,v as linear layers
self
.
query
=
nn
.
Linear
(
channels
,
channels
)
self
.
key
=
nn
.
Linear
(
channels
,
channels
)
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
))
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
self
.
num_head_size
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
def
forward
(
self
,
hidden_states
):
residual
=
hidden_states
batch
,
channel
,
height
,
width
=
hidden_states
.
shape
# norm
hidden_states
=
self
.
group_norm
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
batch
,
channel
,
height
*
width
).
transpose
(
1
,
2
)
# proj to q, k, v
query_proj
=
self
.
query
(
hidden_states
)
key_proj
=
self
.
key
(
hidden_states
)
value_proj
=
self
.
value
(
hidden_states
)
# transpose
query_states
=
self
.
transpose_for_scores
(
query_proj
)
key_states
=
self
.
transpose_for_scores
(
key_proj
)
value_states
=
self
.
transpose_for_scores
(
value_proj
)
# get scores
attention_scores
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
math
.
sqrt
(
self
.
channels
//
self
.
num_heads
)
attention_probs
=
nn
.
functional
.
softmax
(
attention_scores
,
dim
=-
1
)
# compute attention output
context_states
=
torch
.
matmul
(
attention_probs
,
value_states
)
context_states
=
context_states
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_states_shape
=
context_states
.
size
()[:
-
2
]
+
(
self
.
channels
,)
context_states
=
context_states
.
view
(
new_context_states_shape
)
# compute next hidden_states
hidden_states
=
self
.
proj_attn
(
context_states
)
hidden_states
=
hidden_states
.
transpose
(
-
1
,
-
2
).
reshape
(
batch
,
channel
,
height
,
width
)
# res connect and rescale
hidden_states
=
(
hidden_states
+
residual
)
/
self
.
rescale_output_factor
return
hidden_states
def
set_weight
(
self
,
attn_layer
):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
class
SpatialTransformer
(
nn
.
Module
):
class
SpatialTransformer
(
nn
.
Module
):
...
...
src/diffusers/models/resnet.py
View file @
e7fe901e
...
@@ -81,8 +81,10 @@ class Downsample2D(nn.Module):
...
@@ -81,8 +81,10 @@ class Downsample2D(nn.Module):
self
.
conv
=
conv
self
.
conv
=
conv
elif
name
==
"Conv2d_0"
:
elif
name
==
"Conv2d_0"
:
self
.
Conv2d_0
=
conv
self
.
Conv2d_0
=
conv
self
.
conv
=
conv
else
:
else
:
self
.
op
=
conv
self
.
op
=
conv
self
.
conv
=
conv
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
x
.
shape
[
1
]
==
self
.
channels
...
@@ -90,13 +92,16 @@ class Downsample2D(nn.Module):
...
@@ -90,13 +92,16 @@ class Downsample2D(nn.Module):
pad
=
(
0
,
1
,
0
,
1
)
pad
=
(
0
,
1
,
0
,
1
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
return
self
.
conv
(
x
)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if
self
.
name
==
"conv"
:
return
self
.
conv
(
x
)
elif
self
.
name
==
"Conv2d_0"
:
# if self.name == "conv":
return
self
.
Conv2d_0
(
x
)
# return self.conv(x)
else
:
# elif self.name == "Conv2d_0":
return
self
.
op
(
x
)
# return self.Conv2d_0(x)
# else:
# return self.op(x)
class
Upsample1D
(
nn
.
Module
):
class
Upsample1D
(
nn
.
Module
):
...
@@ -656,9 +661,9 @@ class ResnetBlock(nn.Module):
...
@@ -656,9 +661,9 @@ class ResnetBlock(nn.Module):
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
time_embedding_norm
==
"default"
and
temb_channels
>
0
:
if
time_embedding_norm
==
"default"
and
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
t
ime_
emb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
elif
time_embedding_norm
==
"scale_shift"
and
temb_channels
>
0
:
elif
time_embedding_norm
==
"scale_shift"
and
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
t
ime_
emb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
norm2
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
...
@@ -691,9 +696,9 @@ class ResnetBlock(nn.Module):
...
@@ -691,9 +696,9 @@ class ResnetBlock(nn.Module):
self
.
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
use_nin_shortcut
is
None
else
use_nin_shortcut
self
.
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
use_nin_shortcut
is
None
else
use_nin_shortcut
self
.
nin
_shortcut
=
None
self
.
conv
_shortcut
=
None
if
self
.
use_nin_shortcut
:
if
self
.
use_nin_shortcut
:
self
.
nin
_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
conv
_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
x
...
@@ -715,7 +720,7 @@ class ResnetBlock(nn.Module):
...
@@ -715,7 +720,7 @@ class ResnetBlock(nn.Module):
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
temb
is
not
None
:
if
temb
is
not
None
:
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
temb
=
self
.
t
ime_
emb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
else
:
else
:
temb
=
0
temb
=
0
...
@@ -738,8 +743,8 @@ class ResnetBlock(nn.Module):
...
@@ -738,8 +743,8 @@ class ResnetBlock(nn.Module):
h
=
self
.
norm2
(
h
)
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
nin
_shortcut
is
not
None
:
if
self
.
conv
_shortcut
is
not
None
:
x
=
self
.
nin
_shortcut
(
x
)
x
=
self
.
conv
_shortcut
(
x
)
return
(
x
+
h
)
/
self
.
output_scale_factor
return
(
x
+
h
)
/
self
.
output_scale_factor
...
@@ -750,8 +755,8 @@ class ResnetBlock(nn.Module):
...
@@ -750,8 +755,8 @@ class ResnetBlock(nn.Module):
self
.
conv1
.
weight
.
data
=
resnet
.
conv1
.
weight
.
data
self
.
conv1
.
weight
.
data
=
resnet
.
conv1
.
weight
.
data
self
.
conv1
.
bias
.
data
=
resnet
.
conv1
.
bias
.
data
self
.
conv1
.
bias
.
data
=
resnet
.
conv1
.
bias
.
data
self
.
temb_proj
.
weight
.
data
=
resnet
.
temb_proj
.
weight
.
data
self
.
t
ime_
emb_proj
.
weight
.
data
=
resnet
.
temb_proj
.
weight
.
data
self
.
temb_proj
.
bias
.
data
=
resnet
.
temb_proj
.
bias
.
data
self
.
t
ime_
emb_proj
.
bias
.
data
=
resnet
.
temb_proj
.
bias
.
data
self
.
norm2
.
weight
.
data
=
resnet
.
norm2
.
weight
.
data
self
.
norm2
.
weight
.
data
=
resnet
.
norm2
.
weight
.
data
self
.
norm2
.
bias
.
data
=
resnet
.
norm2
.
bias
.
data
self
.
norm2
.
bias
.
data
=
resnet
.
norm2
.
bias
.
data
...
@@ -760,8 +765,8 @@ class ResnetBlock(nn.Module):
...
@@ -760,8 +765,8 @@ class ResnetBlock(nn.Module):
self
.
conv2
.
bias
.
data
=
resnet
.
conv2
.
bias
.
data
self
.
conv2
.
bias
.
data
=
resnet
.
conv2
.
bias
.
data
if
self
.
use_nin_shortcut
:
if
self
.
use_nin_shortcut
:
self
.
nin
_shortcut
.
weight
.
data
=
resnet
.
nin_shortcut
.
weight
.
data
self
.
conv
_shortcut
.
weight
.
data
=
resnet
.
nin_shortcut
.
weight
.
data
self
.
nin
_shortcut
.
bias
.
data
=
resnet
.
nin_shortcut
.
bias
.
data
self
.
conv
_shortcut
.
bias
.
data
=
resnet
.
nin_shortcut
.
bias
.
data
# TODO(Patrick) - just there to convert the weights; can delete afterward
# TODO(Patrick) - just there to convert the weights; can delete afterward
...
...
src/diffusers/models/unet.py
View file @
e7fe901e
...
@@ -177,7 +177,9 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -177,7 +177,9 @@ class UNetModel(ModelMixin, ConfigMixin):
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
# middle
print
(
"hs"
,
hs
[
-
1
].
abs
().
sum
())
h
=
self
.
mid_new
(
hs
[
-
1
],
temb
)
h
=
self
.
mid_new
(
hs
[
-
1
],
temb
)
print
(
"h"
,
h
.
abs
().
sum
())
# upsampling
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
...
...
src/diffusers/models/unet_new.py
View file @
e7fe901e
...
@@ -29,9 +29,10 @@ def get_down_block(
...
@@ -29,9 +29,10 @@ def get_down_block(
resnet_eps
,
resnet_eps
,
resnet_act_fn
,
resnet_act_fn
,
attn_num_head_channels
,
attn_num_head_channels
,
downsample_padding
=
None
,
):
):
if
down_block_type
==
"UNetResDownBlock2D"
:
if
down_block_type
==
"UNetResDownBlock2D"
:
return
UNetRes
Attn
DownBlock2D
(
return
UNetResDownBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
...
@@ -39,6 +40,7 @@ def get_down_block(
...
@@ -39,6 +40,7 @@ def get_down_block(
add_downsample
=
add_downsample
,
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_act_fn
=
resnet_act_fn
,
downsample_padding
=
downsample_padding
,
)
)
elif
down_block_type
==
"UNetResAttnDownBlock2D"
:
elif
down_block_type
==
"UNetResAttnDownBlock2D"
:
return
UNetResAttnDownBlock2D
(
return
UNetResAttnDownBlock2D
(
...
@@ -57,7 +59,8 @@ def get_up_block(
...
@@ -57,7 +59,8 @@ def get_up_block(
up_block_type
,
up_block_type
,
num_layers
,
num_layers
,
in_channels
,
in_channels
,
next_channels
,
out_channels
,
prev_output_channel
,
temb_channels
,
temb_channels
,
add_upsample
,
add_upsample
,
resnet_eps
,
resnet_eps
,
...
@@ -68,7 +71,8 @@ def get_up_block(
...
@@ -68,7 +71,8 @@ def get_up_block(
return
UNetResUpBlock2D
(
return
UNetResUpBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
next_channels
=
next_channels
,
out_channels
=
out_channels
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
...
@@ -78,7 +82,8 @@ def get_up_block(
...
@@ -78,7 +82,8 @@ def get_up_block(
return
UNetResAttnUpBlock2D
(
return
UNetResAttnUpBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
next_channels
=
next_channels
,
out_channels
=
out_channels
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
...
@@ -100,11 +105,14 @@ class UNetMidBlock2D(nn.Module):
...
@@ -100,11 +105,14 @@ class UNetMidBlock2D(nn.Module):
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
**
kwargs
,
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
attention_type
=
attention_type
# there is always at least one resnet
# there is always at least one resnet
resnets
=
[
resnets
=
[
ResnetBlock
(
ResnetBlock
(
...
@@ -128,6 +136,7 @@ class UNetMidBlock2D(nn.Module):
...
@@ -128,6 +136,7 @@ class UNetMidBlock2D(nn.Module):
in_channels
,
in_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
eps
=
resnet_eps
,
)
)
)
)
resnets
.
append
(
resnets
.
append
(
...
@@ -148,18 +157,15 @@ class UNetMidBlock2D(nn.Module):
...
@@ -148,18 +157,15 @@ class UNetMidBlock2D(nn.Module):
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
,
mask
=
None
):
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
):
if
mask
is
not
None
:
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
,
mask
=
mask
)
else
:
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
hidden_states
=
attn
(
hidden_states
,
encoder_states
)
if
self
.
attention_type
==
"default"
:
if
mask
is
not
None
:
hidden_states
=
attn
(
hidden_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
,
mask
=
mask
)
else
:
else
:
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
attn
(
hidden_states
,
encoder_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
return
hidden_states
return
hidden_states
...
@@ -178,6 +184,7 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -178,6 +184,7 @@ class UNetResAttnDownBlock2D(nn.Module):
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
add_downsample
=
True
,
):
):
...
@@ -185,6 +192,8 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -185,6 +192,8 @@ class UNetResAttnDownBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
...
@@ -206,6 +215,7 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -206,6 +215,7 @@ class UNetResAttnDownBlock2D(nn.Module):
out_channels
,
out_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
eps
=
resnet_eps
,
)
)
)
)
...
@@ -251,6 +261,7 @@ class UNetResDownBlock2D(nn.Module):
...
@@ -251,6 +261,7 @@ class UNetResDownBlock2D(nn.Module):
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
add_downsample
=
True
,
downsample_padding
=
1
,
):
):
super
().
__init__
()
super
().
__init__
()
resnets
=
[]
resnets
=
[]
...
@@ -276,7 +287,11 @@ class UNetResDownBlock2D(nn.Module):
...
@@ -276,7 +287,11 @@ class UNetResDownBlock2D(nn.Module):
if
add_downsample
:
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
(
self
.
downsamplers
=
nn
.
ModuleList
(
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
1
,
name
=
"op"
)]
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
name
=
"op"
)
]
)
)
else
:
else
:
self
.
downsamplers
=
None
self
.
downsamplers
=
None
...
@@ -301,7 +316,8 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -301,7 +316,8 @@ class UNetResAttnUpBlock2D(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
:
int
,
in_channels
:
int
,
next_channels
:
int
,
prev_output_channel
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
num_layers
:
int
=
1
,
...
@@ -310,7 +326,7 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -310,7 +326,7 @@ class UNetResAttnUpBlock2D(nn.Module):
resnet_act_fn
:
str
=
"swish"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attention_
layer_type
:
str
=
"self
"
,
attention_
type
=
"default
"
,
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
add_upsample
=
True
,
...
@@ -319,12 +335,16 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -319,12 +335,16 @@ class UNetResAttnUpBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
resnet_channels
=
in_channels
if
i
<
num_layers
-
1
else
next_channels
res_skip_channels
=
in_channels
if
(
i
==
num_layers
-
1
)
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
(
in_channels
=
in_channels
+
res
net
_channels
,
in_channels
=
resnet_
in_channels
+
res
_skip
_channels
,
out_channels
=
in
_channels
,
out_channels
=
out
_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
groups
=
resnet_groups
,
...
@@ -337,9 +357,10 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -337,9 +357,10 @@ class UNetResAttnUpBlock2D(nn.Module):
)
)
attentions
.
append
(
attentions
.
append
(
AttentionBlockNew
(
AttentionBlockNew
(
in
_channels
,
out
_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
eps
=
resnet_eps
,
)
)
)
)
...
@@ -347,7 +368,7 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -347,7 +368,7 @@ class UNetResAttnUpBlock2D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_upsample
:
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
in
_channels
,
use_conv
=
True
,
out_channels
=
in
_channels
)])
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
out
_channels
,
use_conv
=
True
,
out_channels
=
out
_channels
)])
else
:
else
:
self
.
upsamplers
=
None
self
.
upsamplers
=
None
...
@@ -373,7 +394,8 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -373,7 +394,8 @@ class UNetResUpBlock2D(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
:
int
,
in_channels
:
int
,
next_channels
:
int
,
prev_output_channel
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
num_layers
:
int
=
1
,
...
@@ -382,7 +404,6 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -382,7 +404,6 @@ class UNetResUpBlock2D(nn.Module):
resnet_act_fn
:
str
=
"swish"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attention_layer_type
:
str
=
"self"
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
add_upsample
=
True
,
):
):
...
@@ -390,11 +411,13 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -390,11 +411,13 @@ class UNetResUpBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
resnet_channels
=
in_channels
if
i
<
num_layers
-
1
else
next_channels
res_skip_channels
=
in_channels
if
(
i
==
num_layers
-
1
)
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
(
in_channels
=
in_channels
+
res
net
_channels
,
in_channels
=
resnet_
in_channels
+
res
_skip
_channels
,
out_channels
=
in
_channels
,
out_channels
=
out
_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
groups
=
resnet_groups
,
...
@@ -409,7 +432,7 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -409,7 +432,7 @@ class UNetResUpBlock2D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_upsample
:
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
in
_channels
,
use_conv
=
True
,
out_channels
=
in
_channels
)])
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
out
_channels
,
use_conv
=
True
,
out_channels
=
out
_channels
)])
else
:
else
:
self
.
upsamplers
=
None
self
.
upsamplers
=
None
...
...
src/diffusers/models/unet_unconditional.py
View file @
e7fe901e
...
@@ -9,6 +9,30 @@ from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
...
@@ -9,6 +9,30 @@ from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from
.unet_new
import
UNetMidBlock2D
,
get_down_block
,
get_up_block
from
.unet_new
import
UNetMidBlock2D
,
get_down_block
,
get_up_block
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
TimestepEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
channel
,
time_embed_dim
):
super
().
__init__
()
self
.
linear_1
=
nn
.
Linear
(
channel
,
time_embed_dim
)
self
.
act
=
nn
.
SiLU
()
self
.
linear_2
=
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
)
def
forward
(
self
,
sample
):
sample
=
self
.
linear_1
(
sample
)
sample
=
self
.
act
(
sample
)
sample
=
self
.
linear_2
(
sample
)
return
sample
class
UNetUnconditionalModel
(
ModelMixin
,
ConfigMixin
):
class
UNetUnconditionalModel
(
ModelMixin
,
ConfigMixin
):
"""
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
...
@@ -35,35 +59,66 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -35,35 +59,66 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
def
__init__
(
def
__init__
(
self
,
self
,
image_size
,
image_size
=
None
,
in_channels
,
in_channels
=
None
,
out_channels
,
out_channels
=
None
,
num_res_blocks
,
num_res_blocks
=
None
,
dropout
=
0
,
dropout
=
0
,
block_input_channels
=
(
224
,
224
,
448
,
672
),
block_channels
=
(
224
,
448
,
672
,
896
),
block_output_channels
=
(
224
,
448
,
672
,
896
),
down_blocks
=
(
down_blocks
=
(
"UNetResDownBlock2D"
,
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
),
),
downsample_padding
=
1
,
up_blocks
=
(
"UNetResAttnUpBlock2D"
,
"UNetResAttnUpBlock2D"
,
"UNetResAttnUpBlock2D"
,
"UNetResUpBlock2D"
),
up_blocks
=
(
"UNetResAttnUpBlock2D"
,
"UNetResAttnUpBlock2D"
,
"UNetResAttnUpBlock2D"
,
"UNetResUpBlock2D"
),
resnet_act_fn
=
"silu"
,
resnet_act_fn
=
"silu"
,
resnet_eps
=
1e-5
,
resnet_eps
=
1e-5
,
conv_resample
=
True
,
conv_resample
=
True
,
num_head_channels
=
32
,
num_head_channels
=
32
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
,
# To delete once weights are converted
# To delete once weights are converted
# LDM
attention_resolutions
=
(
8
,
4
,
2
),
attention_resolutions
=
(
8
,
4
,
2
),
ldm
=
False
,
# DDPM
out_ch
=
None
,
resolution
=
None
,
attn_resolutions
=
None
,
resamp_with_conv
=
None
,
ch_mult
=
None
,
ch
=
None
,
ddpm
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
# DELETE if statements if not necessary anymore
# DDPM
if
ddpm
:
out_channels
=
out_ch
image_size
=
resolution
block_channels
=
[
x
*
ch
for
x
in
ch_mult
]
conv_resample
=
resamp_with_conv
flip_sin_to_cos
=
False
downscale_freq_shift
=
1
resnet_eps
=
1e-6
block_channels
=
(
32
,
64
)
down_blocks
=
(
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
)
up_blocks
=
(
"UNetResUpBlock2D"
,
"UNetResAttnUpBlock2D"
)
downsample_padding
=
0
num_head_channels
=
64
# register all __init__ params with self.register
# register all __init__ params with self.register
self
.
register_to_config
(
self
.
register_to_config
(
image_size
=
image_size
,
image_size
=
image_size
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
block_
input_
channels
=
block_
input_
channels
,
block_channels
=
block_channels
,
block_output_channels
=
block_output_channels
,
downsample_padding
=
downsample_padding
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
num_res_blocks
=
num_res_blocks
,
down_blocks
=
down_blocks
,
down_blocks
=
down_blocks
,
...
@@ -71,37 +126,34 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -71,37 +126,34 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
dropout
=
dropout
,
dropout
=
dropout
,
conv_resample
=
conv_resample
,
conv_resample
=
conv_resample
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
flip_sin_to_cos
=
flip_sin_to_cos
,
downscale_freq_shift
=
downscale_freq_shift
,
# (TODO(PVP) - To delete once weights are converted
# (TODO(PVP) - To delete once weights are converted
attention_resolutions
=
attention_resolutions
,
attention_resolutions
=
attention_resolutions
,
ldm
=
ldm
,
ddpm
=
ddpm
,
)
)
# To delete - replace with config values
# To delete - replace with config values
self
.
image_size
=
image_size
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
time_embed_dim
=
block_channels
[
0
]
*
4
self
.
out_channels
=
out_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
dropout
=
dropout
time_embed_dim
=
block_input_channels
[
0
]
*
4
# # input
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
block_channels
[
0
],
kernel_size
=
3
,
padding
=
(
1
,
1
))
# ======================== Input ===================
# # time
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
block_input_channels
[
0
],
kernel_size
=
3
,
padding
=
(
1
,
1
))
self
.
time_embedding
=
TimestepEmbedding
(
block_channels
[
0
],
time_embed_dim
)
# ======================== Time ====================
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
block_input_channels
[
0
],
time_embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
),
)
# ======================== Down ====================
input_channels
=
list
(
block_input_channels
)
output_channels
=
list
(
block_output_channels
)
self
.
downsample_blocks
=
nn
.
ModuleList
([])
self
.
downsample_blocks
=
nn
.
ModuleList
([])
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
input_channels
,
output_channels
)):
self
.
mid
=
None
down_block_type
=
down_blocks
[
i
]
self
.
upsample_blocks
=
nn
.
ModuleList
([])
is_final_block
=
i
==
len
(
input_channels
)
-
1
# down
output_channel
=
block_channels
[
0
]
for
i
,
down_block_type
in
enumerate
(
down_blocks
):
input_channel
=
output_channel
output_channel
=
block_channels
[
i
]
is_final_block
=
i
==
len
(
block_channels
)
-
1
down_block
=
get_down_block
(
down_block
=
get_down_block
(
down_block_type
,
down_block_type
,
...
@@ -113,30 +165,48 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -113,30 +165,48 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
num_head_channels
,
attn_num_head_channels
=
num_head_channels
,
downsample_padding
=
downsample_padding
,
)
)
self
.
downsample_blocks
.
append
(
down_block
)
self
.
downsample_blocks
.
append
(
down_block
)
# ======================== Mid ====================
# mid
self
.
mid
=
UNetMidBlock2D
(
if
self
.
config
.
ddpm
:
in_channels
=
output_channels
[
-
1
],
self
.
mid_new_2
=
UNetMidBlock2D
(
dropout
=
dropout
,
in_channels
=
block_channels
[
-
1
],
temb_channels
=
time_embed_dim
,
dropout
=
dropout
,
resnet_eps
=
resnet_eps
,
temb_channels
=
time_embed_dim
,
resnet_act_fn
=
resnet_act_fn
,
resnet_eps
=
resnet_eps
,
resnet_time_scale_shift
=
"default"
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
num_head_channels
,
resnet_time_scale_shift
=
"default"
,
)
attn_num_head_channels
=
num_head_channels
,
)
else
:
self
.
mid
=
UNetMidBlock2D
(
in_channels
=
block_channels
[
-
1
],
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_time_scale_shift
=
"default"
,
attn_num_head_channels
=
num_head_channels
,
)
self
.
upsample_blocks
=
nn
.
ModuleList
([])
# up
for
i
,
(
input_channel
,
output_channel
)
in
enumerate
(
zip
(
reversed
(
input_channels
),
reversed
(
output_channels
))):
reversed_block_channels
=
list
(
reversed
(
block_channels
))
up_block_type
=
up_blocks
[
i
]
output_channel
=
reversed_block_channels
[
0
]
is_final_block
=
i
==
len
(
input_channels
)
-
1
for
i
,
up_block_type
in
enumerate
(
up_blocks
):
prev_output_channel
=
output_channel
output_channel
=
reversed_block_channels
[
i
]
input_channel
=
reversed_block_channels
[
min
(
i
+
1
,
len
(
block_channels
)
-
1
)]
is_final_block
=
i
==
len
(
block_channels
)
-
1
up_block
=
get_up_block
(
up_block
=
get_up_block
(
up_block_type
,
up_block_type
,
num_layers
=
num_res_blocks
+
1
,
num_layers
=
num_res_blocks
+
1
,
in_channels
=
output_channel
,
in_channels
=
input_channel
,
next_channels
=
input_channel
,
out_channels
=
output_channel
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
add_upsample
=
not
is_final_block
,
add_upsample
=
not
is_final_block
,
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
...
@@ -144,50 +214,72 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -144,50 +214,72 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
attn_num_head_channels
=
num_head_channels
,
attn_num_head_channels
=
num_head_channels
,
)
)
self
.
upsample_blocks
.
append
(
up_block
)
self
.
upsample_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
# out
self
.
conv_norm_out
=
nn
.
GroupNorm
(
num_channels
=
block_channels
[
0
],
num_groups
=
32
,
eps
=
1e-5
)
self
.
conv_act
=
nn
.
SiLU
()
self
.
conv_out
=
nn
.
Conv2d
(
block_channels
[
0
],
out_channels
,
3
,
padding
=
1
)
# ======================== Out ====================
# ======================== Out ====================
self
.
out
=
nn
.
Sequential
(
nn
.
GroupNorm
(
num_channels
=
output_channels
[
0
],
num_groups
=
32
,
eps
=
1e-5
),
nn
.
SiLU
(),
nn
.
Conv2d
(
block_input_channels
[
0
],
out_channels
,
3
,
padding
=
1
),
)
# =========== TO DELETE AFTER CONVERSION ==========
self
.
is_overwritten
=
False
transformer_depth
=
1
if
ldm
:
context_dim
=
None
# =========== TO DELETE AFTER CONVERSION ==========
legacy
=
True
transformer_depth
=
1
num_heads
=
-
1
context_dim
=
None
model_channels
=
block_input_channels
[
0
]
legacy
=
True
channel_mult
=
tuple
([
x
//
model_channels
for
x
in
block_output_channels
])
num_heads
=
-
1
self
.
init_for_ldm
(
model_channels
=
block_channels
[
0
]
in_channels
,
channel_mult
=
tuple
([
x
//
model_channels
for
x
in
block_channels
])
model_channels
,
self
.
init_for_ldm
(
channel_mult
,
in_channels
,
num_res_blocks
,
model_channels
,
dropout
,
channel_mult
,
time_embed_dim
,
num_res_blocks
,
attention_resolutions
,
dropout
,
num_head_channels
,
time_embed_dim
,
num_heads
,
attention_resolutions
,
legacy
,
num_head_channels
,
False
,
num_heads
,
transformer_depth
,
legacy
,
context_dim
,
False
,
conv_resample
,
transformer_depth
,
out_channels
,
context_dim
,
)
conv_resample
,
out_channels
,
)
if
ddpm
:
self
.
init_for_ddpm
(
ch_mult
,
ch
,
num_res_blocks
,
resolution
,
in_channels
,
resamp_with_conv
,
attn_resolutions
,
out_ch
,
dropout
=
0.1
,
)
def
forward
(
self
,
sample
,
timesteps
=
None
):
def
forward
(
self
,
sample
,
timesteps
=
None
):
# TODO(PVP) - to delete later
if
not
self
.
is_overwritten
:
self
.
set_weights
()
# 1. time step embeddings
# 1. time step embeddings
if
not
torch
.
is_tensor
(
timesteps
):
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
t_emb
=
get_timestep_embedding
(
t_emb
=
get_timestep_embedding
(
timesteps
,
self
.
config
.
block_input_channels
[
0
],
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
timesteps
,
self
.
config
.
block_channels
[
0
],
flip_sin_to_cos
=
self
.
config
.
flip_sin_to_cos
,
downscale_freq_shift
=
self
.
config
.
downscale_freq_shift
,
)
)
emb
=
self
.
time_embed
(
t_emb
)
emb
=
self
.
time_embed
ding
(
t_emb
)
# 2. pre-process sample
# 2. pre-process sample
# sample = sample.type(self.dtype_)
sample
=
self
.
conv_in
(
sample
)
sample
=
self
.
conv_in
(
sample
)
# 3. down blocks
# 3. down blocks
...
@@ -198,8 +290,13 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -198,8 +290,13 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# append to tuple
# append to tuple
down_block_res_samples
+=
res_samples
down_block_res_samples
+=
res_samples
print
(
"sample"
,
sample
.
abs
().
sum
())
# 4. mid block
# 4. mid block
sample
=
self
.
mid
(
sample
,
emb
)
if
self
.
config
.
ddpm
:
sample
=
self
.
mid_new_2
(
sample
,
emb
)
else
:
sample
=
self
.
mid
(
sample
,
emb
)
print
(
"sample"
,
sample
.
abs
().
sum
())
# 5. up blocks
# 5. up blocks
for
upsample_block
in
self
.
upsample_blocks
:
for
upsample_block
in
self
.
upsample_blocks
:
...
@@ -211,10 +308,192 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -211,10 +308,192 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
sample
=
upsample_block
(
sample
,
res_samples
,
emb
)
sample
=
upsample_block
(
sample
,
res_samples
,
emb
)
# 6. post-process sample
# 6. post-process sample
sample
=
self
.
out
(
sample
)
sample
=
self
.
conv_norm_out
(
sample
)
sample
=
self
.
conv_act
(
sample
)
sample
=
self
.
conv_out
(
sample
)
return
sample
return
sample
def
set_weights
(
self
):
self
.
is_overwritten
=
True
if
self
.
config
.
ldm
:
self
.
time_embedding
.
linear_1
.
weight
.
data
=
self
.
time_embed
[
0
].
weight
.
data
self
.
time_embedding
.
linear_1
.
bias
.
data
=
self
.
time_embed
[
0
].
bias
.
data
self
.
time_embedding
.
linear_2
.
weight
.
data
=
self
.
time_embed
[
2
].
weight
.
data
self
.
time_embedding
.
linear_2
.
bias
.
data
=
self
.
time_embed
[
2
].
bias
.
data
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
for
i
,
input_layer
in
enumerate
(
self
.
input_blocks
[
1
:]):
block_id
=
i
//
(
self
.
config
.
num_res_blocks
+
1
)
layer_in_block_id
=
i
%
(
self
.
config
.
num_res_blocks
+
1
)
if
layer_in_block_id
==
2
:
self
.
downsample_blocks
[
block_id
].
downsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
0
].
op
.
weight
.
data
self
.
downsample_blocks
[
block_id
].
downsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
0
].
op
.
bias
.
data
elif
len
(
input_layer
)
>
1
:
self
.
downsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
downsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
else
:
self
.
downsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
mid
.
resnets
[
0
].
set_weight
(
self
.
middle_block
[
0
])
self
.
mid
.
resnets
[
1
].
set_weight
(
self
.
middle_block
[
2
])
self
.
mid
.
attentions
[
0
].
set_weight
(
self
.
middle_block
[
1
])
for
i
,
input_layer
in
enumerate
(
self
.
output_blocks
):
block_id
=
i
//
(
self
.
config
.
num_res_blocks
+
1
)
layer_in_block_id
=
i
%
(
self
.
config
.
num_res_blocks
+
1
)
if
len
(
input_layer
)
>
2
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
2
].
conv
.
weight
.
data
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
2
].
conv
.
bias
.
data
elif
len
(
input_layer
)
>
1
and
"Upsample2D"
in
input_layer
[
1
].
__class__
.
__name__
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
1
].
conv
.
weight
.
data
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
1
].
conv
.
bias
.
data
elif
len
(
input_layer
)
>
1
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
else
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
conv_in
.
weight
.
data
=
self
.
input_blocks
[
0
][
0
].
weight
.
data
self
.
conv_in
.
bias
.
data
=
self
.
input_blocks
[
0
][
0
].
bias
.
data
self
.
conv_norm_out
.
weight
.
data
=
self
.
out
[
0
].
weight
.
data
self
.
conv_norm_out
.
bias
.
data
=
self
.
out
[
0
].
bias
.
data
self
.
conv_out
.
weight
.
data
=
self
.
out
[
2
].
weight
.
data
self
.
conv_out
.
bias
.
data
=
self
.
out
[
2
].
bias
.
data
self
.
remove_ldm
()
elif
self
.
config
.
ddpm
:
# =============== SET WEIGHTS ===============
# =============== TIME ======================
self
.
time_embed
[
0
]
=
self
.
temb
.
dense
[
0
]
self
.
time_embed
[
2
]
=
self
.
temb
.
dense
[
1
]
for
i
,
block
in
enumerate
(
self
.
down
):
if
hasattr
(
block
,
"downsample"
):
self
.
downsample_blocks
[
i
].
downsamplers
[
0
].
conv
.
weight
.
data
=
block
.
downsample
.
conv
.
weight
.
data
self
.
downsample_blocks
[
i
].
downsamplers
[
0
].
conv
.
bias
.
data
=
block
.
downsample
.
conv
.
bias
.
data
if
hasattr
(
block
,
"block"
)
and
len
(
block
.
block
)
>
0
:
for
j
in
range
(
self
.
num_res_blocks
):
self
.
downsample_blocks
[
i
].
resnets
[
j
].
set_weight
(
block
.
block
[
j
])
if
hasattr
(
block
,
"attn"
)
and
len
(
block
.
attn
)
>
0
:
for
j
in
range
(
self
.
num_res_blocks
):
self
.
downsample_blocks
[
i
].
attentions
[
j
].
set_weight
(
block
.
attn
[
j
])
self
.
mid_new_2
.
resnets
[
0
].
set_weight
(
self
.
mid
.
block_1
)
self
.
mid_new_2
.
resnets
[
1
].
set_weight
(
self
.
mid
.
block_2
)
self
.
mid_new_2
.
attentions
[
0
].
set_weight
(
self
.
mid
.
attn_1
)
def
init_for_ddpm
(
self
,
ch_mult
,
ch
,
num_res_blocks
,
resolution
,
in_channels
,
resamp_with_conv
,
attn_resolutions
,
out_ch
,
dropout
=
0.1
,
):
ch_mult
=
tuple
(
ch_mult
)
self
.
ch
=
ch
self
.
temb_ch
=
self
.
ch
*
4
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
# timestep embedding
self
.
temb
=
nn
.
Module
()
self
.
temb
.
dense
=
nn
.
ModuleList
(
[
torch
.
nn
.
Linear
(
self
.
ch
,
self
.
temb_ch
),
torch
.
nn
.
Linear
(
self
.
temb_ch
,
self
.
temb_ch
),
]
)
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
ch_mult
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample2D
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock2D
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid_new
=
UNetMidBlock2D
(
in_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid_new
.
resnets
[
0
]
=
self
.
mid
.
block_1
self
.
mid_new
.
attentions
[
0
]
=
self
.
mid
.
attn_1
self
.
mid_new
.
resnets
[
1
]
=
self
.
mid
.
block_2
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
skip_in
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
if
i_block
==
self
.
num_res_blocks
:
skip_in
=
ch
*
in_ch_mult
[
i_level
]
block
.
append
(
ResnetBlock2D
(
in_channels
=
block_in
+
skip_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample2D
(
block_in
,
use_conv
=
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
init_for_ldm
(
def
init_for_ldm
(
self
,
self
,
in_channels
,
in_channels
,
...
@@ -234,7 +513,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -234,7 +513,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_channels
,
out_channels
,
):
):
# TODO(PVP) - delete after weight conversion
# TODO(PVP) - delete after weight conversion
class
TimestepEmbedSequential
(
nn
.
Sequential
):
class
TimestepEmbedSequential
(
nn
.
Sequential
):
"""
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
A sequential module that passes timestep embeddings to the children that support it as an extra input.
...
@@ -255,6 +533,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -255,6 +533,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
),
)
dims
=
2
dims
=
2
self
.
input_blocks
=
nn
.
ModuleList
(
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
...
@@ -389,42 +673,15 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -389,42 +673,15 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
self
.
out
=
nn
.
Sequential
(
for
i
,
input_layer
in
enumerate
(
self
.
input_blocks
[
1
:]):
nn
.
GroupNorm
(
num_channels
=
model_channels
,
num_groups
=
32
,
eps
=
1e-5
),
block_id
=
i
//
(
num_res_blocks
+
1
)
nn
.
SiLU
(),
layer_in_block_id
=
i
%
(
num_res_blocks
+
1
)
nn
.
Conv2d
(
model_channels
,
out_channels
,
3
,
padding
=
1
),
)
if
layer_in_block_id
==
2
:
self
.
downsample_blocks
[
block_id
].
downsamplers
[
0
].
op
.
weight
.
data
=
input_layer
[
0
].
op
.
weight
.
data
def
remove_ldm
(
self
):
self
.
downsample_blocks
[
block_id
].
downsamplers
[
0
].
op
.
bias
.
data
=
input_layer
[
0
].
op
.
bias
.
data
del
self
.
time_embed
elif
len
(
input_layer
)
>
1
:
del
self
.
input_blocks
self
.
downsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
del
self
.
middle_block
self
.
downsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
del
self
.
output_blocks
else
:
del
self
.
out
self
.
downsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
mid
.
resnets
[
0
].
set_weight
(
self
.
middle_block
[
0
])
self
.
mid
.
resnets
[
1
].
set_weight
(
self
.
middle_block
[
2
])
self
.
mid
.
attentions
[
0
].
set_weight
(
self
.
middle_block
[
1
])
for
i
,
input_layer
in
enumerate
(
self
.
output_blocks
):
block_id
=
i
//
(
num_res_blocks
+
1
)
layer_in_block_id
=
i
%
(
num_res_blocks
+
1
)
if
len
(
input_layer
)
>
2
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
2
].
conv
.
weight
.
data
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
2
].
conv
.
bias
.
data
elif
len
(
input_layer
)
>
1
and
"Upsample2D"
in
input_layer
[
1
].
__class__
.
__name__
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
1
].
conv
.
weight
.
data
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
1
].
conv
.
bias
.
data
elif
len
(
input_layer
)
>
1
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
else
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
conv_in
.
weight
.
data
=
self
.
input_blocks
[
0
][
0
].
weight
.
data
self
.
conv_in
.
bias
.
data
=
self
.
input_blocks
[
0
][
0
].
bias
.
data
tests/test_modeling_utils.py
View file @
e7fe901e
...
@@ -271,6 +271,27 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -271,6 +271,27 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
print
(
"Original success!!!"
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
ddpm
=
True
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
GlideSuperResUNetModel
model_class
=
GlideSuperResUNetModel
...
@@ -486,18 +507,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -486,18 +507,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"out_channels"
:
4
,
"out_channels"
:
4
,
"num_res_blocks"
:
2
,
"num_res_blocks"
:
2
,
"attention_resolutions"
:
(
16
,),
"attention_resolutions"
:
(
16
,),
"block_input_channels"
:
[
32
,
32
],
"block_channels"
:
(
32
,
64
),
"block_output_channels"
:
[
32
,
64
],
"num_head_channels"
:
32
,
"num_head_channels"
:
32
,
"conv_resample"
:
True
,
"conv_resample"
:
True
,
"down_blocks"
:
(
"UNetResDownBlock2D"
,
"UNetResDownBlock2D"
),
"down_blocks"
:
(
"UNetResDownBlock2D"
,
"UNetResDownBlock2D"
),
"up_blocks"
:
(
"UNetResUpBlock2D"
,
"UNetResUpBlock2D"
),
"up_blocks"
:
(
"UNetResUpBlock2D"
,
"UNetResUpBlock2D"
),
"ldm"
:
True
,
}
}
inputs_dict
=
self
.
dummy_input
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
)
model
,
loading_info
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
,
ldm
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
...
@@ -507,7 +530,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -507,7 +530,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
def
test_output_pretrained
(
self
):
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
ldm
=
True
)
model
.
eval
()
model
.
eval
()
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment