Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
e45dae7d
Commit
e45dae7d
authored
Jun 22, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
d0032c60
33abc795
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
16 deletions
+22
-16
examples/README.md
examples/README.md
+2
-2
examples/train_unconditional.py
examples/train_unconditional.py
+7
-9
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+13
-5
No files found.
examples/README.md
View file @
e45dae7d
...
...
@@ -10,7 +10,7 @@ python -m torch.distributed.launch \
train_unconditional.py
\
--dataset
=
"huggan/flowers-102-categories"
\
--resolution
=
64
\
--output_
path
=
"flowers-ddpm"
\
--output_
dir
=
"flowers-ddpm"
\
--batch_size
=
16
\
--num_epochs
=
100
\
--gradient_accumulation_steps
=
1
\
...
...
@@ -34,7 +34,7 @@ python -m torch.distributed.launch \
train_unconditional.py
\
--dataset
=
"huggan/pokemon"
\
--resolution
=
64
\
--output_
path
=
"pokemon-ddpm"
\
--output_
dir
=
"pokemon-ddpm"
\
--batch_size
=
16
\
--num_epochs
=
100
\
--gradient_accumulation_steps
=
1
\
...
...
examples/train_unconditional.py
View file @
e45dae7d
...
...
@@ -39,7 +39,7 @@ def main(args):
resamp_with_conv
=
True
,
resolution
=
args
.
resolution
,
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
,
tensor_format
=
"pt"
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
augmentations
=
Compose
(
...
...
@@ -93,15 +93,13 @@ def main(args):
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_images
=
batch
[
"input"
]
noisy_images
=
torch
.
empty_like
(
clean_images
)
noise_samples
=
torch
.
empty_like
(
clean_images
)
noise_samples
=
torch
.
randn
(
clean_images
.
shape
).
to
(
clean_images
.
device
)
bsz
=
clean_images
.
shape
[
0
]
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_images
.
device
).
long
()
for
idx
in
range
(
bsz
):
noise
=
torch
.
randn
(
clean
_
images
.
shape
[
1
:]).
to
(
clean_images
.
device
)
noise_samples
[
idx
]
=
noise
noisy_images
[
idx
]
=
noise_scheduler
.
forward
_step
(
clean_images
[
idx
]
,
noise
,
timesteps
[
idx
]
)
# add
noise
onto the
clean
images
according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images
=
noise_scheduler
.
training
_step
(
clean_images
,
noise
_samples
,
timesteps
)
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
with
accelerator
.
no_sync
(
model
):
...
...
@@ -146,7 +144,7 @@ def main(args):
# save image
test_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"test_samples"
)
os
.
makedirs
(
test_dir
,
exist_ok
=
True
)
image_pil
.
save
(
f
"
{
test_dir
}
/
{
epoch
}
.png"
)
image_pil
.
save
(
f
"
{
test_dir
}
/
{
epoch
:
04
d
}
.png"
)
# save the model
if
args
.
push_to_hub
:
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
e45dae7d
...
...
@@ -17,6 +17,7 @@
import
math
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -142,11 +143,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
pred_prev_sample
def
forward_step
(
self
,
original_sample
,
noise
,
t
):
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
t
]
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
t
])
**
0.5
noisy_sample
=
sqrt_alpha_prod
*
original_sample
+
sqrt_one_minus_alpha_prod
*
noise
return
noisy_sample
def
training_step
(
self
,
original_samples
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
,
timesteps
:
torch
.
Tensor
):
if
timesteps
.
dim
()
!=
1
:
raise
ValueError
(
"`timesteps` must be a 1D tensor"
)
device
=
original_samples
.
device
batch_size
=
original_samples
.
shape
[
0
]
timesteps
=
timesteps
.
reshape
(
batch_size
,
1
,
1
,
1
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
noisy_samples
=
sqrt_alpha_prod
.
to
(
device
)
*
original_samples
+
sqrt_one_minus_alpha_prod
.
to
(
device
)
*
noise
return
noisy_samples
def
__len__
(
self
):
return
self
.
config
.
timesteps
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment