Unverified Commit e2ead7cd authored by Leo Jiang's avatar Leo Jiang Committed by GitHub
Browse files

Fix the issue on sd3 dreambooth w./w.t. lora training (#9419)



* Fix dtype error

* [bugfix] Fixed the issue on sd3 dreambooth training

* [bugfix] Fixed the issue on sd3 dreambooth training

---------
Co-authored-by: default avatar蒋硕 <jiangshuo9@h-partners.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 48e36353
...@@ -154,13 +154,14 @@ def log_validation( ...@@ -154,13 +154,14 @@ def log_validation(
accelerator, accelerator,
pipeline_args, pipeline_args,
epoch, epoch,
torch_dtype,
is_final_validation=False, is_final_validation=False,
): ):
logger.info( logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
...@@ -1717,6 +1718,7 @@ def main(args): ...@@ -1717,6 +1718,7 @@ def main(args):
accelerator=accelerator, accelerator=accelerator,
pipeline_args=pipeline_args, pipeline_args=pipeline_args,
epoch=epoch, epoch=epoch,
torch_dtype=weight_dtype,
) )
if not args.train_text_encoder: if not args.train_text_encoder:
del text_encoder_one, text_encoder_two del text_encoder_one, text_encoder_two
...@@ -1761,6 +1763,7 @@ def main(args): ...@@ -1761,6 +1763,7 @@ def main(args):
pipeline_args=pipeline_args, pipeline_args=pipeline_args,
epoch=epoch, epoch=epoch,
is_final_validation=True, is_final_validation=True,
torch_dtype=weight_dtype,
) )
if args.push_to_hub: if args.push_to_hub:
......
...@@ -122,6 +122,7 @@ def log_validation( ...@@ -122,6 +122,7 @@ def log_validation(
accelerator, accelerator,
pipeline_args, pipeline_args,
epoch, epoch,
torch_dtype,
is_final_validation=False, is_final_validation=False,
): ):
logger.info( logger.info(
...@@ -141,7 +142,7 @@ def log_validation( ...@@ -141,7 +142,7 @@ def log_validation(
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
...@@ -1360,6 +1361,7 @@ def main(args): ...@@ -1360,6 +1361,7 @@ def main(args):
accelerator, accelerator,
pipeline_args, pipeline_args,
epoch, epoch,
torch_dtype=weight_dtype,
) )
# Save the lora layers # Save the lora layers
...@@ -1402,6 +1404,7 @@ def main(args): ...@@ -1402,6 +1404,7 @@ def main(args):
pipeline_args, pipeline_args,
epoch, epoch,
is_final_validation=True, is_final_validation=True,
torch_dtype=weight_dtype,
) )
if args.push_to_hub: if args.push_to_hub:
......
...@@ -170,13 +170,14 @@ def log_validation( ...@@ -170,13 +170,14 @@ def log_validation(
accelerator, accelerator,
pipeline_args, pipeline_args,
epoch, epoch,
torch_dtype,
is_final_validation=False, is_final_validation=False,
): ):
logger.info( logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
...@@ -1785,6 +1786,7 @@ def main(args): ...@@ -1785,6 +1786,7 @@ def main(args):
accelerator=accelerator, accelerator=accelerator,
pipeline_args=pipeline_args, pipeline_args=pipeline_args,
epoch=epoch, epoch=epoch,
torch_dtype=weight_dtype,
) )
if not args.train_text_encoder: if not args.train_text_encoder:
del text_encoder_one, text_encoder_two del text_encoder_one, text_encoder_two
...@@ -1832,6 +1834,7 @@ def main(args): ...@@ -1832,6 +1834,7 @@ def main(args):
pipeline_args=pipeline_args, pipeline_args=pipeline_args,
epoch=epoch, epoch=epoch,
is_final_validation=True, is_final_validation=True,
torch_dtype=weight_dtype,
) )
if args.push_to_hub: if args.push_to_hub:
......
...@@ -179,13 +179,14 @@ def log_validation( ...@@ -179,13 +179,14 @@ def log_validation(
accelerator, accelerator,
pipeline_args, pipeline_args,
epoch, epoch,
torch_dtype,
is_final_validation=False, is_final_validation=False,
): ):
logger.info( logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
...@@ -1788,6 +1789,7 @@ def main(args): ...@@ -1788,6 +1789,7 @@ def main(args):
accelerator=accelerator, accelerator=accelerator,
pipeline_args=pipeline_args, pipeline_args=pipeline_args,
epoch=epoch, epoch=epoch,
torch_dtype=weight_dtype,
) )
objs = [] objs = []
if not args.train_text_encoder: if not args.train_text_encoder:
...@@ -1840,6 +1842,7 @@ def main(args): ...@@ -1840,6 +1842,7 @@ def main(args):
pipeline_args=pipeline_args, pipeline_args=pipeline_args,
epoch=epoch, epoch=epoch,
is_final_validation=True, is_final_validation=True,
torch_dtype=weight_dtype,
) )
if args.push_to_hub: if args.push_to_hub:
......
...@@ -180,6 +180,7 @@ def log_validation( ...@@ -180,6 +180,7 @@ def log_validation(
accelerator, accelerator,
pipeline_args, pipeline_args,
epoch, epoch,
torch_dtype,
is_final_validation=False, is_final_validation=False,
): ):
logger.info( logger.info(
...@@ -201,7 +202,7 @@ def log_validation( ...@@ -201,7 +202,7 @@ def log_validation(
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
...@@ -1890,6 +1891,7 @@ def main(args): ...@@ -1890,6 +1891,7 @@ def main(args):
accelerator, accelerator,
pipeline_args, pipeline_args,
epoch, epoch,
torch_dtype=weight_dtype,
) )
# Save the lora layers # Save the lora layers
...@@ -1955,6 +1957,7 @@ def main(args): ...@@ -1955,6 +1957,7 @@ def main(args):
pipeline_args, pipeline_args,
epoch, epoch,
is_final_validation=True, is_final_validation=True,
torch_dtype=weight_dtype,
) )
if args.push_to_hub: if args.push_to_hub:
......
...@@ -157,13 +157,14 @@ def log_validation( ...@@ -157,13 +157,14 @@ def log_validation(
accelerator, accelerator,
pipeline_args, pipeline_args,
epoch, epoch,
torch_dtype,
is_final_validation=False, is_final_validation=False,
): ):
logger.info( logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
...@@ -1725,6 +1726,7 @@ def main(args): ...@@ -1725,6 +1726,7 @@ def main(args):
accelerator=accelerator, accelerator=accelerator,
pipeline_args=pipeline_args, pipeline_args=pipeline_args,
epoch=epoch, epoch=epoch,
torch_dtype=weight_dtype,
) )
if not args.train_text_encoder: if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three del text_encoder_one, text_encoder_two, text_encoder_three
...@@ -1775,6 +1777,7 @@ def main(args): ...@@ -1775,6 +1777,7 @@ def main(args):
pipeline_args=pipeline_args, pipeline_args=pipeline_args,
epoch=epoch, epoch=epoch,
is_final_validation=True, is_final_validation=True,
torch_dtype=weight_dtype,
) )
if args.push_to_hub: if args.push_to_hub:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment