Unverified Commit d9b9533c authored by Isamu Isozaki's avatar Isamu Isozaki Committed by GitHub
Browse files

Textual inv make save log both steps (#2178)



* Initial commit

* removed images

* Made logging the same as save

* Removed logging function

* Quality fixes

* Quality fixes

* Tested

* Added support back for validation_epochs

* Fixing styles

* Did changes

* Change to log_validation

* Add extra space after wandb import

* Add extra space after wandb
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>

* Fixed spacing

---------
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>
parent 80148484
...@@ -18,6 +18,7 @@ import logging ...@@ -18,6 +18,7 @@ import logging
import math import math
import os import os
import random import random
import warnings
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
...@@ -54,6 +55,9 @@ from diffusers.utils import check_min_version, is_wandb_available ...@@ -54,6 +55,9 @@ from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
if is_wandb_available():
import wandb
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = { PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR, "linear": PIL.Image.Resampling.BILINEAR,
...@@ -79,6 +83,50 @@ check_min_version("0.14.0.dev0") ...@@ -79,6 +83,50 @@ check_min_version("0.14.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline (note: unet and vae are loaded again in float32)
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
unet=unet,
vae=vae,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
images.append(image)
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
]
}
)
del pipeline
torch.cuda.empty_cache()
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path): def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
logger.info("Saving embeddings") logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
...@@ -268,12 +316,22 @@ def parse_args(): ...@@ -268,12 +316,22 @@ def parse_args():
default=4, default=4,
help="Number of images that should be generated during validation with `validation_prompt`.", help="Number of images that should be generated during validation with `validation_prompt`.",
) )
parser.add_argument(
"--validation_steps",
type=int,
default=100,
help=(
"Run validation every X steps. Validation consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`"
" and logging the images."
),
)
parser.add_argument( parser.add_argument(
"--validation_epochs", "--validation_epochs",
type=int, type=int,
default=50, default=None,
help=( help=(
"Run validation every X epochs. Validation consists of running the prompt" "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`" " `args.validation_prompt` multiple times: `args.num_validation_images`"
" and logging the images." " and logging the images."
), ),
...@@ -488,7 +546,6 @@ def main(): ...@@ -488,7 +546,6 @@ def main():
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
...@@ -627,6 +684,15 @@ def main(): ...@@ -627,6 +684,15 @@ def main():
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
) )
if args.validation_epochs is not None:
warnings.warn(
f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}."
" Deprecated validation_epochs in favor of `validation_steps`"
f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}",
FutureWarning,
stacklevel=2,
)
args.validation_steps = args.validation_epochs * len(train_dataset)
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
...@@ -683,7 +749,6 @@ def main(): ...@@ -683,7 +749,6 @@ def main():
logger.info(f" Total optimization steps = {args.max_train_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0 global_step = 0
first_epoch = 0 first_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest": if args.resume_from_checkpoint != "latest":
...@@ -783,6 +848,8 @@ def main(): ...@@ -783,6 +848,8 @@ def main():
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
...@@ -790,53 +857,6 @@ def main(): ...@@ -790,53 +857,6 @@ def main():
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline (note: unet and vae are loaded again in float32)
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
unet=unet,
vae=vae,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = (
None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
)
images = []
for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
images.append(image)
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
del pipeline
torch.cuda.empty_cache()
# Create the pipeline using using the trained modules and save it. # Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment