Unverified Commit d3ce6f4b authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Support revision in Flax text-to-image training (#2567)

Support revision in Flax text-to-image training.
parent ff91f154
...@@ -48,6 +48,13 @@ def parse_args(): ...@@ -48,6 +48,13 @@ def parse_args():
required=True, required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.", help="Path to pretrained model or model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument( parser.add_argument(
"--dataset_name", "--dataset_name",
type=str, type=str,
...@@ -386,15 +393,17 @@ def main(): ...@@ -386,15 +393,17 @@ def main():
weight_dtype = jnp.bfloat16 weight_dtype = jnp.bfloat16
# Load models and create wrapper for stable diffusion # Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, subfolder="tokenizer"
)
text_encoder = FlaxCLIPTextModel.from_pretrained( text_encoder = FlaxCLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype args.pretrained_model_name_or_path, revision=args.revision, subfolder="text_encoder", dtype=weight_dtype
) )
vae, vae_params = FlaxAutoencoderKL.from_pretrained( vae, vae_params = FlaxAutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype args.pretrained_model_name_or_path, revision=args.revision, subfolder="vae", dtype=weight_dtype
) )
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype args.pretrained_model_name_or_path, revision=args.revision, subfolder="unet", dtype=weight_dtype
) )
# Optimization # Optimization
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment