Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b344c953
Unverified
Commit
b344c953
authored
Aug 10, 2022
by
Suraj Patil
Committed by
GitHub
Aug 10, 2022
Browse files
add attention up/down blocks for VAE (#161)
parent
dd10da76
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
140 additions
and
0 deletions
+140
-0
src/diffusers/models/unet_blocks.py
src/diffusers/models/unet_blocks.py
+140
-0
No files found.
src/diffusers/models/unet_blocks.py
View file @
b344c953
...
...
@@ -640,6 +640,79 @@ class DownEncoderBlock2D(nn.Module):
return
hidden_states
class
AttnDownEncoderBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
downsample_padding
=
1
,
):
super
().
__init__
()
resnets
=
[]
attentions
=
[]
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
None
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
attentions
.
append
(
AttentionBlockNew
(
out_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
eps
=
resnet_eps
,
num_groups
=
resnet_groups
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
(
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
name
=
"op"
)
]
)
else
:
self
.
downsamplers
=
None
def
forward
(
self
,
hidden_states
):
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
hidden_states
=
resnet
(
hidden_states
,
temb
=
None
)
hidden_states
=
attn
(
hidden_states
)
if
self
.
downsamplers
is
not
None
:
for
downsampler
in
self
.
downsamplers
:
hidden_states
=
downsampler
(
hidden_states
)
return
hidden_states
class
AttnSkipDownBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -1087,6 +1160,73 @@ class UpDecoderBlock2D(nn.Module):
return
hidden_states
class
AttnUpDecoderBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
):
super
().
__init__
()
resnets
=
[]
attentions
=
[]
for
i
in
range
(
num_layers
):
input_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
in_channels
=
input_channels
,
out_channels
=
out_channels
,
temb_channels
=
None
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
attentions
.
append
(
AttentionBlockNew
(
out_channels
,
num_head_channels
=
attn_num_head_channels
,
rescale_output_factor
=
output_scale_factor
,
eps
=
resnet_eps
,
num_groups
=
resnet_groups
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
out_channels
,
use_conv
=
True
,
out_channels
=
out_channels
)])
else
:
self
.
upsamplers
=
None
def
forward
(
self
,
hidden_states
):
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
hidden_states
=
resnet
(
hidden_states
,
temb
=
None
)
hidden_states
=
attn
(
hidden_states
)
if
self
.
upsamplers
is
not
None
:
for
upsampler
in
self
.
upsamplers
:
hidden_states
=
upsampler
(
hidden_states
)
return
hidden_states
class
AttnSkipUpBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment