Unverified Commit 9be94d9c authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[textual_inversion] unwrap_model text encoder before accessing weights (#1816)

* unwrap_model text encoder before accessing weights

* fix another call

* fix the right call
parent f2acfb67
......@@ -592,7 +592,7 @@ def main():
progress_bar.set_description("Steps")
# keep original embeddings as reference
orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone()
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
for epoch in range(first_epoch, args.num_train_epochs):
text_encoder.train()
......@@ -644,7 +644,9 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
with torch.no_grad():
text_encoder.get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates]
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment