"vscode:/vscode.git/clone" did not exist on "25ec3f2617450ed7c5f848e18c0f10354e9786d6"
Unverified Commit 5729829c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[From single file] Make accelerate optional (#4132)

* Make accelerate optional

* make accelerate optional
parent e27500b7
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" Conversion script for the Stable Diffusion checkpoints.""" """ Conversion script for the Stable Diffusion checkpoints."""
import re import re
from contextlib import nullcontext
from io import BytesIO from io import BytesIO
from typing import Optional from typing import Optional
...@@ -779,7 +780,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder ...@@ -779,7 +780,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
config_name = "openai/clip-vit-large-patch14" config_name = "openai/clip-vit-large-patch14"
config = CLIPTextConfig.from_pretrained(config_name) config = CLIPTextConfig.from_pretrained(config_name)
with init_empty_weights(): ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
text_model = CLIPTextModel(config) text_model = CLIPTextModel(config)
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
...@@ -793,8 +795,11 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder ...@@ -793,8 +795,11 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
if key.startswith(prefix): if key.startswith(prefix):
text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]
if is_accelerate_available():
for param_name, param in text_model_dict.items(): for param_name, param in text_model_dict.items():
set_module_tensor_to_device(text_model, param_name, "cpu", value=param) set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
else:
text_model.load_state_dict(text_model_dict)
return text_model return text_model
...@@ -900,7 +905,8 @@ def convert_open_clip_checkpoint( ...@@ -900,7 +905,8 @@ def convert_open_clip_checkpoint(
# ) # )
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs) config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs)
with init_empty_weights(): ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
...@@ -950,8 +956,11 @@ def convert_open_clip_checkpoint( ...@@ -950,8 +956,11 @@ def convert_open_clip_checkpoint(
text_model_dict[new_key] = checkpoint[key] text_model_dict[new_key] = checkpoint[key]
if is_accelerate_available():
for param_name, param in text_model_dict.items(): for param_name, param in text_model_dict.items():
set_module_tensor_to_device(text_model, param_name, "cpu", value=param) set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
else:
text_model.load_state_dict(text_model_dict)
return text_model return text_model
...@@ -1172,11 +1181,6 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1172,11 +1181,6 @@ def download_from_original_stable_diffusion_ckpt(
StableUnCLIPPipeline, StableUnCLIPPipeline,
) )
if not is_accelerate_available():
raise ImportError(
"To correctly use `from_single_file`, please make sure that `accelerate` is installed. You can install it with `pip install accelerate`."
)
if pipeline_class is None: if pipeline_class is None:
pipeline_class = StableDiffusionPipeline pipeline_class = StableDiffusionPipeline
...@@ -1346,15 +1350,19 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1346,15 +1350,19 @@ def download_from_original_stable_diffusion_ckpt(
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention unet_config["upcast_attention"] = upcast_attention
with init_empty_weights():
unet = UNet2DConditionModel(**unet_config)
converted_unet_checkpoint = convert_ldm_unet_checkpoint( converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
) )
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
unet = UNet2DConditionModel(**unet_config)
if is_accelerate_available():
for param_name, param in converted_unet_checkpoint.items(): for param_name, param in converted_unet_checkpoint.items():
set_module_tensor_to_device(unet, param_name, "cpu", value=param) set_module_tensor_to_device(unet, param_name, "cpu", value=param)
else:
unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model. # Convert the VAE model.
if vae_path is None: if vae_path is None:
...@@ -1372,11 +1380,15 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1372,11 +1380,15 @@ def download_from_original_stable_diffusion_ckpt(
vae_config["scaling_factor"] = vae_scaling_factor vae_config["scaling_factor"] = vae_scaling_factor
with init_empty_weights(): ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
if is_accelerate_available():
for param_name, param in converted_vae_checkpoint.items(): for param_name, param in converted_vae_checkpoint.items():
set_module_tensor_to_device(vae, param_name, "cpu", value=param) set_module_tensor_to_device(vae, param_name, "cpu", value=param)
else:
vae.load_state_dict(converted_vae_checkpoint)
else: else:
vae = AutoencoderKL.from_pretrained(vae_path) vae = AutoencoderKL.from_pretrained(vae_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment