Unverified Commit 43672b4a authored by Andy W's avatar Andy W Committed by GitHub
Browse files

Fix "push_to_hub only create repo in consistency model lora SDXL training script" (#6102)



* fix

* style fix

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 9df3d843
...@@ -38,7 +38,7 @@ from accelerate import Accelerator ...@@ -38,7 +38,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from braceexpand import braceexpand from braceexpand import braceexpand
from huggingface_hub import create_repo from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
from torch.utils.data import default_collate from torch.utils.data import default_collate
...@@ -847,7 +847,7 @@ def main(args): ...@@ -847,7 +847,7 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub: if args.push_to_hub:
create_repo( repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True, exist_ok=True,
token=args.hub_token, token=args.hub_token,
...@@ -1366,6 +1366,14 @@ def main(args): ...@@ -1366,6 +1366,14 @@ def main(args):
lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default") lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict) StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)
if args.push_to_hub:
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
...@@ -39,7 +39,7 @@ from accelerate import Accelerator ...@@ -39,7 +39,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from braceexpand import braceexpand from braceexpand import braceexpand
from huggingface_hub import create_repo from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
from torch.utils.data import default_collate from torch.utils.data import default_collate
...@@ -842,7 +842,7 @@ def main(args): ...@@ -842,7 +842,7 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub: if args.push_to_hub:
create_repo( repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True, exist_ok=True,
token=args.hub_token, token=args.hub_token,
...@@ -1424,6 +1424,14 @@ def main(args): ...@@ -1424,6 +1424,14 @@ def main(args):
lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default") lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
StableDiffusionXLPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict) StableDiffusionXLPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)
if args.push_to_hub:
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
...@@ -38,7 +38,7 @@ from accelerate import Accelerator ...@@ -38,7 +38,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from braceexpand import braceexpand from braceexpand import braceexpand
from huggingface_hub import create_repo from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from torch.utils.data import default_collate from torch.utils.data import default_collate
from torchvision import transforms from torchvision import transforms
...@@ -835,7 +835,7 @@ def main(args): ...@@ -835,7 +835,7 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub: if args.push_to_hub:
create_repo( repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True, exist_ok=True,
token=args.hub_token, token=args.hub_token,
...@@ -1354,6 +1354,14 @@ def main(args): ...@@ -1354,6 +1354,14 @@ def main(args):
target_unet = accelerator.unwrap_model(target_unet) target_unet = accelerator.unwrap_model(target_unet)
target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target")) target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
if args.push_to_hub:
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
...@@ -39,7 +39,7 @@ from accelerate import Accelerator ...@@ -39,7 +39,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from braceexpand import braceexpand from braceexpand import braceexpand
from huggingface_hub import create_repo from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from torch.utils.data import default_collate from torch.utils.data import default_collate
from torchvision import transforms from torchvision import transforms
...@@ -875,7 +875,7 @@ def main(args): ...@@ -875,7 +875,7 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub: if args.push_to_hub:
create_repo( repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True, exist_ok=True,
token=args.hub_token, token=args.hub_token,
...@@ -1457,6 +1457,14 @@ def main(args): ...@@ -1457,6 +1457,14 @@ def main(args):
target_unet = accelerator.unwrap_model(target_unet) target_unet = accelerator.unwrap_model(target_unet)
target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target")) target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
if args.push_to_hub:
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment