Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
185347e4
"vscode:/vscode.git/clone" did not exist on "86a14cbad46f6f026ffcee7f504ffaca8da33929"
Commit
185347e4
authored
Jun 30, 2022
by
Patrick von Platen
Browse files
up
parent
c1c4dea9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
21 deletions
+18
-21
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+18
-21
No files found.
src/diffusers/models/resnet.py
View file @
185347e4
...
...
@@ -207,9 +207,6 @@ class ResBlock(TimestepBlock):
self
.
updown
=
up
or
down
# if self.updown:
# import ipdb; ipdb.set_trace()
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
...
...
@@ -227,8 +224,10 @@ class ResBlock(TimestepBlock):
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
# nn.SiLU() if use_scale_shift_norm else nn.Identity(),
normalization
(
self
.
out_channels
,
swish
=
0.0
),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
...
...
@@ -322,6 +321,7 @@ class ResBlock(TimestepBlock):
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
...
...
@@ -338,35 +338,31 @@ class ResBlock(TimestepBlock):
return
result
def
forward_2
(
self
,
x
,
temb
,
mask
=
1.0
):
def
forward_2
(
self
,
x
,
temb
):
if
self
.
overwrite
and
not
self
.
is_overwritten
:
self
.
set_weights
()
self
.
is_overwritten
=
True
h
=
x
if
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
h
+
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
h
=
self
.
norm2
(
h
)
h
=
h
*
scale
+
shift
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
...
...
@@ -376,7 +372,7 @@ class ResBlock(TimestepBlock):
return
x
+
h
# unet.py
and
unet_grad_tts.py
# unet.py
,
unet_grad_tts.py
, unet_ldm.py
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -410,6 +406,7 @@ class ResnetBlock(nn.Module):
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
nonlinearity
elif
non_linearity
==
"mish"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment