Unverified Commit 41833109 authored by Yuta Hayashibe's avatar Yuta Hayashibe Committed by GitHub
Browse files

Run inference on a specific condition and fix call of manual_seed() (#2074)

parent fc8afa3a
......@@ -972,9 +972,10 @@ def main(args):
pipeline.unet.load_attn_procs(args.output_dir)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
prompt = args.num_validation_images * [args.validation_prompt]
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
if args.validation_prompt and args.num_validation_images > 0:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
prompt = args.num_validation_images * [args.validation_prompt]
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment