Unverified Commit 906e4105 authored by YaYaB's avatar YaYaB Committed by GitHub
Browse files

Fix push_to_hub for dreambooth and textual_inversion (#748)

* Fix push_to_hub for dreambooth and textual_inversion

* Use repo.push_to_hub instead of push_to_hub
parent 7258dc49
...@@ -575,9 +575,7 @@ def main(): ...@@ -575,9 +575,7 @@ def main():
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub( repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
)
accelerator.end_training() accelerator.end_training()
......
...@@ -569,9 +569,7 @@ def main(): ...@@ -569,9 +569,7 @@ def main():
save_progress(text_encoder, placeholder_token_id, accelerator, args) save_progress(text_encoder, placeholder_token_id, accelerator, args)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub( repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
)
accelerator.end_training() accelerator.end_training()
......
...@@ -9,7 +9,7 @@ from accelerate import Accelerator ...@@ -9,7 +9,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.hub_utils import init_git_repo, push_to_hub from diffusers.hub_utils import init_git_repo
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from torchvision.transforms import ( from torchvision.transforms import (
...@@ -185,7 +185,7 @@ def main(args): ...@@ -185,7 +185,7 @@ def main(args):
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model # save the model
if args.push_to_hub: if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
else: else:
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment