Unverified Commit be0b4257 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Training] make checkpointing compatible when using `torch.compile` (part II) (#6511)

make checkpointing compatible when using torch.compile.
parent da843b3d
...@@ -56,6 +56,7 @@ from diffusers.optimization import get_scheduler ...@@ -56,6 +56,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...@@ -1007,6 +1008,11 @@ def main(args): ...@@ -1007,6 +1008,11 @@ def main(args):
if param.requires_grad: if param.requires_grad:
param.data = param.to(torch.float32) param.data = param.to(torch.float32)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process: if accelerator.is_main_process:
...@@ -1017,13 +1023,13 @@ def main(args): ...@@ -1017,13 +1023,13 @@ def main(args):
text_encoder_two_lora_layers_to_save = None text_encoder_two_lora_layers_to_save = None
for model in models: for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))): if isinstance(model, type(unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model) get_peft_model_state_dict(model)
) )
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model) get_peft_model_state_dict(model)
) )
...@@ -1048,11 +1054,11 @@ def main(args): ...@@ -1048,11 +1054,11 @@ def main(args):
while len(models) > 0: while len(models) > 0:
model = models.pop() model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(unet))): if isinstance(model, type(unwrap_model(unet))):
unet_ = model unet_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model text_encoder_one_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_ = model text_encoder_two_ = model
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
...@@ -1621,16 +1627,16 @@ def main(args): ...@@ -1621,16 +1627,16 @@ def main(args):
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet) unet = unwrap_model(unet)
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one) text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_lora_layers = convert_state_dict_to_diffusers( text_encoder_lora_layers = convert_state_dict_to_diffusers(
get_peft_model_state_dict(text_encoder_one.to(torch.float32)) get_peft_model_state_dict(text_encoder_one.to(torch.float32))
) )
text_encoder_two = accelerator.unwrap_model(text_encoder_two) text_encoder_two = unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = convert_state_dict_to_diffusers( text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
get_peft_model_state_dict(text_encoder_two.to(torch.float32)) get_peft_model_state_dict(text_encoder_two.to(torch.float32))
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment