Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
0bf6aeb8
Unverified
Commit
0bf6aeb8
authored
Jun 28, 2023
by
Saurav Maheshkar
Committed by
GitHub
Jun 28, 2023
Browse files
feat: rename single-letter vars in `resnet.py` (#3868)
feat: rename single-letter vars
parent
9a45d7fb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
26 deletions
+26
-26
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+26
-26
No files found.
src/diffusers/models/resnet.py
View file @
0bf6aeb8
...
@@ -95,9 +95,9 @@ class Downsample1D(nn.Module):
...
@@ -95,9 +95,9 @@ class Downsample1D(nn.Module):
assert
self
.
channels
==
self
.
out_channels
assert
self
.
channels
==
self
.
out_channels
self
.
conv
=
nn
.
AvgPool1d
(
kernel_size
=
stride
,
stride
=
stride
)
self
.
conv
=
nn
.
AvgPool1d
(
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
def
forward
(
self
,
inputs
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
inputs
.
shape
[
1
]
==
self
.
channels
return
self
.
conv
(
x
)
return
self
.
conv
(
inputs
)
class
Upsample2D
(
nn
.
Module
):
class
Upsample2D
(
nn
.
Module
):
...
@@ -431,13 +431,13 @@ class KDownsample2D(nn.Module):
...
@@ -431,13 +431,13 @@ class KDownsample2D(nn.Module):
self
.
pad
=
kernel_1d
.
shape
[
1
]
//
2
-
1
self
.
pad
=
kernel_1d
.
shape
[
1
]
//
2
-
1
self
.
register_buffer
(
"kernel"
,
kernel_1d
.
T
@
kernel_1d
,
persistent
=
False
)
self
.
register_buffer
(
"kernel"
,
kernel_1d
.
T
@
kernel_1d
,
persistent
=
False
)
def
forward
(
self
,
x
):
def
forward
(
self
,
inputs
):
x
=
F
.
pad
(
x
,
(
self
.
pad
,)
*
4
,
self
.
pad_mode
)
inputs
=
F
.
pad
(
inputs
,
(
self
.
pad
,)
*
4
,
self
.
pad_mode
)
weight
=
x
.
new_zeros
([
x
.
shape
[
1
],
x
.
shape
[
1
],
self
.
kernel
.
shape
[
0
],
self
.
kernel
.
shape
[
1
]])
weight
=
inputs
.
new_zeros
([
inputs
.
shape
[
1
],
inputs
.
shape
[
1
],
self
.
kernel
.
shape
[
0
],
self
.
kernel
.
shape
[
1
]])
indices
=
torch
.
arange
(
x
.
shape
[
1
],
device
=
x
.
device
)
indices
=
torch
.
arange
(
inputs
.
shape
[
1
],
device
=
inputs
.
device
)
kernel
=
self
.
kernel
.
to
(
weight
)[
None
,
:].
expand
(
x
.
shape
[
1
],
-
1
,
-
1
)
kernel
=
self
.
kernel
.
to
(
weight
)[
None
,
:].
expand
(
inputs
.
shape
[
1
],
-
1
,
-
1
)
weight
[
indices
,
indices
]
=
kernel
weight
[
indices
,
indices
]
=
kernel
return
F
.
conv2d
(
x
,
weight
,
stride
=
2
)
return
F
.
conv2d
(
inputs
,
weight
,
stride
=
2
)
class
KUpsample2D
(
nn
.
Module
):
class
KUpsample2D
(
nn
.
Module
):
...
@@ -448,13 +448,13 @@ class KUpsample2D(nn.Module):
...
@@ -448,13 +448,13 @@ class KUpsample2D(nn.Module):
self
.
pad
=
kernel_1d
.
shape
[
1
]
//
2
-
1
self
.
pad
=
kernel_1d
.
shape
[
1
]
//
2
-
1
self
.
register_buffer
(
"kernel"
,
kernel_1d
.
T
@
kernel_1d
,
persistent
=
False
)
self
.
register_buffer
(
"kernel"
,
kernel_1d
.
T
@
kernel_1d
,
persistent
=
False
)
def
forward
(
self
,
x
):
def
forward
(
self
,
inputs
):
x
=
F
.
pad
(
x
,
((
self
.
pad
+
1
)
//
2
,)
*
4
,
self
.
pad_mode
)
inputs
=
F
.
pad
(
inputs
,
((
self
.
pad
+
1
)
//
2
,)
*
4
,
self
.
pad_mode
)
weight
=
x
.
new_zeros
([
x
.
shape
[
1
],
x
.
shape
[
1
],
self
.
kernel
.
shape
[
0
],
self
.
kernel
.
shape
[
1
]])
weight
=
inputs
.
new_zeros
([
inputs
.
shape
[
1
],
inputs
.
shape
[
1
],
self
.
kernel
.
shape
[
0
],
self
.
kernel
.
shape
[
1
]])
indices
=
torch
.
arange
(
x
.
shape
[
1
],
device
=
x
.
device
)
indices
=
torch
.
arange
(
inputs
.
shape
[
1
],
device
=
inputs
.
device
)
kernel
=
self
.
kernel
.
to
(
weight
)[
None
,
:].
expand
(
x
.
shape
[
1
],
-
1
,
-
1
)
kernel
=
self
.
kernel
.
to
(
weight
)[
None
,
:].
expand
(
inputs
.
shape
[
1
],
-
1
,
-
1
)
weight
[
indices
,
indices
]
=
kernel
weight
[
indices
,
indices
]
=
kernel
return
F
.
conv_transpose2d
(
x
,
weight
,
stride
=
2
,
padding
=
self
.
pad
*
2
+
1
)
return
F
.
conv_transpose2d
(
inputs
,
weight
,
stride
=
2
,
padding
=
self
.
pad
*
2
+
1
)
class
ResnetBlock2D
(
nn
.
Module
):
class
ResnetBlock2D
(
nn
.
Module
):
...
@@ -664,13 +664,13 @@ class Conv1dBlock(nn.Module):
...
@@ -664,13 +664,13 @@ class Conv1dBlock(nn.Module):
self
.
group_norm
=
nn
.
GroupNorm
(
n_groups
,
out_channels
)
self
.
group_norm
=
nn
.
GroupNorm
(
n_groups
,
out_channels
)
self
.
mish
=
nn
.
Mish
()
self
.
mish
=
nn
.
Mish
()
def
forward
(
self
,
x
):
def
forward
(
self
,
inputs
):
x
=
self
.
conv1d
(
x
)
intermediate_repr
=
self
.
conv1d
(
inputs
)
x
=
rearrange_dims
(
x
)
intermediate_repr
=
rearrange_dims
(
intermediate_repr
)
x
=
self
.
group_norm
(
x
)
intermediate_repr
=
self
.
group_norm
(
intermediate_repr
)
x
=
rearrange_dims
(
x
)
intermediate_repr
=
rearrange_dims
(
intermediate_repr
)
x
=
self
.
mish
(
x
)
output
=
self
.
mish
(
intermediate_repr
)
return
x
return
output
# unet_rl.py
# unet_rl.py
...
@@ -687,10 +687,10 @@ class ResidualTemporalBlock1D(nn.Module):
...
@@ -687,10 +687,10 @@ class ResidualTemporalBlock1D(nn.Module):
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
)
)
def
forward
(
self
,
x
,
t
):
def
forward
(
self
,
inputs
,
t
):
"""
"""
Args:
Args:
x
: [ batch_size x inp_channels x horizon ]
inputs
: [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
t : [ batch_size x embed_dim ]
returns:
returns:
...
@@ -698,9 +698,9 @@ class ResidualTemporalBlock1D(nn.Module):
...
@@ -698,9 +698,9 @@ class ResidualTemporalBlock1D(nn.Module):
"""
"""
t
=
self
.
time_emb_act
(
t
)
t
=
self
.
time_emb_act
(
t
)
t
=
self
.
time_emb
(
t
)
t
=
self
.
time_emb
(
t
)
out
=
self
.
conv_in
(
x
)
+
rearrange_dims
(
t
)
out
=
self
.
conv_in
(
inputs
)
+
rearrange_dims
(
t
)
out
=
self
.
conv_out
(
out
)
out
=
self
.
conv_out
(
out
)
return
out
+
self
.
residual_conv
(
x
)
return
out
+
self
.
residual_conv
(
inputs
)
def
upsample_2d
(
hidden_states
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
def
upsample_2d
(
hidden_states
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment