"src/diffusers/models/modeling_flax_utils.py" did not exist on "4b8880a30660e24d5815f940c5e5f70d05ba7e04"
Unverified Commit 8fcd52fe authored by Thuan H. Nguyen's avatar Thuan H. Nguyen Committed by GitHub
Browse files

Correct code for distributed training of RealFill (#5740)

Correct code for distributed training
parent 0488810f
......@@ -639,7 +639,7 @@ def main(args):
for model in models:
sub_dir = (
"unet"
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model)))
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))
else "text_encoder"
)
model.save_pretrained(os.path.join(output_dir, sub_dir))
......@@ -654,12 +654,12 @@ def main(args):
sub_dir = (
"unet"
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model)))
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))
else "text_encoder"
)
model_cls = (
UNet2DConditionModel
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model)))
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))
else CLIPTextModel
)
......@@ -937,8 +937,8 @@ def main(args):
if accelerator.is_main_process:
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet.merge_and_unload(), keep_fp32_wrapper=True),
text_encoder=accelerator.unwrap_model(text_encoder.merge_and_unload(), keep_fp32_wrapper=True),
unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True).merge_and_unload(),
text_encoder=accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True).merge_and_unload(),
revision=args.revision,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment