Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
ee20d1f8
Unverified
Commit
ee20d1f8
authored
Apr 04, 2023
by
YiYi Xu
Committed by
GitHub
Apr 04, 2023
Browse files
update flax controlnet training script (#2951)
* load_from_disk + checkpointing_steps * apply feedback
parent
0d0fa2a3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
9 deletions
+40
-9
examples/controlnet/train_controlnet_flax.py
examples/controlnet/train_controlnet_flax.py
+40
-9
No files found.
examples/controlnet/train_controlnet_flax.py
View file @
ee20d1f8
...
@@ -27,13 +27,13 @@ import optax
...
@@ -27,13 +27,13 @@ import optax
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
import
transformers
import
transformers
from
datasets
import
load_dataset
from
datasets
import
load_dataset
,
load_from_disk
from
flax
import
jax_utils
from
flax
import
jax_utils
from
flax.core.frozen_dict
import
unfreeze
from
flax.core.frozen_dict
import
unfreeze
from
flax.training
import
train_state
from
flax.training
import
train_state
from
flax.training.common_utils
import
shard
from
flax.training.common_utils
import
shard
from
huggingface_hub
import
create_repo
,
upload_folder
from
huggingface_hub
import
create_repo
,
upload_folder
from
PIL
import
Image
from
PIL
import
Image
,
PngImagePlugin
from
torch.utils.data
import
IterableDataset
from
torch.utils.data
import
IterableDataset
from
torchvision
import
transforms
from
torchvision
import
transforms
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
...
@@ -49,6 +49,11 @@ from diffusers import (
...
@@ -49,6 +49,11 @@ from diffusers import (
from
diffusers.utils
import
check_min_version
,
is_wandb_available
from
diffusers.utils
import
check_min_version
,
is_wandb_available
# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
# see more https://github.com/python-pillow/Pillow/issues/5610
LARGE_ENOUGH_NUMBER
=
100
PngImagePlugin
.
MAX_TEXT_CHUNK
=
LARGE_ENOUGH_NUMBER
*
(
1024
**
2
)
if
is_wandb_available
():
if
is_wandb_available
():
import
wandb
import
wandb
...
@@ -246,6 +251,12 @@ def parse_args():
...
@@ -246,6 +251,12 @@ def parse_args():
default
=
None
,
default
=
None
,
help
=
"Total number of training steps to perform."
,
help
=
"Total number of training steps to perform."
,
)
)
parser
.
add_argument
(
"--checkpointing_steps"
,
type
=
int
,
default
=
5000
,
help
=
(
"Save a checkpoint of the training state every X updates."
),
)
parser
.
add_argument
(
parser
.
add_argument
(
"--learning_rate"
,
"--learning_rate"
,
type
=
float
,
type
=
float
,
...
@@ -344,9 +355,17 @@ def parse_args():
...
@@ -344,9 +355,17 @@ def parse_args():
type
=
str
,
type
=
str
,
default
=
None
,
default
=
None
,
help
=
(
help
=
(
"A folder containing the training data. Folder contents must follow the structure described in"
"A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder."
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
"Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ."
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
"If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified."
),
)
parser
.
add_argument
(
"--load_from_disk"
,
action
=
"store_true"
,
help
=
(
"If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`"
"See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk"
),
),
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -478,10 +497,15 @@ def make_train_dataset(args, tokenizer, batch_size=None):
...
@@ -478,10 +497,15 @@ def make_train_dataset(args, tokenizer, batch_size=None):
)
)
else
:
else
:
if
args
.
train_data_dir
is
not
None
:
if
args
.
train_data_dir
is
not
None
:
dataset
=
load_dataset
(
if
args
.
load_from_disk
:
args
.
train_data_dir
,
dataset
=
load_from_disk
(
cache_dir
=
args
.
cache_dir
,
args
.
train_data_dir
,
)
)
else
:
dataset
=
load_dataset
(
args
.
train_data_dir
,
cache_dir
=
args
.
cache_dir
,
)
# See more about loading custom images at
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
...
@@ -545,6 +569,7 @@ def make_train_dataset(args, tokenizer, batch_size=None):
...
@@ -545,6 +569,7 @@ def make_train_dataset(args, tokenizer, batch_size=None):
image_transforms
=
transforms
.
Compose
(
image_transforms
=
transforms
.
Compose
(
[
[
transforms
.
Resize
(
args
.
resolution
,
interpolation
=
transforms
.
InterpolationMode
.
BILINEAR
),
transforms
.
Resize
(
args
.
resolution
,
interpolation
=
transforms
.
InterpolationMode
.
BILINEAR
),
transforms
.
CenterCrop
(
args
.
resolution
),
transforms
.
ToTensor
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.5
],
[
0.5
]),
transforms
.
Normalize
([
0.5
],
[
0.5
]),
]
]
...
@@ -553,6 +578,7 @@ def make_train_dataset(args, tokenizer, batch_size=None):
...
@@ -553,6 +578,7 @@ def make_train_dataset(args, tokenizer, batch_size=None):
conditioning_image_transforms
=
transforms
.
Compose
(
conditioning_image_transforms
=
transforms
.
Compose
(
[
[
transforms
.
Resize
(
args
.
resolution
,
interpolation
=
transforms
.
InterpolationMode
.
BILINEAR
),
transforms
.
Resize
(
args
.
resolution
,
interpolation
=
transforms
.
InterpolationMode
.
BILINEAR
),
transforms
.
CenterCrop
(
args
.
resolution
),
transforms
.
ToTensor
(),
transforms
.
ToTensor
(),
]
]
)
)
...
@@ -981,6 +1007,11 @@ def main():
...
@@ -981,6 +1007,11 @@ def main():
"train/loss"
:
jax_utils
.
unreplicate
(
train_metric
)[
"loss"
],
"train/loss"
:
jax_utils
.
unreplicate
(
train_metric
)[
"loss"
],
}
}
)
)
if
global_step
%
args
.
checkpointing_steps
==
0
and
jax
.
process_index
()
==
0
:
controlnet
.
save_pretrained
(
f
"
{
args
.
output_dir
}
/
{
global_step
}
"
,
params
=
get_params_to_save
(
state
.
params
),
)
train_metric
=
jax_utils
.
unreplicate
(
train_metric
)
train_metric
=
jax_utils
.
unreplicate
(
train_metric
)
train_step_progress_bar
.
close
()
train_step_progress_bar
.
close
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment