Unverified Commit d3ce6f4b authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Support revision in Flax text-to-image training (#2567)

Support revision in Flax text-to-image training.
parent ff91f154
......@@ -48,6 +48,13 @@ def parse_args():
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--dataset_name",
type=str,
......@@ -386,15 +393,17 @@ def main():
weight_dtype = jnp.bfloat16
# Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, subfolder="tokenizer"
)
text_encoder = FlaxCLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype
args.pretrained_model_name_or_path, revision=args.revision, subfolder="text_encoder", dtype=weight_dtype
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype
args.pretrained_model_name_or_path, revision=args.revision, subfolder="vae", dtype=weight_dtype
)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype
args.pretrained_model_name_or_path, revision=args.revision, subfolder="unet", dtype=weight_dtype
)
# Optimization
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment