Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
07ff0abf
Commit
07ff0abf
authored
Jun 27, 2022
by
anton-l
Browse files
Glide and LDM training experiments
parent
3286dac6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
246 additions
and
31 deletions
+246
-31
examples/experimental/train_glide_text_to_image.py
examples/experimental/train_glide_text_to_image.py
+201
-0
examples/train_latent_text_to_image.py
examples/train_latent_text_to_image.py
+45
-31
No files found.
examples/experimental/train_glide_text_to_image.py
0 → 100644
View file @
07ff0abf
import
argparse
import
os
import
torch
import
torch.nn.functional
as
F
import
bitsandbytes
as
bnb
import
PIL.Image
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
diffusers
import
DDPMScheduler
,
Glide
,
GlideUNetModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
logging
from
torchvision.transforms
import
(
CenterCrop
,
Compose
,
InterpolationMode
,
Normalize
,
RandomHorizontalFlip
,
Resize
,
ToTensor
,
)
from
tqdm.auto
import
tqdm
logger
=
logging
.
get_logger
(
__name__
)
def
main
(
args
):
accelerator
=
Accelerator
(
mixed_precision
=
args
.
mixed_precision
)
pipeline
=
Glide
.
from_pretrained
(
"fusing/glide-base"
)
model
=
pipeline
.
text_unet
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
,
tensor_format
=
"pt"
)
optimizer
=
bnb
.
optim
.
Adam8bit
(
model
.
parameters
(),
lr
=
args
.
lr
)
augmentations
=
Compose
(
[
Resize
(
args
.
resolution
,
interpolation
=
InterpolationMode
.
BILINEAR
),
CenterCrop
(
args
.
resolution
),
RandomHorizontalFlip
(),
ToTensor
(),
Normalize
([
0.5
],
[
0.5
]),
]
)
dataset
=
load_dataset
(
args
.
dataset
,
split
=
"train"
)
text_encoder
=
pipeline
.
text_encoder
.
eval
()
def
transforms
(
examples
):
images
=
[
augmentations
(
image
.
convert
(
"RGB"
))
for
image
in
examples
[
"image"
]]
text_inputs
=
pipeline
.
tokenizer
(
examples
[
"caption"
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
)
text_inputs
=
text_inputs
.
input_ids
.
to
(
accelerator
.
device
)
with
torch
.
no_grad
():
text_embeddings
=
accelerator
.
unwrap_model
(
text_encoder
)(
text_inputs
).
last_hidden_state
return
{
"images"
:
images
,
"text_embeddings"
:
text_embeddings
}
dataset
.
set_transform
(
transforms
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
)
lr_scheduler
=
get_scheduler
(
"linear"
,
optimizer
=
optimizer
,
num_warmup_steps
=
args
.
warmup_steps
,
num_training_steps
=
(
len
(
train_dataloader
)
*
args
.
num_epochs
)
//
args
.
gradient_accumulation_steps
,
)
model
,
text_encoder
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
text_encoder
,
optimizer
,
train_dataloader
,
lr_scheduler
)
if
args
.
push_to_hub
:
repo
=
init_git_repo
(
args
,
at_init
=
True
)
# Train!
is_distributed
=
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
()
world_size
=
torch
.
distributed
.
get_world_size
()
if
is_distributed
else
1
total_train_batch_size
=
args
.
batch_size
*
args
.
gradient_accumulation_steps
*
world_size
max_steps
=
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
*
args
.
num_epochs
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
f
" Num examples =
{
len
(
train_dataloader
.
dataset
)
}
"
)
logger
.
info
(
f
" Num Epochs =
{
args
.
num_epochs
}
"
)
logger
.
info
(
f
" Instantaneous batch size per device =
{
args
.
batch_size
}
"
)
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_train_batch_size
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
for
epoch
in
range
(
args
.
num_epochs
):
model
.
train
()
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_images
=
batch
[
"images"
]
batch_size
,
n_channels
,
height
,
width
=
clean_images
.
shape
noise_samples
=
torch
.
randn
(
clean_images
.
shape
).
to
(
clean_images
.
device
)
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
batch_size
,),
device
=
clean_images
.
device
).
long
()
# add noise onto the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images
=
noise_scheduler
.
training_step
(
clean_images
,
noise_samples
,
timesteps
)
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
with
accelerator
.
no_sync
(
model
):
model_output
=
model
(
noisy_images
,
timesteps
,
batch
[
"text_embeddings"
])
model_output
,
model_var_values
=
torch
.
split
(
model_output
,
n_channels
,
dim
=
1
)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out
=
torch
.
cat
([
model_output
.
detach
(),
model_var_values
],
dim
=
1
)
# predict the noise residual
loss
=
F
.
mse_loss
(
model_output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
accelerator
.
backward
(
loss
)
optimizer
.
step
()
else
:
model_output
=
model
(
noisy_images
,
timesteps
,
batch
[
"text_embeddings"
])
model_output
,
model_var_values
=
torch
.
split
(
model_output
,
n_channels
,
dim
=
1
)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out
=
torch
.
cat
([
model_output
.
detach
(),
model_var_values
],
dim
=
1
)
# predict the noise residual
loss
=
F
.
mse_loss
(
model_output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
accelerator
.
backward
(
loss
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
pbar
.
update
(
1
)
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
accelerator
.
wait_for_everyone
()
# Generate a sample image for visual inspection
if
accelerator
.
is_main_process
:
model
.
eval
()
with
torch
.
no_grad
():
pipeline
.
unet
=
accelerator
.
unwrap_model
(
model
)
generator
=
torch
.
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
image
=
pipeline
(
"a clip art of a corgi"
,
generator
=
generator
,
num_upscale_inference_steps
=
50
)
# process image to PIL
image_processed
=
image
.
squeeze
(
0
)
image_processed
=
((
image_processed
+
1
)
*
127.5
).
round
().
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
)
# save image
test_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"test_samples"
)
os
.
makedirs
(
test_dir
,
exist_ok
=
True
)
image_pil
.
save
(
f
"
{
test_dir
}
/
{
epoch
:
04
d
}
.png"
)
# save the model
if
args
.
push_to_hub
:
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
else
:
pipeline
.
save_pretrained
(
args
.
output_dir
)
accelerator
.
wait_for_everyone
()
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of a training script."
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"fusing/dog_captions"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"glide-text2image"
)
parser
.
add_argument
(
"--overwrite_output_dir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--hub_token"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_model_id"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_private_repo"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--mixed_precision"
,
type
=
str
,
default
=
"no"
,
choices
=
[
"no"
,
"fp16"
,
"bf16"
],
help
=
(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
args
=
parser
.
parse_args
()
env_local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
-
1
))
if
env_local_rank
!=
-
1
and
env_local_rank
!=
args
.
local_rank
:
args
.
local_rank
=
env_local_rank
main
(
args
)
examples/train_latent_text_to_image.py
View file @
07ff0abf
...
...
@@ -4,19 +4,19 @@ import os
import
torch
import
torch.nn.functional
as
F
import
bitsandbytes
as
bnb
import
PIL.Image
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetLDMModel
from
diffusers
import
DDPMScheduler
,
LatentDiffusion
,
UNetLDMModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.modeling_utils
import
unwrap_model
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
logging
from
torchvision.transforms
import
(
CenterCrop
,
Compose
,
InterpolationMode
,
Lambda
,
Normalize
,
RandomHorizontalFlip
,
Resize
,
ToTensor
,
...
...
@@ -30,6 +30,8 @@ logger = logging.get_logger(__name__)
def
main
(
args
):
accelerator
=
Accelerator
(
mixed_precision
=
args
.
mixed_precision
)
pipeline
=
LatentDiffusion
.
from_pretrained
(
"fusing/latent-diffusion-text2im-large"
)
pipeline
.
unet
=
None
# this model will be trained from scratch now
model
=
UNetLDMModel
(
attention_resolutions
=
[
4
,
2
,
1
],
channel_mult
=
[
1
,
2
,
4
,
4
],
...
...
@@ -37,7 +39,7 @@ def main(args):
conv_resample
=
True
,
dims
=
2
,
dropout
=
0
,
image_size
=
32
,
image_size
=
8
,
in_channels
=
4
,
model_channels
=
320
,
num_heads
=
8
,
...
...
@@ -51,7 +53,7 @@ def main(args):
legacy
=
False
,
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
,
tensor_format
=
"pt"
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
optimizer
=
bnb
.
optim
.
Adam
8bit
(
model
.
parameters
(),
lr
=
args
.
lr
)
augmentations
=
Compose
(
[
...
...
@@ -59,14 +61,22 @@ def main(args):
CenterCrop
(
args
.
resolution
),
RandomHorizontalFlip
(),
ToTensor
(),
Lambda
(
lambda
x
:
x
*
2
-
1
),
Normalize
([
0.5
],
[
0.5
]
),
]
)
dataset
=
load_dataset
(
args
.
dataset
,
split
=
"train"
)
text_encoder
=
pipeline
.
bert
.
eval
()
vqvae
=
pipeline
.
vqvae
.
eval
()
def
transforms
(
examples
):
images
=
[
augmentations
(
image
.
convert
(
"RGB"
))
for
image
in
examples
[
"image"
]]
return
{
"input"
:
images
}
text_inputs
=
pipeline
.
tokenizer
(
examples
[
"caption"
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
)
with
torch
.
no_grad
():
text_embeddings
=
accelerator
.
unwrap_model
(
text_encoder
)(
text_inputs
.
input_ids
.
cpu
()).
last_hidden_state
images
=
1
/
0.18215
*
torch
.
stack
(
images
,
dim
=
0
)
latents
=
accelerator
.
unwrap_model
(
vqvae
).
encode
(
images
.
cpu
()).
mode
()
return
{
"images"
:
images
,
"text_embeddings"
:
text_embeddings
,
"latents"
:
latents
}
dataset
.
set_transform
(
transforms
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
)
...
...
@@ -78,9 +88,11 @@ def main(args):
num_training_steps
=
(
len
(
train_dataloader
)
*
args
.
num_epochs
)
//
args
.
gradient_accumulation_steps
,
)
model
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
optimizer
,
train_dataloader
,
lr_scheduler
model
,
text_encoder
,
vqvae
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
text_encoder
,
vqvae
,
optimizer
,
train_dataloader
,
lr_scheduler
)
text_encoder
=
text_encoder
.
cpu
()
vqvae
=
vqvae
.
cpu
()
if
args
.
push_to_hub
:
repo
=
init_git_repo
(
args
,
at_init
=
True
)
...
...
@@ -98,29 +110,31 @@ def main(args):
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
global_step
=
0
for
epoch
in
range
(
args
.
num_epochs
):
model
.
train
()
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_
image
s
=
batch
[
"
input
"
]
noise_samples
=
torch
.
randn
(
clean_
image
s
.
shape
).
to
(
clean_
image
s
.
device
)
bsz
=
clean_
image
s
.
shape
[
0
]
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_
image
s
.
device
).
long
()
clean_
latent
s
=
batch
[
"
latents
"
]
noise_samples
=
torch
.
randn
(
clean_
latent
s
.
shape
).
to
(
clean_
latent
s
.
device
)
bsz
=
clean_
latent
s
.
shape
[
0
]
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_
latent
s
.
device
).
long
()
# add noise onto the clean
image
s according to the noise magnitude at each timestep
# add noise onto the clean
latent
s according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_
image
s
=
noise_scheduler
.
training_step
(
clean_
image
s
,
noise_samples
,
timesteps
)
noisy_
latent
s
=
noise_scheduler
.
training_step
(
clean_
latent
s
,
noise_samples
,
timesteps
)
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
with
accelerator
.
no_sync
(
model
):
output
=
model
(
noisy_
image
s
,
timesteps
)
output
=
model
(
noisy_
latent
s
,
timesteps
,
context
=
batch
[
"text_embeddings"
]
)
# predict the noise residual
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
accelerator
.
backward
(
loss
)
optimizer
.
step
()
else
:
output
=
model
(
noisy_
image
s
,
timesteps
)
output
=
model
(
noisy_
latent
s
,
timesteps
,
context
=
batch
[
"text_embeddings"
]
)
# predict the noise residual
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
...
...
@@ -131,24 +145,25 @@ def main(args):
optimizer
.
zero_grad
()
pbar
.
update
(
1
)
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
global_step
+=
1
optimizer
.
step
()
if
is_distributed
:
torch
.
distributed
.
barrier
()
accelerator
.
wait_for_everyone
()
# Generate a sample image for visual inspection
if
a
rgs
.
local_rank
in
[
-
1
,
0
]
:
if
a
ccelerator
.
is_main_process
:
model
.
eval
()
with
torch
.
no_grad
():
pipeline
=
DDPM
(
unet
=
unwrap_model
(
model
),
noise_scheduler
=
noise_scheduler
)
pipeline
.
unet
=
accelerator
.
unwrap_model
(
model
)
generator
=
torch
.
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
image
=
pipeline
(
generator
=
generator
)
image
=
pipeline
(
[
"a clip art of a corgi"
],
generator
=
generator
,
eta
=
0.3
,
guidance_scale
=
6.0
,
num_inference_steps
=
50
)
# process image to PIL
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
(
image_processed
+
1.0
)
*
127.5
image_processed
=
image_processed
*
255.0
image_processed
=
image_processed
.
type
(
torch
.
uint8
).
numpy
()
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
...
...
@@ -162,20 +177,19 @@ def main(args):
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
else
:
pipeline
.
save_pretrained
(
args
.
output_dir
)
if
is_distributed
:
torch
.
distributed
.
barrier
()
accelerator
.
wait_for_everyone
()
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of a training script."
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"
huggan/flowers-102-categorie
s"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"
ddpm-model
"
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"
fusing/dog_caption
s"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"
ldm-text2image
"
)
parser
.
add_argument
(
"--overwrite_output_dir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
6
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
6
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment