Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
e45dae7d
Commit
e45dae7d
authored
Jun 22, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
d0032c60
33abc795
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
16 deletions
+22
-16
examples/README.md
examples/README.md
+2
-2
examples/train_unconditional.py
examples/train_unconditional.py
+7
-9
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+13
-5
No files found.
examples/README.md
View file @
e45dae7d
...
@@ -10,7 +10,7 @@ python -m torch.distributed.launch \
...
@@ -10,7 +10,7 @@ python -m torch.distributed.launch \
train_unconditional.py
\
train_unconditional.py
\
--dataset
=
"huggan/flowers-102-categories"
\
--dataset
=
"huggan/flowers-102-categories"
\
--resolution
=
64
\
--resolution
=
64
\
--output_
path
=
"flowers-ddpm"
\
--output_
dir
=
"flowers-ddpm"
\
--batch_size
=
16
\
--batch_size
=
16
\
--num_epochs
=
100
\
--num_epochs
=
100
\
--gradient_accumulation_steps
=
1
\
--gradient_accumulation_steps
=
1
\
...
@@ -34,7 +34,7 @@ python -m torch.distributed.launch \
...
@@ -34,7 +34,7 @@ python -m torch.distributed.launch \
train_unconditional.py
\
train_unconditional.py
\
--dataset
=
"huggan/pokemon"
\
--dataset
=
"huggan/pokemon"
\
--resolution
=
64
\
--resolution
=
64
\
--output_
path
=
"pokemon-ddpm"
\
--output_
dir
=
"pokemon-ddpm"
\
--batch_size
=
16
\
--batch_size
=
16
\
--num_epochs
=
100
\
--num_epochs
=
100
\
--gradient_accumulation_steps
=
1
\
--gradient_accumulation_steps
=
1
\
...
...
examples/train_unconditional.py
View file @
e45dae7d
...
@@ -39,7 +39,7 @@ def main(args):
...
@@ -39,7 +39,7 @@ def main(args):
resamp_with_conv
=
True
,
resamp_with_conv
=
True
,
resolution
=
args
.
resolution
,
resolution
=
args
.
resolution
,
)
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
,
tensor_format
=
"pt"
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
augmentations
=
Compose
(
augmentations
=
Compose
(
...
@@ -93,15 +93,13 @@ def main(args):
...
@@ -93,15 +93,13 @@ def main(args):
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
for
step
,
batch
in
enumerate
(
train_dataloader
):
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_images
=
batch
[
"input"
]
clean_images
=
batch
[
"input"
]
noisy_images
=
torch
.
empty_like
(
clean_images
)
noise_samples
=
torch
.
randn
(
clean_images
.
shape
).
to
(
clean_images
.
device
)
noise_samples
=
torch
.
empty_like
(
clean_images
)
bsz
=
clean_images
.
shape
[
0
]
bsz
=
clean_images
.
shape
[
0
]
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_images
.
device
).
long
()
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_images
.
device
).
long
()
for
idx
in
range
(
bsz
):
noise
=
torch
.
randn
(
clean
_
images
.
shape
[
1
:]).
to
(
clean_images
.
device
)
# add
noise
onto the
clean
images
according to the noise magnitude at each timestep
noise_samples
[
idx
]
=
noise
# (this is the forward diffusion process)
noisy_images
[
idx
]
=
noise_scheduler
.
forward
_step
(
clean_images
[
idx
]
,
noise
,
timesteps
[
idx
]
)
noisy_images
=
noise_scheduler
.
training
_step
(
clean_images
,
noise
_samples
,
timesteps
)
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
with
accelerator
.
no_sync
(
model
):
with
accelerator
.
no_sync
(
model
):
...
@@ -146,7 +144,7 @@ def main(args):
...
@@ -146,7 +144,7 @@ def main(args):
# save image
# save image
test_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"test_samples"
)
test_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"test_samples"
)
os
.
makedirs
(
test_dir
,
exist_ok
=
True
)
os
.
makedirs
(
test_dir
,
exist_ok
=
True
)
image_pil
.
save
(
f
"
{
test_dir
}
/
{
epoch
}
.png"
)
image_pil
.
save
(
f
"
{
test_dir
}
/
{
epoch
:
04
d
}
.png"
)
# save the model
# save the model
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
e45dae7d
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -142,11 +143,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -142,11 +143,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
pred_prev_sample
return
pred_prev_sample
def
forward_step
(
self
,
original_sample
,
noise
,
t
):
def
training_step
(
self
,
original_samples
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
,
timesteps
:
torch
.
Tensor
):
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
t
]
**
0.5
if
timesteps
.
dim
()
!=
1
:
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
t
])
**
0.5
raise
ValueError
(
"`timesteps` must be a 1D tensor"
)
noisy_sample
=
sqrt_alpha_prod
*
original_sample
+
sqrt_one_minus_alpha_prod
*
noise
return
noisy_sample
device
=
original_samples
.
device
batch_size
=
original_samples
.
shape
[
0
]
timesteps
=
timesteps
.
reshape
(
batch_size
,
1
,
1
,
1
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
noisy_samples
=
sqrt_alpha_prod
.
to
(
device
)
*
original_samples
+
sqrt_one_minus_alpha_prod
.
to
(
device
)
*
noise
return
noisy_samples
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
config
.
timesteps
return
self
.
config
.
timesteps
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment