Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
2ea64a08
Commit
2ea64a08
authored
Jul 19, 2022
by
Patrick von Platen
Browse files
Prepare code for big cleaning
parents
37fe8e00
0ea78f0d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
19 deletions
+24
-19
MANIFEST.in
MANIFEST.in
+1
-0
src/diffusers/hub_utils.py
src/diffusers/hub_utils.py
+23
-19
No files found.
MANIFEST.in
0 → 100644
View file @
2ea64a08
include diffusers/utils/model_card_template.md
\ No newline at end of file
src/diffusers/hub_utils.py
View file @
2ea64a08
...
@@ -50,15 +50,16 @@ def init_git_repo(args, at_init: bool = False):
...
@@ -50,15 +50,16 @@ def init_git_repo(args, at_init: bool = False):
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
"""
"""
if
args
.
local_rank
not
in
[
-
1
,
0
]:
if
hasattr
(
args
,
"local_rank"
)
and
args
.
local_rank
not
in
[
-
1
,
0
]:
return
return
use_auth_token
=
True
if
args
.
hub_token
is
None
else
args
.
hub_token
hub_token
=
args
.
hub_token
if
hasattr
(
args
,
"hub_token"
)
else
None
if
args
.
hub_model_id
is
None
:
use_auth_token
=
True
if
hub_token
is
None
else
hub_token
if
not
hasattr
(
args
,
"hub_model_id"
)
or
args
.
hub_model_id
is
None
:
repo_name
=
Path
(
args
.
output_dir
).
absolute
().
name
repo_name
=
Path
(
args
.
output_dir
).
absolute
().
name
else
:
else
:
repo_name
=
args
.
hub_model_id
repo_name
=
args
.
hub_model_id
if
"/"
not
in
repo_name
:
if
"/"
not
in
repo_name
:
repo_name
=
get_full_repo_name
(
repo_name
,
token
=
args
.
hub_token
)
repo_name
=
get_full_repo_name
(
repo_name
,
token
=
hub_token
)
try
:
try
:
repo
=
Repository
(
repo
=
Repository
(
...
@@ -111,7 +112,7 @@ def push_to_hub(
...
@@ -111,7 +112,7 @@ def push_to_hub(
commit and an object to track the progress of the commit if `blocking=True`
commit and an object to track the progress of the commit if `blocking=True`
"""
"""
if
args
.
hub_model_id
is
None
:
if
not
hasattr
(
args
,
"hub_model_id"
)
or
args
.
hub_model_id
is
None
:
model_name
=
Path
(
args
.
output_dir
).
name
model_name
=
Path
(
args
.
output_dir
).
name
else
:
else
:
model_name
=
args
.
hub_model_id
.
split
(
"/"
)[
-
1
]
model_name
=
args
.
hub_model_id
.
split
(
"/"
)[
-
1
]
...
@@ -122,7 +123,7 @@ def push_to_hub(
...
@@ -122,7 +123,7 @@ def push_to_hub(
pipeline
.
save_pretrained
(
output_dir
)
pipeline
.
save_pretrained
(
output_dir
)
# Only push from one node.
# Only push from one node.
if
args
.
local_rank
not
in
[
-
1
,
0
]:
if
hasattr
(
args
,
"local_rank"
)
and
args
.
local_rank
not
in
[
-
1
,
0
]:
return
return
# Cancel any async push in progress if blocking=True. The commits will all be pushed together.
# Cancel any async push in progress if blocking=True. The commits will all be pushed together.
...
@@ -146,10 +147,11 @@ def push_to_hub(
...
@@ -146,10 +147,11 @@ def push_to_hub(
def
create_model_card
(
args
,
model_name
):
def
create_model_card
(
args
,
model_name
):
if
args
.
local_rank
not
in
[
-
1
,
0
]:
if
hasattr
(
args
,
"local_rank"
)
and
args
.
local_rank
not
in
[
-
1
,
0
]:
return
return
repo_name
=
get_full_repo_name
(
model_name
,
token
=
args
.
hub_token
)
hub_token
=
args
.
hub_token
if
hasattr
(
args
,
"hub_token"
)
else
None
repo_name
=
get_full_repo_name
(
model_name
,
token
=
hub_token
)
model_card
=
ModelCard
.
from_template
(
model_card
=
ModelCard
.
from_template
(
card_data
=
CardData
(
# Card metadata object that will be converted to YAML block
card_data
=
CardData
(
# Card metadata object that will be converted to YAML block
...
@@ -163,20 +165,22 @@ def create_model_card(args, model_name):
...
@@ -163,20 +165,22 @@ def create_model_card(args, model_name):
template_path
=
MODEL_CARD_TEMPLATE_PATH
,
template_path
=
MODEL_CARD_TEMPLATE_PATH
,
model_name
=
model_name
,
model_name
=
model_name
,
repo_name
=
repo_name
,
repo_name
=
repo_name
,
dataset_name
=
args
.
dataset
,
dataset_name
=
args
.
dataset
if
hasattr
(
args
,
"dataset"
)
else
None
,
learning_rate
=
args
.
learning_rate
,
learning_rate
=
args
.
learning_rate
,
train_batch_size
=
args
.
train_batch_size
,
train_batch_size
=
args
.
train_batch_size
,
eval_batch_size
=
args
.
eval_batch_size
,
eval_batch_size
=
args
.
eval_batch_size
,
gradient_accumulation_steps
=
args
.
gradient_accumulation_steps
,
gradient_accumulation_steps
=
args
.
gradient_accumulation_steps
adam_beta1
=
args
.
adam_beta1
,
if
hasattr
(
args
,
"gradient_accumulation_steps"
)
adam_beta2
=
args
.
adam_beta2
,
else
None
,
adam_weight_decay
=
args
.
adam_weight_decay
,
adam_beta1
=
args
.
adam_beta1
if
hasattr
(
args
,
"adam_beta1"
)
else
None
,
adam_epsilon
=
args
.
adam_epsilon
,
adam_beta2
=
args
.
adam_beta2
if
hasattr
(
args
,
"adam_beta2"
)
else
None
,
lr_scheduler
=
args
.
lr_scheduler
,
adam_weight_decay
=
args
.
adam_weight_decay
if
hasattr
(
args
,
"adam_weight_decay"
)
else
None
,
lr_warmup_steps
=
args
.
lr_warmup_steps
,
adam_epsilon
=
args
.
adam_epsilon
if
hasattr
(
args
,
"adam_weight_decay"
)
else
None
,
ema_inv_gamma
=
args
.
ema_inv_gamma
,
lr_scheduler
=
args
.
lr_scheduler
if
hasattr
(
args
,
"lr_scheduler"
)
else
None
,
ema_power
=
args
.
ema_power
,
lr_warmup_steps
=
args
.
lr_warmup_steps
if
hasattr
(
args
,
"lr_warmup_steps"
)
else
None
,
ema_max_decay
=
args
.
ema_max_decay
,
ema_inv_gamma
=
args
.
ema_inv_gamma
if
hasattr
(
args
,
"ema_inv_gamma"
)
else
None
,
ema_power
=
args
.
ema_power
if
hasattr
(
args
,
"ema_power"
)
else
None
,
ema_max_decay
=
args
.
ema_max_decay
if
hasattr
(
args
,
"ema_max_decay"
)
else
None
,
mixed_precision
=
args
.
mixed_precision
,
mixed_precision
=
args
.
mixed_precision
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment