Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c9bd4d43
Commit
c9bd4d43
authored
Jun 30, 2022
by
patil-suraj
Browse files
remove if fir from resent block and upsample, downsample for sde unet
parent
7e0fd19f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
58 deletions
+30
-58
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+4
-12
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+26
-46
No files found.
src/diffusers/models/resnet.py
View file @
c9bd4d43
...
@@ -614,19 +614,11 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -614,19 +614,11 @@ class ResnetBlockBigGANpp(nn.Module):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
if
self
.
up
:
if
self
.
fir
:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_upsample_2d
(
h
,
factor
=
2
)
x
=
naive_upsample_2d
(
x
,
factor
=
2
)
elif
self
.
down
:
elif
self
.
down
:
if
self
.
fir
:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_downsample_2d
(
h
,
factor
=
2
)
x
=
naive_downsample_2d
(
x
,
factor
=
2
)
h
=
self
.
Conv_0
(
h
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
# Add bias to each feature map conditioned on the time embedding
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
c9bd4d43
...
@@ -417,20 +417,16 @@ class Upsample(nn.Module):
...
@@ -417,20 +417,16 @@ class Upsample(nn.Module):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
out_ch
=
out_ch
if
out_ch
else
in_ch
if
not
fir
:
if
with_conv
:
if
with_conv
:
self
.
Conv2d_0
=
Conv2d
(
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
in_ch
,
else
:
out_ch
,
if
with_conv
:
kernel
=
3
,
self
.
Conv2d_0
=
Conv2d
(
up
=
True
,
in_ch
,
resample_kernel
=
fir_kernel
,
out_ch
,
use_bias
=
True
,
kernel
=
3
,
kernel_init
=
default_init
(),
up
=
True
,
)
resample_kernel
=
fir_kernel
,
use_bias
=
True
,
kernel_init
=
default_init
(),
)
self
.
fir
=
fir
self
.
fir
=
fir
self
.
with_conv
=
with_conv
self
.
with_conv
=
with_conv
self
.
fir_kernel
=
fir_kernel
self
.
fir_kernel
=
fir_kernel
...
@@ -438,15 +434,10 @@ class Upsample(nn.Module):
...
@@ -438,15 +434,10 @@ class Upsample(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
B
,
C
,
H
,
W
=
x
.
shape
if
not
self
.
fir
:
if
not
self
.
with_conv
:
h
=
F
.
interpolate
(
x
,
(
H
*
2
,
W
*
2
),
"nearest"
)
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
if
self
.
with_conv
:
h
=
self
.
Conv_0
(
h
)
else
:
else
:
if
not
self
.
with_conv
:
h
=
self
.
Conv2d_0
(
x
)
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
self
.
Conv2d_0
(
x
)
return
h
return
h
...
@@ -455,20 +446,16 @@ class Downsample(nn.Module):
...
@@ -455,20 +446,16 @@ class Downsample(nn.Module):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
out_ch
=
out_ch
if
out_ch
else
in_ch
if
not
fir
:
if
with_conv
:
if
with_conv
:
self
.
Conv2d_0
=
Conv2d
(
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
,
stride
=
2
,
padding
=
0
)
in_ch
,
else
:
out_ch
,
if
with_conv
:
kernel
=
3
,
self
.
Conv2d_0
=
Conv2d
(
down
=
True
,
in_ch
,
resample_kernel
=
fir_kernel
,
out_ch
,
use_bias
=
True
,
kernel
=
3
,
kernel_init
=
default_init
(),
down
=
True
,
)
resample_kernel
=
fir_kernel
,
use_bias
=
True
,
kernel_init
=
default_init
(),
)
self
.
fir
=
fir
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
fir_kernel
=
fir_kernel
self
.
with_conv
=
with_conv
self
.
with_conv
=
with_conv
...
@@ -476,17 +463,10 @@ class Downsample(nn.Module):
...
@@ -476,17 +463,10 @@ class Downsample(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
B
,
C
,
H
,
W
=
x
.
shape
if
not
self
.
fir
:
if
not
self
.
with_conv
:
if
self
.
with_conv
:
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
x
=
F
.
pad
(
x
,
(
0
,
1
,
0
,
1
))
x
=
self
.
Conv_0
(
x
)
else
:
x
=
F
.
avg_pool2d
(
x
,
2
,
stride
=
2
)
else
:
else
:
if
not
self
.
with_conv
:
x
=
self
.
Conv2d_0
(
x
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
x
=
self
.
Conv2d_0
(
x
)
return
x
return
x
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment