Unverified Commit c9779665 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[Dreambooth flux] bug fix for dreambooth script (align with dreambooth lora) (#9257)

* fix shape

* fix prompt encoding

* style

* fix device

* add comment
parent 1ca0a755
...@@ -842,7 +842,7 @@ class PromptDataset(Dataset): ...@@ -842,7 +842,7 @@ class PromptDataset(Dataset):
return example return example
def tokenize_prompt(tokenizer, prompt, max_sequence_length=512): def tokenize_prompt(tokenizer, prompt, max_sequence_length):
text_inputs = tokenizer( text_inputs = tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -863,20 +863,26 @@ def _encode_prompt_with_t5( ...@@ -863,20 +863,26 @@ def _encode_prompt_with_t5(
prompt=None, prompt=None,
num_images_per_prompt=1, num_images_per_prompt=1,
device=None, device=None,
text_input_ids=None,
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
text_inputs = tokenizer( if tokenizer is not None:
prompt, text_inputs = tokenizer(
padding="max_length", prompt,
max_length=max_sequence_length, padding="max_length",
truncation=True, max_length=max_sequence_length,
return_length=False, truncation=True,
return_overflowing_tokens=False, return_length=False,
return_tensors="pt", return_overflowing_tokens=False,
) return_tensors="pt",
text_input_ids = text_inputs.input_ids )
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype dtype = text_encoder.dtype
...@@ -896,22 +902,28 @@ def _encode_prompt_with_clip( ...@@ -896,22 +902,28 @@ def _encode_prompt_with_clip(
tokenizer, tokenizer,
prompt: str, prompt: str,
device=None, device=None,
text_input_ids=None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
text_inputs = tokenizer( if tokenizer is not None:
prompt, text_inputs = tokenizer(
padding="max_length", prompt,
max_length=77, padding="max_length",
truncation=True, max_length=77,
return_overflowing_tokens=False, truncation=True,
return_length=False, return_overflowing_tokens=False,
return_tensors="pt", return_length=False,
) return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel # Use pooled output of CLIPTextModel
...@@ -932,17 +944,19 @@ def encode_prompt( ...@@ -932,17 +944,19 @@ def encode_prompt(
max_sequence_length, max_sequence_length,
device=None, device=None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
text_input_ids_list=None,
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
dtype = text_encoders[0].dtype dtype = text_encoders[0].dtype
device = device if device is not None else text_encoders[1].device
pooled_prompt_embeds = _encode_prompt_with_clip( pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0], text_encoder=text_encoders[0],
tokenizer=tokenizers[0], tokenizer=tokenizers[0],
prompt=prompt, prompt=prompt,
device=device if device is not None else text_encoders[0].device, device=device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
) )
prompt_embeds = _encode_prompt_with_t5( prompt_embeds = _encode_prompt_with_t5(
...@@ -951,7 +965,8 @@ def encode_prompt( ...@@ -951,7 +965,8 @@ def encode_prompt(
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
prompt=prompt, prompt=prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[1].device, device=device,
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
) )
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
...@@ -1499,7 +1514,25 @@ def main(args): ...@@ -1499,7 +1514,25 @@ def main(args):
) )
else: else:
tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)
tokens_two = tokenize_prompt(tokenizer_two, prompts, max_sequence_length=512) tokens_two = tokenize_prompt(
tokenizer_two, prompts, max_sequence_length=args.max_sequence_length
)
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two],
max_sequence_length=args.max_sequence_length,
prompt=prompts,
)
else:
if args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two],
max_sequence_length=args.max_sequence_length,
prompt=args.instance_prompt,
)
# Convert images to latent space # Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample() model_input = vae.encode(pixel_values).latent_dist.sample()
...@@ -1553,41 +1586,22 @@ def main(args): ...@@ -1553,41 +1586,22 @@ def main(args):
guidance = None guidance = None
# Predict the noise residual # Predict the noise residual
if not args.train_text_encoder: model_pred = transformer(
model_pred = transformer( hidden_states=packed_noisy_model_input,
hidden_states=packed_noisy_model_input, # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=timesteps / 1000,
timestep=timesteps / 1000, guidance=guidance,
guidance=guidance, pooled_projections=pooled_prompt_embeds,
pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds,
encoder_hidden_states=prompt_embeds, txt_ids=text_ids,
txt_ids=text_ids, img_ids=latent_image_ids,
img_ids=latent_image_ids, return_dict=False,
return_dict=False, )[0]
)[0] # upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042
else:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None,
prompt=None,
text_input_ids_list=[tokens_one, tokens_two],
)
model_pred = transformer(
hidden_states=packed_noisy_model_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
model_pred = FluxPipeline._unpack_latents( model_pred = FluxPipeline._unpack_latents(
model_pred, model_pred,
height=int(model_input.shape[2]), height=int(model_input.shape[2] * vae_scale_factor / 2),
width=int(model_input.shape[3]), width=int(model_input.shape[3] * vae_scale_factor / 2),
vae_scale_factor=vae_scale_factor, vae_scale_factor=vae_scale_factor,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment