Commit b8bfef2a authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make style

parent f3f626d5
...@@ -433,9 +433,15 @@ def main(): ...@@ -433,9 +433,15 @@ def main():
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
# Load models and create wrapper for stable diffusion # Load models and create wrapper for stable diffusion
text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder",revision=args.revision) text_encoder = FlaxCLIPTextModel.from_pretrained(
vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae",revision=args.revision) args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",revision=args.revision) )
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# Create sampling rng # Create sampling rng
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
...@@ -633,11 +639,13 @@ def main(): ...@@ -633,11 +639,13 @@ def main():
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][ learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"][
placeholder_token_id "embedding"
] ][placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds} learned_embeds_dict = {args.placeholder_token: learned_embeds}
jnp.save(os.path.join(args.output_dir, "learned_embeds-"+str(global_step)+".npy"), learned_embeds_dict) jnp.save(
os.path.join(args.output_dir, "learned_embeds-" + str(global_step) + ".npy"), learned_embeds_dict
)
train_metric = jax_utils.unreplicate(train_metric) train_metric = jax_utils.unreplicate(train_metric)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment