Unverified Commit 9276b1e1 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Replace deprecated hub utils in `train_unconditional_ort` (#1504)

* Replace deprecated hub utils in `train_unconditional_ort`

* typo
parent 2579d421
import argparse import argparse
import math import math
import os import os
from pathlib import Path
from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -9,9 +11,9 @@ from accelerate import Accelerator ...@@ -9,9 +11,9 @@ from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from huggingface_hub import HfFolder, Repository, whoami
from onnxruntime.training.ortmodule import ORTModule from onnxruntime.training.ortmodule import ORTModule
from torchvision.transforms import ( from torchvision.transforms import (
CenterCrop, CenterCrop,
...@@ -28,6 +30,16 @@ from tqdm.auto import tqdm ...@@ -28,6 +30,16 @@ from tqdm.auto import tqdm
logger = get_logger(__name__) logger = get_logger(__name__)
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def main(args): def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator = Accelerator( accelerator = Accelerator(
...@@ -113,8 +125,22 @@ def main(args): ...@@ -113,8 +125,22 @@ def main(args):
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
if args.push_to_hub: # Handle the repository creation
repo = init_git_repo(args, at_init=True) if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if accelerator.is_main_process: if accelerator.is_main_process:
run = os.path.split(__file__)[-1].split(".")[0] run = os.path.split(__file__)[-1].split(".")[0]
...@@ -186,10 +212,9 @@ def main(args): ...@@ -186,10 +212,9 @@ def main(args):
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model # save the model
pipeline.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
accelerator.end_training() accelerator.end_training()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment