Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e185084a
Unverified
Commit
e185084a
authored
Dec 04, 2023
by
Levi McCallum
Committed by
GitHub
Dec 04, 2023
Browse files
Add variant argument to dreambooth lora sdxl advanced (#6021)
Co-authored-by:
Sayak Paul
<
spsayakpaul@gmail.com
>
parent
b2172922
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
9 deletions
+39
-9
examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
...diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+39
-9
No files found.
examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
View file @
e185084a
...
...
@@ -225,6 +225,12 @@ def parse_args(input_args=None):
required
=
False
,
help
=
"Revision of pretrained model identifier from huggingface.co/models."
,
)
parser
.
add_argument
(
"--variant"
,
type
=
str
,
default
=
None
,
help
=
"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16"
,
)
parser
.
add_argument
(
"--dataset_name"
,
type
=
str
,
...
...
@@ -1064,6 +1070,7 @@ def main(args):
args
.
pretrained_model_name_or_path
,
torch_dtype
=
torch_dtype
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
...
...
@@ -1102,10 +1109,18 @@ def main(args):
# Load the tokenizers
tokenizer_one
=
AutoTokenizer
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"tokenizer"
,
revision
=
args
.
revision
,
use_fast
=
False
args
.
pretrained_model_name_or_path
,
subfolder
=
"tokenizer"
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
use_fast
=
False
,
)
tokenizer_two
=
AutoTokenizer
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"tokenizer_2"
,
revision
=
args
.
revision
,
use_fast
=
False
args
.
pretrained_model_name_or_path
,
subfolder
=
"tokenizer_2"
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
use_fast
=
False
,
)
# import correct text encoder classes
...
...
@@ -1119,10 +1134,10 @@ def main(args):
# Load scheduler and models
noise_scheduler
=
DDPMScheduler
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"scheduler"
)
text_encoder_one
=
text_encoder_cls_one
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
args
.
revision
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
args
.
revision
,
variant
=
args
.
variant
)
text_encoder_two
=
text_encoder_cls_two
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder_2"
,
revision
=
args
.
revision
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder_2"
,
revision
=
args
.
revision
,
variant
=
args
.
variant
)
vae_path
=
(
args
.
pretrained_model_name_or_path
...
...
@@ -1130,10 +1145,13 @@ def main(args):
else
args
.
pretrained_vae_model_name_or_path
)
vae
=
AutoencoderKL
.
from_pretrained
(
vae_path
,
subfolder
=
"vae"
if
args
.
pretrained_vae_model_name_or_path
is
None
else
None
,
revision
=
args
.
revision
vae_path
,
subfolder
=
"vae"
if
args
.
pretrained_vae_model_name_or_path
is
None
else
None
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
)
unet
=
UNet2DConditionModel
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"unet"
,
revision
=
args
.
revision
args
.
pretrained_model_name_or_path
,
subfolder
=
"unet"
,
revision
=
args
.
revision
,
variant
=
args
.
variant
)
if
args
.
train_text_encoder_ti
:
...
...
@@ -1843,10 +1861,16 @@ def main(args):
# create pipeline
if
freeze_text_encoder
:
text_encoder_one
=
text_encoder_cls_one
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
args
.
revision
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
)
text_encoder_two
=
text_encoder_cls_two
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder_2"
,
revision
=
args
.
revision
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder_2"
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
)
pipeline
=
StableDiffusionXLPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
...
...
@@ -1855,6 +1879,7 @@ def main(args):
text_encoder_2
=
accelerator
.
unwrap_model
(
text_encoder_two
),
unet
=
accelerator
.
unwrap_model
(
unet
),
revision
=
args
.
revision
,
variant
=
args
.
variant
,
torch_dtype
=
weight_dtype
,
)
...
...
@@ -1932,10 +1957,15 @@ def main(args):
vae_path
,
subfolder
=
"vae"
if
args
.
pretrained_vae_model_name_or_path
is
None
else
None
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
torch_dtype
=
weight_dtype
,
)
pipeline
=
StableDiffusionXLPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
vae
=
vae
,
revision
=
args
.
revision
,
torch_dtype
=
weight_dtype
args
.
pretrained_model_name_or_path
,
vae
=
vae
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
torch_dtype
=
weight_dtype
,
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment