"references/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "24f16a338391d6f45aa6291c48eb6d5513771631"
Unverified Commit aef11cbf authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add pipeline_class_name argument to Stable Diffusion conversion script (#4461)



* add pipeline class

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* style

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 71c82241
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" Conversion script for the LDM checkpoints. """ """ Conversion script for the LDM checkpoints. """
import argparse import argparse
import importlib
import torch import torch
...@@ -133,8 +134,22 @@ if __name__ == "__main__": ...@@ -133,8 +134,22 @@ if __name__ == "__main__":
required=False, required=False,
help="Set to a path, hub id to an already converted vae to not convert it again.", help="Set to a path, hub id to an already converted vae to not convert it again.",
) )
parser.add_argument(
"--pipeline_class_name",
type=str,
default=None,
required=False,
help="Specify the pipeline class name",
)
args = parser.parse_args() args = parser.parse_args()
if args.pipeline_class_name is not None:
library = importlib.import_module("diffusers")
class_obj = getattr(library, args.pipeline_class_name)
else:
pipeline_class = None
pipe = download_from_original_stable_diffusion_ckpt( pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=args.checkpoint_path, checkpoint_path=args.checkpoint_path,
original_config_file=args.original_config_file, original_config_file=args.original_config_file,
...@@ -152,6 +167,7 @@ if __name__ == "__main__": ...@@ -152,6 +167,7 @@ if __name__ == "__main__":
clip_stats_path=args.clip_stats_path, clip_stats_path=args.clip_stats_path,
controlnet=args.controlnet, controlnet=args.controlnet,
vae_path=args.vae_path, vae_path=args.vae_path,
pipeline_class=pipeline_class,
) )
if args.half: if args.half:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment