Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
bb306642
Commit
bb306642
authored
Jun 14, 2022
by
anton-l
Browse files
Move the training example
parent
418888a5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
12 deletions
+21
-12
Makefile
Makefile
+1
-1
examples/training_ddpm.py
examples/training_ddpm.py
+19
-10
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+1
-1
No files found.
Makefile
View file @
bb306642
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export
PYTHONPATH
=
src
export
PYTHONPATH
=
src
check_dirs
:=
tests src utils
check_dirs
:=
examples
tests src utils
modified_only_fixup
:
modified_only_fixup
:
$(
eval
modified_py_files :
=
$(
shell
python utils/get_modified_files.py
$(check_dirs)
))
$(
eval
modified_py_files :
=
$(
shell
python utils/get_modified_files.py
$(check_dirs)
))
...
...
src/diffusers/trainer
s/training_ddpm.py
→
example
s/training_ddpm.py
View file @
bb306642
...
@@ -8,14 +8,23 @@ import PIL.Image
...
@@ -8,14 +8,23 @@ import PIL.Image
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
torchvision.transforms
import
InterpolationMode
,
CenterCrop
,
Compose
,
Lambda
,
RandomRotation
,
RandomHorizontalFlip
,
Resize
,
ToTensor
from
torchvision.transforms
import
(
Compose
,
InterpolationMode
,
Lambda
,
RandomCrop
,
RandomHorizontalFlip
,
RandomVerticalFlip
,
Resize
,
ToTensor
,
)
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
from
transformers
import
get_linear_schedule_with_warmup
from
transformers
import
get_linear_schedule_with_warmup
def
set_seed
(
seed
):
def
set_seed
(
seed
):
#torch.backends.cudnn.deterministic = True
#
torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = False
#
torch.backends.cudnn.benchmark = False
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
...
@@ -33,7 +42,7 @@ model = UNetModel(
...
@@ -33,7 +42,7 @@ model = UNetModel(
dropout
=
0.0
,
dropout
=
0.0
,
num_res_blocks
=
2
,
num_res_blocks
=
2
,
resamp_with_conv
=
True
,
resamp_with_conv
=
True
,
resolution
=
32
resolution
=
32
,
)
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
3e-4
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
3e-4
)
...
@@ -44,15 +53,15 @@ gradient_accumulation_steps = 2
...
@@ -44,15 +53,15 @@ gradient_accumulation_steps = 2
augmentations
=
Compose
(
augmentations
=
Compose
(
[
[
RandomHorizontalFlip
(),
RandomRotation
(
15
,
interpolation
=
InterpolationMode
.
BILINEAR
,
fill
=
1
),
Resize
(
32
,
interpolation
=
InterpolationMode
.
BILINEAR
),
Resize
(
32
,
interpolation
=
InterpolationMode
.
BILINEAR
),
CenterCrop
(
32
),
RandomHorizontalFlip
(),
RandomVerticalFlip
(),
RandomCrop
(
32
),
ToTensor
(),
ToTensor
(),
Lambda
(
lambda
x
:
x
*
2
-
1
),
Lambda
(
lambda
x
:
x
*
2
-
1
),
]
]
)
)
dataset
=
load_dataset
(
"huggan/
pokemon
"
,
split
=
"train"
)
dataset
=
load_dataset
(
"huggan/
flowers-102-categories
"
,
split
=
"train"
)
def
transforms
(
examples
):
def
transforms
(
examples
):
...
@@ -127,5 +136,5 @@ for epoch in range(num_epochs):
...
@@ -127,5 +136,5 @@ for epoch in range(num_epochs):
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
# save image
# save image
pipeline
.
save_pretrained
(
"./
poke
-ddpm"
)
pipeline
.
save_pretrained
(
"./
flowers
-ddpm"
)
image_pil
.
save
(
f
"./
poke
-ddpm/test_
{
epoch
}
.png"
)
image_pil
.
save
(
f
"./
flowers
-ddpm/test_
{
epoch
}
.png"
)
tests/test_modeling_utils.py
View file @
bb306642
...
@@ -19,7 +19,7 @@ import unittest
...
@@ -19,7 +19,7 @@ import unittest
import
torch
import
torch
from
diffusers
import
DDIM
,
DDPM
,
DDIMScheduler
,
DDPMScheduler
,
LatentDiffusion
,
UNetModel
,
PNDM
,
PNDMScheduler
from
diffusers
import
DDIM
,
DDPM
,
PNDM
,
DDIMScheduler
,
DDPMScheduler
,
LatentDiffusion
,
PNDMScheduler
,
UNetModel
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment