Unverified Commit e2ead7cd authored by Leo Jiang's avatar Leo Jiang Committed by GitHub
Browse files

Fix the issue on sd3 dreambooth w./w.t. lora training (#9419)



* Fix dtype error

* [bugfix] Fixed the issue on sd3 dreambooth training

* [bugfix] Fixed the issue on sd3 dreambooth training

---------
Co-authored-by: default avatar蒋硕 <jiangshuo9@h-partners.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 48e36353
......@@ -154,13 +154,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
......@@ -1717,6 +1718,7 @@ def main(args):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two
......@@ -1761,6 +1763,7 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:
......
......@@ -122,6 +122,7 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
......@@ -141,7 +142,7 @@ def log_validation(
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
......@@ -1360,6 +1361,7 @@ def main(args):
accelerator,
pipeline_args,
epoch,
torch_dtype=weight_dtype,
)
# Save the lora layers
......@@ -1402,6 +1404,7 @@ def main(args):
pipeline_args,
epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:
......
......@@ -170,13 +170,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
......@@ -1785,6 +1786,7 @@ def main(args):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two
......@@ -1832,6 +1834,7 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:
......
......@@ -179,13 +179,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
......@@ -1788,6 +1789,7 @@ def main(args):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
objs = []
if not args.train_text_encoder:
......@@ -1840,6 +1842,7 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:
......
......@@ -180,6 +180,7 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
......@@ -201,7 +202,7 @@ def log_validation(
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
......@@ -1890,6 +1891,7 @@ def main(args):
accelerator,
pipeline_args,
epoch,
torch_dtype=weight_dtype,
)
# Save the lora layers
......@@ -1955,6 +1957,7 @@ def main(args):
pipeline_args,
epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:
......
......@@ -157,13 +157,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
......@@ -1725,6 +1726,7 @@ def main(args):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
......@@ -1775,6 +1777,7 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
if args.push_to_hub:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment