Unverified Commit 413ca29b authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[Flux Dreambooth LoRA] - te bug fixes & updates (#9139)

* add requirements + fix link to bghira's guide

* text ecnoder training fixes

* text encoder training fixes

* text encoder training fixes

* text encoder training fixes

* style

* add tests

* fix encode_prompt call

* style

* unpack_latents test

* fix lora saving

* remove default val for max_sequenece_length in encode_prompt

* remove default val for max_sequenece_length in encode_prompt

* style

* testing

* style

* testing

* testing

* style

* fix sizing issue

* style

* revert scaling

* style

* style

* scaling test

* style

* scaling test

* remove model pred operation left from pre-conditioning

* remove model pred operation left from pre-conditioning

* fix trainable params

* remove te2 from casting

* transformer to accelerator

* remove prints

* empty commit
parent 10dc06c8
...@@ -8,7 +8,7 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced ...@@ -8,7 +8,7 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced
> >
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - > Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training. > a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training.
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](documentation/quickstart/FLUX.md) > For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)
> [!NOTE] > [!NOTE]
...@@ -96,7 +96,7 @@ accelerate launch train_dreambooth_flux.py \ ...@@ -96,7 +96,7 @@ accelerate launch train_dreambooth_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \ --instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \ --output_dir=$OUTPUT_DIR \
--mixed_precision="fp16" \ --mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \ --instance_prompt="a photo of sks dog" \
--resolution=1024 \ --resolution=1024 \
--train_batch_size=1 \ --train_batch_size=1 \
...@@ -140,7 +140,7 @@ accelerate launch train_dreambooth_lora_flux.py \ ...@@ -140,7 +140,7 @@ accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \ --instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \ --output_dir=$OUTPUT_DIR \
--mixed_precision="fp16" \ --mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \ --instance_prompt="a photo of sks dog" \
--resolution=512 \ --resolution=512 \
--train_batch_size=1 \ --train_batch_size=1 \
...@@ -175,7 +175,7 @@ accelerate launch train_dreambooth_lora_flux.py \ ...@@ -175,7 +175,7 @@ accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \ --instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \ --output_dir=$OUTPUT_DIR \
--mixed_precision="fp16" \ --mixed_precision="bf16" \
--train_text_encoder\ --train_text_encoder\
--instance_prompt="a photo of sks dog" \ --instance_prompt="a photo of sks dog" \
--resolution=512 \ --resolution=512 \
......
accelerate>=0.31.0
torchvision
transformers>=4.41.2
ftfy
tensorboard
Jinja2
peft>=0.11.1
sentencepiece
\ No newline at end of file
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
import sys
import tempfile
from diffusers import DiffusionPipeline, FluxTransformer2DModel
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothFlux(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
script_path = "examples/dreambooth/train_dreambooth_flux.py"
def test_dreambooth(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_dreambooth_checkpointing(self):
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 4, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
# check can run the original fully trained output pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir)
pipe(self.instance_prompt, num_inference_steps=1)
# check checkpoint directories exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
# check can run an intermediate checkpoint
transformer = FluxTransformer2DModel.from_pretrained(tmpdir, subfolder="checkpoint-2/transformer")
pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer)
pipe(self.instance_prompt, num_inference_steps=1)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
# Run training script for 7 total steps resuming from checkpoint 4
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 6
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--seed=0
""".split()
run_command(self._launch_args + resume_run_args)
# check can run new fully trained pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir)
pipe(self.instance_prompt, num_inference_steps=1)
# check old checkpoints do not exist
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
# check new checkpoints exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
def test_dreambooth_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4"},
)
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
import safetensors
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_flux.py"
def test_dreambooth_lora_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_text_encoder_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--train_text_encoder
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
starts_with_expected_prefix = all(
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_expected_prefix)
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
...@@ -1505,6 +1505,9 @@ def main(args): ...@@ -1505,6 +1505,9 @@ def main(args):
model_input = vae.encode(pixel_values).latent_dist.sample() model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
model_input = model_input.to(dtype=weight_dtype) model_input = model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
latent_image_ids = FluxPipeline._prepare_latent_image_ids( latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0], model_input.shape[0],
model_input.shape[2], model_input.shape[2],
...@@ -1583,16 +1586,11 @@ def main(args): ...@@ -1583,16 +1586,11 @@ def main(args):
model_pred = FluxPipeline._unpack_latents( model_pred = FluxPipeline._unpack_latents(
model_pred, model_pred,
height=int(model_input.shape[2]) * 8, height=int(model_input.shape[2]),
width=int(model_input.shape[3]) * 8, width=int(model_input.shape[3]),
vae_scale_factor=2 vae_scale_factor=vae_scale_factor,
** (
len(vae.config.block_out_channels)
), # should this be 2 ** (len(vae.config.block_out_channels))?
) )
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling # these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss # and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
......
...@@ -319,7 +319,7 @@ def parse_args(input_args=None): ...@@ -319,7 +319,7 @@ def parse_args(input_args=None):
parser.add_argument( parser.add_argument(
"--max_sequence_length", "--max_sequence_length",
type=int, type=int,
default=77, default=512,
help="Maximum sequence length to use with with the T5 text encoder", help="Maximum sequence length to use with with the T5 text encoder",
) )
parser.add_argument( parser.add_argument(
...@@ -864,7 +864,7 @@ class PromptDataset(Dataset): ...@@ -864,7 +864,7 @@ class PromptDataset(Dataset):
return example return example
def tokenize_prompt(tokenizer, prompt, max_sequence_length=512): def tokenize_prompt(tokenizer, prompt, max_sequence_length):
text_inputs = tokenizer( text_inputs = tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -885,20 +885,26 @@ def _encode_prompt_with_t5( ...@@ -885,20 +885,26 @@ def _encode_prompt_with_t5(
prompt=None, prompt=None,
num_images_per_prompt=1, num_images_per_prompt=1,
device=None, device=None,
text_input_ids=None,
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
text_inputs = tokenizer( if tokenizer is not None:
prompt, text_inputs = tokenizer(
padding="max_length", prompt,
max_length=max_sequence_length, padding="max_length",
truncation=True, max_length=max_sequence_length,
return_length=False, truncation=True,
return_overflowing_tokens=False, return_length=False,
return_tensors="pt", return_overflowing_tokens=False,
) return_tensors="pt",
text_input_ids = text_inputs.input_ids )
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype dtype = text_encoder.dtype
...@@ -918,22 +924,28 @@ def _encode_prompt_with_clip( ...@@ -918,22 +924,28 @@ def _encode_prompt_with_clip(
tokenizer, tokenizer,
prompt: str, prompt: str,
device=None, device=None,
text_input_ids=None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
text_inputs = tokenizer( if tokenizer is not None:
prompt, text_inputs = tokenizer(
padding="max_length", prompt,
max_length=77, padding="max_length",
truncation=True, max_length=77,
return_overflowing_tokens=False, truncation=True,
return_length=False, return_overflowing_tokens=False,
return_tensors="pt", return_length=False,
) return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel # Use pooled output of CLIPTextModel
...@@ -954,6 +966,7 @@ def encode_prompt( ...@@ -954,6 +966,7 @@ def encode_prompt(
max_sequence_length, max_sequence_length,
device=None, device=None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
text_input_ids_list=None,
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
...@@ -965,6 +978,7 @@ def encode_prompt( ...@@ -965,6 +978,7 @@ def encode_prompt(
prompt=prompt, prompt=prompt,
device=device if device is not None else text_encoders[0].device, device=device if device is not None else text_encoders[0].device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
) )
prompt_embeds = _encode_prompt_with_t5( prompt_embeds = _encode_prompt_with_t5(
...@@ -974,6 +988,7 @@ def encode_prompt( ...@@ -974,6 +988,7 @@ def encode_prompt(
prompt=prompt, prompt=prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[1].device, device=device if device is not None else text_encoders[1].device,
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
) )
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
...@@ -1127,14 +1142,11 @@ def main(args): ...@@ -1127,14 +1142,11 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
) )
# We only train the additional adapter LoRA layers
transformer.requires_grad_(False) transformer.requires_grad_(False)
vae.requires_grad_(False) vae.requires_grad_(False)
if args.train_text_encoder: text_encoder_one.requires_grad_(False)
text_encoder_one.requires_grad_(True) text_encoder_two.requires_grad_(False)
text_encoder_two.requires_grad_(False)
else:
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required. # as these weights are only used for inference, keeping weights in full precision is not required.
...@@ -1151,9 +1163,9 @@ def main(args): ...@@ -1151,9 +1163,9 @@ def main(args):
) )
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
if not args.train_text_encoder: transformer.to(accelerator.device, dtype=weight_dtype)
text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing: if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing() transformer.enable_gradient_checkpointing()
...@@ -1168,6 +1180,14 @@ def main(args): ...@@ -1168,6 +1180,14 @@ def main(args):
target_modules=["to_k", "to_q", "to_v", "to_out.0"], target_modules=["to_k", "to_q", "to_v", "to_out.0"],
) )
transformer.add_adapter(transformer_lora_config) transformer.add_adapter(transformer_lora_config)
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder_one.add_adapter(text_lora_config)
def unwrap_model(model): def unwrap_model(model):
model = accelerator.unwrap_model(model) model = accelerator.unwrap_model(model)
...@@ -1257,15 +1277,16 @@ def main(args): ...@@ -1257,15 +1277,16 @@ def main(args):
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
models = [transformer] models = [transformer]
if args.train_text_encoder: if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two]) models.extend([text_encoder_one])
# only upcast trainable parameters (LoRA) into fp32 # only upcast trainable parameters (LoRA) into fp32
cast_training_params(models, dtype=torch.float32) cast_training_params(models, dtype=torch.float32)
# Optimization parameters transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
transformer_parameters_with_lr = {"params": transformer.parameters(), "lr": args.learning_rate}
if args.train_text_encoder: if args.train_text_encoder:
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
# Optimization parameters
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
if args.train_text_encoder: if args.train_text_encoder:
# different learning rate for text encoder and unet # different learning rate for text encoder and unet
text_parameters_one_with_lr = { text_parameters_one_with_lr = {
...@@ -1420,14 +1441,18 @@ def main(args): ...@@ -1420,14 +1441,18 @@ def main(args):
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
text_ids = torch.cat([text_ids, class_text_ids], dim=0) text_ids = torch.cat([text_ids, class_text_ids], dim=0)
# if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts)
# batch prompts on all training steps # we need to tokenize and encode the batch prompts on all training steps
else: else:
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77) tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77)
tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt, max_sequence_length=512) tokens_two = tokenize_prompt(
tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length
)
if args.with_prior_preservation: if args.with_prior_preservation:
class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77) class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77)
class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt, max_sequence_length=512) class_tokens_two = tokenize_prompt(
tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length
)
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
...@@ -1545,6 +1570,8 @@ def main(args): ...@@ -1545,6 +1570,8 @@ def main(args):
transformer.train() transformer.train()
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one.train() text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer] models_to_accumulate = [transformer]
...@@ -1562,12 +1589,33 @@ def main(args): ...@@ -1562,12 +1589,33 @@ def main(args):
) )
else: else:
tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)
tokens_two = tokenize_prompt(tokenizer_two, prompts, max_sequence_length=512) tokens_two = tokenize_prompt(
tokenizer_two, prompts, max_sequence_length=args.max_sequence_length
)
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two],
max_sequence_length=args.max_sequence_length,
prompt=prompts,
)
else:
if args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two],
max_sequence_length=args.max_sequence_length,
prompt=args.instance_prompt,
)
# Convert images to latent space # Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample() model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
model_input = model_input.to(dtype=weight_dtype) model_input = model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
latent_image_ids = FluxPipeline._prepare_latent_image_ids( latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0], model_input.shape[0],
model_input.shape[2], model_input.shape[2],
...@@ -1575,7 +1623,6 @@ def main(args): ...@@ -1575,7 +1623,6 @@ def main(args):
accelerator.device, accelerator.device,
weight_dtype, weight_dtype,
) )
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(model_input) noise = torch.randn_like(model_input)
bsz = model_input.shape[0] bsz = model_input.shape[0]
...@@ -1613,49 +1660,24 @@ def main(args): ...@@ -1613,49 +1660,24 @@ def main(args):
guidance = None guidance = None
# Predict the noise residual # Predict the noise residual
if not args.train_text_encoder: model_pred = transformer(
model_pred = transformer( hidden_states=packed_noisy_model_input,
hidden_states=packed_noisy_model_input, # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=timesteps / 1000,
timestep=timesteps / 1000, guidance=guidance,
guidance=guidance, pooled_projections=pooled_prompt_embeds,
pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds,
encoder_hidden_states=prompt_embeds, txt_ids=text_ids,
txt_ids=text_ids, img_ids=latent_image_ids,
img_ids=latent_image_ids, return_dict=False,
return_dict=False, )[0]
)[0]
else:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None,
prompt=None,
text_input_ids_list=[tokens_one, tokens_two],
)
model_pred = transformer(
hidden_states=packed_noisy_model_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
model_pred = FluxPipeline._unpack_latents( model_pred = FluxPipeline._unpack_latents(
model_pred, model_pred,
height=int(model_input.shape[2]) * 8, height=int(model_input.shape[2] * vae_scale_factor / 2),
width=int(model_input.shape[3]) * 8, width=int(model_input.shape[3] * vae_scale_factor / 2),
vae_scale_factor=2 vae_scale_factor=vae_scale_factor,
** (
len(vae.config.block_out_channels)
), # should this be 2 ** (len(vae.config.block_out_channels))?
) )
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling # these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss # and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
...@@ -1783,7 +1805,7 @@ def main(args): ...@@ -1783,7 +1805,7 @@ def main(args):
FluxPipeline.save_lora_weights( FluxPipeline.save_lora_weights(
save_directory=args.output_dir, save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers, transformer_lora_layers=transformer_lora_layers,
text_encoder_one_lora_layers=text_encoder_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers,
) )
# Final inference # Final inference
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment