Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
0bf6aeb8
Unverified
Commit
0bf6aeb8
authored
Jun 28, 2023
by
Saurav Maheshkar
Committed by
GitHub
Jun 28, 2023
Browse files
feat: rename single-letter vars in `resnet.py` (#3868)
feat: rename single-letter vars
parent
9a45d7fb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
26 deletions
+26
-26
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+26
-26
No files found.
src/diffusers/models/resnet.py
View file @
0bf6aeb8
...
@@ -95,9 +95,9 @@ class Downsample1D(nn.Module):
...
@@ -95,9 +95,9 @@ class Downsample1D(nn.Module):
assert
self
.
channels
==
self
.
out_channels
assert
self
.
channels
==
self
.
out_channels
self
.
conv
=
nn
.
AvgPool1d
(
kernel_size
=
stride
,
stride
=
stride
)
self
.
conv
=
nn
.
AvgPool1d
(
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
def
forward
(
self
,
inputs
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
inputs
.
shape
[
1
]
==
self
.
channels
return
self
.
conv
(
x
)
return
self
.
conv
(
inputs
)
class
Upsample2D
(
nn
.
Module
):
class
Upsample2D
(
nn
.
Module
):
...
@@ -431,13 +431,13 @@ class KDownsample2D(nn.Module):
...
@@ -431,13 +431,13 @@ class KDownsample2D(nn.Module):
self
.
pad
=
kernel_1d
.
shape
[
1
]
//
2
-
1
self
.
pad
=
kernel_1d
.
shape
[
1
]
//
2
-
1
self
.
register_buffer
(
"kernel"
,
kernel_1d
.
T
@
kernel_1d
,
persistent
=
False
)
self
.
register_buffer
(
"kernel"
,
kernel_1d
.
T
@
kernel_1d
,
persistent
=
False
)
def
forward
(
self
,
x
):
def
forward
(
self
,
inputs
):
x
=
F
.
pad
(
x
,
(
self
.
pad
,)
*
4
,
self
.
pad_mode
)
inputs
=
F
.
pad
(
inputs
,
(
self
.
pad
,)
*
4
,
self
.
pad_mode
)
weight
=
x
.
new_zeros
([
x
.
shape
[
1
],
x
.
shape
[
1
],
self
.
kernel
.
shape
[
0
],
self
.
kernel
.
shape
[
1
]])
weight
=
inputs
.
new_zeros
([
inputs
.
shape
[
1
],
inputs
.
shape
[
1
],
self
.
kernel
.
shape
[
0
],
self
.
kernel
.
shape
[
1
]])
indices
=
torch
.
arange
(
x
.
shape
[
1
],
device
=
x
.
device
)
indices
=
torch
.
arange
(
inputs
.
shape
[
1
],
device
=
inputs
.
device
)
kernel
=
self
.
kernel
.
to
(
weight
)[
None
,
:].
expand
(
x
.
shape
[
1
],
-
1
,
-
1
)
kernel
=
self
.
kernel
.
to
(
weight
)[
None
,
:].
expand
(
inputs
.
shape
[
1
],
-
1
,
-
1
)
weight
[
indices
,
indices
]
=
kernel
weight
[
indices
,
indices
]
=
kernel
return
F
.
conv2d
(
x
,
weight
,
stride
=
2
)
return
F
.
conv2d
(
inputs
,
weight
,
stride
=
2
)
class
KUpsample2D
(
nn
.
Module
):
class
KUpsample2D
(
nn
.
Module
):
...
@@ -448,13 +448,13 @@ class KUpsample2D(nn.Module):
...
@@ -448,13 +448,13 @@ class KUpsample2D(nn.Module):
self
.
pad
=
kernel_1d
.
shape
[
1
]
//
2
-
1
self
.
pad
=
kernel_1d
.
shape
[
1
]
//
2
-
1
self
.
register_buffer
(
"kernel"
,
kernel_1d
.
T
@
kernel_1d
,
persistent
=
False
)
self
.
register_buffer
(
"kernel"
,
kernel_1d
.
T
@
kernel_1d
,
persistent
=
False
)
def
forward
(
self
,
x
):
def
forward
(
self
,
inputs
):
x
=
F
.
pad
(
x
,
((
self
.
pad
+
1
)
//
2
,)
*
4
,
self
.
pad_mode
)
inputs
=
F
.
pad
(
inputs
,
((
self
.
pad
+
1
)
//
2
,)
*
4
,
self
.
pad_mode
)
weight
=
x
.
new_zeros
([
x
.
shape
[
1
],
x
.
shape
[
1
],
self
.
kernel
.
shape
[
0
],
self
.
kernel
.
shape
[
1
]])
weight
=
inputs
.
new_zeros
([
inputs
.
shape
[
1
],
inputs
.
shape
[
1
],
self
.
kernel
.
shape
[
0
],
self
.
kernel
.
shape
[
1
]])
indices
=
torch
.
arange
(
x
.
shape
[
1
],
device
=
x
.
device
)
indices
=
torch
.
arange
(
inputs
.
shape
[
1
],
device
=
inputs
.
device
)
kernel
=
self
.
kernel
.
to
(
weight
)[
None
,
:].
expand
(
x
.
shape
[
1
],
-
1
,
-
1
)
kernel
=
self
.
kernel
.
to
(
weight
)[
None
,
:].
expand
(
inputs
.
shape
[
1
],
-
1
,
-
1
)
weight
[
indices
,
indices
]
=
kernel
weight
[
indices
,
indices
]
=
kernel
return
F
.
conv_transpose2d
(
x
,
weight
,
stride
=
2
,
padding
=
self
.
pad
*
2
+
1
)
return
F
.
conv_transpose2d
(
inputs
,
weight
,
stride
=
2
,
padding
=
self
.
pad
*
2
+
1
)
class
ResnetBlock2D
(
nn
.
Module
):
class
ResnetBlock2D
(
nn
.
Module
):
...
@@ -664,13 +664,13 @@ class Conv1dBlock(nn.Module):
...
@@ -664,13 +664,13 @@ class Conv1dBlock(nn.Module):
self
.
group_norm
=
nn
.
GroupNorm
(
n_groups
,
out_channels
)
self
.
group_norm
=
nn
.
GroupNorm
(
n_groups
,
out_channels
)
self
.
mish
=
nn
.
Mish
()
self
.
mish
=
nn
.
Mish
()
def
forward
(
self
,
x
):
def
forward
(
self
,
inputs
):
x
=
self
.
conv1d
(
x
)
intermediate_repr
=
self
.
conv1d
(
inputs
)
x
=
rearrange_dims
(
x
)
intermediate_repr
=
rearrange_dims
(
intermediate_repr
)
x
=
self
.
group_norm
(
x
)
intermediate_repr
=
self
.
group_norm
(
intermediate_repr
)
x
=
rearrange_dims
(
x
)
intermediate_repr
=
rearrange_dims
(
intermediate_repr
)
x
=
self
.
mish
(
x
)
output
=
self
.
mish
(
intermediate_repr
)
return
x
return
output
# unet_rl.py
# unet_rl.py
...
@@ -687,10 +687,10 @@ class ResidualTemporalBlock1D(nn.Module):
...
@@ -687,10 +687,10 @@ class ResidualTemporalBlock1D(nn.Module):
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
)
)
def
forward
(
self
,
x
,
t
):
def
forward
(
self
,
inputs
,
t
):
"""
"""
Args:
Args:
x
: [ batch_size x inp_channels x horizon ]
inputs
: [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
t : [ batch_size x embed_dim ]
returns:
returns:
...
@@ -698,9 +698,9 @@ class ResidualTemporalBlock1D(nn.Module):
...
@@ -698,9 +698,9 @@ class ResidualTemporalBlock1D(nn.Module):
"""
"""
t
=
self
.
time_emb_act
(
t
)
t
=
self
.
time_emb_act
(
t
)
t
=
self
.
time_emb
(
t
)
t
=
self
.
time_emb
(
t
)
out
=
self
.
conv_in
(
x
)
+
rearrange_dims
(
t
)
out
=
self
.
conv_in
(
inputs
)
+
rearrange_dims
(
t
)
out
=
self
.
conv_out
(
out
)
out
=
self
.
conv_out
(
out
)
return
out
+
self
.
residual_conv
(
x
)
return
out
+
self
.
residual_conv
(
inputs
)
def
upsample_2d
(
hidden_states
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
def
upsample_2d
(
hidden_states
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment