Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8fcd52fe
Unverified
Commit
8fcd52fe
authored
Nov 14, 2023
by
Thuan H. Nguyen
Committed by
GitHub
Nov 13, 2023
Browse files
Correct code for distributed training of RealFill (#5740)
Correct code for distributed training
parent
0488810f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
examples/research_projects/realfill/train_realfill.py
examples/research_projects/realfill/train_realfill.py
+5
-5
No files found.
examples/research_projects/realfill/train_realfill.py
View file @
8fcd52fe
...
@@ -639,7 +639,7 @@ def main(args):
...
@@ -639,7 +639,7 @@ def main(args):
for
model
in
models
:
for
model
in
models
:
sub_dir
=
(
sub_dir
=
(
"unet"
"unet"
if
isinstance
(
model
.
base_model
.
model
,
type
(
accelerator
.
unwrap_model
(
unet
.
base_model
.
model
))
)
if
isinstance
(
model
.
base_model
.
model
,
type
(
accelerator
.
unwrap_model
(
unet
)
.
base_model
.
model
))
else
"text_encoder"
else
"text_encoder"
)
)
model
.
save_pretrained
(
os
.
path
.
join
(
output_dir
,
sub_dir
))
model
.
save_pretrained
(
os
.
path
.
join
(
output_dir
,
sub_dir
))
...
@@ -654,12 +654,12 @@ def main(args):
...
@@ -654,12 +654,12 @@ def main(args):
sub_dir
=
(
sub_dir
=
(
"unet"
"unet"
if
isinstance
(
model
.
base_model
.
model
,
type
(
accelerator
.
unwrap_model
(
unet
.
base_model
.
model
))
)
if
isinstance
(
model
.
base_model
.
model
,
type
(
accelerator
.
unwrap_model
(
unet
)
.
base_model
.
model
))
else
"text_encoder"
else
"text_encoder"
)
)
model_cls
=
(
model_cls
=
(
UNet2DConditionModel
UNet2DConditionModel
if
isinstance
(
model
.
base_model
.
model
,
type
(
accelerator
.
unwrap_model
(
unet
.
base_model
.
model
))
)
if
isinstance
(
model
.
base_model
.
model
,
type
(
accelerator
.
unwrap_model
(
unet
)
.
base_model
.
model
))
else
CLIPTextModel
else
CLIPTextModel
)
)
...
@@ -937,8 +937,8 @@ def main(args):
...
@@ -937,8 +937,8 @@ def main(args):
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
pipeline
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
pipeline
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
unet
=
accelerator
.
unwrap_model
(
unet
.
merge_and_unload
()
,
keep_fp32_wrapper
=
True
),
unet
=
accelerator
.
unwrap_model
(
unet
,
keep_fp32_wrapper
=
True
)
.
merge_and_unload
()
,
text_encoder
=
accelerator
.
unwrap_model
(
text_encoder
.
merge_and_unload
()
,
keep_fp32_wrapper
=
True
),
text_encoder
=
accelerator
.
unwrap_model
(
text_encoder
,
keep_fp32_wrapper
=
True
)
.
merge_and_unload
()
,
revision
=
args
.
revision
,
revision
=
args
.
revision
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment