Unverified Commit 7ce89e97 authored by Vinh H. Pham's avatar Vinh H. Pham Committed by GitHub
Browse files

Make text-to-image SD LoRA Training Script torch.compile compatible (#6555)

make compile compatible
parent 05faf326
......@@ -46,6 +46,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params, compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
......@@ -596,6 +597,11 @@ def main():
]
)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
examples["pixel_values"] = [train_transforms(image) for image in images]
......@@ -729,7 +735,7 @@ def main():
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
# Get the target for loss depending on the prediction type
if args.prediction_type is not None:
......@@ -744,7 +750,7 @@ def main():
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
......@@ -809,7 +815,7 @@ def main():
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
unwrapped_unet = accelerator.unwrap_model(unet)
unwrapped_unet = unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unwrapped_unet)
)
......@@ -837,7 +843,7 @@ def main():
# create pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
......@@ -878,7 +884,7 @@ def main():
if accelerator.is_main_process:
unet = unet.to(torch.float32)
unwrapped_unet = accelerator.unwrap_model(unet)
unwrapped_unet = unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment