Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
cfe6eb16
Commit
cfe6eb16
authored
Jun 15, 2022
by
anton-l
Browse files
Training example parameterization
parent
31a7c75b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
39 deletions
+51
-39
examples/train_ddpm.py
examples/train_ddpm.py
+36
-25
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+15
-14
No files found.
examples/train
ing
_ddpm.py
→
examples/train_ddpm.py
View file @
cfe6eb16
import
argparse
import
os
import
torch
import
PIL.Image
import
argparse
import
torch.nn.functional
as
F
import
PIL.Image
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
...
...
@@ -31,44 +31,40 @@ def main(args):
dropout
=
0.0
,
num_res_blocks
=
2
,
resamp_with_conv
=
True
,
resolution
=
64
,
resolution
=
args
.
resolution
,
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
1e-4
)
num_epochs
=
100
batch_size
=
16
gradient_accumulation_steps
=
1
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
augmentations
=
Compose
(
[
Resize
(
64
,
interpolation
=
InterpolationMode
.
BILINEAR
),
RandomCrop
(
64
),
Resize
(
args
.
resolution
,
interpolation
=
InterpolationMode
.
BILINEAR
),
RandomCrop
(
args
.
resolution
),
RandomHorizontalFlip
(),
ToTensor
(),
Lambda
(
lambda
x
:
x
*
2
-
1
),
]
)
dataset
=
load_dataset
(
"huggan/pokemon"
,
split
=
"train"
)
dataset
=
load_dataset
(
args
.
dataset
,
split
=
"train"
)
def
transforms
(
examples
):
images
=
[
augmentations
(
image
.
convert
(
"RGB"
))
for
image
in
examples
[
"image"
]]
return
{
"input"
:
images
}
dataset
.
set_transform
(
transforms
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
True
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
)
lr_scheduler
=
get_linear_schedule_with_warmup
(
optimizer
=
optimizer
,
num_warmup_steps
=
500
,
num_training_steps
=
(
len
(
train_dataloader
)
*
num_epochs
)
//
gradient_accumulation_steps
,
num_warmup_steps
=
args
.
warmup_steps
,
num_training_steps
=
(
len
(
train_dataloader
)
*
args
.
num_epochs
)
//
args
.
gradient_accumulation_steps
,
)
model
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
optimizer
,
train_dataloader
,
lr_scheduler
)
for
epoch
in
range
(
num_epochs
):
for
epoch
in
range
(
args
.
num_epochs
):
model
.
train
()
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
...
...
@@ -84,14 +80,15 @@ def main(args):
noise_samples
[
idx
]
=
noise
noisy_images
[
idx
]
=
noise_scheduler
.
forward_step
(
clean_images
[
idx
],
noise
,
timesteps
[
idx
])
if
step
%
gradient_accumulation_steps
!=
0
:
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
with
accelerator
.
no_sync
(
model
):
output
=
model
(
noisy_images
,
timesteps
)
# predict the noise
# predict the noise
residual
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
accelerator
.
backward
(
loss
)
else
:
output
=
model
(
noisy_images
,
timesteps
)
# predict the noise residual
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
accelerator
.
backward
(
loss
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
...
...
@@ -103,13 +100,18 @@ def main(args):
optimizer
.
step
()
# Generate a sample image for visual inspection
torch
.
distributed
.
barrier
()
if
args
.
local_rank
in
[
-
1
,
0
]:
model
.
eval
()
with
torch
.
no_grad
():
pipeline
=
DDPM
(
unet
=
model
.
module
,
noise_scheduler
=
noise_scheduler
)
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
0
)
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
DistributedDataParallel
):
pipeline
=
DDPM
(
unet
=
model
.
module
,
noise_scheduler
=
noise_scheduler
)
else
:
pipeline
=
DDPM
(
unet
=
model
,
noise_scheduler
=
noise_scheduler
)
pipeline
.
save_pretrained
(
args
.
output_path
)
generator
=
torch
.
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
image
=
pipeline
(
generator
=
generator
)
...
...
@@ -120,22 +122,31 @@ def main(args):
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
# save image
pipeline
.
save_pretrained
(
"./pokemon-ddpm"
)
image_pil
.
save
(
f
"./pokemon-ddpm/test_
{
epoch
}
.png"
)
test_dir
=
os
.
path
.
join
(
args
.
output_path
,
"test_samples"
)
os
.
makedirs
(
test_dir
,
exist_ok
=
True
)
image_pil
.
save
(
f
"
{
test_dir
}
/
{
epoch
}
.png"
)
torch
.
distributed
.
barrier
()
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of training script."
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of
a
training script."
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"huggan/flowers-102-categories"
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
default
=
"ddpm-model"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--mixed_precision"
,
type
=
str
,
default
=
"no"
,
choices
=
[
"no"
,
"fp16"
,
"bf16"
],
help
=
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
,
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
,
)
args
=
parser
.
parse_args
()
...
...
tests/test_modeling_utils.py
View file @
cfe6eb16
...
...
@@ -214,6 +214,21 @@ class PipelineTesterMixin(unittest.TestCase):
expected_slice
=
torch
.
tensor
([
0.7295
,
0.7358
,
0.7256
,
0.7435
,
0.7095
,
0.6884
,
0.7325
,
0.6921
,
0.6458
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_glide_text2img
(
self
):
model_id
=
"fusing/glide-base"
glide
=
GLIDE
.
from_pretrained
(
model_id
)
prompt
=
"a pencil sketch of a corgi"
generator
=
torch
.
manual_seed
(
0
)
image
=
glide
(
prompt
,
generator
=
generator
,
num_inference_steps_upscale
=
20
)
image_slice
=
image
[
0
,
:
3
,
:
3
,
-
1
].
cpu
()
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
expected_slice
=
torch
.
tensor
([
0.7119
,
0.7073
,
0.6460
,
0.7780
,
0.7423
,
0.6926
,
0.7378
,
0.7189
,
0.7784
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
def
test_module_from_pipeline
(
self
):
model
=
DiffWave
(
num_res_layers
=
4
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
...
...
@@ -229,17 +244,3 @@ class PipelineTesterMixin(unittest.TestCase):
_
=
BDDM
.
from_pretrained
(
tmpdirname
)
# check if the same works using the DifusionPipeline class
_
=
DiffusionPipeline
.
from_pretrained
(
tmpdirname
)
@
slow
def
test_glide_text2img
(
self
):
model_id
=
"fusing/glide-base"
glide
=
GLIDE
.
from_pretrained
(
model_id
)
prompt
=
"a pencil sketch of a corgi"
generator
=
torch
.
manual_seed
(
0
)
image
=
glide
(
prompt
,
generator
=
generator
,
num_inference_steps_upscale
=
20
)
image_slice
=
image
[
0
,
:
3
,
:
3
,
-
1
].
cpu
()
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
expected_slice
=
torch
.
tensor
([
0.7119
,
0.7073
,
0.6460
,
0.7780
,
0.7423
,
0.6926
,
0.7378
,
0.7189
,
0.7784
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment