"vscode:/vscode.git/clone" did not exist on "c1efda70b52dc05857ad214106754d5e2028fc26"
Unverified Commit 5729829c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[From single file] Make accelerate optional (#4132)

* Make accelerate optional

* make accelerate optional
parent e27500b7
......@@ -15,6 +15,7 @@
""" Conversion script for the Stable Diffusion checkpoints."""
import re
from contextlib import nullcontext
from io import BytesIO
from typing import Optional
......@@ -779,7 +780,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
config_name = "openai/clip-vit-large-patch14"
config = CLIPTextConfig.from_pretrained(config_name)
with init_empty_weights():
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
text_model = CLIPTextModel(config)
keys = list(checkpoint.keys())
......@@ -793,8 +795,11 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
if key.startswith(prefix):
text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]
if is_accelerate_available():
for param_name, param in text_model_dict.items():
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
else:
text_model.load_state_dict(text_model_dict)
return text_model
......@@ -900,7 +905,8 @@ def convert_open_clip_checkpoint(
# )
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs)
with init_empty_weights():
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
keys = list(checkpoint.keys())
......@@ -950,8 +956,11 @@ def convert_open_clip_checkpoint(
text_model_dict[new_key] = checkpoint[key]
if is_accelerate_available():
for param_name, param in text_model_dict.items():
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
else:
text_model.load_state_dict(text_model_dict)
return text_model
......@@ -1172,11 +1181,6 @@ def download_from_original_stable_diffusion_ckpt(
StableUnCLIPPipeline,
)
if not is_accelerate_available():
raise ImportError(
"To correctly use `from_single_file`, please make sure that `accelerate` is installed. You can install it with `pip install accelerate`."
)
if pipeline_class is None:
pipeline_class = StableDiffusionPipeline
......@@ -1346,15 +1350,19 @@ def download_from_original_stable_diffusion_ckpt(
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention
with init_empty_weights():
unet = UNet2DConditionModel(**unet_config)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
unet = UNet2DConditionModel(**unet_config)
if is_accelerate_available():
for param_name, param in converted_unet_checkpoint.items():
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
else:
unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model.
if vae_path is None:
......@@ -1372,11 +1380,15 @@ def download_from_original_stable_diffusion_ckpt(
vae_config["scaling_factor"] = vae_scaling_factor
with init_empty_weights():
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
vae = AutoencoderKL(**vae_config)
if is_accelerate_available():
for param_name, param in converted_vae_checkpoint.items():
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
else:
vae.load_state_dict(converted_vae_checkpoint)
else:
vae = AutoencoderKL.from_pretrained(vae_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment