Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
9c96682a
Commit
9c96682a
authored
Jun 17, 2022
by
Nathan Lambert
Browse files
ddpm changes for rl, add rl unet
parent
1997b908
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
260 additions
and
2 deletions
+260
-2
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+242
-0
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+8
-2
src/diffusers/schedulers/scheduling_utils.py
src/diffusers/schedulers/scheduling_utils.py
+10
-0
No files found.
src/diffusers/models/unet_rl.py
0 → 100644
View file @
9c96682a
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
import
torch
import
torch.nn
as
nn
import
einops
from
einops.layers.torch
import
Rearrange
import
math
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
)
*
-
emb
)
emb
=
x
[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
class
Downsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv1d
(
dim
,
dim
,
3
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
ConvTranspose1d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Conv1dBlock
(
nn
.
Module
):
'''
Conv1d --> GroupNorm --> Mish
'''
def
__init__
(
self
,
inp_channels
,
out_channels
,
kernel_size
,
n_groups
=
8
):
super
().
__init__
()
self
.
block
=
nn
.
Sequential
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
Rearrange
(
'batch channels horizon -> batch channels 1 horizon'
),
nn
.
GroupNorm
(
n_groups
,
out_channels
),
Rearrange
(
'batch channels 1 horizon -> batch channels horizon'
),
nn
.
Mish
(),
)
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
super
().
__init__
()
self
.
blocks
=
nn
.
ModuleList
([
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
])
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
Rearrange
(
'batch t -> batch t 1'
),
)
self
.
residual_conv
=
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
\
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
def
forward
(
self
,
x
,
t
):
'''
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
'''
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
1
](
out
)
return
out
+
self
.
residual_conv
(
x
)
class
TemporalUnet
(
nn
.
Module
):
def
__init__
(
self
,
horizon
,
transition_dim
,
cond_dim
,
dim
=
32
,
dim_mults
=
(
1
,
2
,
4
,
8
),
):
super
().
__init__
()
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
print
(
f
'[ models/temporal ] Channel dimensions:
{
in_out
}
'
)
time_dim
=
dim
self
.
time_mlp
=
nn
.
Sequential
(
SinusoidalPosEmb
(
dim
),
nn
.
Linear
(
dim
,
dim
*
4
),
nn
.
Mish
(),
nn
.
Linear
(
dim
*
4
,
dim
),
)
self
.
downs
=
nn
.
ModuleList
([])
self
.
ups
=
nn
.
ModuleList
([])
num_resolutions
=
len
(
in_out
)
print
(
in_out
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
downs
.
append
(
nn
.
ModuleList
([
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
()
]))
if
not
is_last
:
horizon
=
horizon
//
2
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
horizon
)
self
.
mid_block2
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
horizon
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
ups
.
append
(
nn
.
ModuleList
([
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
()
]))
if
not
is_last
:
horizon
=
horizon
*
2
self
.
final_conv
=
nn
.
Sequential
(
Conv1dBlock
(
dim
,
dim
,
kernel_size
=
5
),
nn
.
Conv1d
(
dim
,
transition_dim
,
1
),
)
def
forward
(
self
,
x
,
cond
,
time
):
'''
x : [ batch x horizon x transition ]
'''
x
=
einops
.
rearrange
(
x
,
'b h t -> b t h'
)
t
=
self
.
time_mlp
(
time
)
h
=
[]
for
resnet
,
resnet2
,
downsample
in
self
.
downs
:
x
=
resnet
(
x
,
t
)
x
=
resnet2
(
x
,
t
)
h
.
append
(
x
)
x
=
downsample
(
x
)
x
=
self
.
mid_block1
(
x
,
t
)
x
=
self
.
mid_block2
(
x
,
t
)
for
resnet
,
resnet2
,
upsample
in
self
.
ups
:
x
=
torch
.
cat
((
x
,
h
.
pop
()),
dim
=
1
)
x
=
resnet
(
x
,
t
)
x
=
resnet2
(
x
,
t
)
x
=
upsample
(
x
)
x
=
self
.
final_conv
(
x
)
x
=
einops
.
rearrange
(
x
,
'b t h -> b h t'
)
return
x
class
TemporalValue
(
nn
.
Module
):
def
__init__
(
self
,
horizon
,
transition_dim
,
cond_dim
,
dim
=
32
,
time_dim
=
None
,
out_dim
=
1
,
dim_mults
=
(
1
,
2
,
4
,
8
),
):
super
().
__init__
()
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
time_dim
=
time_dim
or
dim
self
.
time_mlp
=
nn
.
Sequential
(
SinusoidalPosEmb
(
dim
),
nn
.
Linear
(
dim
,
dim
*
4
),
nn
.
Mish
(),
nn
.
Linear
(
dim
*
4
,
dim
),
)
self
.
blocks
=
nn
.
ModuleList
([])
print
(
in_out
)
for
dim_in
,
dim_out
in
in_out
:
self
.
blocks
.
append
(
nn
.
ModuleList
([
ResidualTemporalBlock
(
dim_in
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
)
]))
horizon
=
horizon
//
2
fc_dim
=
dims
[
-
1
]
*
max
(
horizon
,
1
)
self
.
final_block
=
nn
.
Sequential
(
nn
.
Linear
(
fc_dim
+
time_dim
,
fc_dim
//
2
),
nn
.
Mish
(),
nn
.
Linear
(
fc_dim
//
2
,
out_dim
),
)
def
forward
(
self
,
x
,
cond
,
time
,
*
args
):
'''
x : [ batch x horizon x transition ]
'''
x
=
einops
.
rearrange
(
x
,
'b h t -> b t h'
)
t
=
self
.
time_mlp
(
time
)
for
resnet
,
resnet2
,
downsample
in
self
.
blocks
:
x
=
resnet
(
x
,
t
)
x
=
resnet2
(
x
,
t
)
x
=
downsample
(
x
)
x
=
x
.
view
(
len
(
x
),
-
1
)
out
=
self
.
final_block
(
torch
.
cat
([
x
,
t
],
dim
=-
1
))
return
out
\ No newline at end of file
src/diffusers/schedulers/scheduling_ddpm.py
View file @
9c96682a
...
...
@@ -105,12 +105,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# hacks - were probs added for training stability
if
self
.
config
.
variance_type
==
"fixed_small"
:
variance
=
self
.
clip
(
variance
,
min_value
=
1e-20
)
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif
self
.
config
.
variance_type
==
"fixed_small_log"
:
variance
=
self
.
log
(
self
.
clip
(
variance
,
min_value
=
1e-20
))
elif
self
.
config
.
variance_type
==
"fixed_large"
:
variance
=
self
.
get_beta
(
t
)
return
variance
def
step
(
self
,
residual
,
sample
,
t
):
def
step
(
self
,
residual
,
sample
,
t
,
predict_epsilon
=
True
):
# 1. compute alphas, betas
alpha_prod_t
=
self
.
get_alpha_prod
(
t
)
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
t
-
1
)
...
...
@@ -119,7 +122,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample
=
(
sample
-
beta_prod_t
**
(
0.5
)
*
residual
)
/
alpha_prod_t
**
(
0.5
)
if
predict_epsilon
:
pred_original_sample
=
(
sample
-
beta_prod_t
**
(
0.5
)
*
residual
)
/
alpha_prod_t
**
(
0.5
)
else
:
pred_original_sample
=
residual
# 3. Clip "predicted x_0"
if
self
.
config
.
clip_sample
:
...
...
src/diffusers/schedulers/scheduling_utils.py
View file @
9c96682a
...
...
@@ -64,3 +64,13 @@ class SchedulerMixin:
return
torch
.
clamp
(
tensor
,
min_value
,
max_value
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
def
log
(
self
,
tensor
):
tensor_format
=
getattr
(
self
,
"tensor_format"
,
"pt"
)
if
tensor_format
==
"np"
:
return
np
.
log
(
tensor
)
elif
tensor_format
==
"pt"
:
return
torch
.
log
(
tensor
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment