Unverified Commit 76e2727b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[SANA LoRA] sana lora training tests and misc. (#10296)



* sana lora training tests and misc.

* remove push to hub

* Update examples/dreambooth/train_dreambooth_lora_sana.py
Co-authored-by: default avatarAryan <aryan@huggingface.co>

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 02c777c0
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
import safetensors
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRASANA(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
pretrained_model_name_or_path = "hf-internal-testing/tiny-sana-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_sana.py"
transformer_layer_type = "transformer_blocks.0.attn1.to_k"
def test_dreambooth_lora_sana(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# `self.transformer_layer_type` should be in the state dict.
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
--max_sequence_length 166
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--max_sequence_length 16
""".split()
resume_run_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
...@@ -943,7 +943,7 @@ def main(args): ...@@ -943,7 +943,7 @@ def main(args):
# Load scheduler and models # Load scheduler and models
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler" args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision
) )
noise_scheduler_copy = copy.deepcopy(noise_scheduler) noise_scheduler_copy = copy.deepcopy(noise_scheduler)
text_encoder = Gemma2Model.from_pretrained( text_encoder = Gemma2Model.from_pretrained(
...@@ -964,15 +964,6 @@ def main(args): ...@@ -964,15 +964,6 @@ def main(args):
vae.requires_grad_(False) vae.requires_grad_(False)
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = SanaPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=None,
transformer=None,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required. # as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32 weight_dtype = torch.float32
...@@ -993,6 +984,15 @@ def main(args): ...@@ -993,6 +984,15 @@ def main(args):
# because Gemma2 is particularly suited for bfloat16. # because Gemma2 is particularly suited for bfloat16.
text_encoder.to(dtype=torch.bfloat16) text_encoder.to(dtype=torch.bfloat16)
# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = SanaPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=None,
transformer=None,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
if args.gradient_checkpointing: if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing() transformer.enable_gradient_checkpointing()
...@@ -1182,6 +1182,7 @@ def main(args): ...@@ -1182,6 +1182,7 @@ def main(args):
) )
if args.offload: if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu") text_encoding_pipeline = text_encoding_pipeline.to("cpu")
prompt_embeds = prompt_embeds.to(transformer.dtype)
return prompt_embeds, prompt_attention_mask return prompt_embeds, prompt_attention_mask
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
...@@ -1216,7 +1217,7 @@ def main(args): ...@@ -1216,7 +1217,7 @@ def main(args):
vae_config_scaling_factor = vae.config.scaling_factor vae_config_scaling_factor = vae.config.scaling_factor
if args.cache_latents: if args.cache_latents:
latents_cache = [] latents_cache = []
vae = vae.to("cuda") vae = vae.to(accelerator.device)
for batch in tqdm(train_dataloader, desc="Caching latents"): for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad(): with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to( batch["pixel_values"] = batch["pixel_values"].to(
......
...@@ -16,7 +16,7 @@ import sys ...@@ -16,7 +16,7 @@ import sys
import unittest import unittest
import torch import torch
from transformers import Gemma2ForCausalLM, GemmaTokenizer from transformers import Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
...@@ -73,7 +73,7 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -73,7 +73,7 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
} }
vae_cls = AutoencoderDC vae_cls = AutoencoderDC
tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma" tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
text_encoder_cls, text_encoder_id = Gemma2ForCausalLM, "hf-internal-testing/dummy-gemma-for-diffusers" text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers"
@property @property
def output_shape(self): def output_shape(self):
...@@ -105,34 +105,34 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -105,34 +105,34 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in Sana.") @unittest.skip("Not supported in SANA.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@unittest.skip("Not supported in Mochi.") @unittest.skip("Not supported in SANA.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in Mochi.") @unittest.skip("Not supported in SANA.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.") @unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.") @unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.") @unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.") @unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_fused(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.") @unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
pass pass
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from transformers import Gemma2Config, Gemma2ForCausalLM, GemmaTokenizer from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
...@@ -101,7 +101,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -101,7 +101,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
torch.manual_seed(0) torch.manual_seed(0)
text_encoder_config = Gemma2Config( text_encoder_config = Gemma2Config(
head_dim=16, head_dim=16,
hidden_size=32, hidden_size=8,
initializer_range=0.02, initializer_range=0.02,
intermediate_size=64, intermediate_size=64,
max_position_embeddings=8192, max_position_embeddings=8192,
...@@ -112,7 +112,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -112,7 +112,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
vocab_size=8, vocab_size=8,
attn_implementation="eager", attn_implementation="eager",
) )
text_encoder = Gemma2ForCausalLM(text_encoder_config) text_encoder = Gemma2Model(text_encoder_config)
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma") tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
components = { components = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment