Unverified Commit c7da8fd2 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add intermediate logging for dreambooth training script (#2557)



* add  intermediate logging
---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent b8bfef2a
...@@ -24,6 +24,7 @@ from pathlib import Path ...@@ -24,6 +24,7 @@ from pathlib import Path
from typing import Optional from typing import Optional
import accelerate import accelerate
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -40,18 +41,71 @@ from tqdm.auto import tqdm ...@@ -40,18 +41,71 @@ from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig from transformers import AutoTokenizer, PretrainedConfig
import diffusers import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.15.0.dev0") check_min_version("0.15.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline (note: unet and vae are loaded again in float32)
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
unet=accelerator.unwrap_model(unet),
vae=vae,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
images.append(image)
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
]
}
)
del pipeline
torch.cuda.empty_cache()
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained( text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, pretrained_model_name_or_path,
...@@ -306,6 +360,28 @@ def parse_args(input_args=None): ...@@ -306,6 +360,28 @@ def parse_args(input_args=None):
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
), ),
) )
parser.add_argument(
"--validation_prompt",
type=str,
default=None,
help="A prompt that is used during validation to verify that the model is learning.",
)
parser.add_argument(
"--num_validation_images",
type=int,
default=4,
help="Number of images that should be generated during validation with `validation_prompt`.",
)
parser.add_argument(
"--validation_steps",
type=int,
default=100,
help=(
"Run validation every X steps. Validation consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`"
" and logging the images."
),
)
parser.add_argument( parser.add_argument(
"--mixed_precision", "--mixed_precision",
type=str, type=str,
...@@ -508,6 +584,10 @@ def main(args): ...@@ -508,6 +584,10 @@ def main(args):
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
...@@ -920,6 +1000,8 @@ def main(args): ...@@ -920,6 +1000,8 @@ def main(args):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment