Unverified Commit 7f31142c authored by Isamu Isozaki's avatar Isamu Isozaki Committed by GitHub
Browse files

Added script to save during textual inversion training. Issue 524 (#645)

* Added script to save during training

* Suggested changes
parent 765506ce
...@@ -29,8 +29,21 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer ...@@ -29,8 +29,21 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
logger = get_logger(__name__) logger = get_logger(__name__)
def save_progress(text_encoder, placeholder_token_id, accelerator, args):
logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--save_steps",
type=int,
default=500,
help="Save learned_embeds.bin every X updates steps.",
)
parser.add_argument( parser.add_argument(
"--pretrained_model_name_or_path", "--pretrained_model_name_or_path",
type=str, type=str,
...@@ -542,6 +555,8 @@ def main(): ...@@ -542,6 +555,8 @@ def main():
if accelerator.sync_gradients: if accelerator.sync_gradients:
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
if global_step % args.save_steps == 0:
save_progress(text_encoder, placeholder_token_id, accelerator, args)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
...@@ -567,9 +582,7 @@ def main(): ...@@ -567,9 +582,7 @@ def main():
) )
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
# Also save the newly trained embeddings # Also save the newly trained embeddings
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] save_progress(text_encoder, placeholder_token_id, accelerator, args)
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub( repo.push_to_hub(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment