Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
e7fe901e
Unverified
Commit
e7fe901e
authored
Jul 14, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 14, 2022
Browse files
save intermediate (#87)
* save intermediate * up * up
parent
c3d78cd3
Changes
10
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
801 additions
and
173 deletions
+801
-173
conversion.py
conversion.py
+94
-0
docs/source/examples/diffusers_for_vision.mdx
docs/source/examples/diffusers_for_vision.mdx
+2
-1
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+2
-0
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+3
-0
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+223
-5
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+22
-17
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+2
-0
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+50
-27
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+376
-119
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+27
-4
No files found.
conversion.py
0 → 100755
View file @
e7fe901e
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
inspect
import
tempfile
import
unittest
import
numpy
as
np
import
torch
from
diffusers
import
(
AutoencoderKL
,
DDIMPipeline
,
DDIMScheduler
,
DDPMPipeline
,
DDPMScheduler
,
GlidePipeline
,
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
LatentDiffusionPipeline
,
LatentDiffusionUncondPipeline
,
NCSNpp
,
PNDMPipeline
,
PNDMScheduler
,
ScoreSdeVePipeline
,
ScoreSdeVeScheduler
,
ScoreSdeVpPipeline
,
ScoreSdeVpScheduler
,
UNetLDMModel
,
UNetModel
,
UNetUnconditionalModel
,
VQModel
,
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.training_utils
import
EMAModel
def
test_output_pretrained_ldm_dummy
():
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
ldm
=
True
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
print
(
model
)
import
ipdb
;
ipdb
.
set_trace
()
def
test_output_pretrained_ldm
():
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/latent-diffusion-celeba-256"
,
subfolder
=
"unet"
,
ldm
=
True
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
print
(
model
)
import
ipdb
;
ipdb
.
set_trace
()
# To see the how the final model should look like
test_output_pretrained_ldm_dummy
()
test_output_pretrained_ldm
()
# => this is the architecture in which the model should be saved in the new format
# -> verify new repo with the following tests (in `test_modeling_utils.py`)
# - test_ldm_uncond (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetLDMModelTests)
docs/source/examples/diffusers_for_vision.mdx
View file @
e7fe901e
...
@@ -111,7 +111,7 @@ prompt = "A painting of a squirrel eating a burger"
...
@@ -111,7 +111,7 @@ prompt = "A painting of a squirrel eating a burger"
image
=
ldm
([
prompt
],
generator
=
generator
,
eta
=
0.3
,
guidance_scale
=
6.0
,
num_inference_steps
=
50
)
image
=
ldm
([
prompt
],
generator
=
generator
,
eta
=
0.3
,
guidance_scale
=
6.0
,
num_inference_steps
=
50
)
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
image_processed
*
255.
image_processed
=
image_processed
*
255.
0
image_processed
=
image_processed
.
numpy
().
astype
(
np
.
uint8
)
image_processed
=
image_processed
.
numpy
().
astype
(
np
.
uint8
)
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
...
@@ -143,6 +143,7 @@ audio = bddm(mel_spec, generator, torch_device=torch_device)
...
@@ -143,6 +143,7 @@ audio = bddm(mel_spec, generator, torch_device=torch_device)
#
save
generated
audio
#
save
generated
audio
from
scipy
.
io
.
wavfile
import
write
as
wavwrite
from
scipy
.
io
.
wavfile
import
write
as
wavwrite
sampling_rate
=
22050
sampling_rate
=
22050
wavwrite
(
"generated_audio.wav"
,
sampling_rate
,
audio
.
squeeze
().
cpu
().
numpy
())
wavwrite
(
"generated_audio.wav"
,
sampling_rate
,
audio
.
squeeze
().
cpu
().
numpy
())
```
```
...
...
src/diffusers/configuration_utils.py
View file @
e7fe901e
...
@@ -116,6 +116,7 @@ class ConfigMixin:
...
@@ -116,6 +116,7 @@ class ConfigMixin:
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
subfolder
=
kwargs
.
pop
(
"subfolder"
,
None
)
user_agent
=
{
"file_type"
:
"config"
}
user_agent
=
{
"file_type"
:
"config"
}
...
@@ -150,6 +151,7 @@ class ConfigMixin:
...
@@ -150,6 +151,7 @@ class ConfigMixin:
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
user_agent
=
user_agent
,
subfolder
=
subfolder
,
)
)
except
RepositoryNotFoundError
:
except
RepositoryNotFoundError
:
...
...
src/diffusers/modeling_utils.py
View file @
e7fe901e
...
@@ -321,6 +321,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -321,6 +321,7 @@ class ModelMixin(torch.nn.Module):
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
from_auto_class
=
kwargs
.
pop
(
"_from_auto"
,
False
)
from_auto_class
=
kwargs
.
pop
(
"_from_auto"
,
False
)
subfolder
=
kwargs
.
pop
(
"subfolder"
,
None
)
user_agent
=
{
"file_type"
:
"model"
,
"framework"
:
"pytorch"
,
"from_auto_class"
:
from_auto_class
}
user_agent
=
{
"file_type"
:
"model"
,
"framework"
:
"pytorch"
,
"from_auto_class"
:
from_auto_class
}
...
@@ -336,6 +337,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -336,6 +337,7 @@ class ModelMixin(torch.nn.Module):
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
revision
=
revision
,
subfolder
=
subfolder
,
**
kwargs
,
**
kwargs
,
)
)
model
.
register_to_config
(
name_or_path
=
pretrained_model_name_or_path
)
model
.
register_to_config
(
name_or_path
=
pretrained_model_name_or_path
)
...
@@ -363,6 +365,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -363,6 +365,7 @@ class ModelMixin(torch.nn.Module):
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
user_agent
=
user_agent
,
subfolder
=
subfolder
,
)
)
except
RepositoryNotFoundError
:
except
RepositoryNotFoundError
:
...
...
src/diffusers/models/attention.py
View file @
e7fe901e
...
@@ -51,6 +51,7 @@ class AttentionBlock(nn.Module):
...
@@ -51,6 +51,7 @@ class AttentionBlock(nn.Module):
overwrite_qkv
=
False
,
overwrite_qkv
=
False
,
overwrite_linear
=
False
,
overwrite_linear
=
False
,
rescale_output_factor
=
1.0
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -62,7 +63,7 @@ class AttentionBlock(nn.Module):
...
@@ -62,7 +63,7 @@ class AttentionBlock(nn.Module):
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
num_heads
=
channels
//
num_head_channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-5
,
affine
=
True
)
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
self
.
num_heads
self
.
n_heads
=
self
.
num_heads
self
.
rescale_output_factor
=
rescale_output_factor
self
.
rescale_output_factor
=
rescale_output_factor
...
@@ -165,7 +166,7 @@ class AttentionBlock(nn.Module):
...
@@ -165,7 +166,7 @@ class AttentionBlock(nn.Module):
return
result
return
result
class
AttentionBlockNew
(
nn
.
Module
):
class
AttentionBlockNew
_2
(
nn
.
Module
):
"""
"""
An attention block that allows spatial positions to attend to each other.
An attention block that allows spatial positions to attend to each other.
...
@@ -180,11 +181,14 @@ class AttentionBlockNew(nn.Module):
...
@@ -180,11 +181,14 @@ class AttentionBlockNew(nn.Module):
num_groups
=
32
,
num_groups
=
32
,
encoder_channels
=
None
,
encoder_channels
=
None
,
rescale_output_factor
=
1.0
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-5
,
affine
=
True
)
self
.
channels
=
channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
channels
//
num_head_channels
self
.
n_heads
=
channels
//
num_head_channels
self
.
num_head_size
=
num_head_channels
self
.
rescale_output_factor
=
rescale_output_factor
self
.
rescale_output_factor
=
rescale_output_factor
if
encoder_channels
is
not
None
:
if
encoder_channels
is
not
None
:
...
@@ -192,6 +196,28 @@ class AttentionBlockNew(nn.Module):
...
@@ -192,6 +196,28 @@ class AttentionBlockNew(nn.Module):
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
# ------------------------- new -----------------------
num_heads
=
self
.
n_heads
self
.
channels
=
channels
if
num_head_channels
is
None
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
group_norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
# define q,k,v as linear layers
self
.
query
=
nn
.
Linear
(
channels
,
channels
)
self
.
key
=
nn
.
Linear
(
channels
,
channels
)
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
))
# ------------------------- new -----------------------
def
set_weight
(
self
,
attn_layer
):
def
set_weight
(
self
,
attn_layer
):
self
.
norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
self
.
norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
...
@@ -202,6 +228,89 @@ class AttentionBlockNew(nn.Module):
...
@@ -202,6 +228,89 @@ class AttentionBlockNew(nn.Module):
self
.
proj
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
self
.
proj
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
self
.
proj
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
self
.
proj
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
if
hasattr
(
attn_layer
,
"q"
):
module
=
attn_layer
qkv_weight
=
torch
.
cat
([
module
.
q
.
weight
.
data
,
module
.
k
.
weight
.
data
,
module
.
v
.
weight
.
data
],
dim
=
0
)[
:,
:,
:,
0
]
qkv_bias
=
torch
.
cat
([
module
.
q
.
bias
.
data
,
module
.
k
.
bias
.
data
,
module
.
v
.
bias
.
data
],
dim
=
0
)
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
bias
.
data
=
qkv_bias
proj_out
=
zero_module
(
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
))
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
self
.
proj
=
proj_out
self
.
set_weights_2
(
attn_layer
)
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
n_heads
,
self
.
num_head_size
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
def
set_weights_2
(
self
,
attn_layer
):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
n_heads
,
3
*
self
.
channels
//
self
.
n_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
n_heads
,
3
*
self
.
channels
//
self
.
n_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
n_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
n_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
def
forward_2
(
self
,
hidden_states
):
residual
=
hidden_states
batch
,
channel
,
height
,
width
=
hidden_states
.
shape
# norm
hidden_states
=
self
.
group_norm
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
batch
,
channel
,
height
*
width
).
transpose
(
1
,
2
)
# proj to q, k, v
query_proj
=
self
.
query
(
hidden_states
)
key_proj
=
self
.
key
(
hidden_states
)
value_proj
=
self
.
value
(
hidden_states
)
# transpose
query_states
=
self
.
transpose_for_scores
(
query_proj
)
key_states
=
self
.
transpose_for_scores
(
key_proj
)
value_states
=
self
.
transpose_for_scores
(
value_proj
)
# get scores
attention_scores
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
math
.
sqrt
(
self
.
channels
//
self
.
n_heads
)
attention_probs
=
nn
.
functional
.
softmax
(
attention_scores
,
dim
=-
1
)
# compute attention output
context_states
=
torch
.
matmul
(
attention_probs
,
value_states
)
context_states
=
context_states
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_states_shape
=
context_states
.
size
()[:
-
2
]
+
(
self
.
channels
,)
context_states
=
context_states
.
view
(
new_context_states_shape
)
# compute next hidden_states
hidden_states
=
self
.
proj_attn
(
context_states
)
hidden_states
=
hidden_states
.
transpose
(
-
1
,
-
2
).
reshape
(
batch
,
channel
,
height
,
width
)
# res connect and rescale
hidden_states
=
(
hidden_states
+
residual
)
/
self
.
rescale_output_factor
return
hidden_states
def
forward
(
self
,
x
,
encoder_out
=
None
):
def
forward
(
self
,
x
,
encoder_out
=
None
):
b
,
c
,
*
spatial
=
x
.
shape
b
,
c
,
*
spatial
=
x
.
shape
hid_states
=
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
)
hid_states
=
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
)
...
@@ -230,10 +339,119 @@ class AttentionBlockNew(nn.Module):
...
@@ -230,10 +339,119 @@ class AttentionBlockNew(nn.Module):
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
result
=
x
+
h
result
=
x
+
h
result
=
result
/
self
.
rescale_output_factor
result
=
result
/
self
.
rescale_output_factor
return
result
result_2
=
self
.
forward_2
(
x
)
print
((
result
-
result_2
).
abs
().
sum
())
return
result_2
class
AttentionBlockNew
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=
None
,
num_groups
=
32
,
rescale_output_factor
=
1.0
,
eps
=
1e-5
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
is
None
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
num_head_size
=
num_head_channels
self
.
group_norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
eps
,
affine
=
True
)
# define q,k,v as linear layers
self
.
query
=
nn
.
Linear
(
channels
,
channels
)
self
.
key
=
nn
.
Linear
(
channels
,
channels
)
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
))
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
self
.
num_head_size
)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection
=
projection
.
view
(
new_projection_shape
).
permute
(
0
,
2
,
1
,
3
)
return
new_projection
def
forward
(
self
,
hidden_states
):
residual
=
hidden_states
batch
,
channel
,
height
,
width
=
hidden_states
.
shape
# norm
hidden_states
=
self
.
group_norm
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
batch
,
channel
,
height
*
width
).
transpose
(
1
,
2
)
# proj to q, k, v
query_proj
=
self
.
query
(
hidden_states
)
key_proj
=
self
.
key
(
hidden_states
)
value_proj
=
self
.
value
(
hidden_states
)
# transpose
query_states
=
self
.
transpose_for_scores
(
query_proj
)
key_states
=
self
.
transpose_for_scores
(
key_proj
)
value_states
=
self
.
transpose_for_scores
(
value_proj
)
# get scores
attention_scores
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
math
.
sqrt
(
self
.
channels
//
self
.
num_heads
)
attention_probs
=
nn
.
functional
.
softmax
(
attention_scores
,
dim
=-
1
)
# compute attention output
context_states
=
torch
.
matmul
(
attention_probs
,
value_states
)
context_states
=
context_states
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_states_shape
=
context_states
.
size
()[:
-
2
]
+
(
self
.
channels
,)
context_states
=
context_states
.
view
(
new_context_states_shape
)
# compute next hidden_states
hidden_states
=
self
.
proj_attn
(
context_states
)
hidden_states
=
hidden_states
.
transpose
(
-
1
,
-
2
).
reshape
(
batch
,
channel
,
height
,
width
)
# res connect and rescale
hidden_states
=
(
hidden_states
+
residual
)
/
self
.
rescale_output_factor
return
hidden_states
def
set_weight
(
self
,
attn_layer
):
self
.
group_norm
.
weight
.
data
=
attn_layer
.
norm
.
weight
.
data
self
.
group_norm
.
bias
.
data
=
attn_layer
.
norm
.
bias
.
data
qkv_weight
=
attn_layer
.
qkv
.
weight
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
,
self
.
channels
)
qkv_bias
=
attn_layer
.
qkv
.
bias
.
data
.
reshape
(
self
.
num_heads
,
3
*
self
.
channels
//
self
.
num_heads
)
q_w
,
k_w
,
v_w
=
qkv_weight
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
q_b
,
k_b
,
v_b
=
qkv_bias
.
split
(
self
.
channels
//
self
.
num_heads
,
dim
=
1
)
self
.
query
.
weight
.
data
=
q_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
key
.
weight
.
data
=
k_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
value
.
weight
.
data
=
v_w
.
reshape
(
-
1
,
self
.
channels
)
self
.
query
.
bias
.
data
=
q_b
.
reshape
(
-
1
)
self
.
key
.
bias
.
data
=
k_b
.
reshape
(
-
1
)
self
.
value
.
bias
.
data
=
v_b
.
reshape
(
-
1
)
self
.
proj_attn
.
weight
.
data
=
attn_layer
.
proj
.
weight
.
data
[:,
:,
0
]
self
.
proj_attn
.
bias
.
data
=
attn_layer
.
proj
.
bias
.
data
class
SpatialTransformer
(
nn
.
Module
):
class
SpatialTransformer
(
nn
.
Module
):
...
...
src/diffusers/models/resnet.py
View file @
e7fe901e
...
@@ -81,8 +81,10 @@ class Downsample2D(nn.Module):
...
@@ -81,8 +81,10 @@ class Downsample2D(nn.Module):
self
.
conv
=
conv
self
.
conv
=
conv
elif
name
==
"Conv2d_0"
:
elif
name
==
"Conv2d_0"
:
self
.
Conv2d_0
=
conv
self
.
Conv2d_0
=
conv
self
.
conv
=
conv
else
:
else
:
self
.
op
=
conv
self
.
op
=
conv
self
.
conv
=
conv
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
x
.
shape
[
1
]
==
self
.
channels
...
@@ -90,13 +92,16 @@ class Downsample2D(nn.Module):
...
@@ -90,13 +92,16 @@ class Downsample2D(nn.Module):
pad
=
(
0
,
1
,
0
,
1
)
pad
=
(
0
,
1
,
0
,
1
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if
self
.
name
==
"conv"
:
return
self
.
conv
(
x
)
return
self
.
conv
(
x
)
elif
self
.
name
==
"Conv2d_0"
:
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
return
self
.
Conv2d_0
(
x
)
else
:
return
self
.
op
(
x
)
# if self.name == "conv":
# return self.conv(x)
# elif self.name == "Conv2d_0":
# return self.Conv2d_0(x)
# else:
# return self.op(x)
class
Upsample1D
(
nn
.
Module
):
class
Upsample1D
(
nn
.
Module
):
...
@@ -656,9 +661,9 @@ class ResnetBlock(nn.Module):
...
@@ -656,9 +661,9 @@ class ResnetBlock(nn.Module):
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
time_embedding_norm
==
"default"
and
temb_channels
>
0
:
if
time_embedding_norm
==
"default"
and
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
t
ime_
emb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
elif
time_embedding_norm
==
"scale_shift"
and
temb_channels
>
0
:
elif
time_embedding_norm
==
"scale_shift"
and
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
t
ime_
emb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
norm2
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
...
@@ -691,9 +696,9 @@ class ResnetBlock(nn.Module):
...
@@ -691,9 +696,9 @@ class ResnetBlock(nn.Module):
self
.
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
use_nin_shortcut
is
None
else
use_nin_shortcut
self
.
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
use_nin_shortcut
is
None
else
use_nin_shortcut
self
.
nin
_shortcut
=
None
self
.
conv
_shortcut
=
None
if
self
.
use_nin_shortcut
:
if
self
.
use_nin_shortcut
:
self
.
nin
_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
conv
_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
x
...
@@ -715,7 +720,7 @@ class ResnetBlock(nn.Module):
...
@@ -715,7 +720,7 @@ class ResnetBlock(nn.Module):
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
temb
is
not
None
:
if
temb
is
not
None
:
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
temb
=
self
.
t
ime_
emb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
else
:
else
:
temb
=
0
temb
=
0
...
@@ -738,8 +743,8 @@ class ResnetBlock(nn.Module):
...
@@ -738,8 +743,8 @@ class ResnetBlock(nn.Module):
h
=
self
.
norm2
(
h
)
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
nin
_shortcut
is
not
None
:
if
self
.
conv
_shortcut
is
not
None
:
x
=
self
.
nin
_shortcut
(
x
)
x
=
self
.
conv
_shortcut
(
x
)
return
(
x
+
h
)
/
self
.
output_scale_factor
return
(
x
+
h
)
/
self
.
output_scale_factor
...
@@ -750,8 +755,8 @@ class ResnetBlock(nn.Module):
...
@@ -750,8 +755,8 @@ class ResnetBlock(nn.Module):
self
.
conv1
.
weight
.
data
=
resnet
.
conv1
.
weight
.
data
self
.
conv1
.
weight
.
data
=
resnet
.
conv1
.
weight
.
data
self
.
conv1
.
bias
.
data
=
resnet
.
conv1
.
bias
.
data
self
.
conv1
.
bias
.
data
=
resnet
.
conv1
.
bias
.
data
self
.
temb_proj
.
weight
.
data
=
resnet
.
temb_proj
.
weight
.
data
self
.
t
ime_
emb_proj
.
weight
.
data
=
resnet
.
temb_proj
.
weight
.
data
self
.
temb_proj
.
bias
.
data
=
resnet
.
temb_proj
.
bias
.
data
self
.
t
ime_
emb_proj
.
bias
.
data
=
resnet
.
temb_proj
.
bias
.
data
self
.
norm2
.
weight
.
data
=
resnet
.
norm2
.
weight
.
data
self
.
norm2
.
weight
.
data
=
resnet
.
norm2
.
weight
.
data
self
.
norm2
.
bias
.
data
=
resnet
.
norm2
.
bias
.
data
self
.
norm2
.
bias
.
data
=
resnet
.
norm2
.
bias
.
data
...
@@ -760,8 +765,8 @@ class ResnetBlock(nn.Module):
...
@@ -760,8 +765,8 @@ class ResnetBlock(nn.Module):
self
.
conv2
.
bias
.
data
=
resnet
.
conv2
.
bias
.
data
self
.
conv2
.
bias
.
data
=
resnet
.
conv2
.
bias
.
data
if
self
.
use_nin_shortcut
:
if
self
.
use_nin_shortcut
:
self
.
nin
_shortcut
.
weight
.
data
=
resnet
.
nin_shortcut
.
weight
.
data
self
.
conv
_shortcut
.
weight
.
data
=
resnet
.
nin_shortcut
.
weight
.
data
self
.
nin
_shortcut
.
bias
.
data
=
resnet
.
nin_shortcut
.
bias
.
data
self
.
conv
_shortcut
.
bias
.
data
=
resnet
.
nin_shortcut
.
bias
.
data
# TODO(Patrick) - just there to convert the weights; can delete afterward
# TODO(Patrick) - just there to convert the weights; can delete afterward
...
...
src/diffusers/models/unet.py
View file @
e7fe901e
...
@@ -177,7 +177,9 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -177,7 +177,9 @@ class UNetModel(ModelMixin, ConfigMixin):
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
# middle
print
(
"hs"
,
hs
[
-
1
].
abs
().
sum
())
h
=
self
.
mid_new
(
hs
[
-
1
],
temb
)
h
=
self
.
mid_new
(
hs
[
-
1
],
temb
)
print
(
"h"
,
h
.
abs
().
sum
())
# upsampling
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
...
...
src/diffusers/models/unet_new.py
View file @
e7fe901e
...
@@ -29,9 +29,10 @@ def get_down_block(
...
@@ -29,9 +29,10 @@ def get_down_block(
resnet_eps
,
resnet_eps
,
resnet_act_fn
,
resnet_act_fn
,
attn_num_head_channels
,
attn_num_head_channels
,
downsample_padding
=
None
,
):
):
if
down_block_type
==
"UNetResDownBlock2D"
:
if
down_block_type
==
"UNetResDownBlock2D"
:
return
UNetRes
Attn
DownBlock2D
(
return
UNetResDownBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
...
@@ -39,6 +40,7 @@ def get_down_block(
...
@@ -39,6 +40,7 @@ def get_down_block(
add_downsample
=
add_downsample
,
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_act_fn
=
resnet_act_fn
,
downsample_padding
=
downsample_padding
,
)
)
elif
down_block_type
==
"UNetResAttnDownBlock2D"
:
elif
down_block_type
==
"UNetResAttnDownBlock2D"
:
return
UNetResAttnDownBlock2D
(
return
UNetResAttnDownBlock2D
(
...
@@ -57,7 +59,8 @@ def get_up_block(
...
@@ -57,7 +59,8 @@ def get_up_block(
up_block_type
,
up_block_type
,
num_layers
,
num_layers
,
in_channels
,
in_channels
,
next_channels
,
out_channels
,
prev_output_channel
,
temb_channels
,
temb_channels
,
add_upsample
,
add_upsample
,
resnet_eps
,
resnet_eps
,
...
@@ -68,7 +71,8 @@ def get_up_block(
...
@@ -68,7 +71,8 @@ def get_up_block(
return
UNetResUpBlock2D
(
return
UNetResUpBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
next_channels
=
next_channels
,
out_channels
=
out_channels
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
...
@@ -78,7 +82,8 @@ def get_up_block(
...
@@ -78,7 +82,8 @@ def get_up_block(
return
UNetResAttnUpBlock2D
(
return
UNetResAttnUpBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
next_channels
=
next_channels
,
out_channels
=
out_channels
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
...
@@ -100,11 +105,14 @@ class UNetMidBlock2D(nn.Module):
...
@@ -100,11 +105,14 @@ class UNetMidBlock2D(nn.Module):
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
**
kwargs
,
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
attention_type
=
attention_type
# there is always at least one resnet
# there is always at least one resnet
resnets
=
[
resnets
=
[
ResnetBlock
(
ResnetBlock
(
...
@@ -128,6 +136,7 @@ class UNetMidBlock2D(nn.Module):
...
@@ -128,6 +136,7 @@ class UNetMidBlock2D(nn.Module):
in_channels
,
in_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
eps
=
resnet_eps
,
)
)
)
)
resnets
.
append
(
resnets
.
append
(
...
@@ -148,17 +157,14 @@ class UNetMidBlock2D(nn.Module):
...
@@ -148,17 +157,14 @@ class UNetMidBlock2D(nn.Module):
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
,
mask
=
None
):
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_states
=
None
):
if
mask
is
not
None
:
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
,
mask
=
mask
)
else
:
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
hidden_states
=
attn
(
hidden_states
,
encoder_states
)
if
self
.
attention_type
==
"default"
:
if
mask
is
not
None
:
hidden_states
=
attn
(
hidden_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
,
mask
=
mask
)
else
:
else
:
hidden_states
=
attn
(
hidden_states
,
encoder_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
return
hidden_states
return
hidden_states
...
@@ -178,6 +184,7 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -178,6 +184,7 @@ class UNetResAttnDownBlock2D(nn.Module):
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
add_downsample
=
True
,
):
):
...
@@ -185,6 +192,8 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -185,6 +192,8 @@ class UNetResAttnDownBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
...
@@ -206,6 +215,7 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -206,6 +215,7 @@ class UNetResAttnDownBlock2D(nn.Module):
out_channels
,
out_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
eps
=
resnet_eps
,
)
)
)
)
...
@@ -251,6 +261,7 @@ class UNetResDownBlock2D(nn.Module):
...
@@ -251,6 +261,7 @@ class UNetResDownBlock2D(nn.Module):
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
add_downsample
=
True
,
downsample_padding
=
1
,
):
):
super
().
__init__
()
super
().
__init__
()
resnets
=
[]
resnets
=
[]
...
@@ -276,7 +287,11 @@ class UNetResDownBlock2D(nn.Module):
...
@@ -276,7 +287,11 @@ class UNetResDownBlock2D(nn.Module):
if
add_downsample
:
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
(
self
.
downsamplers
=
nn
.
ModuleList
(
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
1
,
name
=
"op"
)]
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
name
=
"op"
)
]
)
)
else
:
else
:
self
.
downsamplers
=
None
self
.
downsamplers
=
None
...
@@ -301,7 +316,8 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -301,7 +316,8 @@ class UNetResAttnUpBlock2D(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
:
int
,
in_channels
:
int
,
next_channels
:
int
,
prev_output_channel
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
num_layers
:
int
=
1
,
...
@@ -310,7 +326,7 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -310,7 +326,7 @@ class UNetResAttnUpBlock2D(nn.Module):
resnet_act_fn
:
str
=
"swish"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attention_
layer_type
:
str
=
"self
"
,
attention_
type
=
"default
"
,
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
add_upsample
=
True
,
...
@@ -319,12 +335,16 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -319,12 +335,16 @@ class UNetResAttnUpBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
attentions
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
resnet_channels
=
in_channels
if
i
<
num_layers
-
1
else
next_channels
res_skip_channels
=
in_channels
if
(
i
==
num_layers
-
1
)
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
(
in_channels
=
in_channels
+
res
net
_channels
,
in_channels
=
resnet_
in_channels
+
res
_skip
_channels
,
out_channels
=
in
_channels
,
out_channels
=
out
_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
groups
=
resnet_groups
,
...
@@ -337,9 +357,10 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -337,9 +357,10 @@ class UNetResAttnUpBlock2D(nn.Module):
)
)
attentions
.
append
(
attentions
.
append
(
AttentionBlockNew
(
AttentionBlockNew
(
in
_channels
,
out
_channels
,
num_head_channels
=
attn_num_head_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
rescale_output_factor
=
output_scale_factor
,
eps
=
resnet_eps
,
)
)
)
)
...
@@ -347,7 +368,7 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -347,7 +368,7 @@ class UNetResAttnUpBlock2D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_upsample
:
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
in
_channels
,
use_conv
=
True
,
out_channels
=
in
_channels
)])
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
out
_channels
,
use_conv
=
True
,
out_channels
=
out
_channels
)])
else
:
else
:
self
.
upsamplers
=
None
self
.
upsamplers
=
None
...
@@ -373,7 +394,8 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -373,7 +394,8 @@ class UNetResUpBlock2D(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
:
int
,
in_channels
:
int
,
next_channels
:
int
,
prev_output_channel
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
num_layers
:
int
=
1
,
...
@@ -382,7 +404,6 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -382,7 +404,6 @@ class UNetResUpBlock2D(nn.Module):
resnet_act_fn
:
str
=
"swish"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
resnet_pre_norm
:
bool
=
True
,
attention_layer_type
:
str
=
"self"
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
add_upsample
=
True
,
):
):
...
@@ -390,11 +411,13 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -390,11 +411,13 @@ class UNetResUpBlock2D(nn.Module):
resnets
=
[]
resnets
=
[]
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
resnet_channels
=
in_channels
if
i
<
num_layers
-
1
else
next_channels
res_skip_channels
=
in_channels
if
(
i
==
num_layers
-
1
)
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
resnets
.
append
(
ResnetBlock
(
ResnetBlock
(
in_channels
=
in_channels
+
res
net
_channels
,
in_channels
=
resnet_
in_channels
+
res
_skip
_channels
,
out_channels
=
in
_channels
,
out_channels
=
out
_channels
,
temb_channels
=
temb_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
groups
=
resnet_groups
,
...
@@ -409,7 +432,7 @@ class UNetResUpBlock2D(nn.Module):
...
@@ -409,7 +432,7 @@ class UNetResUpBlock2D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_upsample
:
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
in
_channels
,
use_conv
=
True
,
out_channels
=
in
_channels
)])
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
out
_channels
,
use_conv
=
True
,
out_channels
=
out
_channels
)])
else
:
else
:
self
.
upsamplers
=
None
self
.
upsamplers
=
None
...
...
src/diffusers/models/unet_unconditional.py
View file @
e7fe901e
This diff is collapsed.
Click to expand it.
tests/test_modeling_utils.py
View file @
e7fe901e
...
@@ -271,6 +271,27 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -271,6 +271,27 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
print
(
"Original success!!!"
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
ddpm
=
True
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
GlideSuperResUNetModel
model_class
=
GlideSuperResUNetModel
...
@@ -486,18 +507,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -486,18 +507,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"out_channels"
:
4
,
"out_channels"
:
4
,
"num_res_blocks"
:
2
,
"num_res_blocks"
:
2
,
"attention_resolutions"
:
(
16
,),
"attention_resolutions"
:
(
16
,),
"block_input_channels"
:
[
32
,
32
],
"block_channels"
:
(
32
,
64
),
"block_output_channels"
:
[
32
,
64
],
"num_head_channels"
:
32
,
"num_head_channels"
:
32
,
"conv_resample"
:
True
,
"conv_resample"
:
True
,
"down_blocks"
:
(
"UNetResDownBlock2D"
,
"UNetResDownBlock2D"
),
"down_blocks"
:
(
"UNetResDownBlock2D"
,
"UNetResDownBlock2D"
),
"up_blocks"
:
(
"UNetResUpBlock2D"
,
"UNetResUpBlock2D"
),
"up_blocks"
:
(
"UNetResUpBlock2D"
,
"UNetResUpBlock2D"
),
"ldm"
:
True
,
}
}
inputs_dict
=
self
.
dummy_input
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
)
model
,
loading_info
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
,
ldm
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
...
@@ -507,7 +530,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -507,7 +530,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
def
test_output_pretrained
(
self
):
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
ldm
=
True
)
model
.
eval
()
model
.
eval
()
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment