Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
6ca9c4af
Unverified
Commit
6ca9c4af
authored
Dec 21, 2023
by
lvzi
Committed by
GitHub
Dec 21, 2023
Browse files
fix: unscale fp16 gradient problem & potential error (#6086) (#6231)
Co-authored-by:
Sayak Paul
<
spsayakpaul@gmail.com
>
parent
0532cece
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
0 deletions
+14
-0
examples/text_to_image/train_text_to_image_lora_sdxl.py
examples/text_to_image/train_text_to_image_lora_sdxl.py
+14
-0
No files found.
examples/text_to_image/train_text_to_image_lora_sdxl.py
View file @
6ca9c4af
...
@@ -640,6 +640,17 @@ def main(args):
...
@@ -640,6 +640,17 @@ def main(args):
text_encoder_one
.
add_adapter
(
text_lora_config
)
text_encoder_one
.
add_adapter
(
text_lora_config
)
text_encoder_two
.
add_adapter
(
text_lora_config
)
text_encoder_two
.
add_adapter
(
text_lora_config
)
# Make sure the trainable params are in float32.
if
args
.
mixed_precision
==
"fp16"
:
models
=
[
unet
]
if
args
.
train_text_encoder
:
models
.
extend
([
text_encoder_one
,
text_encoder_two
])
for
model
in
models
:
for
param
in
model
.
parameters
():
# only upcast trainable parameters (LoRA) into fp32
if
param
.
requires_grad
:
param
.
data
=
param
.
to
(
torch
.
float32
)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def
save_model_hook
(
models
,
weights
,
output_dir
):
def
save_model_hook
(
models
,
weights
,
output_dir
):
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
...
@@ -1187,6 +1198,9 @@ def main(args):
...
@@ -1187,6 +1198,9 @@ def main(args):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
# Final inference
# Final inference
# Make sure vae.dtype is consistent with the unet.dtype
if
args
.
mixed_precision
==
"fp16"
:
vae
.
to
(
weight_dtype
)
# Load previous pipeline
# Load previous pipeline
pipeline
=
StableDiffusionXLPipeline
.
from_pretrained
(
pipeline
=
StableDiffusionXLPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment