Unverified Commit 6683f979 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Training] Add `datasets` version of LCM LoRA SDXL (#5778)

* add: script to train lcm lora for sdxl with 🤗

 datasets

* suit up the args.

* remove comments.

* fix num_update_steps

* fix batch unmarshalling

* fix num_update_steps_per_epoch

* fix; dataloading.

* fix microconditions.

* unconditional predictions debug

* fix batch size.

* no need to use use_auth_token

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* make vae encoding batch size an arg

* final serialization in kohya

* style

* state dict rejigging

* feat: no separate teacher unet.

* debug

* fix state dict serialization

* debug

* debug

* debug

* remove prints.

* remove kohya utility and make style

* fix serialization

* fix

* add test

* add peft dependency.

* add: peft

* remove peft

* autocast device determination from accelerator

* autocast

* reduce lora rank.

* remove unneeded space

* Apply suggestions...
parent 4e7b0cb3
......@@ -161,6 +161,8 @@ tags:
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
widget:
- text: '{validation_prompt if validation_prompt else instance_prompt}'
---
"""
......
......@@ -112,3 +112,37 @@ accelerate launch train_lcm_distill_lora_sdxl_wds.py \
--seed=453645634 \
--push_to_hub \
```
We provide another version for LCM LoRA SDXL that follows best practices of `peft` and leverages the `datasets` library for quick experimentation. The script doesn't load two UNets unlike `train_lcm_distill_lora_sdxl_wds.py` which reduces the memory requirements quite a bit.
Below is an example training command that trains an LCM LoRA on the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions):
```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
accelerate launch train_lcm_distill_lora_sdxl.py \
--pretrained_teacher_model=${MODEL_NAME} \
--pretrained_vae_model_name_or_path=${VAE_PATH} \
--output_dir="pokemons-lora-lcm-sdxl" \
--mixed_precision="fp16" \
--dataset_name=$DATASET_NAME \
--resolution=1024 \
--train_batch_size=24 \
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--use_8bit_adam \
--lora_rank=64 \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=3000 \
--checkpointing_steps=500 \
--validation_steps=50 \
--seed="0" \
--report_to="wandb" \
--push_to_hub
```
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
import safetensors
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class TextToImageLCM(ExamplesTestsAccelerate):
def test_text_to_image_lcm_lora_sdxl(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--lora_rank 4
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
def test_text_to_image_lcm_lora_sdxl_checkpointing(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--lora_rank 4
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--checkpointing_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
)
test_args = f"""
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--lora_rank 4
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 9
--checkpointing_steps 2
--resume_from_checkpoint latest
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment