Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
ca72c1f8
Commit
ca72c1f8
authored
Jun 13, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
059a6e9d
55d29ab7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
23 deletions
+35
-23
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-1
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+3
-3
src/diffusers/trainers/training_ddpm.py
src/diffusers/trainers/training_ddpm.py
+30
-18
No files found.
src/diffusers/__init__.py
View file @
ca72c1f8
...
...
@@ -6,7 +6,7 @@ __version__ = "0.0.3"
from
.modeling_utils
import
ModelMixin
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_glide
import
GLIDEUNetModel
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_ldm
import
UNetLDMModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
,
PNDM
,
BDDM
...
...
src/diffusers/models/__init__.py
View file @
ca72c1f8
...
...
@@ -17,5 +17,5 @@
# limitations under the License.
from
.unet
import
UNetModel
from
.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.unet_glide
import
GLIDEUNetModel
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.unet_ldm
import
UNetLDMModel
src/diffusers/schedulers/scheduling_ddpm.py
View file @
ca72c1f8
...
...
@@ -63,8 +63,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self
.
alphas
=
1.0
-
self
.
betas
self
.
alphas_cumprod
=
np
.
cumprod
(
self
.
alphas
,
axis
=
0
)
self
.
sqrt_alphas_cumprod
=
np
.
sqrt
(
self
.
alphas_cumprod
)
self
.
sqrt_one_minus_alphas_cumprod
=
np
.
sqrt
(
1
-
self
.
alphas_cumprod
)
self
.
one
=
np
.
array
(
1.0
)
self
.
set_format
(
tensor_format
=
tensor_format
)
...
...
@@ -141,7 +139,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
pred_prev_image
def
forward_step
(
self
,
original_image
,
noise
,
t
):
noisy_image
=
self
.
sqrt_alphas_cumprod
[
t
]
*
original_image
+
self
.
sqrt_one_minus_alphas_cumprod
[
t
]
*
noise
sqrt_alpha_prod
=
self
.
get_alpha_prod
(
t
)
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
get_alpha_prod
(
t
))
**
0.5
noisy_image
=
sqrt_alpha_prod
*
original_image
+
sqrt_one_minus_alpha_prod
*
noise
return
noisy_image
def
__len__
(
self
):
...
...
src/diffusers/trainers/training_ddpm.py
View file @
ca72c1f8
...
...
@@ -24,20 +24,28 @@ def set_seed(seed):
set_seed
(
0
)
accelerator
=
Accelerator
(
mixed_precision
=
"fp16"
)
model
=
UNetModel
(
ch
=
128
,
ch_mult
=
(
1
,
2
,
4
,
8
),
resolution
=
64
)
accelerator
=
Accelerator
()
model
=
UNetModel
(
attn_resolutions
=
(
16
,),
ch
=
128
,
ch_mult
=
(
1
,
2
,
2
,
2
),
dropout
=
0.1
,
num_res_blocks
=
2
,
resamp_with_conv
=
True
,
resolution
=
32
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
)
optimizer
=
torch
.
optim
.
Adam
W
(
model
.
parameters
(),
lr
=
1e-4
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.0002
)
num_epochs
=
100
batch_size
=
8
gradient_accumulation_steps
=
8
batch_size
=
64
gradient_accumulation_steps
=
2
augmentations
=
Compose
(
[
Resize
(
64
),
CenterCrop
(
64
),
Resize
(
32
),
CenterCrop
(
32
),
RandomHorizontalFlip
(),
ToTensor
(),
Lambda
(
lambda
x
:
x
*
2
-
1
),
...
...
@@ -55,14 +63,14 @@ dataset = dataset.shuffle(seed=0)
dataset
.
set_transform
(
transforms
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
False
)
lr_scheduler
=
get_linear_schedule_with_warmup
(
optimizer
=
optimizer
,
num_warmup_steps
=
1000
,
num_training_steps
=
(
len
(
train_dataloader
)
*
num_epochs
)
//
gradient_accumulation_steps
,
)
#
lr_scheduler = get_linear_schedule_with_warmup(
#
optimizer=optimizer,
#
num_warmup_steps=1000,
#
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
#
)
model
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
optimizer
,
train_dataloader
,
lr_scheduler
model
,
optimizer
,
train_dataloader
=
accelerator
.
prepare
(
model
,
optimizer
,
train_dataloader
)
for
epoch
in
range
(
num_epochs
):
...
...
@@ -72,24 +80,28 @@ for epoch in range(num_epochs):
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_images
=
batch
[
"input"
]
noisy_images
=
torch
.
empty_like
(
clean_images
)
noise_samples
=
torch
.
empty_like
(
clean_images
)
bsz
=
clean_images
.
shape
[
0
]
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_images
.
device
).
long
()
for
idx
in
range
(
bsz
):
noise
=
torch
.
randn_like
(
clean_images
[
0
]).
to
(
clean_images
.
device
)
noise
=
torch
.
randn
((
3
,
32
,
32
)).
to
(
clean_images
.
device
)
noise_samples
[
idx
]
=
noise
noisy_images
[
idx
]
=
noise_scheduler
.
forward_step
(
clean_images
[
idx
],
noise
,
timesteps
[
idx
])
if
step
%
gradient_accumulation_steps
==
0
:
with
accelerator
.
no_sync
(
model
):
output
=
model
(
noisy_images
,
timesteps
)
loss
=
F
.
l1_loss
(
output
,
clean_images
)
# predict the noise
loss
=
F
.
l1_loss
(
output
,
noise_samples
)
accelerator
.
backward
(
loss
)
else
:
output
=
model
(
noisy_images
,
timesteps
)
loss
=
F
.
l1_loss
(
output
,
clean_images
)
accelerator
.
backward
(
loss
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
optimizer
.
step
()
lr_scheduler
.
step
()
#
lr_scheduler.step()
optimizer
.
zero_grad
()
pbar
.
update
(
1
)
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment