Unverified Commit d185b5ed authored by Yinzhen Wang's avatar Yinzhen Wang Committed by GitHub
Browse files

change validation scheduler for train_dreambooth.py when training IF (#4333)

* dreambooth training

* train_dreambooth validation scheduler

* set a particular scheduler via a string

* modify readme after setting a particular scheduler via a string

* modify readme after setting a particular scheduler

* use importlib to set a particular scheduler

* import with correct sort
parent 709a6428
......@@ -673,6 +673,8 @@ likely the learning rate can be increased with larger batch sizes.
Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.
`--validation_scheduler`: Set a particular scheduler via a string. We found that it is better to use the DDPMScheduler for validation when training DeepFloyd IF.
```sh
export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
......@@ -697,6 +699,7 @@ accelerate launch train_dreambooth.py \
--use_8bit_adam \
--set_grads_to_none \
--skip_save_text_encoder \
--validation_scheduler DDPMScheduler \
--push_to_hub
```
......@@ -735,6 +738,7 @@ accelerate launch train_dreambooth.py \
--text_encoder_use_attention_mask \
--validation_images $VALIDATION_IMAGES \
--class_labels_conditioning timesteps \
--validation_scheduler DDPMScheduler\
--push_to_hub
```
......
......@@ -17,6 +17,7 @@ import argparse
import copy
import gc
import hashlib
import importlib
import itertools
import logging
import math
......@@ -47,7 +48,6 @@ from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
......@@ -153,7 +153,9 @@ def log_validation(
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
module = importlib.import_module("diffusers")
scheduler_class = getattr(module, args.validation_scheduler)
pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
......@@ -556,6 +558,13 @@ def parse_args(input_args=None):
default=None,
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
)
parser.add_argument(
"--validation_scheduler",
type=str,
default="DPMSolverMultistepScheduler",
choices=["DPMSolverMultistepScheduler", "DDPMScheduler"],
help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.",
)
if input_args is not None:
args = parser.parse_args(input_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment