Unverified Commit 08702fc1 authored by Vinh H. Pham's avatar Vinh H. Pham Committed by GitHub
Browse files

Make text-to-image SDXL LoRA Training Script torch.compile compatible (#6556)

make compile compatible
parent 7ce89e97
...@@ -54,6 +54,7 @@ from diffusers.optimization import get_scheduler ...@@ -54,6 +54,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params, compute_snr from diffusers.training_utils import cast_training_params, compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...@@ -460,13 +461,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): ...@@ -460,13 +461,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
text_input_ids = text_input_ids_list[i] text_input_ids = text_input_ids_list[i]
prompt_embeds = text_encoder( prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device), text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds[-1][-2]
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds) prompt_embeds_list.append(prompt_embeds)
...@@ -637,6 +637,11 @@ def main(args): ...@@ -637,6 +637,11 @@ def main(args):
# only upcast trainable parameters (LoRA) into fp32 # only upcast trainable parameters (LoRA) into fp32
cast_training_params(models, dtype=torch.float32) cast_training_params(models, dtype=torch.float32)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process: if accelerator.is_main_process:
...@@ -647,13 +652,13 @@ def main(args): ...@@ -647,13 +652,13 @@ def main(args):
text_encoder_two_lora_layers_to_save = None text_encoder_two_lora_layers_to_save = None
for model in models: for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))): if isinstance(model, type(unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model) get_peft_model_state_dict(model)
) )
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model) get_peft_model_state_dict(model)
) )
...@@ -678,11 +683,11 @@ def main(args): ...@@ -678,11 +683,11 @@ def main(args):
while len(models) > 0: while len(models) > 0:
model = models.pop() model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(unet))): if isinstance(model, type(unwrap_model(unet))):
unet_ = model unet_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model text_encoder_one_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_ = model text_encoder_two_ = model
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
...@@ -1031,8 +1036,12 @@ def main(args): ...@@ -1031,8 +1036,12 @@ def main(args):
) )
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
model_pred = unet( model_pred = unet(
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions noisy_model_input,
).sample timesteps,
prompt_embeds,
added_cond_kwargs=unet_added_conditions,
return_dict=False,
)[0]
# Get the target for loss depending on the prediction type # Get the target for loss depending on the prediction type
if args.prediction_type is not None: if args.prediction_type is not None:
...@@ -1125,9 +1134,9 @@ def main(args): ...@@ -1125,9 +1134,9 @@ def main(args):
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae, vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one), text_encoder=unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two), text_encoder_2=unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet), unet=unwrap_model(unet),
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
...@@ -1166,12 +1175,12 @@ def main(args): ...@@ -1166,12 +1175,12 @@ def main(args):
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet) unet = unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one) text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_two = accelerator.unwrap_model(text_encoder_two) text_encoder_two = unwrap_model(text_encoder_two)
text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one)) text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two)) text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment