Commit 2ea64a08 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

Prepare code for big cleaning

parents 37fe8e00 0ea78f0d
include diffusers/utils/model_card_template.md
\ No newline at end of file
...@@ -50,15 +50,16 @@ def init_git_repo(args, at_init: bool = False): ...@@ -50,15 +50,16 @@ def init_git_repo(args, at_init: bool = False):
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True` Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out. and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
""" """
if args.local_rank not in [-1, 0]: if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
return return
use_auth_token = True if args.hub_token is None else args.hub_token hub_token = args.hub_token if hasattr(args, "hub_token") else None
if args.hub_model_id is None: use_auth_token = True if hub_token is None else hub_token
if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
repo_name = Path(args.output_dir).absolute().name repo_name = Path(args.output_dir).absolute().name
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
if "/" not in repo_name: if "/" not in repo_name:
repo_name = get_full_repo_name(repo_name, token=args.hub_token) repo_name = get_full_repo_name(repo_name, token=hub_token)
try: try:
repo = Repository( repo = Repository(
...@@ -111,7 +112,7 @@ def push_to_hub( ...@@ -111,7 +112,7 @@ def push_to_hub(
commit and an object to track the progress of the commit if `blocking=True` commit and an object to track the progress of the commit if `blocking=True`
""" """
if args.hub_model_id is None: if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
model_name = Path(args.output_dir).name model_name = Path(args.output_dir).name
else: else:
model_name = args.hub_model_id.split("/")[-1] model_name = args.hub_model_id.split("/")[-1]
...@@ -122,7 +123,7 @@ def push_to_hub( ...@@ -122,7 +123,7 @@ def push_to_hub(
pipeline.save_pretrained(output_dir) pipeline.save_pretrained(output_dir)
# Only push from one node. # Only push from one node.
if args.local_rank not in [-1, 0]: if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
return return
# Cancel any async push in progress if blocking=True. The commits will all be pushed together. # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
...@@ -146,10 +147,11 @@ def push_to_hub( ...@@ -146,10 +147,11 @@ def push_to_hub(
def create_model_card(args, model_name): def create_model_card(args, model_name):
if args.local_rank not in [-1, 0]: if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
return return
repo_name = get_full_repo_name(model_name, token=args.hub_token) hub_token = args.hub_token if hasattr(args, "hub_token") else None
repo_name = get_full_repo_name(model_name, token=hub_token)
model_card = ModelCard.from_template( model_card = ModelCard.from_template(
card_data=CardData( # Card metadata object that will be converted to YAML block card_data=CardData( # Card metadata object that will be converted to YAML block
...@@ -163,20 +165,22 @@ def create_model_card(args, model_name): ...@@ -163,20 +165,22 @@ def create_model_card(args, model_name):
template_path=MODEL_CARD_TEMPLATE_PATH, template_path=MODEL_CARD_TEMPLATE_PATH,
model_name=model_name, model_name=model_name,
repo_name=repo_name, repo_name=repo_name,
dataset_name=args.dataset, dataset_name=args.dataset if hasattr(args, "dataset") else None,
learning_rate=args.learning_rate, learning_rate=args.learning_rate,
train_batch_size=args.train_batch_size, train_batch_size=args.train_batch_size,
eval_batch_size=args.eval_batch_size, eval_batch_size=args.eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps
adam_beta1=args.adam_beta1, if hasattr(args, "gradient_accumulation_steps")
adam_beta2=args.adam_beta2, else None,
adam_weight_decay=args.adam_weight_decay, adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
adam_epsilon=args.adam_epsilon, adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
lr_scheduler=args.lr_scheduler, adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
lr_warmup_steps=args.lr_warmup_steps, adam_epsilon=args.adam_epsilon if hasattr(args, "adam_weight_decay") else None,
ema_inv_gamma=args.ema_inv_gamma, lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
ema_power=args.ema_power, lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
ema_max_decay=args.ema_max_decay, ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
ema_power=args.ema_power if hasattr(args, "ema_power") else None,
ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment