You need to sign in or sign up before continuing.
Unverified Commit d73e6ad0 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

guard save model hooks to only execute on main process (#4929)

parent d0cf681a
...@@ -785,16 +785,17 @@ def main(args): ...@@ -785,16 +785,17 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
i = len(weights) - 1 if accelerator.is_main_process:
i = len(weights) - 1
while len(weights) > 0: while len(weights) > 0:
weights.pop() weights.pop()
model = models[i] model = models[i]
sub_dir = "controlnet" sub_dir = "controlnet"
model.save_pretrained(os.path.join(output_dir, sub_dir)) model.save_pretrained(os.path.join(output_dir, sub_dir))
i -= 1 i -= 1
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
while len(models) > 0: while len(models) > 0:
......
...@@ -840,16 +840,17 @@ def main(args): ...@@ -840,16 +840,17 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
i = len(weights) - 1 if accelerator.is_main_process:
i = len(weights) - 1
while len(weights) > 0: while len(weights) > 0:
weights.pop() weights.pop()
model = models[i] model = models[i]
sub_dir = "controlnet" sub_dir = "controlnet"
model.save_pretrained(os.path.join(output_dir, sub_dir)) model.save_pretrained(os.path.join(output_dir, sub_dir))
i -= 1 i -= 1
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
while len(models) > 0: while len(models) > 0:
......
...@@ -920,12 +920,13 @@ def main(args): ...@@ -920,12 +920,13 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
for model in models: if accelerator.is_main_process:
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder" for model in models:
model.save_pretrained(os.path.join(output_dir, sub_dir)) sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
while len(models) > 0: while len(models) > 0:
......
...@@ -894,27 +894,28 @@ def main(args): ...@@ -894,27 +894,28 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
# there are only two options here. Either are just the unet attn processor layers if accelerator.is_main_process:
# or there are the unet and text encoder atten layers # there are only two options here. Either are just the unet attn processor layers
unet_lora_layers_to_save = None # or there are the unet and text encoder atten layers
text_encoder_lora_layers_to_save = None unet_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))): for model in models:
unet_lora_layers_to_save = unet_attn_processors_state_dict(model) if isinstance(model, type(accelerator.unwrap_model(unet))):
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
else: text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
raise ValueError(f"unexpected save model: {model.__class__}") else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
LoraLoaderMixin.save_lora_weights( LoraLoaderMixin.save_lora_weights(
output_dir, output_dir,
unet_lora_layers=unet_lora_layers_to_save, unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save, text_encoder_lora_layers=text_encoder_lora_layers_to_save,
) )
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
unet_ = None unet_ = None
......
...@@ -798,31 +798,32 @@ def main(args): ...@@ -798,31 +798,32 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
# there are only two options here. Either are just the unet attn processor layers if accelerator.is_main_process:
# or there are the unet and text encoder atten layers # there are only two options here. Either are just the unet attn processor layers
unet_lora_layers_to_save = None # or there are the unet and text encoder atten layers
text_encoder_one_lora_layers_to_save = None unet_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))): for model in models:
unet_lora_layers_to_save = unet_attn_processors_state_dict(model) if isinstance(model, type(accelerator.unwrap_model(unet))):
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
else: text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
raise ValueError(f"unexpected save model: {model.__class__}") else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
StableDiffusionXLPipeline.save_lora_weights( StableDiffusionXLPipeline.save_lora_weights(
output_dir, output_dir,
unet_lora_layers=unet_lora_layers_to_save, unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
) )
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
unet_ = None unet_ = None
......
...@@ -485,14 +485,15 @@ def main(): ...@@ -485,14 +485,15 @@ def main():
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if args.use_ema: if accelerator.is_main_process:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models): for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet")) model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
......
...@@ -528,14 +528,15 @@ def main(): ...@@ -528,14 +528,15 @@ def main():
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if args.use_ema: if accelerator.is_main_process:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models): for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet")) model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
......
...@@ -1010,16 +1010,17 @@ def main(args): ...@@ -1010,16 +1010,17 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
i = len(weights) - 1 if accelerator.is_main_process:
i = len(weights) - 1
while len(weights) > 0: while len(weights) > 0:
weights.pop() weights.pop()
model = models[i] model = models[i]
sub_dir = "controlnet" sub_dir = "controlnet"
model.save_pretrained(os.path.join(output_dir, sub_dir)) model.save_pretrained(os.path.join(output_dir, sub_dir))
i -= 1 i -= 1
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
while len(models) > 0: while len(models) > 0:
......
...@@ -552,14 +552,15 @@ def main(): ...@@ -552,14 +552,15 @@ def main():
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if args.use_ema: if accelerator.is_main_process:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models): for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet")) model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
......
...@@ -313,14 +313,15 @@ def main(args): ...@@ -313,14 +313,15 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if args.use_ema: if accelerator.is_main_process:
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema")) if args.use_ema:
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models): for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet")) model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
......
...@@ -629,14 +629,15 @@ def main(): ...@@ -629,14 +629,15 @@ def main():
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if args.use_ema: if accelerator.is_main_process:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models): for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet")) model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
......
...@@ -669,31 +669,32 @@ def main(args): ...@@ -669,31 +669,32 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
# there are only two options here. Either are just the unet attn processor layers if accelerator.is_main_process:
# or there are the unet and text encoder atten layers # there are only two options here. Either are just the unet attn processor layers
unet_lora_layers_to_save = None # or there are the unet and text encoder atten layers
text_encoder_one_lora_layers_to_save = None unet_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))): for model in models:
unet_lora_layers_to_save = unet_attn_processors_state_dict(model) if isinstance(model, type(accelerator.unwrap_model(unet))):
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
else: text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
raise ValueError(f"unexpected save model: {model.__class__}") else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
StableDiffusionXLPipeline.save_lora_weights( StableDiffusionXLPipeline.save_lora_weights(
output_dir, output_dir,
unet_lora_layers=unet_lora_layers_to_save, unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
) )
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
unet_ = None unet_ = None
......
...@@ -651,14 +651,15 @@ def main(args): ...@@ -651,14 +651,15 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if args.use_ema: if accelerator.is_main_process:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models): for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet")) model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
......
...@@ -309,14 +309,15 @@ def main(args): ...@@ -309,14 +309,15 @@ def main(args):
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if args.use_ema: if accelerator.is_main_process:
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema")) if args.use_ema:
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models): for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet")) model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment