Unverified Commit d73e6ad0 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

guard save model hooks to only execute on main process (#4929)

parent d0cf681a
...@@ -785,6 +785,7 @@ def main(args): ...@@ -785,6 +785,7 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
i = len(weights) - 1 i = len(weights) - 1
while len(weights) > 0: while len(weights) > 0:
......
...@@ -840,6 +840,7 @@ def main(args): ...@@ -840,6 +840,7 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
i = len(weights) - 1 i = len(weights) - 1
while len(weights) > 0: while len(weights) > 0:
......
...@@ -920,6 +920,7 @@ def main(args): ...@@ -920,6 +920,7 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for model in models: for model in models:
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder" sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir)) model.save_pretrained(os.path.join(output_dir, sub_dir))
......
...@@ -894,6 +894,7 @@ def main(args): ...@@ -894,6 +894,7 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
# there are only two options here. Either are just the unet attn processor layers # there are only two options here. Either are just the unet attn processor layers
# or there are the unet and text encoder atten layers # or there are the unet and text encoder atten layers
unet_lora_layers_to_save = None unet_lora_layers_to_save = None
......
...@@ -798,6 +798,7 @@ def main(args): ...@@ -798,6 +798,7 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
# there are only two options here. Either are just the unet attn processor layers # there are only two options here. Either are just the unet attn processor layers
# or there are the unet and text encoder atten layers # or there are the unet and text encoder atten layers
unet_lora_layers_to_save = None unet_lora_layers_to_save = None
......
...@@ -485,6 +485,7 @@ def main(): ...@@ -485,6 +485,7 @@ def main():
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
if args.use_ema: if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
......
...@@ -528,6 +528,7 @@ def main(): ...@@ -528,6 +528,7 @@ def main():
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
if args.use_ema: if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
......
...@@ -1010,6 +1010,7 @@ def main(args): ...@@ -1010,6 +1010,7 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
i = len(weights) - 1 i = len(weights) - 1
while len(weights) > 0: while len(weights) > 0:
......
...@@ -552,6 +552,7 @@ def main(): ...@@ -552,6 +552,7 @@ def main():
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
if args.use_ema: if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
......
...@@ -313,6 +313,7 @@ def main(args): ...@@ -313,6 +313,7 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
if args.use_ema: if args.use_ema:
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema")) ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
......
...@@ -629,6 +629,7 @@ def main(): ...@@ -629,6 +629,7 @@ def main():
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
if args.use_ema: if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
......
...@@ -669,6 +669,7 @@ def main(args): ...@@ -669,6 +669,7 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
# there are only two options here. Either are just the unet attn processor layers # there are only two options here. Either are just the unet attn processor layers
# or there are the unet and text encoder atten layers # or there are the unet and text encoder atten layers
unet_lora_layers_to_save = None unet_lora_layers_to_save = None
......
...@@ -651,6 +651,7 @@ def main(args): ...@@ -651,6 +651,7 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
if args.use_ema: if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
......
...@@ -309,6 +309,7 @@ def main(args): ...@@ -309,6 +309,7 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
if args.use_ema: if args.use_ema:
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema")) ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment