"git@developer.sourcefind.cn:change/sglang.git" did not exist on "7b69d91b4f94a73f6b8fa3a86de3a910a16dc645"
Unverified Commit 5f724735 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[training] add ds support to lora sd3. (#10378)



* add ds support to lora sd3.
Co-authored-by: default avatarleisuzz <jiangshuonb@gmail.com>

* style.

---------
Co-authored-by: default avatarleisuzz <jiangshuonb@gmail.com>
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent 01780c3c
...@@ -29,7 +29,7 @@ import numpy as np ...@@ -29,7 +29,7 @@ import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, upload_folder
...@@ -1292,11 +1292,17 @@ def main(args): ...@@ -1292,11 +1292,17 @@ def main(args):
text_encoder_two_lora_layers_to_save = None text_encoder_two_lora_layers_to_save = None
for model in models: for model in models:
if isinstance(model, type(unwrap_model(transformer))): if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
model = unwrap_model(model)
if args.upcast_before_saving:
model = model.to(torch.float32)
transformer_lora_layers_to_save = get_peft_model_state_dict(model) transformer_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two elif args.train_text_encoder and isinstance(
unwrap_model(model), type(unwrap_model(text_encoder_one))
): # or text_encoder_two
# both text encoders are of the same class, so we check hidden size to distinguish between the two # both text encoders are of the same class, so we check hidden size to distinguish between the two
hidden_size = unwrap_model(model).config.hidden_size model = unwrap_model(model)
hidden_size = model.config.hidden_size
if hidden_size == 768: if hidden_size == 768:
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
elif hidden_size == 1280: elif hidden_size == 1280:
...@@ -1305,6 +1311,7 @@ def main(args): ...@@ -1305,6 +1311,7 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop() weights.pop()
StableDiffusion3Pipeline.save_lora_weights( StableDiffusion3Pipeline.save_lora_weights(
...@@ -1319,18 +1326,32 @@ def main(args): ...@@ -1319,18 +1326,32 @@ def main(args):
text_encoder_one_ = None text_encoder_one_ = None
text_encoder_two_ = None text_encoder_two_ = None
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
while len(models) > 0: while len(models) > 0:
model = models.pop() model = models.pop()
if isinstance(model, type(unwrap_model(transformer))): if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = model transformer_ = unwrap_model(model)
elif isinstance(model, type(unwrap_model(text_encoder_one))): elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model text_encoder_one_ = unwrap_model(model)
elif isinstance(model, type(unwrap_model(text_encoder_two))): elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
text_encoder_two_ = model text_encoder_two_ = unwrap_model(model)
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = SD3Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer"
)
transformer_.add_adapter(transformer_lora_config)
if args.train_text_encoder:
text_encoder_one_ = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder"
)
text_encoder_two_ = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2"
)
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
...@@ -1829,7 +1850,7 @@ def main(args): ...@@ -1829,7 +1850,7 @@ def main(args):
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
if accelerator.is_main_process: if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit` # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None: if args.checkpoints_total_limit is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment