Unverified Commit b8215b1c authored by Alexey Zolotenkov's avatar Alexey Zolotenkov Committed by GitHub
Browse files

Fix incorrect seed initialization when args.seed is 0 (#10964)



* Fix seed initialization to handle args.seed = 0 correctly

* Apply style fixes

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 3ee899fa
......@@ -227,7 +227,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = nullcontext()
with autocast_ctx:
......
......@@ -1883,7 +1883,11 @@ def main(args):
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.seed is not None
else None
)
pipeline_args = {"prompt": args.validation_prompt}
if torch.backends.mps.is_available():
......@@ -1987,7 +1991,9 @@ def main(args):
)
# run inference
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
)
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images)
......
......@@ -269,7 +269,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
......
......@@ -722,7 +722,7 @@ def log_validation(
# pipe.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
videos = []
for _ in range(args.num_validation_videos):
......
......@@ -739,7 +739,7 @@ def log_validation(
# pipe.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
videos = []
for _ in range(args.num_validation_videos):
......
......@@ -1334,7 +1334,9 @@ def main(args):
# run inference
if args.validation_prompt and args.num_validation_images > 0:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
)
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0]
for _ in range(args.num_validation_images)
......
......@@ -172,7 +172,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
......
......@@ -150,7 +150,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
if args.validation_images is None:
images = []
......
......@@ -181,7 +181,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
......
......@@ -167,7 +167,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
with autocast_ctx:
......
......@@ -170,7 +170,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
......
......@@ -199,7 +199,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
......
......@@ -207,7 +207,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
......
......@@ -175,7 +175,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
......
......@@ -137,7 +137,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
pipeline_args = {"prompt": args.validation_prompt}
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
......
......@@ -1241,7 +1241,11 @@ def main(args):
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.seed is not None
else None
)
pipeline_args = {"prompt": args.validation_prompt}
with autocast_ctx:
......@@ -1305,7 +1309,9 @@ def main(args):
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
)
with autocast_ctx:
images = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment