"src/turbomind/vscode:/vscode.git/clone" did not exist on "d5cb0be2cd16e6c5eefd4d266a38357fde83a660"
Unverified Commit 76b7d86a authored by SahilCarterr's avatar SahilCarterr Committed by GitHub
Browse files

Updated _encode_prompt_with_clip and encode_prompt in train_dreamboth_sd3 (#9800)



* updated encode prompt and clip encod prompt


---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent e2b3c248
...@@ -902,11 +902,13 @@ def _encode_prompt_with_clip( ...@@ -902,11 +902,13 @@ def _encode_prompt_with_clip(
tokenizer, tokenizer,
prompt: str, prompt: str,
device=None, device=None,
text_input_ids=None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer( text_inputs = tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -916,6 +918,10 @@ def _encode_prompt_with_clip( ...@@ -916,6 +918,10 @@ def _encode_prompt_with_clip(
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0] pooled_prompt_embeds = prompt_embeds[0]
...@@ -937,6 +943,7 @@ def encode_prompt( ...@@ -937,6 +943,7 @@ def encode_prompt(
max_sequence_length, max_sequence_length,
device=None, device=None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
text_input_ids_list=None,
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
...@@ -945,13 +952,14 @@ def encode_prompt( ...@@ -945,13 +952,14 @@ def encode_prompt(
clip_prompt_embeds_list = [] clip_prompt_embeds_list = []
clip_pooled_prompt_embeds_list = [] clip_pooled_prompt_embeds_list = []
for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders): for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoder, text_encoder=text_encoder,
tokenizer=tokenizer, tokenizer=tokenizer,
prompt=prompt, prompt=prompt,
device=device if device is not None else text_encoder.device, device=device if device is not None else text_encoder.device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
) )
clip_prompt_embeds_list.append(prompt_embeds) clip_prompt_embeds_list.append(prompt_embeds)
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment