Unverified Commit e031caf4 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[flux lora training] fix t5 training bug (#10845)



* fix t5 training bug

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 08f74a8b
...@@ -880,9 +880,7 @@ class TokenEmbeddingsHandler: ...@@ -880,9 +880,7 @@ class TokenEmbeddingsHandler:
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"} idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
for idx, text_encoder in enumerate(self.text_encoders): for idx, text_encoder in enumerate(self.text_encoders):
train_ids = self.train_ids if idx == 0 else self.train_ids_t5 train_ids = self.train_ids if idx == 0 else self.train_ids_t5
embeds = ( embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
)
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same." assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
new_token_embeddings = embeds.weight.data[train_ids] new_token_embeddings = embeds.weight.data[train_ids]
...@@ -904,9 +902,7 @@ class TokenEmbeddingsHandler: ...@@ -904,9 +902,7 @@ class TokenEmbeddingsHandler:
@torch.no_grad() @torch.no_grad()
def retract_embeddings(self): def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders): for idx, text_encoder in enumerate(self.text_encoders):
embeds = ( embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
)
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
embeds.weight.data[index_no_updates] = ( embeds.weight.data[index_no_updates] = (
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
...@@ -1749,7 +1745,7 @@ def main(args): ...@@ -1749,7 +1745,7 @@ def main(args):
if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well
text_lora_parameters_two = [] text_lora_parameters_two = []
for name, param in text_encoder_two.named_parameters(): for name, param in text_encoder_two.named_parameters():
if "token_embedding" in name: if "shared" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param.data = param.to(dtype=torch.float32) param.data = param.to(dtype=torch.float32)
param.requires_grad = True param.requires_grad = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment