"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "2e8d18e6994a91f251e1a26fa3d05d6dad1be212"
Unverified Commit d5f444de authored by Lincoln Stein's avatar Lincoln Stein Committed by GitHub
Browse files

Update checkpoint_merger pipeline to pass the "variant" argument (#6670)



* make checkpoint_merger pipeline pass the "variant" argument to from_pretrained()

* make style

---------
Co-authored-by: default avatarLincoln Stein <lstein@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 5a54dc9e
...@@ -81,6 +81,8 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -81,6 +81,8 @@ class CheckpointMergerPipeline(DiffusionPipeline):
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False. force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
variant - which variant of a pretrained model to load, e.g. "fp16" (None)
""" """
# Default kwargs from DiffusionPipeline # Default kwargs from DiffusionPipeline
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
...@@ -89,6 +91,7 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -89,6 +91,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
token = kwargs.pop("token", None) token = kwargs.pop("token", None)
variant = kwargs.pop("variant", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
...@@ -173,7 +176,10 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -173,7 +176,10 @@ class CheckpointMergerPipeline(DiffusionPipeline):
# Step 3:- # Step 3:-
# Load the first checkpoint as a diffusion pipeline and modify its module state_dict in place # Load the first checkpoint as a diffusion pipeline and modify its module state_dict in place
final_pipe = DiffusionPipeline.from_pretrained( final_pipe = DiffusionPipeline.from_pretrained(
cached_folders[0], torch_dtype=torch_dtype, device_map=device_map cached_folders[0],
torch_dtype=torch_dtype,
device_map=device_map,
variant=variant,
) )
final_pipe.to(self.device) final_pipe.to(self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment