Unverified Commit ac84c2fa authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[textual-inversion] fix saving embeds (#387)

fix saving embeds
parent 5a38033d
......@@ -564,7 +564,7 @@ def main():
pipeline.save_pretrained(args.output_dir)
# Also save the newly trained embeddings
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds}
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
if args.push_to_hub:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment