"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ea893a9ae73fa3913472f1056358869fa33c46a3"
Unverified Commit ac4c695d authored by Duong A. Nguyen's avatar Duong A. Nguyen Committed by GitHub
Browse files

[Flax examples] Load text encoder from subfolder (#1147)

load text encoder from subfolder
parent 01733238
...@@ -452,7 +452,9 @@ def main(): ...@@ -452,7 +452,9 @@ def main():
weight_dtype = jnp.bfloat16 weight_dtype = jnp.bfloat16
# Load models and create wrapper for stable diffusion # Load models and create wrapper for stable diffusion
text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", dtype=weight_dtype) text_encoder = FlaxCLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained( vae, vae_params = FlaxAutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype
) )
......
...@@ -379,7 +379,9 @@ def main(): ...@@ -379,7 +379,9 @@ def main():
# Load models and create wrapper for stable diffusion # Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", dtype=weight_dtype) text_encoder = FlaxCLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained( vae, vae_params = FlaxAutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype
) )
......
...@@ -391,7 +391,7 @@ def main(): ...@@ -391,7 +391,7 @@ def main():
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
# Load models and create wrapper for stable diffusion # Load models and create wrapper for stable diffusion
text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment