"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9c8eca702c2fa811fba1ccff82a6aee6a04a2556"
Unverified Commit cc2c4ae7 authored by Pu Cao's avatar Pu Cao Committed by GitHub
Browse files

fix inference in custom diffusion (#5329)



* Update train_custom_diffusion.py

* make style

* Empty-Commit

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 6bd55b54
...@@ -1214,50 +1214,52 @@ def main(args): ...@@ -1214,50 +1214,52 @@ def main(args):
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
if accelerator.is_main_process: if accelerator.is_main_process:
images = [] images = []
if args.validation_prompt is not None and global_step % args.validation_steps == 0: if args.validation_prompt is not None and global_step % args.validation_steps == 0:
logger.info( logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
) )
# create pipeline # create pipeline
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder), text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer, tokenizer=tokenizer,
revision=args.revision, revision=args.revision,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0]
for _ in range(args.num_validation_images)
]
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
del pipeline # run inference
torch.cuda.empty_cache() generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[
0
]
for _ in range(args.num_validation_images)
]
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
del pipeline
torch.cuda.empty_cache()
# Save the custom diffusion layers # Save the custom diffusion layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment