Unverified Commit 3bc3b48c authored by satani99's avatar satani99 Committed by GitHub
Browse files

Modularize train_text_to_image_lora SD inferencing during and after training in example (#8283)



* Modularized the train_lora file

* Modularized the train_lora file

* Modularized the train_lora file

* Modularized the train_lora file

* Modularized the train_lora file

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 581d8aac
...@@ -52,6 +52,9 @@ from diffusers.utils.import_utils import is_xformers_available ...@@ -52,6 +52,9 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0") check_min_version("0.29.0.dev0")
...@@ -99,6 +102,48 @@ These are LoRA adaption weights for {base_model}. The weights were fine-tuned on ...@@ -99,6 +102,48 @@ These are LoRA adaption weights for {base_model}. The weights were fine-tuned on
model_card.save(os.path.join(repo_folder, "README.md")) model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation(
pipeline,
args,
accelerator,
epoch,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
generator = torch.Generator(device=accelerator.device)
if args.seed is not None:
generator = generator.manual_seed(args.seed)
images = []
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
for _ in range(args.num_validation_images):
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
phase_name: [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
]
}
)
return images
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument( parser.add_argument(
...@@ -414,11 +459,6 @@ def main(): ...@@ -414,11 +459,6 @@ def main():
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
accelerator.native_amp = False accelerator.native_amp = False
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
...@@ -864,10 +904,6 @@ def main(): ...@@ -864,10 +904,6 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline # create pipeline
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
...@@ -876,38 +912,7 @@ def main(): ...@@ -876,38 +912,7 @@ def main():
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline = pipeline.to(accelerator.device) images = log_validation(pipeline, args, accelerator, epoch)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device)
if args.seed is not None:
generator = generator.manual_seed(args.seed)
images = []
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
for _ in range(args.num_validation_images):
images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
del pipeline del pipeline
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -925,21 +930,6 @@ def main(): ...@@ -925,21 +930,6 @@ def main():
safe_serialization=True, safe_serialization=True,
) )
if args.push_to_hub:
save_model_card(
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
dataset_name=args.dataset_name,
repo_folder=args.output_dir,
)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
# Final inference # Final inference
# Load previous pipeline # Load previous pipeline
if args.validation_prompt is not None: if args.validation_prompt is not None:
...@@ -949,41 +939,27 @@ def main(): ...@@ -949,41 +939,27 @@ def main():
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline = pipeline.to(accelerator.device)
# load attention processors # load attention processors
pipeline.load_lora_weights(args.output_dir) pipeline.load_lora_weights(args.output_dir)
# run inference # run inference
generator = torch.Generator(device=accelerator.device) images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True)
if args.seed is not None:
generator = generator.manual_seed(args.seed)
images = []
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
for _ in range(args.num_validation_images):
images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
)
for tracker in accelerator.trackers: if args.push_to_hub:
if len(images) != 0: save_model_card(
if tracker.name == "tensorboard": repo_id,
np_images = np.stack([np.asarray(img) for img in images]) images=images,
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") base_model=args.pretrained_model_name_or_path,
if tracker.name == "wandb": dataset_name=args.dataset_name,
tracker.log( repo_folder=args.output_dir,
{ )
"test": [ upload_folder(
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") repo_id=repo_id,
for i, image in enumerate(images) folder_path=args.output_dir,
] commit_message="End of training",
} ignore_patterns=["step_*", "epoch_*"],
) )
accelerator.end_training() accelerator.end_training()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment