Unverified Commit 8959c5b9 authored by Ran Ran's avatar Ran Ran Committed by GitHub
Browse files

Add from_pt flag to enable model from PT (#5501)

* Add from_pt flag to enable model from PT

* Format the file

* Reformat the file
parent bc8a08f6
......@@ -208,6 +208,12 @@ def parse_args():
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--from_pt",
action="store_true",
default=False,
help="Flag to indicate whether to convert models from PyTorch.",
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
......@@ -374,16 +380,31 @@ def main():
# Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, subfolder="tokenizer"
args.pretrained_model_name_or_path,
from_pt=args.from_pt,
revision=args.revision,
subfolder="tokenizer",
)
text_encoder = FlaxCLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, subfolder="text_encoder", dtype=weight_dtype
args.pretrained_model_name_or_path,
from_pt=args.from_pt,
revision=args.revision,
subfolder="text_encoder",
dtype=weight_dtype,
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, subfolder="vae", dtype=weight_dtype
args.pretrained_model_name_or_path,
from_pt=args.from_pt,
revision=args.revision,
subfolder="vae",
dtype=weight_dtype,
)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, subfolder="unet", dtype=weight_dtype
args.pretrained_model_name_or_path,
from_pt=args.from_pt,
revision=args.revision,
subfolder="unet",
dtype=weight_dtype,
)
# Optimization
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment