Unverified Commit 60cb4432 authored by Vinh H. Pham's avatar Vinh H. Pham Committed by GitHub
Browse files

Make Dreambooth SD LoRA Training Script torch.compile compatible (#6534)

support compile
parent 1dd0ac94
......@@ -56,6 +56,7 @@ from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
......@@ -647,6 +648,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
prompt_embeds = text_encoder(
text_input_ids,
attention_mask=attention_mask,
return_dict=False,
)
prompt_embeds = prompt_embeds[0]
......@@ -843,6 +845,11 @@ def main(args):
)
text_encoder.add_adapter(text_lora_config)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
......@@ -852,9 +859,9 @@ def main(args):
text_encoder_lora_layers_to_save = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
if isinstance(model, type(unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
elif isinstance(model, type(unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
......@@ -877,9 +884,9 @@ def main(args):
while len(models) > 0:
model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(unet))):
if isinstance(model, type(unwrap_model(unet))):
unet_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
elif isinstance(model, type(unwrap_model(text_encoder))):
text_encoder_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
......@@ -1118,7 +1125,7 @@ def main(args):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
if unwrap_model(unet).config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
if args.class_labels_conditioning == "timesteps":
......@@ -1128,8 +1135,12 @@ def main(args):
# Predict the noise residual
model_pred = unet(
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
).sample
noisy_model_input,
timesteps,
encoder_hidden_states,
class_labels=class_labels,
return_dict=False,
)[0]
# if model predicts variance, throw away the prediction. we will only train on the
# simplified training objective. This means that all schedulers using the fine tuned
......@@ -1215,8 +1226,8 @@ def main(args):
# create pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder),
unet=unwrap_model(unet),
text_encoder=None if args.pre_compute_text_embeddings else unwrap_model(text_encoder),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
......@@ -1284,13 +1295,13 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
if args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder = unwrap_model(text_encoder)
text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))
else:
text_encoder_state_dict = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment