Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
848c86ca
Commit
848c86ca
authored
Jun 22, 2022
by
anton-l
Browse files
batched forward diffusion step
parent
62c2c547
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
14 deletions
+20
-14
examples/train_unconditional.py
examples/train_unconditional.py
+7
-9
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+13
-5
No files found.
examples/train_unconditional.py
View file @
848c86ca
...
@@ -39,7 +39,7 @@ def main(args):
...
@@ -39,7 +39,7 @@ def main(args):
resamp_with_conv
=
True
,
resamp_with_conv
=
True
,
resolution
=
args
.
resolution
,
resolution
=
args
.
resolution
,
)
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
,
tensor_format
=
"pt"
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
augmentations
=
Compose
(
augmentations
=
Compose
(
...
@@ -93,15 +93,13 @@ def main(args):
...
@@ -93,15 +93,13 @@ def main(args):
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
for
step
,
batch
in
enumerate
(
train_dataloader
):
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_images
=
batch
[
"input"
]
clean_images
=
batch
[
"input"
]
noisy_images
=
torch
.
empty_like
(
clean_images
)
noise_samples
=
torch
.
randn
(
clean_images
.
shape
).
to
(
clean_images
.
device
)
noise_samples
=
torch
.
empty_like
(
clean_images
)
bsz
=
clean_images
.
shape
[
0
]
bsz
=
clean_images
.
shape
[
0
]
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_images
.
device
).
long
()
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_images
.
device
).
long
()
for
idx
in
range
(
bsz
):
noise
=
torch
.
randn
(
clean
_
images
.
shape
[
1
:]).
to
(
clean_images
.
device
)
# add
noise
onto the
clean
images
according to the noise magnitude at each timestep
noise_samples
[
idx
]
=
noise
# (this is the forward diffusion process)
noisy_images
[
idx
]
=
noise_scheduler
.
forward
_step
(
clean_images
[
idx
]
,
noise
,
timesteps
[
idx
]
)
noisy_images
=
noise_scheduler
.
training
_step
(
clean_images
,
noise
_samples
,
timesteps
)
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
with
accelerator
.
no_sync
(
model
):
with
accelerator
.
no_sync
(
model
):
...
@@ -146,7 +144,7 @@ def main(args):
...
@@ -146,7 +144,7 @@ def main(args):
# save image
# save image
test_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"test_samples"
)
test_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"test_samples"
)
os
.
makedirs
(
test_dir
,
exist_ok
=
True
)
os
.
makedirs
(
test_dir
,
exist_ok
=
True
)
image_pil
.
save
(
f
"
{
test_dir
}
/
{
epoch
}
.png"
)
image_pil
.
save
(
f
"
{
test_dir
}
/
{
epoch
:
04
d
}
.png"
)
# save the model
# save the model
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
848c86ca
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -142,11 +143,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -142,11 +143,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
pred_prev_sample
return
pred_prev_sample
def
forward_step
(
self
,
original_sample
,
noise
,
t
):
def
training_step
(
self
,
original_samples
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
,
timesteps
:
torch
.
Tensor
):
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
t
]
**
0.5
if
timesteps
.
dim
()
!=
1
:
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
t
])
**
0.5
raise
ValueError
(
"`timesteps` must be a 1D tensor"
)
noisy_sample
=
sqrt_alpha_prod
*
original_sample
+
sqrt_one_minus_alpha_prod
*
noise
return
noisy_sample
device
=
original_samples
.
device
batch_size
=
original_samples
.
shape
[
0
]
timesteps
=
timesteps
.
reshape
(
batch_size
,
1
,
1
,
1
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
noisy_samples
=
sqrt_alpha_prod
.
to
(
device
)
*
original_samples
+
sqrt_one_minus_alpha_prod
.
to
(
device
)
*
noise
return
noisy_samples
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
config
.
timesteps
return
self
.
config
.
timesteps
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment