Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b9de7172
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "1e21f061601dda0aa9740e88bfce68bf4aac4acd"
Commit
b9de7172
authored
Jun 27, 2022
by
patil-suraj
Browse files
add Downsample
parent
ee010726
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
4 deletions
+15
-4
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+15
-4
No files found.
src/diffusers/models/resnet.py
View file @
b9de7172
...
@@ -103,7 +103,7 @@ class Downsample(nn.Module):
...
@@ -103,7 +103,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
downsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
name
=
"conv"
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
...
@@ -111,18 +111,29 @@ class Downsample(nn.Module):
...
@@ -111,18 +111,29 @@ class Downsample(nn.Module):
self
.
dims
=
dims
self
.
dims
=
dims
self
.
padding
=
padding
self
.
padding
=
padding
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
self
.
name
=
name
if
use_conv
:
if
use_conv
:
self
.
down
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
else
:
else
:
assert
self
.
channels
==
self
.
out_channels
assert
self
.
channels
==
self
.
out_channels
self
.
down
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
conv
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
if
name
==
"conv"
:
self
.
conv
=
conv
else
:
self
.
op
=
conv
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv
and
self
.
padding
==
0
and
self
.
dims
==
2
:
if
self
.
use_conv
and
self
.
padding
==
0
and
self
.
dims
==
2
:
pad
=
(
0
,
1
,
0
,
1
)
pad
=
(
0
,
1
,
0
,
1
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
return
self
.
down
(
x
)
if
self
.
name
==
"conv"
:
return
self
.
conv
(
x
)
else
:
return
self
.
op
(
x
)
# TODO (patil-suraj): needs test
# TODO (patil-suraj): needs test
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment