Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e7fe901e
Unverified
Commit
e7fe901e
authored
Jul 14, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 14, 2022
Browse files
save intermediate (#87)
* save intermediate * up * up
parent
c3d78cd3
Changes
10
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
801 additions
and
173 deletions
+801
-173
conversion.py
conversion.py
+94
-0
docs/source/examples/diffusers_for_vision.mdx
docs/source/examples/diffusers_for_vision.mdx
+2
-1
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+2
-0
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+3
-0
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+223
-5
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+22
-17
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+2
-0
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+50
-27
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+376
-119
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+27
-4
No files found.
conversion.py
0 → 100755
View file @
e7fe901e
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
inspect
import
tempfile
import
unittest
import
numpy
as
np
import
torch
from
diffusers
import
(
AutoencoderKL
,
DDIMPipeline
,
DDIMScheduler
,
DDPMPipeline
,
DDPMScheduler
,
GlidePipeline
,
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
LatentDiffusionPipeline
,
LatentDiffusionUncondPipeline
,
NCSNpp
,
PNDMPipeline
,
PNDMScheduler
,
ScoreSdeVePipeline
,
ScoreSdeVeScheduler
,
ScoreSdeVpPipeline
,
ScoreSdeVpScheduler
,
UNetLDMModel
,
UNetModel
,
UNetUnconditionalModel
,
VQModel
,
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.training_utils
import
EMAModel
def
test_output_pretrained_ldm_dummy
():
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
ldm
=
True
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
print
(
model
)
import
ipdb
;
ipdb
.
set_trace
()
def
test_output_pretrained_ldm
():
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/latent-diffusion-celeba-256"
,
subfolder
=
"unet"
,
ldm
=
True
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
print
(
model
)
import
ipdb
;
ipdb
.
set_trace
()
# To see the how the final model should look like
test_output_pretrained_ldm_dummy
()
test_output_pretrained_ldm
()
# => this is the architecture in which the model should be saved in the new format
# -> verify new repo with the following tests (in `test_modeling_utils.py`)
# - test_ldm_uncond (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetLDMModelTests)
docs/source/examples/diffusers_for_vision.mdx
View file @
e7fe901e
...
@@ -111,7 +111,7 @@ prompt = "A painting of a squirrel eating a burger"
...
@@ -111,7 +111,7 @@ prompt = "A painting of a squirrel eating a burger"
image
=
ldm
([
prompt
],
generator
=
generator
,
eta
=
0.3
,
guidance_scale
=
6.0
,
num_inference_steps
=
50
)
image
=
ldm
([
prompt
],
generator
=
generator
,
eta
=
0.3
,
guidance_scale
=
6.0
,
num_inference_steps
=
50
)
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
image_processed
*
255.
image_processed
=
image_processed
*
255.
0
image_processed
=
image_processed
.
numpy
().
astype
(
np
.
uint8
)
image_processed
=
image_processed
.
numpy
().
astype
(
np
.
uint8
)
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
...
@@ -143,6 +143,7 @@ audio = bddm(mel_spec, generator, torch_device=torch_device)
...
@@ -143,6 +143,7 @@ audio = bddm(mel_spec, generator, torch_device=torch_device)
#
save
generated
audio
#
save
generated
audio
from
scipy
.
io
.
wavfile
import
write
as
wavwrite
from
scipy
.
io
.
wavfile
import
write
as
wavwrite
sampling_rate
=
22050
sampling_rate
=
22050
wavwrite
(
"generated_audio.wav"
,
sampling_rate
,
audio
.
squeeze
().
cpu
().
numpy
())
wavwrite
(
"generated_audio.wav"
,
sampling_rate
,
audio
.
squeeze
().
cpu
().
numpy
())
```
```
...
...
src/diffusers/configuration_utils.py
View file @
e7fe901e
...
@@ -116,6 +116,7 @@ class ConfigMixin:
...
@@ -116,6 +116,7 @@ class ConfigMixin:
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
subfolder
=
kwargs
.
pop
(
"subfolder"
,
None
)
user_agent
=
{
"file_type"
:
"config"
}
user_agent
=
{
"file_type"
:
"config"
}
...
@@ -150,6 +151,7 @@ class ConfigMixin:
...
@@ -150,6 +151,7 @@ class ConfigMixin:
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
user_agent
=
user_agent
,
subfolder
=
subfolder
,
)
)
except
RepositoryNotFoundError
:
except
RepositoryNotFoundError
:
...
...
src/diffusers/modeling_utils.py
View file @
e7fe901e
...
@@ -321,6 +321,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -321,6 +321,7 @@ class ModelMixin(torch.nn.Module):
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
from_auto_class
=
kwargs
.
pop
(
"_from_auto"
,
False
)
from_auto_class
=
kwargs
.
pop
(
"_from_auto"
,
False
)
subfolder
=
kwargs
.
pop
(
"subfolder"
,
None
)
user_agent
=
{
"file_type"
:
"model"
,
"framework"
:
"pytorch"
,
"from_auto_class"
:
from_auto_class
}
user_agent
=
{
"file_type"
:
"model"
,
"framework"
:
"pytorch"
,
"from_auto_class"
:
from_auto_class
}
...
@@ -336,6 +337,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -336,6 +337,7 @@ class ModelMixin(torch.nn.Module):
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
revision
=
revision
,
subfolder
=
subfolder
,
**
kwargs
,
**
kwargs
,
)
)
model
.
register_to_config
(
name_or_path
=
pretrained_model_name_or_path
)
model
.
register_to_config
(
name_or_path
=
pretrained_model_name_or_path
)
...
@@ -363,6 +365,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -363,6 +365,7 @@ class ModelMixin(torch.nn.Module):
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
user_agent
=
user_agent
,
subfolder
=
subfolder
,
)
)
except
RepositoryNotFoundError
:
except
RepositoryNotFoundError
:
...
...
src/diffusers/models/attention.py
View file @
e7fe901e
...
@@ -51,6 +51,7 @@ class AttentionBlock(nn.Module):
...
@@ -51,6 +51,7 @@ class AttentionBlock(nn.Module):
overwrite_qkv
=
False
,
overwrite_qkv
=
False
,
overwrite_linear
=
False
,
overwrite_linear
=
False
,
rescale_output_factor
=
1.0
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -62,7 +63,7 @@ class AttentionBlock(nn.Module):
...
@@ -62,7 +63,7 @@ class AttentionBlock(nn.Module):
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
num_heads
=
channels
//
num_head_channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-5
,
affine
=
True
)
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
self
.
num_heads
self
.
n_heads
=
self
.
num_heads
self
.
rescale_output_factor
=
rescale_output_factor
self
.
rescale_output_factor
=
rescale_output_factor
...
@@ -165,7 +166,7 @@ class AttentionBlock(nn.Module):
...
@@ -165,7 +166,7 @@ class AttentionBlock(nn.Module):
return
result
return
result
class
AttentionBlockNew
(
nn
.
Module
):
class
AttentionBlockNew
_2
(
nn
.
Module
):
"""
"""
An attention block that allows spatial positions to attend to each other.
An attention block that allows spatial positions to attend to each other.
...
@@ -180,11 +181,14 @@ class AttentionBlockNew(nn.Module):
...
@@ -180,11 +181,14 @@ class AttentionBlockNew(nn.Module):
num_groups
=
32
,
num_groups
=
32
,
encoder_channels
=
None
,
encoder_channels
=
None
,
rescale_output_factor
=
1.0
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-5
,
affine
=
True
)
self
.
channels
=
channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
channels
//
num_head_channels
self
.
n_heads
=
channels
//
num_head_channels
self
.
num_head_size
=
num_head_channels
self
.
rescale_output_factor
=
rescale_output_factor
self
.
rescale_output_factor
=
rescale_output_factor
if
encoder_channels
is
not
None
:
if
encoder_channels
is
not
None
:
...
@@ -192,6 +196,28 @@ class AttentionBlockNew(nn.Module):
...
@@ -192,6 +196,28 @@ class AttentionBlockNew(nn.Module):
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
# ------------------------- new -----------------------
num_heads
=
self
.
n_heads
self
.
channels
=
channels
if
num_head_channels
is
None
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
group_norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
# define q,k,v as linear layers
self
.
query
=
nn
.
Linear
(
channels
,
channels
)
self
.
key
=
nn
.
Linear
(
channels
,
channels
)
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
))
# ------------------------- new -----------------------
def
set_weight
(
self
,
attn_layer
):
def
set_weight
(
self
,
attn_layer
):
self
.
norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
self
.
norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
...
@@ -202,6 +228,89 @@ class AttentionBlockNew(nn.Module):
...
@@ -202,6 +228,89 @@ class AttentionBlockNew(nn.Module):
self
.
proj
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
self
.
proj
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
self
.
proj
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
self
.
proj
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
if
hasattr
(
attn_layer
,
"q"
):
module
=
attn_layer
qkv_weight
=
torch
.
cat
([
module
.
q
.
weight
.
data
,
module
.
k
.
weight
.
data
,
module
.
v
.
weight
.
data
],
dim
=
0
)[
:,
:,
:,
0
]
qkv_bias
=
torch
.
cat
([
module
.
q
.
bias
.
data
,
module
.
k
.
bias
.
data
,
module
.
v
.
bias
.
data
],
dim
=
0
)
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
bias
.
data
=
qkv_bias
proj_out
=
zero_module
(
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
))
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
self
.
proj
=
proj_out
self
.
set_weights_2
(
attn_layer
)
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
n_heads
,
self
.
num_head_size
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
def
set_weights_2
(
self
,
attn_layer
):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
n_heads
,
3
*
self
.
channels
//
self
.
n_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
n_heads
,
3
*
self
.
channels
//
self
.
n_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
n_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
n_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
def
forward_2
(
self
,
hidden_states
):
residual
=
hidden_states
batch
,
channel
,
height
,
width
=
hidden_states
.
shape
# norm
hidden_states
=
self
.
group_norm
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
batch
,
channel
,
height
*
width
).
transpose
(
1
,
2
)
# proj to q, k, v
query_proj
=
self
.
query
(
hidden_states
)
key_proj
=
self
.
key
(
hidden_states
)
value_proj
=
self
.
value
(
hidden_states
)
# transpose
query_states
=
self
.
transpose_for_scores
(
query_proj
)
key_states
=
self
.
transpose_for_scores
(
key_proj
)
value_states
=
self
.
transpose_for_scores
(
value_proj
)
# get scores
attention_scores
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
math
.
sqrt
(
self
.
channels
//
self
.
n_heads
)
attention_probs
=
nn
.
functional
.
softmax
(
attention_scores
,
dim
=-
1
)
# compute attention output
context_states
=
torch
.
matmul
(
attention_probs
,
value_states
)
context_states
=
context_states
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_states_shape
=
context_states
.
size
()[:
-
2
]
+
(
self
.
channels
,)
context_states
=
context_states
.
view
(
new_context_states_shape
)
# compute next hidden_states
hidden_states
=
self
.
proj_attn
(
context_states
)
hidden_states
=
hidden_states
.
transpose
(
-
1
,
-
2
).
reshape
(
batch
,
channel
,
height
,
width
)
# res connect and rescale
hidden_states
=
(
hidden_states
+
residual
)
/
self
.
rescale_output_factor
return
hidden_states
def
forward
(
self
,
x
,
encoder_out
=
None
):
def
forward
(
self
,
x
,
encoder_out
=
None
):
b
,
c
,
*
spatial
=
x
.
shape
b
,
c
,
*
spatial
=
x
.
shape
hid_states
=
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
)
hid_states
=
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
)
...
@@ -230,10 +339,119 @@ class AttentionBlockNew(nn.Module):
...
@@ -230,10 +339,119 @@ class AttentionBlockNew(nn.Module):
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
result
=
x
+
h
result
=
x
+
h
result
=
result
/
self
.
rescale_output_factor
result
=
result
/
self
.
rescale_output_factor
return
result
result_2
=
self
.
forward_2
(
x
)
print
((
result
-
result_2
).
abs
().
sum
())
return
result_2
class
AttentionBlockNew
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=
None
,
num_groups
=
32
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
is
None
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
num_head_size
=
num_head_channels
self
.
group_norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
# define q,k,v as linear layers
self
.
query
=
nn
.
Linear
(
channels
,
channels
)
self
.
key
=
nn
.
Linear
(
channels
,
channels
)
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
))
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
self
.
num_head_size
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
def
forward
(
self
,
hidden_states
):
residual
=
hidden_states
batch
,
channel
,
height
,
width
=
hidden_states
.
shape
# norm
hidden_states
=
self
.
group_norm
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
batch
,
channel
,
height
*
width
).
transpose
(
1
,
2
)
# proj to q, k, v
query_proj
=
self
.
query
(
hidden_states
)
key_proj
=
self
.
key
(
hidden_states
)
value_proj
=
self
.
value
(
hidden_states
)
# transpose
query_states
=
self
.
transpose_for_scores
(
query_proj
)
key_states
=
self
.
transpose_for_scores
(
key_proj
)
value_states
=
self
.
transpose_for_scores
(
value_proj
)
# get scores
attention_scores
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
math
.
sqrt
(
self
.
channels
//
self
.
num_heads
)
attention_probs
=
nn
.
functional
.
softmax
(
attention_scores
,
dim
=-
1
)
# compute attention output
context_states
=
torch
.
matmul
(
attention_probs
,
value_states
)
context_states
=
context_states
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_states_shape
=
context_states
.
size
()[:
-
2
]
+
(
self
.
channels
,)
context_states
=
context_states
.
view
(
new_context_states_shape
)
# compute next hidden_states
hidden_states
=
self
.
proj_attn
(
context_states
)
hidden_states
=
hidden_states
.
transpose
(
-
1
,
-
2
).
reshape
(
batch
,
channel
,
height
,
width
)
# res connect and rescale
hidden_states
=
(
hidden_states
+
residual
)
/
self
.
rescale_output_factor
return
hidden_states
def
set_weight
(
self
,
attn_layer
):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
class
SpatialTransformer
(
nn
.
Module
):
class
SpatialTransformer
(
nn
.
Module
):
...
...
src/diffusers/models/resnet.py
View file @
e7fe901e
...
@@ -81,8 +81,10 @@ class Downsample2D(nn.Module):
...
@@ -81,8 +81,10 @@ class Downsample2D(nn.Module):
self
.
conv
=
conv
self
.
conv
=
conv
elif
name
==
"Conv2d_0"
:
elif
name
==
"Conv2d_0"
:
self
.
Conv2d_0
=
conv
self
.
Conv2d_0
=
conv
self
.
conv
=
conv
else
:
else
:
self
.
op
=
conv
self
.
op
=
conv
self
.
conv
=
conv
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
x
.
shape
[
1
]
==
self
.
channels
...
@@ -90,13 +92,16 @@ class Downsample2D(nn.Module):
...
@@ -90,13 +92,16 @@ class Downsample2D(nn.Module):
pad
=
(
0
,
1
,
0
,
1
)
pad
=
(
0
,
1
,
0
,
1
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if
self
.
name
==
"conv"
:
return
self
.
conv
(
x
)
return
self
.
conv
(
x
)
elif
self
.
name
==
"Conv2d_0"
:
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
return
self
.
Conv2d_0
(
x
)
else
:
return
self
.
op
(
x
)
# if self.name == "conv":
# return self.conv(x)
# elif self.name == "Conv2d_0":
# return self.Conv2d_0(x)
# else:
# return self.op(x)
class
Upsample1D
(
nn
.
Module
):
class
Upsample1D
(
nn
.
Module
):
...
@@ -656,9 +661,9 @@ class ResnetBlock(nn.Module):
...
@@ -656,9 +661,9 @@ class ResnetBlock(nn.Module):
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
time_embedding_norm
==
"default"
and
temb_channels
>
0
:
if
time_embedding_norm
==
"default"
and
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
t
ime_
emb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
elif
time_embedding_norm
==
"scale_shift"
and
temb_channels
>
0
:
elif
time_embedding_norm
==
"scale_shift"
and
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
t
ime_
emb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
norm2
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
...
@@ -691,9 +696,9 @@ class ResnetBlock(nn.Module):
...
@@ -691,9 +696,9 @@ class ResnetBlock(nn.Module):
self
.
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
use_nin_shortcut
is
None
else
use_nin_shortcut
self
.
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
use_nin_shortcut
is
None
else
use_nin_shortcut
self
.
nin
_shortcut
=
None
self
.
conv
_shortcut
=
None
if
self
.
use_nin_shortcut
:
if
self
.
use_nin_shortcut
:
self
.
nin
_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
conv
_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
x
...
@@ -715,7 +720,7 @@ class ResnetBlock(nn.Module):
...
@@ -715,7 +720,7 @@ class ResnetBlock(nn.Module):
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
temb
is
not
None
:
if
temb
is
not
None
:
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
temb
=
self
.
t
ime_
emb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
else
:
else
:
temb
=
0
temb
=
0
...
@@ -738,8 +743,8 @@ class ResnetBlock(nn.Module):
...
@@ -738,8 +743,8 @@ class ResnetBlock(nn.Module):
h
=
self
.
norm2
(
h
)
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
nin
_shortcut
is
not
None
:
if
self
.
conv
_shortcut
is
not
None
:
x
=
self
.
nin
_shortcut
(
x
)
x
=
self
.
conv
_shortcut
(
x
)
return
(
x
+
h
)
/
self
.
output_scale_factor
return
(
x
+
h
)
/
self
.
output_scale_factor
...
@@ -750,8 +755,8 @@ class ResnetBlock(nn.Module):
...
@@ -750,8 +755,8 @@ class ResnetBlock(nn.Module):
self
.
conv1
.
weight
.
data
=
resnet
.
conv1
.
weight
.
data
self
.
conv1
.
weight
.
data
=
resnet
.
conv1
.
weight
.
data
self
.
conv1
.
bias
.
data
=
resnet
.
conv1
.
bias
.
data
self
.
conv1
.
bias
.
data
=
resnet
.
conv1
.
bias
.
data
self
.
temb_proj
.
weight
.
data
=
resnet
.
temb_proj
.
weight
.
data
self
.
t
ime_
emb_proj
.
weight
.
data
=
resnet
.
temb_proj
.
weight
.
data
self
.
temb_proj
.
bias
.
data
=
resnet
.
temb_proj
.
bias
.
data
self
.
t
ime_
emb_proj
.
bias
.
data
=
resnet
.
temb_proj
.
bias
.
data
self
.
norm2
.
weight
.
data
=
resnet
.
norm2
.
weight
.
data
self
.
norm2
.
weight
.
data
=
resnet
.
norm2
.
weight
.
data
self
.
norm2
.
bias
.
data
=
resnet
.
norm2
.
bias
.
data
self
.
norm2
.
bias
.
data
=
resnet
.
norm2
.
bias
.
data
...
@@ -760,8 +765,8 @@ class ResnetBlock(nn.Module):
...
@@ -760,8 +765,8 @@ class ResnetBlock(nn.Module):
self
.
conv2
.
bias
.
data
=
resnet
.
conv2
.
bias
.
data
self
.
conv2
.
bias
.
data
=
resnet
.
conv2
.
bias
.
data
if
self
.
use_nin_shortcut
:
if
self
.
use_nin_shortcut
:
self
.
nin
_shortcut
.
weight
.
data
=
resnet
.
nin_shortcut
.
weight
.
data
self
.
conv
_shortcut
.
weight
.
data
=
resnet
.
nin_shortcut
.
weight
.
data
self
.
nin
_shortcut
.
bias
.
data
=
resnet
.
nin_shortcut
.
bias
.
data
self
.
conv
_shortcut
.
bias
.
data
=
resnet
.
nin_shortcut
.
bias
.
data
# TODO(Patrick) - just there to convert the weights; can delete afterward
# TODO(Patrick) - just there to convert the weights; can delete afterward
...
...
src/diffusers/models/unet.py
View file @
e7fe901e
...
@@ -177,7 +177,9 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -177,7 +177,9 @@ class UNetModel(ModelMixin, ConfigMixin):
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
# middle
print
(
"hs"
,
hs
[
-
1
].
abs
().
sum
())
h
=
self
.
mid_new
(
hs
[
-
1
],
temb
)
h
=
self
.
mid_new
(
hs
[
-
1
],
temb
)
print
(
"h"
,
h
.
abs
().
sum
())
# upsampling
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
...
...
src/diffusers/models/unet_new.py
View file @
e7fe901e
...
@@ -29,9 +29,10 @@ def get_down_block(
...
@@ -29,9 +29,10 @@ def get_down_block(
resnet_eps
,
resnet_eps
,
resnet_act_fn
,
resnet_act_fn
,
attn_num_head_channels
,
attn_num_head_channels
,
downsample_padding
=
None
,
):
):
if
down_block_type
==
"UNetResDownBlock2D"
:
if
down_block_type
==
"UNetResDownBlock2D"
:
return
UNetRes
Attn
DownBlock2D
(
return
UNetResDownBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
...
@@ -39,6 +40,7 @@ def get_down_block(
...
@@ -39,6 +40,7 @@ def get_down_block(
add_downsample
=
add_downsample
,
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_act_fn
=
resnet_act_fn
,
downsample_padding
=
downsample_padding
,
)
)
elif
down_block_type
==
"UNetResAttnDownBlock2D"
:
elif
down_block_type
==
"UNetResAttnDownBlock2D"
:
return
UNetResAttnDownBlock2D
(
return
UNetResAttnDownBlock2D
(
...
@@ -57,7 +59,8 @@ def get_up_block(
...
@@ -57,7 +59,8 @@ def get_up_block(
up_block_type
,
up_block_type
,
num_layers
,
num_layers
,
in_channels
,
in_channels
,
next_channels
,
out_channels
,
prev_output_channel
,
temb_channels
,
temb_channels
,
add_upsample
,
add_upsample
,
resnet_eps
,
resnet_eps
,
...
@@ -68,7 +71,8 @@ def get_up_block(
...
@@ -68,7 +71,8 @@ def get_up_block(
return
UNetResUpBlock2D
(
return
UNetResUpBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
next_channels
=
next_channels
,
out_channels
=
out_channels
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
...
@@ -78,7 +82,8 @@ def get_up_block(
...
@@ -78,7 +82,8 @@ def get_up_block(
return
UNetResAttnUpBlock2D
(
return
UNetResAttnUpBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
next_channels
=
next_channels
,
out_channels
=
out_channels
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
...
@@ -100,11 +105,14 @@ class UNetMidBlock2D(nn.Module):
...
@@ -100,11 +105,14 @@ class UNetMidBlock2D(nn.Module):
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
**
kwargs
,
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
attention_type
=
attention_type
# there is always at least one resnet
# there is always at least one resnet
resnets
=
[
resnets
=
[
ResnetBlock
(
ResnetBlock
(
...
@@ -128,6 +136,7 @@ class UNetMidBlock2D(nn.Module):
...
@@ -128,6 +136,7 @@ class UNetMidBlock2D(nn.Module):
in_channels
,
in_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
eps
=
resnet_eps
,
)
)
)
)
resnets
.
append
(
resnets
.
append
(
...
@@ -148,17 +157,14 @@ class UNetMidBlock2D(nn.Module):
...
@@ -148,17 +157,14 @@ class UNetMidBlock2D(nn.Module):
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
,
mask
=
None
):
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
):
if
mask
is
not
None
:
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
,
mask
=
mask
)
else
:
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
hidden_states
=
attn
(
hidden_states
,
encoder_states
)
if
self
.
attention_type
==
"default"
:
if
mask
is
not
None
:
hidden_states
=
attn
(
hidden_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
,
mask
=
mask
)
else
:
else
:
hidden_states
=
attn
(
hidden_states
,
encoder_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
return
hidden_states
return
hidden_states
...
@@ -178,6 +184,7 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -178,6 +184,7 @@ class UNetResAttnDownBlock2D(nn.Module):
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
add_downsample
=
True
,
):
):
...
@@ -185,6 +192,8 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -185,6 +192,8 @@ class UNetResAttnDownBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
...
@@ -206,6 +215,7 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -206,6 +215,7 @@ class UNetResAttnDownBlock2D(nn.Module):
out_channels
,
out_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
eps
=
resnet_eps
,
)
)
)
)
...
@@ -251,6 +261,7 @@ class UNetResDownBlock2D(nn.Module):
...
@@ -251,6 +261,7 @@ class UNetResDownBlock2D(nn.Module):
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
add_downsample
=
True
,
downsample_padding
=
1
,
):
):
super
().
__init__
()
super
().
__init__
()
resnets
=
[]
resnets
=
[]
...
@@ -276,7 +287,11 @@ class UNetResDownBlock2D(nn.Module):
...
@@ -276,7 +287,11 @@ class UNetResDownBlock2D(nn.Module):
if
add_downsample
:
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
(
self
.
downsamplers
=
nn
.
ModuleList
(
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
1
,
name
=
"op"
)]
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
name
=
"op"
)
]
)
)
else
:
else
:
self
.
downsamplers
=
None
self
.
downsamplers
=
None
...
@@ -301,7 +316,8 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -301,7 +316,8 @@ class UNetResAttnUpBlock2D(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
:
int
,
in_channels
:
int
,
next_channels
:
int
,
prev_output_channel
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
num_layers
:
int
=
1
,
...
@@ -310,7 +326,7 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -310,7 +326,7 @@ class UNetResAttnUpBlock2D(nn.Module):
resnet_act_fn
:
str
=
"swish"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attention_
layer_type
:
str
=
"self
"
,
attention_
type
=
"default
"
,
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
add_upsample
=
True
,
...
@@ -319,12 +335,16 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -319,12 +335,16 @@ class UNetResAttnUpBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
resnet_channels
=
in_channels
if
i
<
num_layers
-
1
else
next_channels
res_skip_channels
=
in_channels
if
(
i
==
num_layers
-
1
)
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
(
in_channels
=
in_channels
+
res
net
_channels
,
in_channels
=
resnet_
in_channels
+
res
_skip
_channels
,
out_channels
=
in
_channels
,
out_channels
=
out
_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
groups
=
resnet_groups
,
...
@@ -337,9 +357,10 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -337,9 +357,10 @@ class UNetResAttnUpBlock2D(nn.Module):
)
)
attentions
.
append
(
attentions
.
append
(
AttentionBlockNew
(
AttentionBlockNew
(
in
_channels
,
out
_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
eps
=
resnet_eps
,
)
)
)
)
...
@@ -347,7 +368,7 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -347,7 +368,7 @@ class UNetResAttnUpBlock2D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_upsample
:
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
in
_channels
,
use_conv
=
True
,
out_channels
=
in
_channels
)])
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
out
_channels
,
use_conv
=
True
,
out_channels
=
out
_channels
)])
else
:
else
:
self
.
upsamplers
=
None
self
.
upsamplers
=
None
...
@@ -373,7 +394,8 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -373,7 +394,8 @@ class UNetResUpBlock2D(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
:
int
,
in_channels
:
int
,
next_channels
:
int
,
prev_output_channel
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
num_layers
:
int
=
1
,
...
@@ -382,7 +404,6 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -382,7 +404,6 @@ class UNetResUpBlock2D(nn.Module):
resnet_act_fn
:
str
=
"swish"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attention_layer_type
:
str
=
"self"
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
add_upsample
=
True
,
):
):
...
@@ -390,11 +411,13 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -390,11 +411,13 @@ class UNetResUpBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
resnet_channels
=
in_channels
if
i
<
num_layers
-
1
else
next_channels
res_skip_channels
=
in_channels
if
(
i
==
num_layers
-
1
)
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
(
in_channels
=
in_channels
+
res
net
_channels
,
in_channels
=
resnet_
in_channels
+
res
_skip
_channels
,
out_channels
=
in
_channels
,
out_channels
=
out
_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
groups
=
resnet_groups
,
...
@@ -409,7 +432,7 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -409,7 +432,7 @@ class UNetResUpBlock2D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_upsample
:
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
in
_channels
,
use_conv
=
True
,
out_channels
=
in
_channels
)])
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
out
_channels
,
use_conv
=
True
,
out_channels
=
out
_channels
)])
else
:
else
:
self
.
upsamplers
=
None
self
.
upsamplers
=
None
...
...
src/diffusers/models/unet_unconditional.py
View file @
e7fe901e
This diff is collapsed.
Click to expand it.
tests/test_modeling_utils.py
View file @
e7fe901e
...
@@ -271,6 +271,27 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -271,6 +271,27 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
print
(
"Original success!!!"
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
ddpm
=
True
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
GlideSuperResUNetModel
model_class
=
GlideSuperResUNetModel
...
@@ -486,18 +507,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -486,18 +507,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"out_channels"
:
4
,
"out_channels"
:
4
,
"num_res_blocks"
:
2
,
"num_res_blocks"
:
2
,
"attention_resolutions"
:
(
16
,),
"attention_resolutions"
:
(
16
,),
"block_input_channels"
:
[
32
,
32
],
"block_channels"
:
(
32
,
64
),
"block_output_channels"
:
[
32
,
64
],
"num_head_channels"
:
32
,
"num_head_channels"
:
32
,
"conv_resample"
:
True
,
"conv_resample"
:
True
,
"down_blocks"
:
(
"UNetResDownBlock2D"
,
"UNetResDownBlock2D"
),
"down_blocks"
:
(
"UNetResDownBlock2D"
,
"UNetResDownBlock2D"
),
"up_blocks"
:
(
"UNetResUpBlock2D"
,
"UNetResUpBlock2D"
),
"up_blocks"
:
(
"UNetResUpBlock2D"
,
"UNetResUpBlock2D"
),
"ldm"
:
True
,
}
}
inputs_dict
=
self
.
dummy_input
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
)
model
,
loading_info
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
,
ldm
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
...
@@ -507,7 +530,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -507,7 +530,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
def
test_output_pretrained
(
self
):
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
ldm
=
True
)
model
.
eval
()
model
.
eval
()
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment