"docs/vscode:/vscode.git/clone" did not exist on "fb2f74e2b9f1639c0db1464d56c906a7df3d6865"
Unverified Commit 4497b3ec authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Training] make DreamBooth SDXL LoRA training script compatible with torch.compile (#6483)

* make it torch.compile comaptible

* make the text encoder compatible too.

* style
parent fc63ebdd
......@@ -780,13 +780,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
text_input_ids = text_input_ids_list[i]
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds = prompt_embeds[-1][-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
......@@ -1429,7 +1428,8 @@ def main(args):
timesteps,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
).sample
return_dict=False,
)[0]
else:
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
prompt_embeds, pooled_prompt_embeds = encode_prompt(
......@@ -1443,8 +1443,12 @@ def main(args):
)
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet(
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
).sample
noisy_model_input,
timesteps,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
return_dict=False,
)[0]
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment