Unverified Commit b8215b1c authored by Alexey Zolotenkov's avatar Alexey Zolotenkov Committed by GitHub
Browse files

Fix incorrect seed initialization when args.seed is 0 (#10964)



* Fix seed initialization to handle args.seed = 0 correctly

* Apply style fixes

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 3ee899fa
...@@ -227,7 +227,7 @@ def log_validation( ...@@ -227,7 +227,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = nullcontext() autocast_ctx = nullcontext()
with autocast_ctx: with autocast_ctx:
......
...@@ -1883,7 +1883,11 @@ def main(args): ...@@ -1883,7 +1883,11 @@ def main(args):
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.seed is not None
else None
)
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
...@@ -1987,7 +1991,9 @@ def main(args): ...@@ -1987,7 +1991,9 @@ def main(args):
) )
# run inference # run inference
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
)
images = [ images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images) for _ in range(args.num_validation_images)
......
...@@ -269,7 +269,7 @@ def log_validation( ...@@ -269,7 +269,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
......
...@@ -722,7 +722,7 @@ def log_validation( ...@@ -722,7 +722,7 @@ def log_validation(
# pipe.set_progress_bar_config(disable=True) # pipe.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
videos = [] videos = []
for _ in range(args.num_validation_videos): for _ in range(args.num_validation_videos):
......
...@@ -739,7 +739,7 @@ def log_validation( ...@@ -739,7 +739,7 @@ def log_validation(
# pipe.set_progress_bar_config(disable=True) # pipe.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
videos = [] videos = []
for _ in range(args.num_validation_videos): for _ in range(args.num_validation_videos):
......
...@@ -1334,7 +1334,9 @@ def main(args): ...@@ -1334,7 +1334,9 @@ def main(args):
# run inference # run inference
if args.validation_prompt and args.num_validation_images > 0: if args.validation_prompt and args.num_validation_images > 0:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
)
images = [ images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0] pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0]
for _ in range(args.num_validation_images) for _ in range(args.num_validation_images)
......
...@@ -172,7 +172,7 @@ def log_validation( ...@@ -172,7 +172,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext() autocast_ctx = nullcontext()
......
...@@ -150,7 +150,7 @@ def log_validation( ...@@ -150,7 +150,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
if args.validation_images is None: if args.validation_images is None:
images = [] images = []
......
...@@ -181,7 +181,7 @@ def log_validation( ...@@ -181,7 +181,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext() autocast_ctx = nullcontext()
......
...@@ -167,7 +167,7 @@ def log_validation( ...@@ -167,7 +167,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
with autocast_ctx: with autocast_ctx:
......
...@@ -170,7 +170,7 @@ def log_validation( ...@@ -170,7 +170,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
......
...@@ -199,7 +199,7 @@ def log_validation( ...@@ -199,7 +199,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext() autocast_ctx = nullcontext()
......
...@@ -207,7 +207,7 @@ def log_validation( ...@@ -207,7 +207,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
......
...@@ -175,7 +175,7 @@ def log_validation( ...@@ -175,7 +175,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext() autocast_ctx = nullcontext()
......
...@@ -137,7 +137,7 @@ def log_validation( ...@@ -137,7 +137,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
autocast_ctx = nullcontext() autocast_ctx = nullcontext()
......
...@@ -1241,7 +1241,11 @@ def main(args): ...@@ -1241,7 +1241,11 @@ def main(args):
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.seed is not None
else None
)
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
with autocast_ctx: with autocast_ctx:
...@@ -1305,7 +1309,9 @@ def main(args): ...@@ -1305,7 +1309,9 @@ def main(args):
images = [] images = []
if args.validation_prompt and args.num_validation_images > 0: if args.validation_prompt and args.num_validation_images > 0:
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
)
with autocast_ctx: with autocast_ctx:
images = [ images = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment