Unverified Commit e031caf4 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[flux lora training] fix t5 training bug (#10845)



* fix t5 training bug

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 08f74a8b
......@@ -880,9 +880,7 @@ class TokenEmbeddingsHandler:
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
for idx, text_encoder in enumerate(self.text_encoders):
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
embeds = (
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
)
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
new_token_embeddings = embeds.weight.data[train_ids]
......@@ -904,9 +902,7 @@ class TokenEmbeddingsHandler:
@torch.no_grad()
def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders):
embeds = (
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
)
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
embeds.weight.data[index_no_updates] = (
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
......@@ -1749,7 +1745,7 @@ def main(args):
if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well
text_lora_parameters_two = []
for name, param in text_encoder_two.named_parameters():
if "token_embedding" in name:
if "shared" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment