Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
9c96682a
Commit
9c96682a
authored
Jun 17, 2022
by
Nathan Lambert
Browse files
ddpm changes for rl, add rl unet
parent
1997b908
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
260 additions
and
2 deletions
+260
-2
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+242
-0
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+8
-2
src/diffusers/schedulers/scheduling_utils.py
src/diffusers/schedulers/scheduling_utils.py
+10
-0
No files found.
src/diffusers/models/unet_rl.py
0 → 100644
View file @
9c96682a
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
import
torch
import
torch.nn
as
nn
import
einops
from
einops.layers.torch
import
Rearrange
import
math
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
)
*
-
emb
)
emb
=
x
[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
class
Downsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv1d
(
dim
,
dim
,
3
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
ConvTranspose1d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Conv1dBlock
(
nn
.
Module
):
'''
Conv1d --> GroupNorm --> Mish
'''
def
__init__
(
self
,
inp_channels
,
out_channels
,
kernel_size
,
n_groups
=
8
):
super
().
__init__
()
self
.
block
=
nn
.
Sequential
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
Rearrange
(
'batch channels horizon -> batch channels 1 horizon'
),
nn
.
GroupNorm
(
n_groups
,
out_channels
),
Rearrange
(
'batch channels 1 horizon -> batch channels horizon'
),
nn
.
Mish
(),
)
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
super
().
__init__
()
self
.
blocks
=
nn
.
ModuleList
([
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
])
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
Rearrange
(
'batch t -> batch t 1'
),
)
self
.
residual_conv
=
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
\
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
def
forward
(
self
,
x
,
t
):
'''
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
'''
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
1
](
out
)
return
out
+
self
.
residual_conv
(
x
)
class
TemporalUnet
(
nn
.
Module
):
def
__init__
(
self
,
horizon
,
transition_dim
,
cond_dim
,
dim
=
32
,
dim_mults
=
(
1
,
2
,
4
,
8
),
):
super
().
__init__
()
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
print
(
f
'[ models/temporal ] Channel dimensions:
{
in_out
}
'
)
time_dim
=
dim
self
.
time_mlp
=
nn
.
Sequential
(
SinusoidalPosEmb
(
dim
),
nn
.
Linear
(
dim
,
dim
*
4
),
nn
.
Mish
(),
nn
.
Linear
(
dim
*
4
,
dim
),
)
self
.
downs
=
nn
.
ModuleList
([])
self
.
ups
=
nn
.
ModuleList
([])
num_resolutions
=
len
(
in_out
)
print
(
in_out
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
downs
.
append
(
nn
.
ModuleList
([
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
()
]))
if
not
is_last
:
horizon
=
horizon
//
2
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
horizon
)
self
.
mid_block2
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
horizon
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
ups
.
append
(
nn
.
ModuleList
([
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
()
]))
if
not
is_last
:
horizon
=
horizon
*
2
self
.
final_conv
=
nn
.
Sequential
(
Conv1dBlock
(
dim
,
dim
,
kernel_size
=
5
),
nn
.
Conv1d
(
dim
,
transition_dim
,
1
),
)
def
forward
(
self
,
x
,
cond
,
time
):
'''
x : [ batch x horizon x transition ]
'''
x
=
einops
.
rearrange
(
x
,
'b h t -> b t h'
)
t
=
self
.
time_mlp
(
time
)
h
=
[]
for
resnet
,
resnet2
,
downsample
in
self
.
downs
:
x
=
resnet
(
x
,
t
)
x
=
resnet2
(
x
,
t
)
h
.
append
(
x
)
x
=
downsample
(
x
)
x
=
self
.
mid_block1
(
x
,
t
)
x
=
self
.
mid_block2
(
x
,
t
)
for
resnet
,
resnet2
,
upsample
in
self
.
ups
:
x
=
torch
.
cat
((
x
,
h
.
pop
()),
dim
=
1
)
x
=
resnet
(
x
,
t
)
x
=
resnet2
(
x
,
t
)
x
=
upsample
(
x
)
x
=
self
.
final_conv
(
x
)
x
=
einops
.
rearrange
(
x
,
'b t h -> b h t'
)
return
x
class
TemporalValue
(
nn
.
Module
):
def
__init__
(
self
,
horizon
,
transition_dim
,
cond_dim
,
dim
=
32
,
time_dim
=
None
,
out_dim
=
1
,
dim_mults
=
(
1
,
2
,
4
,
8
),
):
super
().
__init__
()
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
time_dim
=
time_dim
or
dim
self
.
time_mlp
=
nn
.
Sequential
(
SinusoidalPosEmb
(
dim
),
nn
.
Linear
(
dim
,
dim
*
4
),
nn
.
Mish
(),
nn
.
Linear
(
dim
*
4
,
dim
),
)
self
.
blocks
=
nn
.
ModuleList
([])
print
(
in_out
)
for
dim_in
,
dim_out
in
in_out
:
self
.
blocks
.
append
(
nn
.
ModuleList
([
ResidualTemporalBlock
(
dim_in
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
)
]))
horizon
=
horizon
//
2
fc_dim
=
dims
[
-
1
]
*
max
(
horizon
,
1
)
self
.
final_block
=
nn
.
Sequential
(
nn
.
Linear
(
fc_dim
+
time_dim
,
fc_dim
//
2
),
nn
.
Mish
(),
nn
.
Linear
(
fc_dim
//
2
,
out_dim
),
)
def
forward
(
self
,
x
,
cond
,
time
,
*
args
):
'''
x : [ batch x horizon x transition ]
'''
x
=
einops
.
rearrange
(
x
,
'b h t -> b t h'
)
t
=
self
.
time_mlp
(
time
)
for
resnet
,
resnet2
,
downsample
in
self
.
blocks
:
x
=
resnet
(
x
,
t
)
x
=
resnet2
(
x
,
t
)
x
=
downsample
(
x
)
x
=
x
.
view
(
len
(
x
),
-
1
)
out
=
self
.
final_block
(
torch
.
cat
([
x
,
t
],
dim
=-
1
))
return
out
\ No newline at end of file
src/diffusers/schedulers/scheduling_ddpm.py
View file @
9c96682a
...
@@ -105,12 +105,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -105,12 +105,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# hacks - were probs added for training stability
# hacks - were probs added for training stability
if
self
.
config
.
variance_type
==
"fixed_small"
:
if
self
.
config
.
variance_type
==
"fixed_small"
:
variance
=
self
.
clip
(
variance
,
min_value
=
1e-20
)
variance
=
self
.
clip
(
variance
,
min_value
=
1e-20
)
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif
self
.
config
.
variance_type
==
"fixed_small_log"
:
variance
=
self
.
log
(
self
.
clip
(
variance
,
min_value
=
1e-20
))
elif
self
.
config
.
variance_type
==
"fixed_large"
:
elif
self
.
config
.
variance_type
==
"fixed_large"
:
variance
=
self
.
get_beta
(
t
)
variance
=
self
.
get_beta
(
t
)
return
variance
return
variance
def
step
(
self
,
residual
,
sample
,
t
):
def
step
(
self
,
residual
,
sample
,
t
,
predict_epsilon
=
True
):
# 1. compute alphas, betas
# 1. compute alphas, betas
alpha_prod_t
=
self
.
get_alpha_prod
(
t
)
alpha_prod_t
=
self
.
get_alpha_prod
(
t
)
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
t
-
1
)
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
t
-
1
)
...
@@ -119,7 +122,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -119,7 +122,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if
predict_epsilon
:
pred_original_sample
=
(
sample
-
beta_prod_t
**
(
0.5
)
*
residual
)
/
alpha_prod_t
**
(
0.5
)
pred_original_sample
=
(
sample
-
beta_prod_t
**
(
0.5
)
*
residual
)
/
alpha_prod_t
**
(
0.5
)
else
:
pred_original_sample
=
residual
# 3. Clip "predicted x_0"
# 3. Clip "predicted x_0"
if
self
.
config
.
clip_sample
:
if
self
.
config
.
clip_sample
:
...
...
src/diffusers/schedulers/scheduling_utils.py
View file @
9c96682a
...
@@ -64,3 +64,13 @@ class SchedulerMixin:
...
@@ -64,3 +64,13 @@ class SchedulerMixin:
return
torch
.
clamp
(
tensor
,
min_value
,
max_value
)
return
torch
.
clamp
(
tensor
,
min_value
,
max_value
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
def
log
(
self
,
tensor
):
tensor_format
=
getattr
(
self
,
"tensor_format"
,
"pt"
)
if
tensor_format
==
"np"
:
return
np
.
log
(
tensor
)
elif
tensor_format
==
"pt"
:
return
torch
.
log
(
tensor
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment