"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "d168481c3258d4b8c3cfcc67a8e1fe7a4f79452a"
Unverified Commit f3f626d5 authored by haixinxu's avatar haixinxu Committed by GitHub
Browse files

Allow textual_inversion_flax script to use save_steps and revision flag (#2075)

* Update textual_inversion_flax.py

* Update textual_inversion_flax.py

* Typo

sorry.

* Format source
parent b7b4683b
...@@ -121,6 +121,12 @@ def parse_args(): ...@@ -121,6 +121,12 @@ def parse_args():
default=5000, default=5000,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.", help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
) )
parser.add_argument(
"--save_steps",
type=int,
default=500,
help="Save learned_embeds.bin every X updates steps.",
)
parser.add_argument( parser.add_argument(
"--learning_rate", "--learning_rate",
type=float, type=float,
...@@ -136,6 +142,13 @@ def parse_args(): ...@@ -136,6 +142,13 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
) )
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument( parser.add_argument(
"--lr_scheduler", "--lr_scheduler",
type=str, type=str,
...@@ -420,9 +433,9 @@ def main(): ...@@ -420,9 +433,9 @@ def main():
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
# Load models and create wrapper for stable diffusion # Load models and create wrapper for stable diffusion
text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder",revision=args.revision)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae",revision=args.revision)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",revision=args.revision)
# Create sampling rng # Create sampling rng
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
...@@ -619,6 +632,12 @@ def main(): ...@@ -619,6 +632,12 @@ def main():
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
if global_step % args.save_steps == 0:
learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][
placeholder_token_id
]
learned_embeds_dict = {args.placeholder_token: learned_embeds}
jnp.save(os.path.join(args.output_dir, "learned_embeds-"+str(global_step)+".npy"), learned_embeds_dict)
train_metric = jax_utils.unreplicate(train_metric) train_metric = jax_utils.unreplicate(train_metric)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment