"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b389f339ec016cb83f0975c1c9cc0d7965e411f8"
Unverified Commit 4497b3ec authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Training] make DreamBooth SDXL LoRA training script compatible with torch.compile (#6483)

* make it torch.compile comaptible

* make the text encoder compatible too.

* style
parent fc63ebdd
...@@ -780,13 +780,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): ...@@ -780,13 +780,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
text_input_ids = text_input_ids_list[i] text_input_ids = text_input_ids_list[i]
prompt_embeds = text_encoder( prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device), text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds[-1][-2]
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds) prompt_embeds_list.append(prompt_embeds)
...@@ -1429,7 +1428,8 @@ def main(args): ...@@ -1429,7 +1428,8 @@ def main(args):
timesteps, timesteps,
prompt_embeds_input, prompt_embeds_input,
added_cond_kwargs=unet_added_conditions, added_cond_kwargs=unet_added_conditions,
).sample return_dict=False,
)[0]
else: else:
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
prompt_embeds, pooled_prompt_embeds = encode_prompt( prompt_embeds, pooled_prompt_embeds = encode_prompt(
...@@ -1443,8 +1443,12 @@ def main(args): ...@@ -1443,8 +1443,12 @@ def main(args):
) )
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet( model_pred = unet(
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions noisy_model_input,
).sample timesteps,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
return_dict=False,
)[0]
# Get the target for loss depending on the prediction type # Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon": if noise_scheduler.config.prediction_type == "epsilon":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment