Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a2117cb7
Commit
a2117cb7
authored
Jun 21, 2022
by
anton-l
Browse files
add push_to_hub
parent
8c1f5197
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
197 additions
and
12 deletions
+197
-12
examples/README.md
examples/README.md
+6
-6
examples/train_unconditional.py
examples/train_unconditional.py
+28
-6
src/diffusers/hub_utils.py
src/diffusers/hub_utils.py
+149
-0
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+14
-0
No files found.
examples/README.md
View file @
a2117cb7
## Training examples
### Flowers
DDPM
###
Unconditional
Flowers
The command to train a DDPM UNet model on the Oxford Flowers dataset:
```
bash
python
-m
torch.distributed.launch
\
--nproc_per_node
4
\
train_
ddpm
.py
\
train_
unconditional
.py
\
--dataset
=
"huggan/flowers-102-categories"
\
--resolution
=
64
\
--output_path
=
"flowers-ddpm"
\
...
...
@@ -19,19 +19,19 @@ python -m torch.distributed.launch \
--mixed_precision
=
no
```
A full
l
training run takes 2 hours on 4xV100 GPUs.
A full training run takes 2 hours on 4xV100 GPUs.
<img
src=
"https://user-images.githubusercontent.com/26864830/173855866-5628989f-856b-4725-a944-d6c09490b2df.png"
width=
"500"
/>
### Pokemon
DDPM
###
Unconditional
Pokemon
The command to train a DDPM UNet model on the Pokemon dataset:
```
bash
python
-m
torch.distributed.launch
\
--nproc_per_node
4
\
train_
ddpm
.py
\
train_
unconditional
.py
\
--dataset
=
"huggan/pokemon"
\
--resolution
=
64
\
--output_path
=
"pokemon-ddpm"
\
...
...
@@ -43,6 +43,6 @@ python -m torch.distributed.launch \
--mixed_precision
=
no
```
A full
l
training run takes 2 hours on 4xV100 GPUs.
A full training run takes 2 hours on 4xV100 GPUs.
<img
src=
"https://user-images.githubusercontent.com/26864830/173856733-4f117f8c-97bd-4f51-8002-56b488c96df9.png"
width=
"500"
/>
examples/train_
ddpm
.py
→
examples/train_
unconditional
.py
View file @
a2117cb7
...
...
@@ -19,6 +19,12 @@ from torchvision.transforms import (
)
from
tqdm.auto
import
tqdm
from
transformers
import
get_linear_schedule_with_warmup
from
diffusers.modeling_utils
import
unwrap_model
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
def
main
(
args
):
...
...
@@ -64,6 +70,21 @@ def main(args):
model
,
optimizer
,
train_dataloader
,
lr_scheduler
)
if
args
.
push_to_hub
:
repo
=
init_git_repo
(
args
,
at_init
=
True
)
# Train!
world_size
=
torch
.
distributed
.
get_world_size
()
if
args
.
local_rank
!=
-
1
else
1
total_train_batch_size
=
args
.
batch_size
*
args
.
gradient_accumulation_steps
*
world_size
max_steps
=
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
*
args
.
num_epochs
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
f
" Num examples =
{
len
(
train_dataloader
.
dataset
)
}
"
)
logger
.
info
(
f
" Num Epochs =
{
args
.
num_epochs
}
"
)
logger
.
info
(
f
" Instantaneous batch size per device =
{
args
.
batch_size
}
"
)
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_train_batch_size
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
for
epoch
in
range
(
args
.
num_epochs
):
model
.
train
()
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
...
...
@@ -105,10 +126,10 @@ def main(args):
if
args
.
local_rank
in
[
-
1
,
0
]:
model
.
eval
()
with
torch
.
no_grad
():
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
DistributedDataParallel
):
pipeline
=
DDPM
(
unet
=
model
.
module
,
noise_scheduler
=
noise_scheduler
)
pipeline
=
DDPM
(
unet
=
unwrap_model
(
model
),
noise_scheduler
=
noise_scheduler
)
if
args
.
push_to_hub
:
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
else
:
pipeline
=
DDPM
(
unet
=
model
,
noise_scheduler
=
noise_scheduler
)
pipeline
.
save_pretrained
(
args
.
output_path
)
generator
=
torch
.
manual_seed
(
0
)
...
...
@@ -130,15 +151,16 @@ def main(args):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of a training script."
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"huggan/flowers-102-categories"
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
default
=
"ddpm-model"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--mixed_precision"
,
type
=
str
,
...
...
src/diffusers/hub_utils.py
0 → 100644
View file @
a2117cb7
from
typing
import
Optional
from
.utils
import
logging
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
import
yaml
import
os
from
pathlib
import
Path
import
shutil
from
diffusers
import
DiffusionPipeline
logger
=
logging
.
get_logger
(
__name__
)
AUTOGENERATED_TRAINER_COMMENT
=
"""
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->
"""
def
get_full_repo_name
(
model_id
:
str
,
organization
:
Optional
[
str
]
=
None
,
token
:
Optional
[
str
]
=
None
):
if
token
is
None
:
token
=
HfFolder
.
get_token
()
if
organization
is
None
:
username
=
whoami
(
token
)[
"name"
]
return
f
"
{
username
}
/
{
model_id
}
"
else
:
return
f
"
{
organization
}
/
{
model_id
}
"
def
init_git_repo
(
args
,
at_init
:
bool
=
False
):
"""
Initializes a git repo in `args.hub_model_id`.
Args:
at_init (`bool`, *optional*, defaults to `False`):
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
`True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
out.
"""
if
args
.
local_rank
not
in
[
-
1
,
0
]:
return
use_auth_token
=
True
if
args
.
hub_token
is
None
else
args
.
hub_token
if
args
.
hub_model_id
is
None
:
repo_name
=
Path
(
args
.
output_dir
).
absolute
().
name
else
:
repo_name
=
args
.
hub_model_id
if
"/"
not
in
repo_name
:
repo_name
=
get_full_repo_name
(
repo_name
,
token
=
args
.
hub_token
)
try
:
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
use_auth_token
=
use_auth_token
,
private
=
args
.
hub_private_repo
,
)
except
EnvironmentError
:
if
args
.
overwrite_output_dir
and
at_init
:
# Try again after wiping output_dir
shutil
.
rmtree
(
args
.
output_dir
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
use_auth_token
=
use_auth_token
,
)
else
:
raise
repo
.
git_pull
()
# By default, ignore the checkpoint folders
if
(
not
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
))
and
args
.
hub_strategy
!=
"all_checkpoints"
):
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
writelines
([
"checkpoint-*/"
])
return
repo
def
push_to_hub
(
args
,
pipeline
:
DiffusionPipeline
,
repo
:
Repository
,
commit_message
:
Optional
[
str
]
=
"End of training"
,
blocking
:
bool
=
True
,
**
kwargs
)
->
str
:
"""
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Parameters:
commit_message (`str`, *optional*, defaults to `"End of training"`):
Message to commit while pushing.
blocking (`bool`, *optional*, defaults to `True`):
Whether the function should return only when the `git push` has finished.
kwargs:
Additional keyword arguments passed along to [`create_model_card`].
Returns:
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
the commit and an object to track the progress of the commit if `blocking=True`
"""
if
args
.
hub_model_id
is
None
:
model_name
=
Path
(
args
.
output_dir
).
name
else
:
model_name
=
args
.
hub_model_id
.
split
(
"/"
)[
-
1
]
output_dir
=
args
.
output_dir
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
logger
.
info
(
f
"Saving pipeline checkpoint to
{
output_dir
}
"
)
pipeline
.
save_pretrained
(
output_dir
)
# Only push from one node.
if
args
.
local_rank
not
in
[
-
1
,
0
]:
return
# Cancel any async push in progress if blocking=True. The commits will all be pushed together.
if
blocking
and
len
(
repo
.
command_queue
)
>
0
and
repo
.
command_queue
[
-
1
]
is
not
None
and
not
repo
.
command_queue
[
-
1
].
is_done
:
repo
.
command_queue
[
-
1
].
_process
.
kill
()
git_head_commit_url
=
repo
.
push_to_hub
(
commit_message
=
commit_message
,
blocking
=
blocking
,
auto_lfs_prune
=
True
)
# push separately the model card to be independent from the rest of the model
create_model_card
(
args
,
model_name
=
model_name
)
try
:
repo
.
push_to_hub
(
commit_message
=
"update model card README.md"
,
blocking
=
blocking
,
auto_lfs_prune
=
True
)
except
EnvironmentError
as
exc
:
logger
.
error
(
f
"Error pushing update to the model card. Please read logs and retry.
\n
$
{
exc
}
"
)
return
git_head_commit_url
def
create_model_card
(
args
,
model_name
):
if
args
.
local_rank
not
in
[
-
1
,
0
]:
return
# TODO: replace this placeholder model card generation
model_card
=
""
metadata
=
{
"license"
:
"apache-2.0"
,
"tags"
:
[
"pytorch"
,
"diffusers"
]
}
metadata
=
yaml
.
dump
(
metadata
,
sort_keys
=
False
)
if
len
(
metadata
)
>
0
:
model_card
=
f
"---
\n
{
metadata
}
---
\n
"
model_card
+=
AUTOGENERATED_TRAINER_COMMENT
model_card
+=
f
"
\n
#
{
model_name
}
\n\n
"
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
"README.md"
),
"w"
)
as
f
:
f
.
write
(
model_card
)
src/diffusers/modeling_utils.py
View file @
a2117cb7
...
...
@@ -572,3 +572,17 @@ class ModelMixin(torch.nn.Module):
return
sum
(
p
.
numel
()
for
p
in
non_embedding_parameters
if
p
.
requires_grad
or
not
only_trainable
)
else
:
return
sum
(
p
.
numel
()
for
p
in
self
.
parameters
()
if
p
.
requires_grad
or
not
only_trainable
)
def
unwrap_model
(
model
:
torch
.
nn
.
Module
)
->
torch
.
nn
.
Module
:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if
hasattr
(
model
,
"module"
):
return
unwrap_model
(
model
.
module
)
else
:
return
model
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment