Unverified Commit 5f724735 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[training] add ds support to lora sd3. (#10378)



* add ds support to lora sd3.
Co-authored-by: default avatarleisuzz <jiangshuonb@gmail.com>

* style.

---------
Co-authored-by: default avatarleisuzz <jiangshuonb@gmail.com>
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent 01780c3c
......@@ -29,7 +29,7 @@ import numpy as np
import torch
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
......@@ -1292,11 +1292,17 @@ def main(args):
text_encoder_two_lora_layers_to_save = None
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
model = unwrap_model(model)
if args.upcast_before_saving:
model = model.to(torch.float32)
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two
elif args.train_text_encoder and isinstance(
unwrap_model(model), type(unwrap_model(text_encoder_one))
): # or text_encoder_two
# both text encoders are of the same class, so we check hidden size to distinguish between the two
hidden_size = unwrap_model(model).config.hidden_size
model = unwrap_model(model)
hidden_size = model.config.hidden_size
if hidden_size == 768:
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
elif hidden_size == 1280:
......@@ -1305,6 +1311,7 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
StableDiffusion3Pipeline.save_lora_weights(
......@@ -1319,18 +1326,32 @@ def main(args):
text_encoder_one_ = None
text_encoder_two_ = None
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_ = model
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
text_encoder_one_ = unwrap_model(model)
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
text_encoder_two_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = SD3Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer"
)
transformer_.add_adapter(transformer_lora_config)
if args.train_text_encoder:
text_encoder_one_ = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder"
)
text_encoder_two_ = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2"
)
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
......@@ -1829,7 +1850,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment