Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
e185084a
Unverified
Commit
e185084a
authored
Dec 04, 2023
by
Levi McCallum
Committed by
GitHub
Dec 04, 2023
Browse files
Add variant argument to dreambooth lora sdxl advanced (#6021)
Co-authored-by:
Sayak Paul
<
spsayakpaul@gmail.com
>
parent
b2172922
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
9 deletions
+39
-9
examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
...diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+39
-9
No files found.
examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
View file @
e185084a
...
@@ -225,6 +225,12 @@ def parse_args(input_args=None):
...
@@ -225,6 +225,12 @@ def parse_args(input_args=None):
required
=
False
,
required
=
False
,
help
=
"Revision of pretrained model identifier from huggingface.co/models."
,
help
=
"Revision of pretrained model identifier from huggingface.co/models."
,
)
)
parser
.
add_argument
(
"--variant"
,
type
=
str
,
default
=
None
,
help
=
"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--dataset_name"
,
"--dataset_name"
,
type
=
str
,
type
=
str
,
...
@@ -1064,6 +1070,7 @@ def main(args):
...
@@ -1064,6 +1070,7 @@ def main(args):
args
.
pretrained_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
torch_dtype
=
torch_dtype
,
torch_dtype
=
torch_dtype
,
revision
=
args
.
revision
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
)
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
...
@@ -1102,10 +1109,18 @@ def main(args):
...
@@ -1102,10 +1109,18 @@ def main(args):
# Load the tokenizers
# Load the tokenizers
tokenizer_one
=
AutoTokenizer
.
from_pretrained
(
tokenizer_one
=
AutoTokenizer
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"tokenizer"
,
revision
=
args
.
revision
,
use_fast
=
False
args
.
pretrained_model_name_or_path
,
subfolder
=
"tokenizer"
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
use_fast
=
False
,
)
)
tokenizer_two
=
AutoTokenizer
.
from_pretrained
(
tokenizer_two
=
AutoTokenizer
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"tokenizer_2"
,
revision
=
args
.
revision
,
use_fast
=
False
args
.
pretrained_model_name_or_path
,
subfolder
=
"tokenizer_2"
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
use_fast
=
False
,
)
)
# import correct text encoder classes
# import correct text encoder classes
...
@@ -1119,10 +1134,10 @@ def main(args):
...
@@ -1119,10 +1134,10 @@ def main(args):
# Load scheduler and models
# Load scheduler and models
noise_scheduler
=
DDPMScheduler
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"scheduler"
)
noise_scheduler
=
DDPMScheduler
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"scheduler"
)
text_encoder_one
=
text_encoder_cls_one
.
from_pretrained
(
text_encoder_one
=
text_encoder_cls_one
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
args
.
revision
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
args
.
revision
,
variant
=
args
.
variant
)
)
text_encoder_two
=
text_encoder_cls_two
.
from_pretrained
(
text_encoder_two
=
text_encoder_cls_two
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder_2"
,
revision
=
args
.
revision
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder_2"
,
revision
=
args
.
revision
,
variant
=
args
.
variant
)
)
vae_path
=
(
vae_path
=
(
args
.
pretrained_model_name_or_path
args
.
pretrained_model_name_or_path
...
@@ -1130,10 +1145,13 @@ def main(args):
...
@@ -1130,10 +1145,13 @@ def main(args):
else
args
.
pretrained_vae_model_name_or_path
else
args
.
pretrained_vae_model_name_or_path
)
)
vae
=
AutoencoderKL
.
from_pretrained
(
vae
=
AutoencoderKL
.
from_pretrained
(
vae_path
,
subfolder
=
"vae"
if
args
.
pretrained_vae_model_name_or_path
is
None
else
None
,
revision
=
args
.
revision
vae_path
,
subfolder
=
"vae"
if
args
.
pretrained_vae_model_name_or_path
is
None
else
None
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
)
)
unet
=
UNet2DConditionModel
.
from_pretrained
(
unet
=
UNet2DConditionModel
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"unet"
,
revision
=
args
.
revision
args
.
pretrained_model_name_or_path
,
subfolder
=
"unet"
,
revision
=
args
.
revision
,
variant
=
args
.
variant
)
)
if
args
.
train_text_encoder_ti
:
if
args
.
train_text_encoder_ti
:
...
@@ -1843,10 +1861,16 @@ def main(args):
...
@@ -1843,10 +1861,16 @@ def main(args):
# create pipeline
# create pipeline
if
freeze_text_encoder
:
if
freeze_text_encoder
:
text_encoder_one
=
text_encoder_cls_one
.
from_pretrained
(
text_encoder_one
=
text_encoder_cls_one
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
args
.
revision
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
)
)
text_encoder_two
=
text_encoder_cls_two
.
from_pretrained
(
text_encoder_two
=
text_encoder_cls_two
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder_2"
,
revision
=
args
.
revision
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder_2"
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
)
)
pipeline
=
StableDiffusionXLPipeline
.
from_pretrained
(
pipeline
=
StableDiffusionXLPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
...
@@ -1855,6 +1879,7 @@ def main(args):
...
@@ -1855,6 +1879,7 @@ def main(args):
text_encoder_2
=
accelerator
.
unwrap_model
(
text_encoder_two
),
text_encoder_2
=
accelerator
.
unwrap_model
(
text_encoder_two
),
unet
=
accelerator
.
unwrap_model
(
unet
),
unet
=
accelerator
.
unwrap_model
(
unet
),
revision
=
args
.
revision
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
torch_dtype
=
weight_dtype
,
torch_dtype
=
weight_dtype
,
)
)
...
@@ -1932,10 +1957,15 @@ def main(args):
...
@@ -1932,10 +1957,15 @@ def main(args):
vae_path
,
vae_path
,
subfolder
=
"vae"
if
args
.
pretrained_vae_model_name_or_path
is
None
else
None
,
subfolder
=
"vae"
if
args
.
pretrained_vae_model_name_or_path
is
None
else
None
,
revision
=
args
.
revision
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
torch_dtype
=
weight_dtype
,
torch_dtype
=
weight_dtype
,
)
)
pipeline
=
StableDiffusionXLPipeline
.
from_pretrained
(
pipeline
=
StableDiffusionXLPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
vae
=
vae
,
revision
=
args
.
revision
,
torch_dtype
=
weight_dtype
args
.
pretrained_model_name_or_path
,
vae
=
vae
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
torch_dtype
=
weight_dtype
,
)
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment