Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ee20d1f8
Unverified
Commit
ee20d1f8
authored
Apr 04, 2023
by
YiYi Xu
Committed by
GitHub
Apr 04, 2023
Browse files
update flax controlnet training script (#2951)
* load_from_disk + checkpointing_steps * apply feedback
parent
0d0fa2a3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
9 deletions
+40
-9
examples/controlnet/train_controlnet_flax.py
examples/controlnet/train_controlnet_flax.py
+40
-9
No files found.
examples/controlnet/train_controlnet_flax.py
View file @
ee20d1f8
...
...
@@ -27,13 +27,13 @@ import optax
import
torch
import
torch.utils.checkpoint
import
transformers
from
datasets
import
load_dataset
from
datasets
import
load_dataset
,
load_from_disk
from
flax
import
jax_utils
from
flax.core.frozen_dict
import
unfreeze
from
flax.training
import
train_state
from
flax.training.common_utils
import
shard
from
huggingface_hub
import
create_repo
,
upload_folder
from
PIL
import
Image
from
PIL
import
Image
,
PngImagePlugin
from
torch.utils.data
import
IterableDataset
from
torchvision
import
transforms
from
tqdm.auto
import
tqdm
...
...
@@ -49,6 +49,11 @@ from diffusers import (
from
diffusers.utils
import
check_min_version
,
is_wandb_available
# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
# see more https://github.com/python-pillow/Pillow/issues/5610
LARGE_ENOUGH_NUMBER
=
100
PngImagePlugin
.
MAX_TEXT_CHUNK
=
LARGE_ENOUGH_NUMBER
*
(
1024
**
2
)
if
is_wandb_available
():
import
wandb
...
...
@@ -246,6 +251,12 @@ def parse_args():
default
=
None
,
help
=
"Total number of training steps to perform."
,
)
parser
.
add_argument
(
"--checkpointing_steps"
,
type
=
int
,
default
=
5000
,
help
=
(
"Save a checkpoint of the training state every X updates."
),
)
parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
...
...
@@ -344,9 +355,17 @@ def parse_args():
type
=
str
,
default
=
None
,
help
=
(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
"A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder."
"Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ."
"If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified."
),
)
parser
.
add_argument
(
"--load_from_disk"
,
action
=
"store_true"
,
help
=
(
"If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`"
"See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk"
),
)
parser
.
add_argument
(
...
...
@@ -478,10 +497,15 @@ def make_train_dataset(args, tokenizer, batch_size=None):
)
else
:
if
args
.
train_data_dir
is
not
None
:
dataset
=
load_dataset
(
args
.
train_data_dir
,
cache_dir
=
args
.
cache_dir
,
)
if
args
.
load_from_disk
:
dataset
=
load_from_disk
(
args
.
train_data_dir
,
)
else
:
dataset
=
load_dataset
(
args
.
train_data_dir
,
cache_dir
=
args
.
cache_dir
,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
...
...
@@ -545,6 +569,7 @@ def make_train_dataset(args, tokenizer, batch_size=None):
image_transforms
=
transforms
.
Compose
(
[
transforms
.
Resize
(
args
.
resolution
,
interpolation
=
transforms
.
InterpolationMode
.
BILINEAR
),
transforms
.
CenterCrop
(
args
.
resolution
),
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.5
],
[
0.5
]),
]
...
...
@@ -553,6 +578,7 @@ def make_train_dataset(args, tokenizer, batch_size=None):
conditioning_image_transforms
=
transforms
.
Compose
(
[
transforms
.
Resize
(
args
.
resolution
,
interpolation
=
transforms
.
InterpolationMode
.
BILINEAR
),
transforms
.
CenterCrop
(
args
.
resolution
),
transforms
.
ToTensor
(),
]
)
...
...
@@ -981,6 +1007,11 @@ def main():
"train/loss"
:
jax_utils
.
unreplicate
(
train_metric
)[
"loss"
],
}
)
if
global_step
%
args
.
checkpointing_steps
==
0
and
jax
.
process_index
()
==
0
:
controlnet
.
save_pretrained
(
f
"
{
args
.
output_dir
}
/
{
global_step
}
"
,
params
=
get_params_to_save
(
state
.
params
),
)
train_metric
=
jax_utils
.
unreplicate
(
train_metric
)
train_step_progress_bar
.
close
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment