Unverified Commit d73e6ad0 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

guard save model hooks to only execute on main process (#4929)

parent d0cf681a
...@@ -785,16 +785,17 @@ def main(args): ...@@ -785,16 +785,17 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
i = len(weights) - 1 if accelerator.is_main_process:
i = len(weights) - 1
while len(weights) > 0: while len(weights) > 0:
weights.pop() weights.pop()
model = models[i] model = models[i]
sub_dir = "controlnet" sub_dir = "controlnet"
model.save_pretrained(os.path.join(output_dir, sub_dir)) model.save_pretrained(os.path.join(output_dir, sub_dir))
i -= 1 i -= 1
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
while len(models) > 0: while len(models) > 0:
......
...@@ -840,16 +840,17 @@ def main(args): ...@@ -840,16 +840,17 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
i = len(weights) - 1 if accelerator.is_main_process:
i = len(weights) - 1
while len(weights) > 0: while len(weights) > 0:
weights.pop() weights.pop()
model = models[i] model = models[i]
sub_dir = "controlnet" sub_dir = "controlnet"
model.save_pretrained(os.path.join(output_dir, sub_dir)) model.save_pretrained(os.path.join(output_dir, sub_dir))
i -= 1 i -= 1
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
while len(models) > 0: while len(models) > 0:
......
...@@ -920,12 +920,13 @@ def main(args): ...@@ -920,12 +920,13 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
for model in models: if accelerator.is_main_process:
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder" for model in models:
model.save_pretrained(os.path.join(output_dir, sub_dir)) sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
while len(models) > 0: while len(models) > 0:
......
...@@ -894,27 +894,28 @@ def main(args): ...@@ -894,27 +894,28 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
# there are only two options here. Either are just the unet attn processor layers if accelerator.is_main_process:
# or there are the unet and text encoder atten layers # there are only two options here. Either are just the unet attn processor layers
unet_lora_layers_to_save = None # or there are the unet and text encoder atten layers
text_encoder_lora_layers_to_save = None unet_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))): for model in models:
unet_lora_layers_to_save = unet_attn_processors_state_dict(model) if isinstance(model, type(accelerator.unwrap_model(unet))):
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
else: text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
raise ValueError(f"unexpected save model: {model.__class__}") else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
LoraLoaderMixin.save_lora_weights( LoraLoaderMixin.save_lora_weights(
output_dir, output_dir,
unet_lora_layers=unet_lora_layers_to_save, unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save, text_encoder_lora_layers=text_encoder_lora_layers_to_save,
) )
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
unet_ = None unet_ = None
......
...@@ -798,31 +798,32 @@ def main(args): ...@@ -798,31 +798,32 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
# there are only two options here. Either are just the unet attn processor layers if accelerator.is_main_process:
# or there are the unet and text encoder atten layers # there are only two options here. Either are just the unet attn processor layers
unet_lora_layers_to_save = None # or there are the unet and text encoder atten layers
text_encoder_one_lora_layers_to_save = None unet_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))): for model in models:
unet_lora_layers_to_save = unet_attn_processors_state_dict(model) if isinstance(model, type(accelerator.unwrap_model(unet))):
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
else: text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
raise ValueError(f"unexpected save model: {model.__class__}") else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
StableDiffusionXLPipeline.save_lora_weights( StableDiffusionXLPipeline.save_lora_weights(
output_dir, output_dir,
unet_lora_layers=unet_lora_layers_to_save, unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
) )
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
unet_ = None unet_ = None
......
...@@ -485,14 +485,15 @@ def main(): ...@@ -485,14 +485,15 @@ def main():
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if args.use_ema: if accelerator.is_main_process:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models): for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet")) model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
......
...@@ -528,14 +528,15 @@ def main(): ...@@ -528,14 +528,15 @@ def main():
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if args.use_ema: if accelerator.is_main_process:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models): for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet")) model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
......
...@@ -1010,16 +1010,17 @@ def main(args): ...@@ -1010,16 +1010,17 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
i = len(weights) - 1 if accelerator.is_main_process:
i = len(weights) - 1
while len(weights) > 0: while len(weights) > 0:
weights.pop() weights.pop()
model = models[i] model = models[i]
sub_dir = "controlnet" sub_dir = "controlnet"
model.save_pretrained(os.path.join(output_dir, sub_dir)) model.save_pretrained(os.path.join(output_dir, sub_dir))
i -= 1 i -= 1
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
while len(models) > 0: while len(models) > 0:
......
...@@ -552,14 +552,15 @@ def main(): ...@@ -552,14 +552,15 @@ def main():
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if args.use_ema: if accelerator.is_main_process:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models): for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet")) model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
......
...@@ -313,14 +313,15 @@ def main(args): ...@@ -313,14 +313,15 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if args.use_ema: if accelerator.is_main_process:
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema")) if args.use_ema:
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models): for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet")) model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
......
...@@ -629,14 +629,15 @@ def main(): ...@@ -629,14 +629,15 @@ def main():
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if args.use_ema: if accelerator.is_main_process:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models): for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet")) model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
......
...@@ -669,31 +669,32 @@ def main(args): ...@@ -669,31 +669,32 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
# there are only two options here. Either are just the unet attn processor layers if accelerator.is_main_process:
# or there are the unet and text encoder atten layers # there are only two options here. Either are just the unet attn processor layers
unet_lora_layers_to_save = None # or there are the unet and text encoder atten layers
text_encoder_one_lora_layers_to_save = None unet_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))): for model in models:
unet_lora_layers_to_save = unet_attn_processors_state_dict(model) if isinstance(model, type(accelerator.unwrap_model(unet))):
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
else: text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
raise ValueError(f"unexpected save model: {model.__class__}") else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
StableDiffusionXLPipeline.save_lora_weights( StableDiffusionXLPipeline.save_lora_weights(
output_dir, output_dir,
unet_lora_layers=unet_lora_layers_to_save, unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
) )
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
unet_ = None unet_ = None
......
...@@ -651,14 +651,15 @@ def main(args): ...@@ -651,14 +651,15 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if args.use_ema: if accelerator.is_main_process:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models): for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet")) model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
......
...@@ -309,14 +309,15 @@ def main(args): ...@@ -309,14 +309,15 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if args.use_ema: if accelerator.is_main_process:
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema")) if args.use_ema:
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models): for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet")) model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment