Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a73f8b72
Commit
a73f8b72
authored
Oct 10, 2022
by
Nathan Lambert
Committed by
GitHub
Oct 10, 2022
Browse files
Clean up resnet.py file (#780)
* clean up resnet.py * make style and quality * minor formatting
parent
5af6eed9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
57 deletions
+63
-57
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+63
-57
No files found.
src/diffusers/models/resnet.py
View file @
a73f8b72
...
@@ -9,9 +9,10 @@ class Upsample2D(nn.Module):
...
@@ -9,9 +9,10 @@ class Upsample2D(nn.Module):
"""
"""
An upsampling layer with an optional convolution.
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
Parameters:
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
channels: channels in the inputs and outputs.
upsampling occurs in the inner-two dimensions.
use_conv: a bool determining if a convolution is applied.
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
out_channels
=
None
,
name
=
"conv"
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
out_channels
=
None
,
name
=
"conv"
):
...
@@ -61,9 +62,10 @@ class Downsample2D(nn.Module):
...
@@ -61,9 +62,10 @@ class Downsample2D(nn.Module):
"""
"""
A downsampling layer with an optional convolution.
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
Parameters:
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
channels: channels in the inputs and outputs.
downsampling occurs in the inner-two dimensions.
use_conv: a bool determining if a convolution is applied.
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
out_channels
=
None
,
padding
=
1
,
name
=
"conv"
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
out_channels
=
None
,
padding
=
1
,
name
=
"conv"
):
...
@@ -115,21 +117,22 @@ class FirUpsample2D(nn.Module):
...
@@ -115,21 +117,22 @@ class FirUpsample2D(nn.Module):
def
_upsample_2d
(
self
,
hidden_states
,
weight
=
None
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
def
_upsample_2d
(
self
,
hidden_states
,
weight
=
None
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
"""Fused `upsample_2d()` followed by `Conv2d()`.
"""Fused `upsample_2d()` followed by `Conv2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
order.
arbitrary order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
datatype as
output:
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
`x
`.
datatype as `hidden_states
`.
"""
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
...
@@ -164,7 +167,6 @@ class FirUpsample2D(nn.Module):
...
@@ -164,7 +167,6 @@ class FirUpsample2D(nn.Module):
output_shape
[
1
]
-
(
hidden_states
.
shape
[
3
]
-
1
)
*
stride
[
1
]
-
convW
,
output_shape
[
1
]
-
(
hidden_states
.
shape
[
3
]
-
1
)
*
stride
[
1
]
-
convW
,
)
)
assert
output_padding
[
0
]
>=
0
and
output_padding
[
1
]
>=
0
assert
output_padding
[
0
]
>=
0
and
output_padding
[
1
]
>=
0
inC
=
weight
.
shape
[
1
]
num_groups
=
hidden_states
.
shape
[
1
]
//
inC
num_groups
=
hidden_states
.
shape
[
1
]
//
inC
# Transpose weights.
# Transpose weights.
...
@@ -214,20 +216,23 @@ class FirDownsample2D(nn.Module):
...
@@ -214,20 +216,23 @@ class FirDownsample2D(nn.Module):
def
_downsample_2d
(
self
,
hidden_states
,
weight
=
None
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
def
_downsample_2d
(
self
,
hidden_states
,
weight
=
None
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
"""Fused `Conv2d()` followed by `downsample_2d()`.
"""Fused `Conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order.
Args:
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
:
weight
:
order.
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
performed by `inChannels = x.shape[0] // numGroups`.
filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
factor`, which corresponds to average pooling.
factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2).
gain:
factor: Integer downsampling factor (default: 2).
Scaling factor for signal magnitude (default: 1.0).
gain:
Scaling factor for signal magnitude (default: 1.0).
Returns:
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
same
output:
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
datatype as `x`.
same
datatype as `x`.
"""
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
...
@@ -251,17 +256,17 @@ class FirDownsample2D(nn.Module):
...
@@ -251,17 +256,17 @@ class FirDownsample2D(nn.Module):
torch
.
tensor
(
kernel
,
device
=
hidden_states
.
device
),
torch
.
tensor
(
kernel
,
device
=
hidden_states
.
device
),
pad
=
((
pad_value
+
1
)
//
2
,
pad_value
//
2
),
pad
=
((
pad_value
+
1
)
//
2
,
pad_value
//
2
),
)
)
hidden_states
=
F
.
conv2d
(
upfirdn_input
,
weight
,
stride
=
stride_value
,
padding
=
0
)
output
=
F
.
conv2d
(
upfirdn_input
,
weight
,
stride
=
stride_value
,
padding
=
0
)
else
:
else
:
pad_value
=
kernel
.
shape
[
0
]
-
factor
pad_value
=
kernel
.
shape
[
0
]
-
factor
hidden_states
=
upfirdn2d_native
(
output
=
upfirdn2d_native
(
hidden_states
,
hidden_states
,
torch
.
tensor
(
kernel
,
device
=
hidden_states
.
device
),
torch
.
tensor
(
kernel
,
device
=
hidden_states
.
device
),
down
=
factor
,
down
=
factor
,
pad
=
((
pad_value
+
1
)
//
2
,
pad_value
//
2
),
pad
=
((
pad_value
+
1
)
//
2
,
pad_value
//
2
),
)
)
return
hidden_states
return
output
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
if
self
.
use_conv
:
if
self
.
use_conv
:
...
@@ -393,20 +398,20 @@ class Mish(torch.nn.Module):
...
@@ -393,20 +398,20 @@ class Mish(torch.nn.Module):
def
upsample_2d
(
hidden_states
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
def
upsample_2d
(
hidden_states
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Upsample2D a batch of 2D images with the given filter.
r
"""Upsample2D a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
multiple of the upsampling factor.
a: multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
Args:
k: FIR filter of the shape `[firH, firW]` or `[firN]`
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
output:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
if
kernel
is
None
:
if
kernel
is
None
:
...
@@ -419,30 +424,31 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
...
@@ -419,30 +424,31 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
kernel
=
kernel
*
(
gain
*
(
factor
**
2
))
kernel
=
kernel
*
(
gain
*
(
factor
**
2
))
pad_value
=
kernel
.
shape
[
0
]
-
factor
pad_value
=
kernel
.
shape
[
0
]
-
factor
return
upfirdn2d_native
(
output
=
upfirdn2d_native
(
hidden_states
,
hidden_states
,
kernel
.
to
(
device
=
hidden_states
.
device
),
kernel
.
to
(
device
=
hidden_states
.
device
),
up
=
factor
,
up
=
factor
,
pad
=
((
pad_value
+
1
)
//
2
+
factor
-
1
,
pad_value
//
2
),
pad
=
((
pad_value
+
1
)
//
2
+
factor
-
1
,
pad_value
//
2
),
)
)
return
output
def
downsample_2d
(
hidden_states
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
def
downsample_2d
(
hidden_states
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Downsample2D a batch of 2D images with the given filter.
r
"""Downsample2D a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
output:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
...
@@ -456,34 +462,34 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
...
@@ -456,34 +462,34 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
kernel
=
kernel
*
gain
kernel
=
kernel
*
gain
pad_value
=
kernel
.
shape
[
0
]
-
factor
pad_value
=
kernel
.
shape
[
0
]
-
factor
return
upfirdn2d_native
(
output
=
upfirdn2d_native
(
hidden_states
,
kernel
.
to
(
device
=
hidden_states
.
device
),
down
=
factor
,
pad
=
((
pad_value
+
1
)
//
2
,
pad_value
//
2
)
hidden_states
,
kernel
.
to
(
device
=
hidden_states
.
device
),
down
=
factor
,
pad
=
((
pad_value
+
1
)
//
2
,
pad_value
//
2
)
)
)
return
output
def
upfirdn2d_native
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
def
upfirdn2d_native
(
tensor
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
up_x
=
up_y
=
up
up_x
=
up_y
=
up
down_x
=
down_y
=
down
down_x
=
down_y
=
down
pad_x0
=
pad_y0
=
pad
[
0
]
pad_x0
=
pad_y0
=
pad
[
0
]
pad_x1
=
pad_y1
=
pad
[
1
]
pad_x1
=
pad_y1
=
pad
[
1
]
_
,
channel
,
in_h
,
in_w
=
input
.
shape
_
,
channel
,
in_h
,
in_w
=
tensor
.
shape
input
=
input
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
tensor
=
tensor
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
# Rename this variable (input); it shadows a builtin.sonarlint(python:S5806)
_
,
in_h
,
in_w
,
minor
=
input
.
shape
_
,
in_h
,
in_w
,
minor
=
tensor
.
shape
kernel_h
,
kernel_w
=
kernel
.
shape
kernel_h
,
kernel_w
=
kernel
.
shape
out
=
input
.
view
(
-
1
,
in_h
,
1
,
in_w
,
1
,
minor
)
out
=
tensor
.
view
(
-
1
,
in_h
,
1
,
in_w
,
1
,
minor
)
# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
if
input
.
device
.
type
==
"mps"
:
if
tensor
.
device
.
type
==
"mps"
:
out
=
out
.
to
(
"cpu"
)
out
=
out
.
to
(
"cpu"
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
])
out
=
F
.
pad
(
out
,
[
0
,
0
,
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
])
out
=
out
.
view
(
-
1
,
in_h
*
up_y
,
in_w
*
up_x
,
minor
)
out
=
out
.
view
(
-
1
,
in_h
*
up_y
,
in_w
*
up_x
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)])
out
=
F
.
pad
(
out
,
[
0
,
0
,
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)])
out
=
out
.
to
(
input
.
device
)
# Move back to mps if necessary
out
=
out
.
to
(
tensor
.
device
)
# Move back to mps if necessary
out
=
out
[
out
=
out
[
:,
:,
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
1
]
-
max
(
-
pad_y1
,
0
),
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
1
]
-
max
(
-
pad_y1
,
0
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment