Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
007c914c
Unverified
Commit
007c914c
authored
Jan 19, 2023
by
Patrick von Platen
Committed by
GitHub
Jan 19, 2023
Browse files
[Lora] Model card (#2032)
* [Lora] up lora training * finish * finish * finish model card
parent
3c07840b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
62 additions
and
26 deletions
+62
-26
examples/dreambooth/train_dreambooth_lora.py
examples/dreambooth/train_dreambooth_lora.py
+62
-26
No files found.
examples/dreambooth/train_dreambooth_lora.py
View file @
007c914c
...
@@ -58,6 +58,34 @@ check_min_version("0.12.0.dev0")
...
@@ -58,6 +58,34 @@ check_min_version("0.12.0.dev0")
logger
=
get_logger
(
__name__
)
logger
=
get_logger
(
__name__
)
def
save_model_card
(
repo_name
,
images
=
None
,
base_model
=
str
,
prompt
=
str
,
repo_folder
=
None
):
img_str
=
""
for
i
,
image
in
enumerate
(
images
):
image
.
save
(
os
.
path
.
join
(
repo_folder
,
f
"image_
{
i
}
.png"
))
img_str
+=
f
"
\n
"
yaml
=
f
"""
---
license: creativeml-openrail-m
base_model:
{
base_model
}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
inference: true
---
"""
model_card
=
f
"""
# LoRA DreamBooth -
{
repo_name
}
These are LoRA adaption weights for
{
repo_name
}
. The weights were trained on
{
prompt
}
using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following.
\n
{
img_str
}
"""
with
open
(
os
.
path
.
join
(
repo_folder
,
"README.md"
),
"w"
)
as
f
:
f
.
write
(
yaml
+
model_card
)
def
import_model_class_from_model_name_or_path
(
pretrained_model_name_or_path
:
str
,
revision
:
str
):
def
import_model_class_from_model_name_or_path
(
pretrained_model_name_or_path
:
str
,
revision
:
str
):
text_encoder_config
=
PretrainedConfig
.
from_pretrained
(
text_encoder_config
=
PretrainedConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
...
@@ -913,34 +941,42 @@ def main(args):
...
@@ -913,34 +941,42 @@ def main(args):
unet
=
unet
.
to
(
torch
.
float32
)
unet
=
unet
.
to
(
torch
.
float32
)
unet
.
save_attn_procs
(
args
.
output_dir
)
unet
.
save_attn_procs
(
args
.
output_dir
)
if
args
.
push_to_hub
:
# Final inference
repo
.
push_to_hub
(
commit_message
=
"End of training"
,
blocking
=
False
,
auto_lfs_prune
=
True
)
# Load previous pipeline
pipeline
=
DiffusionPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
revision
=
args
.
revision
,
torch_dtype
=
weight_dtype
)
pipeline
.
scheduler
=
DPMSolverMultistepScheduler
.
from_config
(
pipeline
.
scheduler
.
config
)
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
# load attention processors
pipeline
.
unet
.
load_attn_procs
(
args
.
output_dir
)
# run inference
generator
=
torch
.
Generator
(
device
=
accelerator
.
device
).
manual_seed
(
args
.
seed
)
prompt
=
args
.
num_validation_images
*
[
args
.
validation_prompt
]
images
=
pipeline
(
prompt
,
num_inference_steps
=
25
,
generator
=
generator
).
images
for
tracker
in
accelerator
.
trackers
:
if
tracker
.
name
==
"wandb"
:
tracker
.
log
(
{
"test"
:
[
wandb
.
Image
(
image
,
caption
=
f
"
{
i
}
:
{
args
.
validation_prompt
}
"
)
for
i
,
image
in
enumerate
(
images
)
]
}
)
# Final inference
if
args
.
push_to_hub
:
# Load previous pipeline
save_model_card
(
pipeline
=
DiffusionPipeline
.
from_pretrained
(
repo_name
,
args
.
pretrained_model_name_or_path
,
revision
=
args
.
revision
,
torch_dtype
=
weight_dtype
images
=
images
,
)
base_model
=
args
.
pretrained_model_name_or_path
,
pipeline
.
scheduler
=
DPMSolverMultistepScheduler
.
from_config
(
pipeline
.
scheduler
.
config
)
prompt
=
args
.
instance_prompt
,
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
repo_folder
=
args
.
output_dir
,
# load attention processors
pipeline
.
unet
.
load_attn_procs
(
args
.
output_dir
)
# run inference
generator
=
torch
.
Generator
(
device
=
accelerator
.
device
).
manual_seed
(
args
.
seed
)
prompt
=
args
.
num_validation_images
*
[
args
.
validation_prompt
]
images
=
pipeline
(
prompt
,
num_inference_steps
=
25
,
generator
=
generator
).
images
for
tracker
in
accelerator
.
trackers
:
if
tracker
.
name
==
"wandb"
:
tracker
.
log
(
{
"test"
:
[
wandb
.
Image
(
image
,
caption
=
f
"
{
i
}
:
{
args
.
validation_prompt
}
"
)
for
i
,
image
in
enumerate
(
images
)
]
}
)
)
repo
.
push_to_hub
(
commit_message
=
"End of training"
,
blocking
=
False
,
auto_lfs_prune
=
True
)
accelerator
.
end_training
()
accelerator
.
end_training
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment