Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
639b8611
Commit
639b8611
authored
Jun 30, 2022
by
patil-suraj
Browse files
get rid of the custom conv2d layer for up/down sampling
parent
663393e2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
180 deletions
+20
-180
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+20
-180
No files found.
src/diffusers/models/unet_sde_score_estimation.py
View file @
639b8611
...
...
@@ -17,7 +17,6 @@
import
functools
import
math
import
string
import
numpy
as
np
import
torch
...
...
@@ -28,99 +27,21 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
from
.resnet
import
ResnetBlockBigGANpp
,
downsample_2d
,
upfirdn2d
,
upsample_2d
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
return
upfirdn2d_native
(
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
])
def
upfirdn2d_native
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
):
_
,
channel
,
in_h
,
in_w
=
input
.
shape
input
=
input
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
_
,
in_h
,
in_w
,
minor
=
input
.
shape
kernel_h
,
kernel_w
=
kernel
.
shape
out
=
input
.
view
(
-
1
,
in_h
,
1
,
in_w
,
1
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
])
out
=
out
.
view
(
-
1
,
in_h
*
up_y
,
in_w
*
up_x
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)])
out
=
out
[
:,
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
1
]
-
max
(
-
pad_y1
,
0
),
max
(
-
pad_x0
,
0
)
:
out
.
shape
[
2
]
-
max
(
-
pad_x1
,
0
),
:,
]
out
=
out
.
permute
(
0
,
3
,
1
,
2
)
out
=
out
.
reshape
([
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
])
w
=
torch
.
flip
(
kernel
,
[
0
,
1
]).
view
(
1
,
1
,
kernel_h
,
kernel_w
)
out
=
F
.
conv2d
(
out
,
w
)
out
=
out
.
reshape
(
-
1
,
minor
,
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
+
1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
+
1
,
)
out
=
out
.
permute
(
0
,
2
,
3
,
1
)
out
=
out
[:,
::
down_y
,
::
down_x
,
:]
out_h
=
(
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
)
//
down_y
+
1
out_w
=
(
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
)
//
down_x
+
1
return
out
.
view
(
-
1
,
channel
,
out_h
,
out_w
)
# Function ported from StyleGAN2
def
get_weight
(
module
,
shape
,
weight_var
=
"weight"
,
kernel_init
=
None
):
"""Get/create weight tensor for a convolution or fully-connected layer."""
return
module
.
param
(
weight_var
,
kernel_init
,
shape
)
class
Conv2d
(
nn
.
Module
):
"""Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
def
__init__
(
self
,
in_ch
,
out_ch
,
kernel
,
up
=
False
,
down
=
False
,
resample_kernel
=
(
1
,
3
,
3
,
1
),
use_bias
=
True
,
kernel_init
=
None
,
):
super
().
__init__
()
assert
not
(
up
and
down
)
assert
kernel
>=
1
and
kernel
%
2
==
1
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
out_ch
,
in_ch
,
kernel
,
kernel
))
if
kernel_init
is
not
None
:
self
.
weight
.
data
=
kernel_init
(
self
.
weight
.
data
.
shape
)
if
use_bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_ch
))
self
.
up
=
up
self
.
down
=
down
self
.
resample_kernel
=
resample_kernel
self
.
kernel
=
kernel
self
.
use_bias
=
use_bias
def
forward
(
self
,
x
):
if
self
.
up
:
x
=
upsample_conv_2d
(
x
,
self
.
weight
,
k
=
self
.
resample_kernel
)
elif
self
.
down
:
x
=
conv_downsample_2d
(
x
,
self
.
weight
,
k
=
self
.
resample_kernel
)
else
:
x
=
F
.
conv2d
(
x
,
self
.
weight
,
stride
=
1
,
padding
=
self
.
kernel
//
2
)
def
_setup_kernel
(
k
):
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
if
k
.
ndim
==
1
:
k
=
np
.
outer
(
k
,
k
)
k
/=
np
.
sum
(
k
)
assert
k
.
ndim
==
2
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
return
k
if
self
.
use_bias
:
x
=
x
+
self
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
return
x
def
_shape
(
x
,
dim
):
return
x
.
shape
[
dim
]
def
upsample_conv_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
...
...
@@ -222,71 +143,6 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
return
F
.
conv2d
(
x
,
w
,
stride
=
s
,
padding
=
0
)
def
_setup_kernel
(
k
):
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
if
k
.
ndim
==
1
:
k
=
np
.
outer
(
k
,
k
)
k
/=
np
.
sum
(
k
)
assert
k
.
ndim
==
2
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
return
k
def
_shape
(
x
,
dim
):
return
x
.
shape
[
dim
]
def
upsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Upsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
(
gain
*
(
factor
**
2
))
p
=
k
.
shape
[
0
]
-
factor
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
up
=
factor
,
pad
=
((
p
+
1
)
//
2
+
factor
-
1
,
p
//
2
))
def
downsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Downsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
gain
p
=
k
.
shape
[
0
]
-
factor
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
1
):
"""nXn convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
...
...
@@ -338,24 +194,16 @@ class Upsample(nn.Module):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
if
with_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel
=
3
,
up
=
True
,
resample_kernel
=
fir_kernel
,
use_bias
=
True
,
kernel_init
=
variance_scaling
(),
)
self
.
Conv2d_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
with_conv
=
with_conv
self
.
fir_kernel
=
fir_kernel
self
.
out_ch
=
out_ch
def
forward
(
self
,
x
):
if
not
self
.
with_conv
:
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
if
self
.
with_conv
:
h
=
upsample_
conv_
2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
else
:
h
=
self
.
Conv2d_0
(
x
)
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
return
h
...
...
@@ -365,24 +213,16 @@ class Downsample(nn.Module):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
if
with_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel
=
3
,
down
=
True
,
resample_kernel
=
fir_kernel
,
use_bias
=
True
,
kernel_init
=
variance_scaling
(),
)
self
.
Conv2d_0
=
self
.
Conv2d_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
fir_kernel
=
fir_kernel
self
.
with_conv
=
with_conv
self
.
out_ch
=
out_ch
def
forward
(
self
,
x
):
if
not
self
.
with_conv
:
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
if
self
.
with_conv
:
x
=
conv_
downsample_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
else
:
x
=
self
.
Conv2d_0
(
x
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
return
x
...
...
@@ -400,7 +240,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
conv_size
=
3
,
dropout
=
0.0
,
embedding_type
=
"fourier"
,
fir
=
True
,
# TODO (patil-suraj) remove this option from here and pre-trained model configs
fir
=
True
,
# TODO (patil-suraj) remove this option from here and pre-trained model configs
fir_kernel
=
(
1
,
3
,
3
,
1
),
fourier_scale
=
16
,
init_scale
=
0.0
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment