"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "37a4b2219f8f344d3158f3c5ce52ca7fa49fed75"
Unverified Commit 7ce89e97 authored by Vinh H. Pham's avatar Vinh H. Pham Committed by GitHub
Browse files

Make text-to-image SD LoRA Training Script torch.compile compatible (#6555)

make compile compatible
parent 05faf326
...@@ -46,6 +46,7 @@ from diffusers.optimization import get_scheduler ...@@ -46,6 +46,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params, compute_snr from diffusers.training_utils import cast_training_params, compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...@@ -596,6 +597,11 @@ def main(): ...@@ -596,6 +597,11 @@ def main():
] ]
) )
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
def preprocess_train(examples): def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]] images = [image.convert("RGB") for image in examples[image_column]]
examples["pixel_values"] = [train_transforms(image) for image in images] examples["pixel_values"] = [train_transforms(image) for image in images]
...@@ -729,7 +735,7 @@ def main(): ...@@ -729,7 +735,7 @@ def main():
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning # Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0] encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
# Get the target for loss depending on the prediction type # Get the target for loss depending on the prediction type
if args.prediction_type is not None: if args.prediction_type is not None:
...@@ -744,7 +750,7 @@ def main(): ...@@ -744,7 +750,7 @@ def main():
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss # Predict the noise residual and compute loss
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
if args.snr_gamma is None: if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
...@@ -809,7 +815,7 @@ def main(): ...@@ -809,7 +815,7 @@ def main():
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
unwrapped_unet = accelerator.unwrap_model(unet) unwrapped_unet = unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers( unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unwrapped_unet) get_peft_model_state_dict(unwrapped_unet)
) )
...@@ -837,7 +843,7 @@ def main(): ...@@ -837,7 +843,7 @@ def main():
# create pipeline # create pipeline
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet), unet=unwrap_model(unet),
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
...@@ -878,7 +884,7 @@ def main(): ...@@ -878,7 +884,7 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unwrapped_unet = accelerator.unwrap_model(unet) unwrapped_unet = unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet)) unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
StableDiffusionPipeline.save_lora_weights( StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir, save_directory=args.output_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment