Unverified Commit 2de36fae authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

minor fix in controlnet flax example (#2986)

* fix the error when push_to_hub but not log validation

* contronet_from_pt & controlnet_revision

* add intermediate checkpointing to the guide
parent e4052643
...@@ -320,6 +320,12 @@ Then cd in the example folder and run ...@@ -320,6 +320,12 @@ Then cd in the example folder and run
pip install -U -r requirements_flax.txt pip install -U -r requirements_flax.txt
``` ```
If you want to use Weights and Biases logging, you should also install `wandb` now
```bash
pip install wandb
```
Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress
``` ```
...@@ -390,3 +396,16 @@ Note, however, that the performance of the TPUs might get bottlenecked as stream ...@@ -390,3 +396,16 @@ Note, however, that the performance of the TPUs might get bottlenecked as stream
* [Webdataset](https://webdataset.github.io/webdataset/) * [Webdataset](https://webdataset.github.io/webdataset/)
* [TorchData](https://github.com/pytorch/data) * [TorchData](https://github.com/pytorch/data)
* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds) * [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds)
When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing:
```bash
--checkpointing_steps=500
```
This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500
You can then start your training from this saved checkpoint with
```bash
--controlnet_model_name_or_path="./control_out/500"
```
\ No newline at end of file
...@@ -154,6 +154,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d ...@@ -154,6 +154,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
img_str = "" img_str = ""
if image_logs is not None:
for i, log in enumerate(image_logs): for i, log in enumerate(image_logs):
images = log["images"] images = log["images"]
validation_prompt = log["validation_prompt"] validation_prompt = log["validation_prompt"]
...@@ -213,6 +214,17 @@ def parse_args(): ...@@ -213,6 +214,17 @@ def parse_args():
action="store_true", action="store_true",
help="Load the pretrained model from a PyTorch checkpoint.", help="Load the pretrained model from a PyTorch checkpoint.",
) )
parser.add_argument(
"--controlnet_revision",
type=str,
default=None,
help="Revision of controlnet model identifier from huggingface.co/models.",
)
parser.add_argument(
"--controlnet_from_pt",
action="store_true",
help="Load the controlnet model from a PyTorch checkpoint.",
)
parser.add_argument( parser.add_argument(
"--tokenizer_name", "--tokenizer_name",
type=str, type=str,
...@@ -731,7 +743,10 @@ def main(): ...@@ -731,7 +743,10 @@ def main():
if args.controlnet_model_name_or_path: if args.controlnet_model_name_or_path:
logger.info("Loading existing controlnet weights") logger.info("Loading existing controlnet weights")
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
args.controlnet_model_name_or_path, from_pt=True, dtype=jnp.float32 args.controlnet_model_name_or_path,
revision=args.controlnet_revision,
from_pt=args.controlnet_from_pt,
dtype=jnp.float32,
) )
else: else:
logger.info("Initializing controlnet weights from unet") logger.info("Initializing controlnet weights from unet")
...@@ -1021,6 +1036,8 @@ def main(): ...@@ -1021,6 +1036,8 @@ def main():
if jax.process_index() == 0: if jax.process_index() == 0:
if args.validation_prompt is not None: if args.validation_prompt is not None:
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
else:
image_logs = None
controlnet.save_pretrained( controlnet.save_pretrained(
args.output_dir, args.output_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment