Unverified Commit 69de9b2e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Textual Inversion] Do not update other embeddings (#1665)

parent 3ce6380d
......@@ -548,6 +548,9 @@ def main():
progress_bar.set_description("Steps")
global_step = 0
# keep original embeddings as reference
orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone()
for epoch in range(args.num_train_epochs):
text_encoder.train()
for step, batch in enumerate(train_dataloader):
......@@ -585,20 +588,15 @@ def main():
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
# Zero out the gradients for all token embeddings except the newly added
# embeddings for the concept, as we only want to optimize the concept embeddings
if accelerator.num_processes > 1:
grads = text_encoder.module.get_input_embeddings().weight.grad
else:
grads = text_encoder.get_input_embeddings().weight.grad
# Get the index for tokens that we want to zero the grads for
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
with torch.no_grad():
text_encoder.get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment