Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
418888a5
Commit
418888a5
authored
Jun 14, 2022
by
anton-l
Browse files
Pokemon DDPM training
parent
55d29ab7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
18 deletions
+21
-18
src/diffusers/trainers/training_ddpm.py
src/diffusers/trainers/training_ddpm.py
+21
-18
No files found.
src/diffusers/trainers/training_ddpm.py
View file @
418888a5
...
...
@@ -8,14 +8,14 @@ import PIL.Image
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
torchvision.transforms
import
CenterCrop
,
Compose
,
Lambda
,
RandomHorizontalFlip
,
Resize
,
ToTensor
from
torchvision.transforms
import
InterpolationMode
,
CenterCrop
,
Compose
,
Lambda
,
RandomRotation
,
RandomHorizontalFlip
,
Resize
,
ToTensor
from
tqdm.auto
import
tqdm
from
transformers
import
get_linear_schedule_with_warmup
def
set_seed
(
seed
):
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
#
torch.backends.cudnn.deterministic = True
#
torch.backends.cudnn.benchmark = False
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
np
.
random
.
seed
(
seed
)
...
...
@@ -30,13 +30,13 @@ model = UNetModel(
attn_resolutions
=
(
16
,),
ch
=
128
,
ch_mult
=
(
1
,
2
,
2
,
2
),
dropout
=
0.
1
,
dropout
=
0.
0
,
num_res_blocks
=
2
,
resamp_with_conv
=
True
,
resolution
=
32
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.0002
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
3e-4
)
num_epochs
=
100
batch_size
=
64
...
...
@@ -44,9 +44,10 @@ gradient_accumulation_steps = 2
augmentations
=
Compose
(
[
Resize
(
32
),
CenterCrop
(
32
),
RandomHorizontalFlip
(),
RandomRotation
(
15
,
interpolation
=
InterpolationMode
.
BILINEAR
,
fill
=
1
),
Resize
(
32
,
interpolation
=
InterpolationMode
.
BILINEAR
),
CenterCrop
(
32
),
ToTensor
(),
Lambda
(
lambda
x
:
x
*
2
-
1
),
]
...
...
@@ -59,24 +60,24 @@ def transforms(examples):
return
{
"input"
:
images
}
dataset
=
dataset
.
shuffle
(
seed
=
0
)
dataset
.
set_transform
(
transforms
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
Fals
e
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
Tru
e
)
#
lr_scheduler = get_linear_schedule_with_warmup(
#
optimizer=optimizer,
#
num_warmup_steps=
10
00,
#
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
#
)
lr_scheduler
=
get_linear_schedule_with_warmup
(
optimizer
=
optimizer
,
num_warmup_steps
=
5
00
,
num_training_steps
=
(
len
(
train_dataloader
)
*
num_epochs
)
//
gradient_accumulation_steps
,
)
model
,
optimizer
,
train_dataloader
=
accelerator
.
prepare
(
model
,
optimizer
,
train_dataloader
model
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
optimizer
,
train_dataloader
,
lr_scheduler
)
for
epoch
in
range
(
num_epochs
):
model
.
train
()
pbar
=
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
losses
=
[]
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_images
=
batch
[
"input"
]
noisy_images
=
torch
.
empty_like
(
clean_images
)
...
...
@@ -101,10 +102,12 @@ for epoch in range(num_epochs):
accelerator
.
backward
(
loss
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
optimizer
.
step
()
#
lr_scheduler.step()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
loss
=
loss
.
detach
().
item
()
losses
.
append
(
loss
)
pbar
.
update
(
1
)
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(
),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
pbar
.
set_postfix
(
loss
=
loss
,
avg_loss
=
np
.
mean
(
losses
),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
optimizer
.
step
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment