Unverified Commit 8959c5b9 authored by Ran Ran's avatar Ran Ran Committed by GitHub
Browse files

Add from_pt flag to enable model from PT (#5501)

* Add from_pt flag to enable model from PT

* Format the file

* Reformat the file
parent bc8a08f6
...@@ -208,6 +208,12 @@ def parse_args(): ...@@ -208,6 +208,12 @@ def parse_args():
), ),
) )
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--from_pt",
action="store_true",
default=False,
help="Flag to indicate whether to convert models from PyTorch.",
)
args = parser.parse_args() args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
...@@ -374,16 +380,31 @@ def main(): ...@@ -374,16 +380,31 @@ def main():
# Load models and create wrapper for stable diffusion # Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, subfolder="tokenizer" args.pretrained_model_name_or_path,
from_pt=args.from_pt,
revision=args.revision,
subfolder="tokenizer",
) )
text_encoder = FlaxCLIPTextModel.from_pretrained( text_encoder = FlaxCLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, subfolder="text_encoder", dtype=weight_dtype args.pretrained_model_name_or_path,
from_pt=args.from_pt,
revision=args.revision,
subfolder="text_encoder",
dtype=weight_dtype,
) )
vae, vae_params = FlaxAutoencoderKL.from_pretrained( vae, vae_params = FlaxAutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, subfolder="vae", dtype=weight_dtype args.pretrained_model_name_or_path,
from_pt=args.from_pt,
revision=args.revision,
subfolder="vae",
dtype=weight_dtype,
) )
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, subfolder="unet", dtype=weight_dtype args.pretrained_model_name_or_path,
from_pt=args.from_pt,
revision=args.revision,
subfolder="unet",
dtype=weight_dtype,
) )
# Optimization # Optimization
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment