"tests/models/test_models_unet_2d.py" did not exist on "9727cda678abcfdf764a592aa35f5527d706ca34"
Unverified Commit d72184eb authored by Leo Jiang's avatar Leo Jiang Committed by GitHub
Browse files

[training] add ds support to lora hidream (#11737)



* [training] add ds support to lora hidream

* Apply style fixes

---------
Co-authored-by: default avatarJ石页 <jiangshuo9@h-partners.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 5ce4814a
......@@ -29,7 +29,7 @@ from pathlib import Path
import numpy as np
import torch
import transformers
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
......@@ -1181,12 +1181,14 @@ def main(args):
transformer_lora_layers_to_save = None
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
model = unwrap_model(model)
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
HiDreamImagePipeline.save_lora_weights(
......@@ -1197,13 +1199,20 @@ def main(args):
def load_model_hook(models, input_dir):
transformer_ = None
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
model = unwrap_model(model)
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = HiDreamImageTransformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer"
)
transformer_.add_adapter(transformer_lora_config)
lora_state_dict = HiDreamImagePipeline.lora_state_dict(input_dir)
......@@ -1655,7 +1664,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment