Unverified Commit 41833109 authored by Yuta Hayashibe's avatar Yuta Hayashibe Committed by GitHub
Browse files

Run inference on a specific condition and fix call of manual_seed() (#2074)

parent fc8afa3a
...@@ -972,9 +972,10 @@ def main(args): ...@@ -972,9 +972,10 @@ def main(args):
pipeline.unet.load_attn_procs(args.output_dir) pipeline.unet.load_attn_procs(args.output_dir)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.validation_prompt and args.num_validation_images > 0:
prompt = args.num_validation_images * [args.validation_prompt] generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
images = pipeline(prompt, num_inference_steps=25, generator=generator).images prompt = args.num_validation_images * [args.validation_prompt]
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "tensorboard": if tracker.name == "tensorboard":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment