Unverified Commit d185b5ed authored by Yinzhen Wang's avatar Yinzhen Wang Committed by GitHub
Browse files

change validation scheduler for train_dreambooth.py when training IF (#4333)

* dreambooth training

* train_dreambooth validation scheduler

* set a particular scheduler via a string

* modify readme after setting a particular scheduler via a string

* modify readme after setting a particular scheduler

* use importlib to set a particular scheduler

* import with correct sort
parent 709a6428
...@@ -673,6 +673,8 @@ likely the learning rate can be increased with larger batch sizes. ...@@ -673,6 +673,8 @@ likely the learning rate can be increased with larger batch sizes.
Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM. Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.
`--validation_scheduler`: Set a particular scheduler via a string. We found that it is better to use the DDPMScheduler for validation when training DeepFloyd IF.
```sh ```sh
export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0" export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
...@@ -697,6 +699,7 @@ accelerate launch train_dreambooth.py \ ...@@ -697,6 +699,7 @@ accelerate launch train_dreambooth.py \
--use_8bit_adam \ --use_8bit_adam \
--set_grads_to_none \ --set_grads_to_none \
--skip_save_text_encoder \ --skip_save_text_encoder \
--validation_scheduler DDPMScheduler \
--push_to_hub --push_to_hub
``` ```
...@@ -735,6 +738,7 @@ accelerate launch train_dreambooth.py \ ...@@ -735,6 +738,7 @@ accelerate launch train_dreambooth.py \
--text_encoder_use_attention_mask \ --text_encoder_use_attention_mask \
--validation_images $VALIDATION_IMAGES \ --validation_images $VALIDATION_IMAGES \
--class_labels_conditioning timesteps \ --class_labels_conditioning timesteps \
--validation_scheduler DDPMScheduler\
--push_to_hub --push_to_hub
``` ```
......
...@@ -17,6 +17,7 @@ import argparse ...@@ -17,6 +17,7 @@ import argparse
import copy import copy
import gc import gc
import hashlib import hashlib
import importlib
import itertools import itertools
import logging import logging
import math import math
...@@ -47,7 +48,6 @@ from diffusers import ( ...@@ -47,7 +48,6 @@ from diffusers import (
AutoencoderKL, AutoencoderKL,
DDPMScheduler, DDPMScheduler,
DiffusionPipeline, DiffusionPipeline,
DPMSolverMultistepScheduler,
StableDiffusionPipeline, StableDiffusionPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
...@@ -153,7 +153,9 @@ def log_validation( ...@@ -153,7 +153,9 @@ def log_validation(
scheduler_args["variance_type"] = variance_type scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) module = importlib.import_module("diffusers")
scheduler_class = getattr(module, args.validation_scheduler)
pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
...@@ -556,6 +558,13 @@ def parse_args(input_args=None): ...@@ -556,6 +558,13 @@ def parse_args(input_args=None):
default=None, default=None,
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
) )
parser.add_argument(
"--validation_scheduler",
type=str,
default="DPMSolverMultistepScheduler",
choices=["DPMSolverMultistepScheduler", "DDPMScheduler"],
help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.",
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment