Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a7058f42
Unverified
Commit
a7058f42
authored
Sep 30, 2022
by
Partho
Committed by
GitHub
Sep 29, 2022
Browse files
Renamed x -> hidden_states in resnet.py (#676)
renamed x to hidden_states
parent
3dacbb94
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
21 deletions
+21
-21
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+21
-21
No files found.
src/diffusers/models/resnet.py
View file @
a7058f42
...
@@ -34,21 +34,21 @@ class Upsample2D(nn.Module):
...
@@ -34,21 +34,21 @@ class Upsample2D(nn.Module):
else
:
else
:
self
.
Conv2d_0
=
conv
self
.
Conv2d_0
=
conv
def
forward
(
self
,
x
):
def
forward
(
self
,
hidden_states
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv_transpose
:
if
self
.
use_conv_transpose
:
return
self
.
conv
(
x
)
return
self
.
conv
(
hidden_states
)
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
hidden_states
=
F
.
interpolate
(
hidden_states
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if
self
.
use_conv
:
if
self
.
use_conv
:
if
self
.
name
==
"conv"
:
if
self
.
name
==
"conv"
:
x
=
self
.
conv
(
x
)
hidden_states
=
self
.
conv
(
hidden_states
)
else
:
else
:
x
=
self
.
Conv2d_0
(
x
)
hidden_states
=
self
.
Conv2d_0
(
hidden_states
)
return
x
return
hidden_states
class
Downsample2D
(
nn
.
Module
):
class
Downsample2D
(
nn
.
Module
):
...
@@ -84,16 +84,16 @@ class Downsample2D(nn.Module):
...
@@ -84,16 +84,16 @@ class Downsample2D(nn.Module):
else
:
else
:
self
.
conv
=
conv
self
.
conv
=
conv
def
forward
(
self
,
x
):
def
forward
(
self
,
hidden_states
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv
and
self
.
padding
==
0
:
if
self
.
use_conv
and
self
.
padding
==
0
:
pad
=
(
0
,
1
,
0
,
1
)
pad
=
(
0
,
1
,
0
,
1
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
hidden_states
=
F
.
pad
(
hidden_states
,
pad
,
mode
=
"constant"
,
value
=
0
)
assert
x
.
shape
[
1
]
==
self
.
channels
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
x
=
self
.
conv
(
x
)
hidden_states
=
self
.
conv
(
hidden_states
)
return
x
return
hidden_states
class
FirUpsample2D
(
nn
.
Module
):
class
FirUpsample2D
(
nn
.
Module
):
...
@@ -174,12 +174,12 @@ class FirUpsample2D(nn.Module):
...
@@ -174,12 +174,12 @@ class FirUpsample2D(nn.Module):
return
x
return
x
def
forward
(
self
,
x
):
def
forward
(
self
,
hidden_states
):
if
self
.
use_conv
:
if
self
.
use_conv
:
height
=
self
.
_upsample_2d
(
x
,
self
.
Conv2d_0
.
weight
,
kernel
=
self
.
fir_kernel
)
height
=
self
.
_upsample_2d
(
hidden_states
,
self
.
Conv2d_0
.
weight
,
kernel
=
self
.
fir_kernel
)
height
=
height
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
height
=
height
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
else
:
else
:
height
=
self
.
_upsample_2d
(
x
,
kernel
=
self
.
fir_kernel
,
factor
=
2
)
height
=
self
.
_upsample_2d
(
hidden_states
,
kernel
=
self
.
fir_kernel
,
factor
=
2
)
return
height
return
height
...
@@ -236,14 +236,14 @@ class FirDownsample2D(nn.Module):
...
@@ -236,14 +236,14 @@ class FirDownsample2D(nn.Module):
return
x
return
x
def
forward
(
self
,
x
):
def
forward
(
self
,
hidden_states
):
if
self
.
use_conv
:
if
self
.
use_conv
:
x
=
self
.
_downsample_2d
(
x
,
weight
=
self
.
Conv2d_0
.
weight
,
kernel
=
self
.
fir_kernel
)
hidden_states
=
self
.
_downsample_2d
(
hidden_states
,
weight
=
self
.
Conv2d_0
.
weight
,
kernel
=
self
.
fir_kernel
)
x
=
x
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
hidden_states
=
hidden_states
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
else
:
else
:
x
=
self
.
_downsample_2d
(
x
,
kernel
=
self
.
fir_kernel
,
factor
=
2
)
hidden_states
=
self
.
_downsample_2d
(
hidden_states
,
kernel
=
self
.
fir_kernel
,
factor
=
2
)
return
x
return
hidden_states
class
ResnetBlock2D
(
nn
.
Module
):
class
ResnetBlock2D
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment