Unverified Commit 05faf326 authored by gzguevara's avatar gzguevara Committed by GitHub
Browse files

SDXL text-to-image torch compatible (#6550)

* torch compatible

* code quality fix

* ruff style

* ruff format
parent a080f0d3
...@@ -44,16 +44,12 @@ from tqdm.auto import tqdm ...@@ -44,16 +44,12 @@ from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig from transformers import AutoTokenizer, PretrainedConfig
import diffusers import diffusers
from diffusers import ( from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
AutoencoderKL,
DDPMScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel, compute_snr from diffusers.training_utils import EMAModel, compute_snr
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...@@ -508,11 +504,12 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca ...@@ -508,11 +504,12 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
prompt_embeds = text_encoder( prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device), text_input_ids.to(text_encoder.device),
output_hidden_states=True, output_hidden_states=True,
return_dict=False,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds[-1][-2]
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds) prompt_embeds_list.append(prompt_embeds)
...@@ -955,6 +952,12 @@ def main(args): ...@@ -955,6 +952,12 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args)) accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args))
# Function for unwraping if torch.compile() was used in accelerate.
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# Train! # Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
...@@ -1054,8 +1057,12 @@ def main(args): ...@@ -1054,8 +1057,12 @@ def main(args):
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
model_pred = unet( model_pred = unet(
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions noisy_model_input,
).sample timesteps,
prompt_embeds,
added_cond_kwargs=unet_added_conditions,
return_dict=False,
)[0]
# Get the target for loss depending on the prediction type # Get the target for loss depending on the prediction type
if args.prediction_type is not None: if args.prediction_type is not None:
...@@ -1206,7 +1213,7 @@ def main(args): ...@@ -1206,7 +1213,7 @@ def main(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet) unet = unwrap_model(unet)
if args.use_ema: if args.use_ema:
ema_unet.copy_to(unet.parameters()) ema_unet.copy_to(unet.parameters())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment